From 75695d3e42bfe15483cefa43f316a4ae11a3bcca Mon Sep 17 00:00:00 2001 From: Gregor Kleen Date: Wed, 21 May 2025 09:24:30 +0200 Subject: ... --- .../internal_policy_server/__main__.py | 106 +++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py (limited to 'hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py') diff --git a/hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py b/hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py new file mode 100644 index 00000000..04f1a59a --- /dev/null +++ b/hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py @@ -0,0 +1,106 @@ +from systemd.daemon import listen_fds +from sdnotify import SystemdNotifier +from socketserver import StreamRequestHandler, ThreadingMixIn +from systemd_socketserver import SystemdSocketServer +import sys +from threading import Thread +from psycopg_pool import ConnectionPool +from psycopg.rows import namedtuple_row + +import logging + + +class PolicyHandler(StreamRequestHandler): + def handle(self): + logger.debug('Handling new connection...') + + self.args = dict() + + line = None + while line := self.rfile.readline().removesuffix(b'\n'): + if b'=' not in line: + break + + key, val = line.split(sep=b'=', maxsplit=1) + self.args[key.decode()] = val.decode() + + logger.info('Connection parameters: %s', self.args) + + allowed = False + user = None + if self.args['sasl_username']: + user = self.args['sasl_username'] + if self.args['ccert_subject']: + user = self.args['ccert_subject'] + + with self.server.db_pool.connection() as conn: + local, domain = self.args['recipient'].split(sep='@', maxsplit=1) + extension = None + if '+' in local: + local, extension = local.split(sep='+', maxsplit=1) + + logger.debug('Parsed recipient address: %s', {'local': local, 'extension': extension, 'domain': domain}) + + with conn.cursor() as cur: + cur.row_factory = namedtuple_row + cur.execute('SELECT id, internal FROM "mailbox_mapping" WHERE ("local" = %(local)s OR "local" IS NULL) AND ("extension" = %(extension)s OR "extension" IS NULL) AND "domain" = %(domain)s', params = {'local': local, 'extension': extension if extension is not None else '', 'domain': domain}, prepare = True) + if (row := cur.fetchone()) is not None: + if not row.internal: + logger.debug('Recipient mailbox is not internal') + allowed = True + elif user: + cur.execute('SELECT EXISTS(SELECT true FROM "mailbox_mapping_access" INNER JOIN "mailbox" ON "mailbox".id = "mailbox_mapping_access"."mailbox" WHERE mailbox_mapping = %(mailbox_mapping)s AND "mailbox"."mailbox" = %(user)s) as "exists"', params = { 'mailbox_mapping': row.id, 'user': user }, prepare = True) + if (row := cur.fetchone()) is not None: + allowed = row.exists + else: + logger.debug('Recipient is not local') + allowed = True + + action = '550 5.7.0 Recipient mailbox mapping not authorized for current user' + if allowed: + action = 'DUNNO' + + logger.info('Reached verdict: %s', {'allowed': allowed, 'action': action}) + self.wfile.write(f'action={action}\n\n'.encode()) + +class ThreadedSystemdSocketServer(ThreadingMixIn, SystemdSocketServer): + def __init__(self, fd, RequestHandlerClass): + super().__init__(fd, RequestHandlerClass) + + self.db_pool = ConnectionPool(min_size=1) + self.db_pool.wait() + +def main(): + global logger + logger = logging.getLogger(__name__) + console_handler = logging.StreamHandler() + console_handler.setFormatter( logging.Formatter('[%(levelname)s](%(name)s): %(message)s') ) + if sys.stderr.isatty(): + console_handler.setFormatter( logging.Formatter('%(asctime)s [%(levelname)s](%(name)s): %(message)s') ) + logger.addHandler(console_handler) + logger.setLevel(logging.DEBUG) + + # log uncaught exceptions + def log_exceptions(type, value, tb): + global logger + + logger.error(value) + sys.__excepthook__(type, value, tb) # calls default excepthook + + sys.excepthook = log_exceptions + + fds = listen_fds() + servers = [ThreadedSystemdSocketServer(fd, PolicyHandler) for fd in fds] + + if servers: + for server in servers: + Thread(name=f'Server for fd{server.fileno()}', target=server.serve_forever).start() + else: + return 2 + + SystemdNotifier().notify('READY=1') + + return 0 + +if __name__ == '__main__': + sys.exit(main()) -- cgit v1.2.3