rewrite hooking and dispatcher

This commit is contained in:
Ryan Hitchman 2010-03-11 16:34:54 -07:00
parent 5200749c66
commit 1e6c08fb30
13 changed files with 232 additions and 136 deletions

6
bot.py
View File

@ -1,5 +1,7 @@
#!/usr/bin/python
a = 34123
import os
import Queue
import sys
@ -39,7 +41,7 @@ try:
except Exception, e:
print 'ERROR: malformed config file', Exception, e
sys.exit()
bot.persist_dir = os.path.abspath('persist')
if not os.path.exists(bot.persist_dir):
os.mkdir(bot.persist_dir)
@ -57,4 +59,4 @@ while True:
except Queue.Empty:
pass
while all(conn.out.empty() for conn in bot.conns.itervalues()):
time.sleep(.3)
time.sleep(.1)

View File

@ -2,7 +2,7 @@ import thread
import traceback
print thread.stack_size(1024 * 512) # reduce vm size
thread.stack_size(1024 * 512) # reduce vm size
class Input(dict):
@ -25,7 +25,7 @@ class Input(dict):
dict.__init__(self, conn=conn, raw=raw, prefix=prefix, command=command,
params=params, nick=nick, user=user, host=host,
paraml=paraml, msg=msg, server=conn.server, chan=chan,
say=say, reply=reply, pm=pm, bot=bot)
say=say, reply=reply, pm=pm, bot=bot, lastparam=paraml[-1])
def __getattr__(self, key):
return self[key]
@ -35,10 +35,15 @@ class Input(dict):
def run(func, input):
args = func._skybot_args
args = func._args
if 'inp' not in input:
input.inp = input.params
if args:
if 'db' in args:
input['db'] = get_db_connection(input['server'])
if 'input' in args:
input['input'] = input
if 0 in args:
out = func(input['inp'], **input)
else:
@ -50,21 +55,74 @@ def run(func, input):
input['reply'](unicode(out))
def main(conn, out):
for csig, func, args in bot.plugs['tee']:
input = Input(conn, *out)
func._iqueue.put((bot, input))
for csig, func, args in (bot.plugs['command'] + bot.plugs['event']):
input = Input(conn, *out)
for fsig, sieve in bot.plugs['sieve']:
try:
input = sieve(bot, input, func, args)
except Exception, e:
print 'sieve error',
traceback.print_exc(Exception)
input = None
if input == None:
def do_sieve(sieve, bot, input, func, type, args):
try:
return sieve(bot, input, func, type, args)
except Exception, e:
print 'sieve error',
traceback.print_exc(Exception)
return None
class Handler(object):
'''Runs plugins in their own threads (ensures order)'''
def __init__(self, func):
self.func = func
self.input_queue = Queue.Queue()
thread.start_new_thread(self.start, ())
def start(self):
while True:
input = self.input_queue.get()
if input == StopIteration:
break
run(self.func, input)
def stop(self):
self.input_queue.put(StopIteration)
def put(self, value):
self.input_queue.put(value)
def dispatch(input, kind, func, args):
for sieve, in bot.plugs['sieve']:
input = do_sieve(sieve, bot, input, func, kind, args)
if input == None:
continue
return
if func._thread:
bot.threads[func].put(input)
else:
thread.start_new_thread(run, (func, input))
def main(conn, out):
inp = Input(conn, *out)
# EVENTS
for func, args in bot.events[inp.command] + bot.events['*']:
dispatch(Input(conn, *out), "event", func, args)
# COMMANDS
if inp.command == 'PRIVMSG':
if inp.chan == inp.nick: # private message, no command prefix
prefix = r'^(?:[.!]?|'
else:
prefix = r'^(?:[.!]|'
command_re = prefix + inp.conn.nick + r'[:,]*\s+)(\w+)\s+(.*)$'
m = re.match(command_re, inp.lastparam)
if m:
command = m.group(1).lower()
if command in bot.commands:
input = Input(conn, *out)
input.inp_unstripped = m.group(2)
input.inp = m.group(2).strip()
func, args = bot.commands[command]
dispatch(input, "command", func, args)

View File

