from systemd.daemon import listen_fds from sdnotify import SystemdNotifier from socketserver import StreamRequestHandler, ThreadingMixIn from systemd_socketserver import SystemdSocketServer import sys from threading import Thread from psycopg_pool import ConnectionPool from psycopg.rows import namedtuple_row import logging class PolicyHandler(StreamRequestHandler): def handle(self): logger.debug('Handling new connection...') self.args = dict() line = None while line := self.rfile.readline().removesuffix(b'\n'): if b'=' not in line: break key, val = line.split(sep=b'=', maxsplit=1) self.args[key.decode()] = val.decode() logger.info('Connection parameters: %s', self.args) allowed = False with self.server.db_pool.connection() as conn: local, domain = self.args['sender'].split(sep='@', maxsplit=1) extension = None if '+' in local: local, extension = local.split(sep='+', maxsplit=1) logger.debug('Parsed address: %s', {'local': local, 'extension': extension, 'domain': domain}) with conn.cursor() as cur: cur.row_factory = namedtuple_row cur.execute('SELECT "mailbox"."mailbox" as "user", "local", "extension", "domain" FROM "mailbox" INNER JOIN "mailbox_mapping" ON "mailbox".id = "mailbox_mapping"."mailbox" WHERE "mailbox"."mailbox" = %(user)s AND ("local" = %(local)s OR "local" IS NULL) AND ("extension" = %(extension)s OR "extension" IS NULL) AND "domain" = %(domain)s', params = {'user': self.args['ccert_subject'], 'local': local, 'extension': extension if extension is not None else '', 'domain': domain}, prepare=True) for record in cur: logger.debug('Received result: %s', record) allowed = True action = '550 5.7.0 Sender address not authorized for current user' if allowed: action = 'DUNNO' logger.info('Reached verdict: %s', {'allowed': allowed, 'action': action}) self.wfile.write(f'action={action}\n\n'.encode()) class ThreadedSystemdSocketServer(ThreadingMixIn, SystemdSocketServer): def __init__(self, fd, RequestHandlerClass): super().__init__(fd, RequestHandlerClass) self.db_pool = ConnectionPool(min_size=1) self.db_pool.wait() def main(): global logger logger = logging.getLogger(__name__) console_handler = logging.StreamHandler() console_handler.setFormatter( logging.Formatter('[%(levelname)s](%(name)s): %(message)s') ) if sys.stderr.isatty(): console_handler.setFormatter( logging.Formatter('%(asctime)s [%(levelname)s](%(name)s): %(message)s') ) logger.addHandler(console_handler) logger.setLevel(logging.DEBUG) # log uncaught exceptions def log_exceptions(type, value, tb): global logger logger.error(value) sys.__excepthook__(type, value, tb) # calls default excepthook sys.excepthook = log_exceptions fds = listen_fds() servers = [ThreadedSystemdSocketServer(fd, PolicyHandler) for fd in fds] if servers: for server in servers: Thread(name=f'Server for fd{server.fileno()}', target=server.serve_forever).start() else: return 2 SystemdNotifier().notify('READY=1') return 0 if __name__ == '__main__': sys.exit(main())