rewrite hooking and dispatcher

This commit is contained in:
Ryan Hitchman 2010-03-11 16:34:54 -07:00
parent 5200749c66
commit 1e6c08fb30
13 changed files with 232 additions and 136 deletions

6
bot.py
View File

@ -1,5 +1,7 @@
#!/usr/bin/python #!/usr/bin/python
a = 34123
import os import os
import Queue import Queue
import sys import sys
@ -39,7 +41,7 @@ try:
except Exception, e: except Exception, e:
print 'ERROR: malformed config file', Exception, e print 'ERROR: malformed config file', Exception, e
sys.exit() sys.exit()
bot.persist_dir = os.path.abspath('persist') bot.persist_dir = os.path.abspath('persist')
if not os.path.exists(bot.persist_dir): if not os.path.exists(bot.persist_dir):
os.mkdir(bot.persist_dir) os.mkdir(bot.persist_dir)
@ -57,4 +59,4 @@ while True:
except Queue.Empty: except Queue.Empty:
pass pass
while all(conn.out.empty() for conn in bot.conns.itervalues()): while all(conn.out.empty() for conn in bot.conns.itervalues()):
time.sleep(.3) time.sleep(.1)

View File

@ -2,7 +2,7 @@ import thread
import traceback import traceback
print thread.stack_size(1024 * 512) # reduce vm size thread.stack_size(1024 * 512) # reduce vm size
class Input(dict): class Input(dict):
@ -25,7 +25,7 @@ class Input(dict):
dict.__init__(self, conn=conn, raw=raw, prefix=prefix, command=command, dict.__init__(self, conn=conn, raw=raw, prefix=prefix, command=command,
params=params, nick=nick, user=user, host=host, params=params, nick=nick, user=user, host=host,
paraml=paraml, msg=msg, server=conn.server, chan=chan, paraml=paraml, msg=msg, server=conn.server, chan=chan,
say=say, reply=reply, pm=pm, bot=bot) say=say, reply=reply, pm=pm, bot=bot, lastparam=paraml[-1])
def __getattr__(self, key): def __getattr__(self, key):
return self[key] return self[key]
@ -35,10 +35,15 @@ class Input(dict):
def run(func, input): def run(func, input):
args = func._skybot_args args = func._args
if 'inp' not in input:
input.inp = input.params
if args: if args:
if 'db' in args: if 'db' in args:
input['db'] = get_db_connection(input['server']) input['db'] = get_db_connection(input['server'])
if 'input' in args:
input['input'] = input
if 0 in args: if 0 in args:
out = func(input['inp'], **input) out = func(input['inp'], **input)
else: else:
@ -50,21 +55,74 @@ def run(func, input):
input['reply'](unicode(out)) input['reply'](unicode(out))
def main(conn, out): def do_sieve(sieve, bot, input, func, type, args):
for csig, func, args in bot.plugs['tee']: try:
input = Input(conn, *out) return sieve(bot, input, func, type, args)
func._iqueue.put((bot, input)) except Exception, e:
for csig, func, args in (bot.plugs['command'] + bot.plugs['event']): print 'sieve error',
input = Input(conn, *out) traceback.print_exc(Exception)
for fsig, sieve in bot.plugs['sieve']: return None
try:
input = sieve(bot, input, func, args)
except Exception, e: class Handler(object):
print 'sieve error', '''Runs plugins in their own threads (ensures order)'''
traceback.print_exc(Exception) def __init__(self, func):
input = None self.func = func
if input == None: self.input_queue = Queue.Queue()
thread.start_new_thread(self.start, ())
def start(self):
while True:
input = self.input_queue.get()
if input == StopIteration:
break break
run(self.func, input)
def stop(self):
self.input_queue.put(StopIteration)
def put(self, value):
self.input_queue.put(value)
def dispatch(input, kind, func, args):
for sieve, in bot.plugs['sieve']:
input = do_sieve(sieve, bot, input, func, kind, args)
if input == None: if input == None:
continue return
if func._thread:
bot.threads[func].put(input)
else:
thread.start_new_thread(run, (func, input)) thread.start_new_thread(run, (func, input))
def main(conn, out):
inp = Input(conn, *out)
# EVENTS
for func, args in bot.events[inp.command] + bot.events['*']:
dispatch(Input(conn, *out), "event", func, args)
# COMMANDS
if inp.command == 'PRIVMSG':
if inp.chan == inp.nick: # private message, no command prefix
prefix = r'^(?:[.!]?|'
else:
prefix = r'^(?:[.!]|'
command_re = prefix + inp.conn.nick + r'[:,]*\s+)(\w+)\s+(.*)$'
m = re.match(command_re, inp.lastparam)
if m:
command = m.group(1).lower()
if command in bot.commands:
input = Input(conn, *out)
input.inp_unstripped = m.group(2)
input.inp = m.group(2).strip()
func, args = bot.commands[command]
dispatch(input, "command", func, args)