@ -1,9 +1,13 @@
import collections
import glob
import os
import Queue
import re
import sys
import thread
import traceback
if 'mtimes' not in globals():
mtimes = {}
@ -11,21 +15,38 @@ if 'lastfiles' not in globals():
lastfiles = set()
def format_plug(plug, lpad=0, width=40):
out = ' ' * lpad + '%s:%s:%s' % (plug[0])
if len(plug) == 3 and 'hook' in plug[2]:
out += '%s%s' % (' ' * (width - len(out)), plug[2]['hook'])
def make_signature(f):
return f.func_code.co_filename, f.func_name, f.func_code.co_firstlineno
def format_plug(plug, kind='', lpad=0, width=40):
out = ' ' * lpad + '%s:%s:%s' % make_signature(plug[0])
if kind == 'command':
out += ' ' * (50 - len(out)) + plug[1]['name']
if kind == 'event':
out += ' ' * (50 - len(out)) + ', '.join(plug[1]['events'])
return out
def reload(init=False):
if init:
bot.plugs = collections.defaultdict(lambda: [])
for filename in glob.glob(os.path.join("core", "*.py")):
def reload(init=False):
changed = False
if init:
bot.plugs = collections.defaultdict(list)
bot.threads = {}
core_fileset = set(glob.glob(os.path.join("core", "*.py")))
for filename in core_fileset:
mtime = os.stat(filename).st_mtime
if mtime != mtimes.get(filename):
mtimes[filename] = mtime
changed = True
try:
eval(compile(open(filename, 'U').read(), filename, 'exec'),
globals())
@ -39,14 +60,29 @@ def reload(init=False):
reload(init=init)
return
fileset = set(glob.glob(os.path.join('plugins', '*py')))
for name, data in bot.plugs.iteritems(): # remove deleted/moved plugins
bot.plugs[name] = filter(lambda x: x[0][0] in fileset, data)
fileset = set(glob.glob(os.path.join('plugins', '*.py')))
# remove deleted/moved plugins
for name, data in bot.plugs.iteritems():
bot.plugs[name] = [x for x in data if x[0]._filename in fileset]
for filename in list(mtimes):
if filename not in fileset and filename not in core_fileset:
mtimes.pop(filename)
for func, handler in list(bot.threads.iteritems()):
if func._filename not in fileset:
handler.stop()
del bot.threads[func]
# compile new plugins
for filename in fileset:
mtime = os.stat(filename).st_mtime
if mtime != mtimes.get(filename):
mtimes[filename] = mtime
changed = True
try:
code = compile(open(filename, 'U').read(), filename, 'exec')
namespace = {}
@ -57,28 +93,69 @@ def reload(init=False):
# remove plugins already loaded from this filename
for name, data in bot.plugs.iteritems():
bot.plugs[name] = [x for x in data
if x[0]._filename != filename]
if name == 'tee': # signal tee trampolines to stop
for csig, func, args in data:
if csig[0] == filename:
func._iqueue.put(StopIteration)
bot.plugs[name] = filter(lambda x: x[0][0] != filename, data)
for func, handler in list(bot.threads.iteritems()):
if func._filename == filename:
handler.stop()
del bot.threads[func]
for obj in namespace.itervalues():
if hasattr(obj, '_skybot_hook'): # check for magic
for type, data in obj._skybot_hook:
if hasattr(obj, '_hook'): # check for magic
if obj._thread:
bot.threads[obj] = Handler(obj)
for type, data in obj._hook:
bot.plugs[type] += [data]
if not init:
print '### new plugin (type: %s) loaded:' % \
type, format_plug(data)
if changed:
bot.commands = {}
for plug in bot.plugs['command']:
name = plug[1]['name'].lower()
if not re.match(r'^\w+$', name):
print '### ERROR: invalid command name "%s" (%s)' % (name,
format_plug(plug))
continue
if name in bot.commands:
print "### ERROR: command '%s' already registered (%s, %s)" % \
(name, format_plug(bot.commands[name]),
format_plug(plug))
continue
bot.commands[name] = plug
bot.events = collections.defaultdict(list)
for func, args in bot.plugs['event']:
for event in args['events']:
bot.events[event].append((func, args))
if init:
print ' plugin listing:'
for type, plugs in sorted(bot.plugs.iteritems()):
if bot.commands:
# hack to make commands with multiple aliases
# print nicely
print ' command:'
commands = collections.defaultdict(list)
for name, (func, args) in bot.commands.iteritems():
commands[make_signature(func)].append(name)
for sig, names in sorted(commands.iteritems()):
names.sort(key=lambda x: (-len(x), x)) # long names first
out = ' ' * 6 + '%s:%s:%s' % sig
out += ' ' * (50 - len(out)) + ', '.join(names)
print out
for kind, plugs in sorted(bot.plugs.iteritems()):
if kind == 'command':
continue
print ' %s:' % type
for plug in plugs:
out = ' %s:%s:%s' % (plug[0])
print format_plug(plug, lpad=6)
print format_plug(plug, kind=kind, lpad=6)
print

