diff --git a/bot.py b/bot.py index e19b51c..232281a 100755 --- a/bot.py +++ b/bot.py @@ -32,8 +32,17 @@ try: if name in bot.conns: print 'ERROR: more than one connection named "%s"' % name raise ValueError - bot.conns[name] = IRC(conf['server'], conf['nick'], - port=conf.get('port', 6667), channels=conf['channels'], conf=conf) + ssl = conf.get('ssl', False) + password = conf.get('password', None) + if ssl: + bot.conns[name] = SSLIRC(conf['server'], conf['nick'], + port=conf.get('port', 6667), channels=conf['channels'], conf=conf, + password=password, + ignoreCertificateErrors=conf.get('ignore_cert', True)) + else: + bot.conns[name] = IRC(conf['server'], conf['nick'], + port=conf.get('port', 6667), channels=conf['channels'], conf=conf, + password=password) except Exception, e: print 'ERROR: malformed config file', Exception, e sys.exit() diff --git a/core/irc.py b/core/irc.py index 7b6e591..fb32d46 100644 --- a/core/irc.py +++ b/core/irc.py @@ -4,6 +4,7 @@ import socket import time import thread import Queue +from ssl import wrap_socket, CERT_NONE, CERT_REQUIRED, SSLError def decode(txt): for codec in ('utf-8', 'iso-8859-1', 'shift_jis', 'cp1252'): @@ -13,7 +14,6 @@ def decode(txt): continue return txt.decode('utf-8', 'ignore') - class crlf_tcp(object): "Handles tcp connections that consist of utf-8 lines ending with crlf" @@ -22,21 +22,37 @@ class crlf_tcp(object): self.obuffer = "" self.oqueue = Queue.Queue() # lines to be sent out self.iqueue = Queue.Queue() # lines that were received - self.socket = socket.socket(socket.AF_INET, socket.TCP_NODELAY) + self.socket = self.create_socket() self.host = host self.port = port self.timeout = timeout + def create_socket(self): + return socket.socket(socket.AF_INET, socket.TCP_NODELAY) + def run(self): self.socket.connect((self.host, self.port)) thread.start_new_thread(self.recv_loop, ()) thread.start_new_thread(self.send_loop, ()) + def recv_from_socket(self, nbytes): + return self.socket.recv(nbytes) + + def get_timeout_exception_type(self): + return socket.timeout + + def handle_receive_exception(self, error, last_timestamp): + if time.time() - last_timestamp > self.timeout: + self.iqueue.put(StopIteration) + self.socket.close() + return True + return False + def recv_loop(self): last_timestamp = time.time() while True: try: - data = self.socket.recv(4096) + data = self.recv_from_socket(4096) self.ibuffer += data if data: last_timestamp = time.time() @@ -46,10 +62,8 @@ class crlf_tcp(object): self.socket.close() return time.sleep(1) - except socket.timeout, e: - if time.time() - last_timestamp > self.timeout: - self.iqueue.put(StopIteration) - self.socket.close() + except self.get_timeout_exception_type(), e: + if self.handle_receive_exception(e, last_timestamp): return continue @@ -66,6 +80,27 @@ class crlf_tcp(object): sent = self.socket.send(self.obuffer) self.obuffer = self.obuffer[sent:] +class crlf_ssl_tcp(crlf_tcp): + "Handles ssl tcp connetions that consist of utf-8 lines ending with crlf" + def __init__(self, host, port, ignoreCertErrors, timeout=300): + self.ignoreCertErrors = ignoreCertErrors + crlf_tcp.__init__(self, host, port, timeout) + + def create_socket(self): + return wrap_socket(crlf_tcp.create_socket(self), server_side=False, cert_reqs = CERT_NONE if self.ignoreCertErrors else CERT_REQUIRED ) + + def recv_from_socket(self, nbytes): + return self.socket.read(nbytes) + + def get_timeout_exception_type(self): + return SSLError + + def handle_receive_exception(self, error, last_timestamp): + # this is terrible + if not "timed out" in error.args[0]: + raise + return crlf_tcp.handle_receive_exception(self, error, last_timestamp) + irc_prefix_rem = re.compile(r'(.*?) (.*?) (.*)').match irc_noprefix_rem = re.compile(r'()(.*?) (.*)').match irc_netmask_rem = re.compile(r':?([^!@]*)!?([^@]*)@?(.*)').match @@ -75,12 +110,13 @@ irc_param_ref = re.compile(r'(?:^|(?<= ))(:.*|[^ ]+)').findall class IRC(object): "handles the IRC protocol" #see the docs/ folder for more information on the protocol - def __init__(self, server, nick, port=6667, channels=[], conf={}): + def __init__(self, server, nick, port=6667, channels=[], conf={}, password=None): self.channels = channels self.conf = conf self.server = server self.port = port self.nick = nick + self.password = password self.out = Queue.Queue() #responses from the server are placed here # format: [rawline, prefix, command, params, @@ -89,9 +125,13 @@ class IRC(object): thread.start_new_thread(self.parse_loop, ()) + def create_connection(self): + return crlf_tcp(self.server, self.port) + def connect(self): - self.conn = crlf_tcp(self.server, self.port) + self.conn = self.create_connection() thread.start_new_thread(self.conn.run, ()) + self.set_pass(self.password) self.set_nick(self.nick) self.cmd("USER", [conf.get('user', 'skybot'), "3", "*", ':' + conf.get('realname', @@ -119,6 +159,10 @@ class IRC(object): if command == "PING": self.cmd("PONG", [params]) + def set_pass(self, password): + if password: + self.cmd("PASS", [password]) + def set_nick(self, nick): self.cmd("NICK", [nick]) @@ -179,4 +223,12 @@ class FakeIRC(IRC): self.cmd("PONG", [params]) def cmd(self, command, params=None): - pass \ No newline at end of file + pass + +class SSLIRC(IRC): + def __init__(self, server, nick, port=6667, channels=[], conf={}, password=None, ignoreCertificateErrors=True): + self.ignoreCertErrors = ignoreCertificateErrors + IRC.__init__(self, server, nick, port, channels, conf, password) + + def create_connection(self): + return crlf_ssl_tcp(self.server, self.port, self.ignoreCertErrors)