View File

@ -1,9 +1,13 @@
import collections import collections
import glob import glob
import os import os
import Queue
import re
import sys import sys
import thread
import traceback import traceback
if 'mtimes' not in globals(): if 'mtimes' not in globals():
mtimes = {} mtimes = {}
@ -11,21 +15,38 @@ if 'lastfiles' not in globals():
lastfiles = set() lastfiles = set()
def format_plug(plug, lpad=0, width=40): def make_signature(f):
out = ' ' * lpad + '%s:%s:%s' % (plug[0]) return f.func_code.co_filename, f.func_name, f.func_code.co_firstlineno
if len(plug) == 3 and 'hook' in plug[2]:
out += '%s%s' % (' ' * (width - len(out)), plug[2]['hook'])
def format_plug(plug, kind='', lpad=0, width=40):
out = ' ' * lpad + '%s:%s:%s' % make_signature(plug[0])
if kind == 'command':
out += ' ' * (50 - len(out)) + plug[1]['name']
if kind == 'event':
out += ' ' * (50 - len(out)) + ', '.join(plug[1]['events'])
return out return out
def reload(init=False):
if init:
bot.plugs = collections.defaultdict(lambda: [])
for filename in glob.glob(os.path.join("core", "*.py")): def reload(init=False):
changed = False
if init:
bot.plugs = collections.defaultdict(list)
bot.threads = {}
core_fileset = set(glob.glob(os.path.join("core", "*.py")))
for filename in core_fileset:
mtime = os.stat(filename).st_mtime mtime = os.stat(filename).st_mtime
if mtime != mtimes.get(filename): if mtime != mtimes.get(filename):
mtimes[filename] = mtime mtimes[filename] = mtime
changed = True
try: try:
eval(compile(open(filename, 'U').read(), filename, 'exec'), eval(compile(open(filename, 'U').read(), filename, 'exec'),
globals()) globals())
@ -39,14 +60,29 @@ def reload(init=False):
reload(init=init) reload(init=init)
return return
fileset = set(glob.glob(os.path.join('plugins', '*py'))) fileset = set(glob.glob(os.path.join('plugins', '*.py')))
for name, data in bot.plugs.iteritems(): # remove deleted/moved plugins
bot.plugs[name] = filter(lambda x: x[0][0] in fileset, data)
# remove deleted/moved plugins
for name, data in bot.plugs.iteritems():
bot.plugs[name] = [x for x in data if x[0]._filename in fileset]
for filename in list(mtimes):
if filename not in fileset and filename not in core_fileset:
mtimes.pop(filename)
for func, handler in list(bot.threads.iteritems()):
if func._filename not in fileset:
handler.stop()
del bot.threads[func]
# compile new plugins
for filename in fileset: for filename in fileset:
mtime = os.stat(filename).st_mtime mtime = os.stat(filename).st_mtime
if mtime != mtimes.get(filename): if mtime != mtimes.get(filename):
mtimes[filename] = mtime mtimes[filename] = mtime
changed = True
try: try:
code = compile(open(filename, 'U').read(), filename, 'exec') code = compile(open(filename, 'U').read(), filename, 'exec')
namespace = {} namespace = {}
@ -57,28 +93,69 @@ def reload(init=False):
# remove plugins already loaded from this filename # remove plugins already loaded from this filename
for name, data in bot.plugs.iteritems(): for name, data in bot.plugs.iteritems():
bot.plugs[name] = [x for x in data
if x[0]._filename != filename]
if name == 'tee': # signal tee trampolines to stop for func, handler in list(bot.threads.iteritems()):
for csig, func, args in data: if func._filename == filename:
if csig[0] == filename: handler.stop()
func._iqueue.put(StopIteration) del bot.threads[func]
bot.plugs[name] = filter(lambda x: x[0][0] != filename, data)
for obj in namespace.itervalues(): for obj in namespace.itervalues():
if hasattr(obj, '_skybot_hook'): # check for magic if hasattr(obj, '_hook'): # check for magic
for type, data in obj._skybot_hook: if obj._thread:
bot.threads[obj] = Handler(obj)
for type, data in obj._hook:
bot.plugs[type] += [data] bot.plugs[type] += [data]
if not init: if not init:
print '### new plugin (type: %s) loaded:' % \ print '### new plugin (type: %s) loaded:' % \
type, format_plug(data) type, format_plug(data)
if changed:
bot.commands = {}
for plug in bot.plugs['command']:
name = plug[1]['name'].lower()
if not re.match(r'^\w+$', name):
print '### ERROR: invalid command name "%s" (%s)' % (name,
format_plug(plug))
continue
if name in bot.commands:
print "### ERROR: command '%s' already registered (%s, %s)" % \
(name, format_plug(bot.commands[name]),
format_plug(plug))
continue
bot.commands[name] = plug
bot.events = collections.defaultdict(list)
for func, args in bot.plugs['event']:
for event in args['events']:
bot.events[event].append((func, args))
if init: if init:
print ' plugin listing:' print ' plugin listing:'
for type, plugs in sorted(bot.plugs.iteritems()):
if bot.commands:
# hack to make commands with multiple aliases
# print nicely
print ' command:'
commands = collections.defaultdict(list)
for name, (func, args) in bot.commands.iteritems():
commands[make_signature(func)].append(name)
for sig, names in sorted(commands.iteritems()):
names.sort(key=lambda x: (-len(x), x)) # long names first
out = ' ' * 6 + '%s:%s:%s' % sig
out += ' ' * (50 - len(out)) + ', '.join(names)
print out
for kind, plugs in sorted(bot.plugs.iteritems()):
if kind == 'command':
continue
print ' %s:' % type print ' %s:' % type
for plug in plugs: for plug in plugs:
out = ' %s:%s:%s' % (plug[0]) print format_plug(plug, kind=kind, lpad=6)
print format_plug(plug, lpad=6)
print print