View File

@ -83,8 +83,9 @@ def get_log_fd(dir, server, chan):
return fd
@hook.tee
def log(bot, input):
@hook.thread
@hook.event('*')
def log(inp, input=None, bot=None):
with lock:
timestamp = gmtime(timestamp_format)

View File

@ -1,7 +1,9 @@
from util import hook
import socket
import time
socket.setdefaulttimeout(5) # global setting
from util import hook
socket.setdefaulttimeout(10) # global setting
#autorejoin channels
@ -22,11 +24,13 @@ def invite(inp, command='', conn=None):
#join channels when server says hello & identify bot
@hook.event('004')
def onjoin(inp, conn=None):
for channel in conn.channels:
conn.join(channel)
nickserv_password = conn.conf.get('nickserv_password', '')
nickserv_name = conn.conf.get('nickserv_name', 'nickserv')
nickserv_command = conn.conf.get('nickserv_command', 'IDENTIFY %s')
if nickserv_password:
conn.msg(nickserv_name, nickserv_command % nickserv_password)
time.sleep(1)
for channel in conn.channels:
conn.join(channel)
time.sleep(1) # don't flood JOINs

View File

@ -62,7 +62,7 @@ def forget(inp, chan='', db=None):
return "I don't know about that."
@hook.command(hook='\?(.+)', prefix=False)
@hook.event('PRIVMSG', hook=r'\?(.+)')
def question(inp, chan='', say=None, db=None):
"?<word> -- shows what data is associated with word"
db_init(db)

View File

@ -5,11 +5,9 @@ import time
from util import hook, timesince
@hook.tee
def seeninput(bot, input):
if input.command != 'PRIVMSG':
return
@hook.thread
@hook.event('PRIVMSG')
def seeninput(inp, input=None, bot=None):
db = bot.get_db_connection(input.server)
db_init(db)
db.execute("insert or replace into seen(name, time, quote, chan)"

View File

@ -5,26 +5,14 @@ from util import hook
@hook.sieve
def sieve_suite(bot, input, func, args):
def sieve_suite(bot, input, func, kind, args):
events = args.get('events', ['PRIVMSG'])
if input.command not in events and events != '*':
if input.command not in events and '*' not in events:
return None
if input.nick.lower()[-3:] == 'bot' and args.get('ignorebots', True):
return None
hook = args.get('hook', r'(.*)')
if args.get('prefix', True):
if input.chan == input.nick: # private message, prefix not required
prefix = r'^(?:[.!]?|'
else:
prefix = r'^(?:[.!]|'
hook = prefix + input.conn.nick + r'[:,]*\s)' + hook
input.re = re.match(hook, input.msg, flags=re.I)
if input.re is None:
if input.command == 'PRIVMSG' and input.nick.lower()[-3:] == 'bot' \
and args.get('ignorebots', True):
return None
acl = bot.config.get('acls', {}).get(func.__name__)
@ -38,7 +26,7 @@ def sieve_suite(bot, input, func, args):
if input.chan.lower() in denied_channels:
return None
input.inp_unstripped = ' '.join(input.re.groups())
input.inp = input.inp_unstripped.strip()
# input.inp_unstripped = ' '.join(input.re.groups())
# input.inp = input.inp_unstripped.strip()
return input

View File

@ -12,11 +12,9 @@ def get_tells(db, user_to, chan):
(user_to.lower(), chan)).fetchall()
@hook.tee
def tellinput(bot, input):
if input.command != 'PRIVMSG':
return
@hook.thread
@hook.event('PRIVMSG')
def tellinput(inp, input=None, bot=None):
if 'showtells' in input.msg.lower():
return

View File

