summaryrefslogtreecommitdiff
path: root/hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py
blob: 04f1a59aa4635b4ac047f57f7c66cbf963f54561 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from systemd.daemon import listen_fds
from sdnotify import SystemdNotifier
from socketserver import StreamRequestHandler, ThreadingMixIn
from systemd_socketserver import SystemdSocketServer
import sys
from threading import Thread
from psycopg_pool import ConnectionPool
from psycopg.rows import namedtuple_row

import logging


class PolicyHandler(StreamRequestHandler):
    def handle(self):
        logger.debug('Handling new connection...')

        self.args = dict()

        line = None
        while line := self.rfile.readline().removesuffix(b'\n'):
            if b'=' not in line:
                break

            key, val = line.split(sep=b'=', maxsplit=1)
            self.args[key.decode()] = val.decode()

        logger.info('Connection parameters: %s', self.args)

        allowed = False
        user = None
        if self.args['sasl_username']:
            user = self.args['sasl_username']
        if self.args['ccert_subject']:
            user = self.args['ccert_subject']

        with self.server.db_pool.connection() as conn:
            local, domain = self.args['recipient'].split(sep='@', maxsplit=1)
            extension = None
            if '+' in local:
                local, extension = local.split(sep='+', maxsplit=1)

            logger.debug('Parsed recipient address: %s', {'local': local, 'extension': extension, 'domain': domain})

            with conn.cursor() as cur:
                cur.row_factory = namedtuple_row
                cur.execute('SELECT id, internal FROM "mailbox_mapping" WHERE ("local" = %(local)s OR "local" IS NULL) AND ("extension" = %(extension)s OR "extension" IS NULL) AND "domain" = %(domain)s', params = {'local': local, 'extension': extension if extension is not None else '', 'domain': domain}, prepare = True)
                if (row := cur.fetchone()) is not None:
                    if not row.internal:
                        logger.debug('Recipient mailbox is not internal')
                        allowed = True
                    elif user:
                        cur.execute('SELECT EXISTS(SELECT true FROM "mailbox_mapping_access" INNER JOIN "mailbox" ON "mailbox".id = "mailbox_mapping_access"."mailbox" WHERE mailbox_mapping = %(mailbox_mapping)s AND "mailbox"."mailbox" = %(user)s) as "exists"', params = { 'mailbox_mapping': row.id, 'user': user }, prepare = True)
                        if (row := cur.fetchone()) is not None:
                            allowed = row.exists
                else:
                    logger.debug('Recipient is not local')
                    allowed = True

        action = '550 5.7.0 Recipient mailbox mapping not authorized for current user'
        if allowed:
            action = 'DUNNO'

        logger.info('Reached verdict: %s', {'allowed': allowed, 'action': action})
        self.wfile.write(f'action={action}\n\n'.encode())

class ThreadedSystemdSocketServer(ThreadingMixIn, SystemdSocketServer):
    def __init__(self, fd, RequestHandlerClass):
        super().__init__(fd, RequestHandlerClass)

        self.db_pool = ConnectionPool(min_size=1)
        self.db_pool.wait()

def main():
    global logger
    logger = logging.getLogger(__name__)
    console_handler = logging.StreamHandler()
    console_handler.setFormatter( logging.Formatter('[%(levelname)s](%(name)s): %(message)s') )
    if sys.stderr.isatty():
        console_handler.setFormatter( logging.Formatter('%(asctime)s [%(levelname)s](%(name)s): %(message)s') )
    logger.addHandler(console_handler)
    logger.setLevel(logging.DEBUG)

    # log uncaught exceptions
    def log_exceptions(type, value, tb):
        global logger

        logger.error(value)
        sys.__excepthook__(type, value, tb) # calls default excepthook

    sys.excepthook = log_exceptions

    fds = listen_fds()
    servers = [ThreadedSystemdSocketServer(fd, PolicyHandler) for fd in fds]

    if servers:
        for server in servers:
            Thread(name=f'Server for fd{server.fileno()}', target=server.serve_forever).start()
    else:
        return 2

    SystemdNotifier().notify('READY=1')

    return 0

if __name__ == '__main__':
    sys.exit(main())