View File

@ -83,8 +83,9 @@ def get_log_fd(dir, server, chan):
return fd return fd
@hook.tee @hook.thread
def log(bot, input): @hook.event('*')
def log(inp, input=None, bot=None):
with lock: with lock:
timestamp = gmtime(timestamp_format) timestamp = gmtime(timestamp_format)

View File

@ -1,7 +1,9 @@
from util import hook
import socket import socket
import time
socket.setdefaulttimeout(5) # global setting from util import hook
socket.setdefaulttimeout(10) # global setting
#autorejoin channels #autorejoin channels
@ -22,11 +24,13 @@ def invite(inp, command='', conn=None):
#join channels when server says hello & identify bot #join channels when server says hello & identify bot
@hook.event('004') @hook.event('004')
def onjoin(inp, conn=None): def onjoin(inp, conn=None):
for channel in conn.channels:
conn.join(channel)
nickserv_password = conn.conf.get('nickserv_password', '') nickserv_password = conn.conf.get('nickserv_password', '')
nickserv_name = conn.conf.get('nickserv_name', 'nickserv') nickserv_name = conn.conf.get('nickserv_name', 'nickserv')
nickserv_command = conn.conf.get('nickserv_command', 'IDENTIFY %s') nickserv_command = conn.conf.get('nickserv_command', 'IDENTIFY %s')
if nickserv_password: if nickserv_password:
conn.msg(nickserv_name, nickserv_command % nickserv_password) conn.msg(nickserv_name, nickserv_command % nickserv_password)
time.sleep(1)
for channel in conn.channels:
conn.join(channel)
time.sleep(1) # don't flood JOINs

View File

@ -62,7 +62,7 @@ def forget(inp, chan='', db=None):
return "I don't know about that." return "I don't know about that."
@hook.command(hook='\?(.+)', prefix=False) @hook.event('PRIVMSG', hook=r'\?(.+)')
def question(inp, chan='', say=None, db=None): def question(inp, chan='', say=None, db=None):
"?<word> -- shows what data is associated with word" "?<word> -- shows what data is associated with word"
db_init(db) db_init(db)

View File

@ -5,11 +5,9 @@ import time
from util import hook, timesince from util import hook, timesince
@hook.tee @hook.thread
def seeninput(bot, input): @hook.event('PRIVMSG')
if input.command != 'PRIVMSG': def seeninput(inp, input=None, bot=None):
return
db = bot.get_db_connection(input.server) db = bot.get_db_connection(input.server)
db_init(db) db_init(db)
db.execute("insert or replace into seen(name, time, quote, chan)" db.execute("insert or replace into seen(name, time, quote, chan)"

View File

@ -5,26 +5,14 @@ from util import hook
@hook.sieve @hook.sieve
def sieve_suite(bot, input, func, args): def sieve_suite(bot, input, func, kind, args):
events = args.get('events', ['PRIVMSG']) events = args.get('events', ['PRIVMSG'])
if input.command not in events and events != '*': if input.command not in events and '*' not in events:
return None return None
if input.nick.lower()[-3:] == 'bot' and args.get('ignorebots', True): if input.command == 'PRIVMSG' and input.nick.lower()[-3:] == 'bot' \
return None and args.get('ignorebots', True):
hook = args.get('hook', r'(.*)')
if args.get('prefix', True):
if input.chan == input.nick: # private message, prefix not required
prefix = r'^(?:[.!]?|'
else:
prefix = r'^(?:[.!]|'
hook = prefix + input.conn.nick + r'[:,]*\s)' + hook
input.re = re.match(hook, input.msg, flags=re.I)
if input.re is None:
return None return None
acl = bot.config.get('acls', {}).get(func.__name__) acl = bot.config.get('acls', {}).get(func.__name__)
@ -38,7 +26,7 @@ def sieve_suite(bot, input, func, args):
if input.chan.lower() in denied_channels: if input.chan.lower() in denied_channels:
return None return None
input.inp_unstripped = ' '.join(input.re.groups()) # input.inp_unstripped = ' '.join(input.re.groups())
input.inp = input.inp_unstripped.strip() # input.inp = input.inp_unstripped.strip()
return input return input

View File

@ -12,11 +12,9 @@ def get_tells(db, user_to, chan):
(user_to.lower(), chan)).fetchall() (user_to.lower(), chan)).fetchall()
@hook.tee @hook.thread
def tellinput(bot, input): @hook.event('PRIVMSG')
if input.command != 'PRIVMSG': def tellinput(inp, input=None, bot=None):
return
if 'showtells' in input.msg.lower(): if 'showtells' in input.msg.lower():
return return

View File

@ -8,7 +8,7 @@ tinyurl_re = re.compile(r'http://(?:www\.)?tinyurl.com/([A-Za-z0-9\-]+)',
flags=re.IGNORECASE) flags=re.IGNORECASE)
@hook.command(hook=r'(.*)', prefix=False) @hook.event('PRIVMSG')
def tinyurl(inp): def tinyurl(inp):
tumatch = tinyurl_re.search(inp) tumatch = tinyurl_re.search(inp)
if tumatch: if tumatch:

View File

@ -67,7 +67,7 @@ def format_reply(history):
hour_span, nicklist(history), last) hour_span, nicklist(history), last)
@hook.command(hook=r'(.*)', prefix=False) @hook.event('PRIVMSG')
def urlinput(inp, nick='', chan='', server='', reply=None, bot=None): def urlinput(inp, nick='', chan='', server='', reply=None, bot=None):
m = url_re.search(inp.encode('utf8')) m = url_re.search(inp.encode('utf8'))
if not m: if not m:

View File