@ -8,7 +8,7 @@ tinyurl_re = re.compile(r'http://(?:www\.)?tinyurl.com/([A-Za-z0-9\-]+)',
flags=re.IGNORECASE)
@hook.command(hook=r'(.*)', prefix=False)
@hook.event('PRIVMSG')
def tinyurl(inp):
tumatch = tinyurl_re.search(inp)
if tumatch:

View File

@ -67,7 +67,7 @@ def format_reply(history):
hour_span, nicklist(history), last)
@hook.command(hook=r'(.*)', prefix=False)
@hook.event('PRIVMSG')
def urlinput(inp, nick='', chan='', server='', reply=None, bot=None):
m = url_re.search(inp.encode('utf8'))
if not m:

View File

@ -1,20 +1,16 @@
import inspect
import thread
import traceback
import Queue
def _isfunc(x):
if type(x) == type(_isfunc):
return True
return False
def _hook_add(func, add, name=''):
if not hasattr(func, '_skybot_hook'):
func._skybot_hook = []
func._skybot_hook.append(add)
if not hasattr(func, '_skybot_args'):
if not hasattr(func, '_hook'):
func._hook = []
func._hook.append(add)
if not hasattr(func, '_filename'):
func._filename = func.func_code.co_filename
if not hasattr(func, '_args'):
argspec = inspect.getargspec(func)
if name:
n_args = len(argspec.args)
@ -36,39 +32,34 @@ def _hook_add(func, add, name=''):
end if end else None])
if argspec.keywords:
args.append(0) # means kwargs present
func._skybot_args = args
def _make_sig(f):
return f.func_code.co_filename, f.func_name, f.func_code.co_firstlineno
func._args = args
if not hasattr(func, '_skybot_thread'): # does function run in its own thread?
func._thread = False
def sieve(func):
if func.func_code.co_argcount != 4:
if func.func_code.co_argcount != 5:
raise ValueError(
'sieves must take 4 arguments: (bot, input, func, args)')
_hook_add(func, ['sieve', (_make_sig(func), func)])
'sieves must take 5 arguments: (bot, input, func, type, args)')
_hook_add(func, ['sieve', (func,)])
return func
def command(func=None, hook=None, **kwargs):
def command(arg, **kwargs):
args = {}
def command_wrapper(func):
args.setdefault('name', func.func_name)
args.setdefault('hook', args['name'] + r'(?:\s+|$)(.*)')
_hook_add(func, ['command', (_make_sig(func), func, args)], 'command')
_hook_add(func, ['command', (func, args)], 'command')
return func
if hook is not None or kwargs or not _isfunc(func):
if func is not None:
args['name'] = func
if hook is not None:
args['hook'] = hook
if kwargs or not inspect.isfunction(arg):
if arg is not None:
args['name'] = arg
args.update(kwargs)
return command_wrapper
else:
return command_wrapper(func)
return command_wrapper(arg)
def event(arg=None, **kwargs):
@ -76,12 +67,11 @@ def event(arg=None, **kwargs):
def event_wrapper(func):
args['name'] = func.func_name
args['prefix'] = False
args.setdefault('events', '*')
_hook_add(func, ['event', (_make_sig(func), func, args)], 'event')
args.setdefault('events', ['*'])
_hook_add(func, ['event', (func, args)], 'event')
return func
if _isfunc(arg):
if inspect.isfunction(arg):
return event_wrapper(arg, kwargs)
else:
if arg is not None:
@ -89,26 +79,6 @@ def event(arg=None, **kwargs):
return event_wrapper
def tee(func, **kwargs):
"passes _all_ input lines to function, in order (skips sieves)"
if func.func_code.co_argcount != 2:
raise ValueError('tees must take 2 arguments: (bot, input)')
_hook_add(func, ['tee', (_make_sig(func), func, kwargs)])
func._iqueue = Queue.Queue()
def trampoline(func):
input = None
while True:
input = func._iqueue.get()
if input == StopIteration:
return
try:
func(*input)
except Exception:
traceback.print_exc(Exception)
thread.start_new_thread(trampoline, (func,))
def thread(func):
func._thread = True
return func

View File

@ -54,7 +54,7 @@ def get_video_description(vid_id):
return out
@hook.command(hook=r'(.*)', prefix=False)
@hook.event('PRIVMSG')
def youtube_url(inp):
m = youtube_re.search(inp)
if m: