Add support for SSL and server passwords

This commit is contained in:
melonhead 2010-02-22 14:33:03 -05:00
parent 5cc1366e0d
commit 4c147f2402
2 changed files with 73 additions and 12 deletions

13
bot.py
View File

@ -32,8 +32,17 @@ try:
if name in bot.conns: if name in bot.conns:
print 'ERROR: more than one connection named "%s"' % name print 'ERROR: more than one connection named "%s"' % name
raise ValueError raise ValueError
bot.conns[name] = IRC(conf['server'], conf['nick'], ssl = conf.get('ssl', False)
port=conf.get('port', 6667), channels=conf['channels'], conf=conf) password = conf.get('password', None)
if ssl:
bot.conns[name] = SSLIRC(conf['server'], conf['nick'],
port=conf.get('port', 6667), channels=conf['channels'], conf=conf,
password=password,
ignoreCertificateErrors=conf.get('ignore_cert', True))
else:
bot.conns[name] = IRC(conf['server'], conf['nick'],
port=conf.get('port', 6667), channels=conf['channels'], conf=conf,
password=password)
except Exception, e: except Exception, e:
print 'ERROR: malformed config file', Exception, e print 'ERROR: malformed config file', Exception, e
sys.exit() sys.exit()

View File

@ -4,6 +4,7 @@ import socket
import time import time
import thread import thread
import Queue import Queue
from ssl import wrap_socket, CERT_NONE, CERT_REQUIRED, SSLError
def decode(txt): def decode(txt):
for codec in ('utf-8', 'iso-8859-1', 'shift_jis', 'cp1252'): for codec in ('utf-8', 'iso-8859-1', 'shift_jis', 'cp1252'):
@ -13,7 +14,6 @@ def decode(txt):
continue continue
return txt.decode('utf-8', 'ignore') return txt.decode('utf-8', 'ignore')
class crlf_tcp(object): class crlf_tcp(object):
"Handles tcp connections that consist of utf-8 lines ending with crlf" "Handles tcp connections that consist of utf-8 lines ending with crlf"
@ -22,21 +22,37 @@ class crlf_tcp(object):
self.obuffer = "" self.obuffer = ""
self.oqueue = Queue.Queue() # lines to be sent out self.oqueue = Queue.Queue() # lines to be sent out
self.iqueue = Queue.Queue() # lines that were received self.iqueue = Queue.Queue() # lines that were received
self.socket = socket.socket(socket.AF_INET, socket.TCP_NODELAY) self.socket = self.create_socket()
self.host = host self.host = host
self.port = port self.port = port
self.timeout = timeout self.timeout = timeout
def create_socket(self):
return socket.socket(socket.AF_INET, socket.TCP_NODELAY)
def run(self): def run(self):
self.socket.connect((self.host, self.port)) self.socket.connect((self.host, self.port))
thread.start_new_thread(self.recv_loop, ()) thread.start_new_thread(self.recv_loop, ())
thread.start_new_thread(self.send_loop, ()) thread.start_new_thread(self.send_loop, ())
def recv_from_socket(self, nbytes):
return self.socket.recv(nbytes)
def get_timeout_exception_type(self):
return socket.timeout
def handle_receive_exception(self, error, last_timestamp):
if time.time() - last_timestamp > self.timeout:
self.iqueue.put(StopIteration)
self.socket.close()
return True
return False
def recv_loop(self): def recv_loop(self):
last_timestamp = time.time() last_timestamp = time.time()
while True: while True:
try: try:
data = self.socket.recv(4096) data = self.recv_from_socket(4096)
self.ibuffer += data self.ibuffer += data
if data: if data:
last_timestamp = time.time() last_timestamp = time.time()
@ -46,10 +62,8 @@ class crlf_tcp(object):
self.socket.close() self.socket.close()
return return
time.sleep(1) time.sleep(1)
except socket.timeout, e: except self.get_timeout_exception_type(), e:
if time.time() - last_timestamp > self.timeout: if self.handle_receive_exception(e, last_timestamp):
self.iqueue.put(StopIteration)
self.socket.close()
return return
continue continue
@ -66,6 +80,27 @@ class crlf_tcp(object):
sent = self.socket.send(self.obuffer) sent = self.socket.send(self.obuffer)
self.obuffer = self.obuffer[sent:] self.obuffer = self.obuffer[sent:]
class crlf_ssl_tcp(crlf_tcp):
"Handles ssl tcp connetions that consist of utf-8 lines ending with crlf"
def __init__(self, host, port, ignoreCertErrors, timeout=300):
self.ignoreCertErrors = ignoreCertErrors
crlf_tcp.__init__(self, host, port, timeout)
def create_socket(self):
return wrap_socket(crlf_tcp.create_socket(self), server_side=False, cert_reqs = CERT_NONE if self.ignoreCertErrors else CERT_REQUIRED )
def recv_from_socket(self, nbytes):
return self.socket.read(nbytes)
def get_timeout_exception_type(self):
return SSLError
def handle_receive_exception(self, error, last_timestamp):
# this is terrible
if not "timed out" in error.args[0]:
raise
return crlf_tcp.handle_receive_exception(self, error, last_timestamp)
irc_prefix_rem = re.compile(r'(.*?) (.*?) (.*)').match irc_prefix_rem = re.compile(r'(.*?) (.*?) (.*)').match
irc_noprefix_rem = re.compile(r'()(.*?) (.*)').match irc_noprefix_rem = re.compile(r'()(.*?) (.*)').match
irc_netmask_rem = re.compile(r':?([^!@]*)!?([^@]*)@?(.*)').match irc_netmask_rem = re.compile(r':?([^!@]*)!?([^@]*)@?(.*)').match
@ -75,12 +110,13 @@ irc_param_ref = re.compile(r'(?:^|(?<= ))(:.*|[^ ]+)').findall
class IRC(object): class IRC(object):
"handles the IRC protocol" "handles the IRC protocol"
#see the docs/ folder for more information on the protocol #see the docs/ folder for more information on the protocol
def __init__(self, server, nick, port=6667, channels=[], conf={}): def __init__(self, server, nick, port=6667, channels=[], conf={}, password=None):
self.channels = channels self.channels = channels
self.conf = conf self.conf = conf
self.server = server self.server = server
self.port = port self.port = port
self.nick = nick self.nick = nick
self.password = password
self.out = Queue.Queue() #responses from the server are placed here self.out = Queue.Queue() #responses from the server are placed here
# format: [rawline, prefix, command, params, # format: [rawline, prefix, command, params,
@ -89,9 +125,13 @@ class IRC(object):
thread.start_new_thread(self.parse_loop, ()) thread.start_new_thread(self.parse_loop, ())
def create_connection(self):
return crlf_tcp(self.server, self.port)
def connect(self): def connect(self):
self.conn = crlf_tcp(self.server, self.port) self.conn = self.create_connection()
thread.start_new_thread(self.conn.run, ()) thread.start_new_thread(self.conn.run, ())
self.set_pass(self.password)
self.set_nick(self.nick) self.set_nick(self.nick)
self.cmd("USER", self.cmd("USER",
[conf.get('user', 'skybot'), "3", "*", ':' + conf.get('realname', [conf.get('user', 'skybot'), "3", "*", ':' + conf.get('realname',
@ -119,6 +159,10 @@ class IRC(object):
if command == "PING": if command == "PING":
self.cmd("PONG", [params]) self.cmd("PONG", [params])
def set_pass(self, password):
if password:
self.cmd("PASS", [password])
def set_nick(self, nick): def set_nick(self, nick):
self.cmd("NICK", [nick]) self.cmd("NICK", [nick])
@ -180,3 +224,11 @@ class FakeIRC(IRC):
def cmd(self, command, params=None): def cmd(self, command, params=None):
pass pass
class SSLIRC(IRC):
def __init__(self, server, nick, port=6667, channels=[], conf={}, password=None, ignoreCertificateErrors=True):
self.ignoreCertErrors = ignoreCertificateErrors
IRC.__init__(self, server, nick, port, channels, conf, password)
def create_connection(self):
return crlf_ssl_tcp(self.server, self.port, self.ignoreCertErrors)