@ -1,20 +1,16 @@
import inspect import inspect
import thread
import traceback import traceback
import Queue
def _isfunc(x):
if type(x) == type(_isfunc):
return True
return False
def _hook_add(func, add, name=''): def _hook_add(func, add, name=''):
if not hasattr(func, '_skybot_hook'): if not hasattr(func, '_hook'):
func._skybot_hook = [] func._hook = []
func._skybot_hook.append(add) func._hook.append(add)
if not hasattr(func, '_skybot_args'):
if not hasattr(func, '_filename'):
func._filename = func.func_code.co_filename
if not hasattr(func, '_args'):
argspec = inspect.getargspec(func) argspec = inspect.getargspec(func)
if name: if name:
n_args = len(argspec.args) n_args = len(argspec.args)
@ -36,39 +32,34 @@ def _hook_add(func, add, name=''):
end if end else None]) end if end else None])
if argspec.keywords: if argspec.keywords:
args.append(0) # means kwargs present args.append(0) # means kwargs present
func._skybot_args = args func._args = args
def _make_sig(f):
return f.func_code.co_filename, f.func_name, f.func_code.co_firstlineno
if not hasattr(func, '_skybot_thread'): # does function run in its own thread?
func._thread = False
def sieve(func): def sieve(func):
if func.func_code.co_argcount != 4: if func.func_code.co_argcount != 5:
raise ValueError( raise ValueError(
'sieves must take 4 arguments: (bot, input, func, args)') 'sieves must take 5 arguments: (bot, input, func, type, args)')
_hook_add(func, ['sieve', (_make_sig(func), func)]) _hook_add(func, ['sieve', (func,)])
return func return func
def command(func=None, hook=None, **kwargs): def command(arg, **kwargs):
args = {} args = {}
def command_wrapper(func): def command_wrapper(func):
args.setdefault('name', func.func_name) args.setdefault('name', func.func_name)
args.setdefault('hook', args['name'] + r'(?:\s+|$)(.*)') _hook_add(func, ['command', (func, args)], 'command')
_hook_add(func, ['command', (_make_sig(func), func, args)], 'command')
return func return func
if hook is not None or kwargs or not _isfunc(func): if kwargs or not inspect.isfunction(arg):
if func is not None: if arg is not None:
args['name'] = func args['name'] = arg
if hook is not None:
args['hook'] = hook
args.update(kwargs) args.update(kwargs)
return command_wrapper return command_wrapper
else: else:
return command_wrapper(func) return command_wrapper(arg)
def event(arg=None, **kwargs): def event(arg=None, **kwargs):
@ -76,12 +67,11 @@ def event(arg=None, **kwargs):
def event_wrapper(func): def event_wrapper(func):
args['name'] = func.func_name args['name'] = func.func_name
args['prefix'] = False args.setdefault('events', ['*'])
args.setdefault('events', '*') _hook_add(func, ['event', (func, args)], 'event')
_hook_add(func, ['event', (_make_sig(func), func, args)], 'event')
return func return func
if _isfunc(arg): if inspect.isfunction(arg):
return event_wrapper(arg, kwargs) return event_wrapper(arg, kwargs)
else: else:
if arg is not None: if arg is not None:
@ -89,26 +79,6 @@ def event(arg=None, **kwargs):
return event_wrapper return event_wrapper
def tee(func, **kwargs): def thread(func):
"passes _all_ input lines to function, in order (skips sieves)" func._thread = True
if func.func_code.co_argcount != 2:
raise ValueError('tees must take 2 arguments: (bot, input)')
_hook_add(func, ['tee', (_make_sig(func), func, kwargs)])
func._iqueue = Queue.Queue()
def trampoline(func):
input = None
while True:
input = func._iqueue.get()
if input == StopIteration:
return
try:
func(*input)
except Exception:
traceback.print_exc(Exception)
thread.start_new_thread(trampoline, (func,))
return func return func

View File

@ -54,7 +54,7 @@ def get_video_description(vid_id):
return out return out
@hook.command(hook=r'(.*)', prefix=False) @hook.event('PRIVMSG')
def youtube_url(inp): def youtube_url(inp):
m = youtube_re.search(inp) m = youtube_re.search(inp)
if m: if m: