summaryrefslogtreecommitdiffstats
path: root/pinolo/bot.py
blob: 7690b3bb80fa1e33e3ee2b0673823acd22be6992 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
# -*- coding: utf-8 -*-
"""
    pinolo.bot
    ~~~~~~~~~~

    The Bot class contains functions to start and stop the bot, handle network
    traffic and loading of plugins.

    :copyright: (c) 2013 Daniel Kertesz
    :license: BSD, see LICENSE for more details.
"""
import os
import re
import ssl
import socket
import select
import errno
import time
import logging
import Queue
import traceback
from pprint import pprint
import pinolo.plugins
from pinolo.signals import SignalDispatcher
from pinolo.irc import IRCConnection, COMMAND_ALIASES
from pinolo.database import init_db
from pinolo.config import empty_config


log = logging.getLogger()

# Crontab interval in seconds
CRONTAB_INTERVAL = 60

# Timeout in seconds for the select() syscall
SELECT_TIMEOUT = 1


class Bot(SignalDispatcher):
    """Main Bot controller class.

    Handle the network stuff, must be initialized with a configuration object.
    """
    def __init__(self, config):
        SignalDispatcher.__init__(self)
        self.config = config
        self.connections = {}
        self.connection_map = {}
        self.coda = Queue.Queue()
        self.plugins = []
        self.db_uri = "sqlite:///%s" % os.path.join(
            self.config["datadir"], "db.sqlite")
        self.db_engine = None
        self.running = False

        for server in config['servers']:
            server_config = config['servers'][server]
            ircc = IRCConnection(server, server_config, self)
            self.connections[server] = ircc

    def start(self):
        # Here we also load and activate the plugins
        self.load_plugins()
        # XXX Database get initialized HERE.
        self.db_engine = init_db(self.db_uri)
        self.activate_plugins()

        self.signal_emit("pre_connect")
        
        for conn_name, conn_obj in self.connections.iteritems():
            log.info("Connecting to server: %s" % conn_name)
            conn_obj.connect()

        for conn_obj in self.connections.values():
            self.connection_map[conn_obj.socket] = conn_obj

        self.running = True
        self.main_loop()

        # at last...
        self.shutdown()

    def main_loop(self):
        """Main loop. Here we handle the network connections and buffers,
        dispatching events to the IRC clients when needed."""

        self._last_crontab = time.time()
        
        while self.running:
            # handle_network() will block for at most 1 second during
            # the select() syscall
            self.handle_network()
            self.check_queue()
            self.handle_cron()

    def do_handshake(self, s):
        try:
            s.do_handshake()
        except ssl.SSLError as err:
            if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
                return False
            else:
                raise
        return True

    def handle_network(self):
        # For the select() call we must create two distinct groups of sockets
        # to watch for: all the active sockets must be checked for reading, but
        # only sockets with a non empty out-buffer will be checked for writing.
        in_sockets = []
        for connection in self.connections.values():
            if connection.active:
                in_sockets.append(connection.socket)

        out_sockets = []
        for connection in self.connections.values():
            if len(connection.out_buffer):
                out_sockets.append(connection.socket)

        # This is ugly. XXX
        if not in_sockets:
            log.warning("No more active connections. exiting...")
            self.running = False
            return

        readable, writable, _ = select.select(in_sockets,
                                              out_sockets,
                                              [],
                                              SELECT_TIMEOUT)

        # Do the reading for the readable sockets
        for s in readable:
            conn_obj = self.connection_map[s]

            # Do SSL handshake if needed
            if conn_obj.ssl_must_handshake and conn_obj.connected:
                result = self.do_handshake(conn_obj.socket)
                if not result:
                    continue

            # We must read data from socket until the syscall returns EAGAIN;
            # when the OS signals EAGAIN the socket would block reading.
            while True:
                try:
                    chunk = s.recv(512)
                except (socket.error, ssl.SSLError) as err:
                    if err.args[0] in (errno.EAGAIN, ssl.SSL_ERROR_WANT_READ):
                        break
                    else:
                        raise

                if chunk == '':
                    conn_obj.connected = False
                    conn_obj.active = False
                    log.error("{0} disconnected (EOF from server)".format(conn_obj.name))
                    break
                else:
                    conn_obj.in_buffer += chunk

            self.connection_map[s].check_in_buffer()

        # scrive
        for s in writable:
            conn_obj = self.connection_map[s]

            # If this is the first time we get a "writable" status then
            # we are actually connected to the remote server.
            if conn_obj.connected == False:
                log.info("Connected to %s" % conn_obj.name)
                conn_obj.connected = True

                # SSL socket setup
                if conn_obj.config["ssl"]:
                    conn_obj.wrap_ssl()
                    # swap the socket in the connection map with the ssl one
                    self.connection_map[conn_obj.socket] = conn_obj
                    del self.connection_map[s]
                    s = conn_obj.socket

            # SSL handshake
            if conn_obj.ssl_must_handshake and conn_obj.connected:
                result = self.do_handshake(s)
                if not result:
                    continue
                
            # check if we got disconnected while reading from socket
            # XXX should be empty the out buffer?
            if not conn_obj.connected:
                log.error("Trying to write to a non connected socket!")
                conn_obj.out_buffer = ""
                continue

            while len(conn_obj.out_buffer):
                try:
                    sent = s.send(conn_obj.out_buffer)
                    # Qui si potrebbe inserire una pausa artificiale
                    # per evitare i flood? ma il flood anche sticazzi,
                    # server *decenti* tipo inspircd non hanno piĆ¹ quel
                    # problema.
                except (socket.error, ssl.SSLError) as err:
                    if err.args[0] in (errno.EAGAIN, ssl.SSL_ERROR_WANT_WRITE):
                        break
                    else:
                        raise
                conn_obj.out_buffer = conn_obj.out_buffer[sent:]

    def check_queue(self):
        """Check the thread queue

        THIS IS JUST A PROTOTYPE!
        We should pass the IRC event in the Thread object, so we can later send
        the output to the correct channel or nickname.
        """
        try:
            data = list(self.coda.get_nowait())
        except Queue.Empty, e:
            pass
        else:
            fn = data.pop(0)
            fn(*data)

    def handle_cron(self):
        """A simple crontab that will be run approximatly every
        CRONTAB_INTERVAL seconds."""
        now = time.time()
        
        if (now - self._last_crontab) >= CRONTAB_INTERVAL:
            self._last_crontab = now

    def quit(self, message="Ctrl-C"):
        """Quit all connected clients"""

        log.info("Shutting down all connections")
        for conn_obj in self.connections.itervalues():
            conn_obj.quit(message)

    def load_plugins(self, exit_on_fail=False):
        """Load all plugins from the plugins module"""
        
        def my_import(name):
            """Import by filename (taken from effbot)"""
            
            m = __import__(name)
            for n in name.split(".")[1:]:
                m = getattr(m, n)
            return m

        plugins_dir = os.path.join(
            os.path.abspath(os.path.dirname(__file__)),
            "plugins")

        self.signal_emit("pre_load_plugins")

        disabled_plugins = self.config.get("disabled_plugins", [])

        filtro = re.compile(r"^[^_].+\.py$")
        for filename in filter(filtro.match, os.listdir(plugins_dir)):
            plugin_name = os.path.splitext(filename)[0]

            if plugin_name in disabled_plugins:
                log.info("Not loading disabled plugin (from config): %s" % plugin_name)
                continue
            
            log.info("Loading plugin %s" % plugin_name)
            try:
                module = my_import("pinolo.plugins." + plugin_name)
            except Exception, e:
                print "Failed to load plugin '%s':" % plugin_name
                for line in traceback.format_exception_only(type(e), e):
                    print "-", line,
                if exit_on_fail:
                    raise

            self.signal_emit("plugin_loaded", plugin_name=plugin_name,
                             plugin_module=module)

        self.signal_emit("post_load_plugins")

    def activate_plugins(self):
        """Call the activate method on all loaded plugins"""

        def basename(s):
            return s.split(".")[-1]

        for _, plugin_class in pinolo.plugins.registry:
            plugin_name = basename(plugin_class.__module__)
            log.info("Activating plugin %s" % plugin_name)
            if plugin_name in self.config["plugins"]:
                plugin_config = self.config["plugins"][plugin_name]
            else:
                plugin_config = empty_config(self.config, plugin_name)

            p_obj = plugin_class(self, plugin_config)
            p_obj.activate()
            self.plugins.append(p_obj)
            COMMAND_ALIASES.update(p_obj.COMMAND_ALIASES.items())
            self.signal_emit("plugin_activated", plugin_name=plugin_name,
                             plugin_object=p_obj)

    def deactivate_plugins(self):
        """Call deactivate method on all the loaded plugins.

        TODO: Should we also destroy the plugin objects?
        """
        for plugin in self.plugins:
            plugin_name = plugin.__class__.__name__
            plugin.deactivate()
            self.signal_emit("plugin_deactivated", plugin_name=plugin_name,
                             plugin_object=plugin)

    def shutdown(self):
        log.info("Bot shutdown")
        self.deactivate_plugins()