diff options
Diffstat (limited to 'hosts/surtr/email/internal-policy-server/internal_policy_server')
-rw-r--r-- | hosts/surtr/email/internal-policy-server/internal_policy_server/__init__.py | 0 | ||||
-rw-r--r-- | hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py | 106 |
2 files changed, 106 insertions, 0 deletions
diff --git a/hosts/surtr/email/internal-policy-server/internal_policy_server/__init__.py b/hosts/surtr/email/internal-policy-server/internal_policy_server/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/hosts/surtr/email/internal-policy-server/internal_policy_server/__init__.py | |||
diff --git a/hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py b/hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py new file mode 100644 index 00000000..04f1a59a --- /dev/null +++ b/hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py | |||
@@ -0,0 +1,106 @@ | |||
1 | from systemd.daemon import listen_fds | ||
2 | from sdnotify import SystemdNotifier | ||
3 | from socketserver import StreamRequestHandler, ThreadingMixIn | ||
4 | from systemd_socketserver import SystemdSocketServer | ||
5 | import sys | ||
6 | from threading import Thread | ||
7 | from psycopg_pool import ConnectionPool | ||
8 | from psycopg.rows import namedtuple_row | ||
9 | |||
10 | import logging | ||
11 | |||
12 | |||
13 | class PolicyHandler(StreamRequestHandler): | ||
14 | def handle(self): | ||
15 | logger.debug('Handling new connection...') | ||
16 | |||
17 | self.args = dict() | ||
18 | |||
19 | line = None | ||
20 | while line := self.rfile.readline().removesuffix(b'\n'): | ||
21 | if b'=' not in line: | ||
22 | break | ||
23 | |||
24 | key, val = line.split(sep=b'=', maxsplit=1) | ||
25 | self.args[key.decode()] = val.decode() | ||
26 | |||
27 | logger.info('Connection parameters: %s', self.args) | ||
28 | |||
29 | allowed = False | ||
30 | user = None | ||
31 | if self.args['sasl_username']: | ||
32 | user = self.args['sasl_username'] | ||
33 | if self.args['ccert_subject']: | ||
34 | user = self.args['ccert_subject'] | ||
35 | |||
36 | with self.server.db_pool.connection() as conn: | ||
37 | local, domain = self.args['recipient'].split(sep='@', maxsplit=1) | ||
38 | extension = None | ||
39 | if '+' in local: | ||
40 | local, extension = local.split(sep='+', maxsplit=1) | ||
41 | |||
42 | logger.debug('Parsed recipient address: %s', {'local': local, 'extension': extension, 'domain': domain}) | ||
43 | |||
44 | with conn.cursor() as cur: | ||
45 | cur.row_factory = namedtuple_row | ||
46 | cur.execute('SELECT id, internal FROM "mailbox_mapping" WHERE ("local" = %(local)s OR "local" IS NULL) AND ("extension" = %(extension)s OR "extension" IS NULL) AND "domain" = %(domain)s', params = {'local': local, 'extension': extension if extension is not None else '', 'domain': domain}, prepare = True) | ||
47 | if (row := cur.fetchone()) is not None: | ||
48 | if not row.internal: | ||
49 | logger.debug('Recipient mailbox is not internal') | ||
50 | allowed = True | ||
51 | elif user: | ||
52 | cur.execute('SELECT EXISTS(SELECT true FROM "mailbox_mapping_access" INNER JOIN "mailbox" ON "mailbox".id = "mailbox_mapping_access"."mailbox" WHERE mailbox_mapping = %(mailbox_mapping)s AND "mailbox"."mailbox" = %(user)s) as "exists"', params = { 'mailbox_mapping': row.id, 'user': user }, prepare = True) | ||
53 | if (row := cur.fetchone()) is not None: | ||
54 | allowed = row.exists | ||
55 | else: | ||
56 | logger.debug('Recipient is not local') | ||
57 | allowed = True | ||
58 | |||
59 | action = '550 5.7.0 Recipient mailbox mapping not authorized for current user' | ||
60 | if allowed: | ||
61 | action = 'DUNNO' | ||
62 | |||
63 | logger.info('Reached verdict: %s', {'allowed': allowed, 'action': action}) | ||
64 | self.wfile.write(f'action={action}\n\n'.encode()) | ||
65 | |||
66 | class ThreadedSystemdSocketServer(ThreadingMixIn, SystemdSocketServer): | ||
67 | def __init__(self, fd, RequestHandlerClass): | ||
68 | super().__init__(fd, RequestHandlerClass) | ||
69 | |||
70 | self.db_pool = ConnectionPool(min_size=1) | ||
71 | self.db_pool.wait() | ||
72 | |||
73 | def main(): | ||
74 | global logger | ||
75 | logger = logging.getLogger(__name__) | ||
76 | console_handler = logging.StreamHandler() | ||
77 | console_handler.setFormatter( logging.Formatter('[%(levelname)s](%(name)s): %(message)s') ) | ||
78 | if sys.stderr.isatty(): | ||
79 | console_handler.setFormatter( logging.Formatter('%(asctime)s [%(levelname)s](%(name)s): %(message)s') ) | ||
80 | logger.addHandler(console_handler) | ||
81 | logger.setLevel(logging.DEBUG) | ||
82 | |||
83 | # log uncaught exceptions | ||
84 | def log_exceptions(type, value, tb): | ||
85 | global logger | ||
86 | |||
87 | logger.error(value) | ||
88 | sys.__excepthook__(type, value, tb) # calls default excepthook | ||
89 | |||
90 | sys.excepthook = log_exceptions | ||
91 | |||
92 | fds = listen_fds() | ||
93 | servers = [ThreadedSystemdSocketServer(fd, PolicyHandler) for fd in fds] | ||
94 | |||
95 | if servers: | ||
96 | for server in servers: | ||
97 | Thread(name=f'Server for fd{server.fileno()}', target=server.serve_forever).start() | ||
98 | else: | ||
99 | return 2 | ||
100 | |||
101 | SystemdNotifier().notify('READY=1') | ||
102 | |||
103 | return 0 | ||
104 | |||
105 | if __name__ == '__main__': | ||
106 | sys.exit(main()) | ||