summaryrefslogtreecommitdiff
path: root/hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py
diff options
context:
space:
mode:
Diffstat (limited to 'hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py')
-rw-r--r--hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py106
1 files changed, 106 insertions, 0 deletions
diff --git a/hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py b/hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py
new file mode 100644
index 00000000..04f1a59a
--- /dev/null
+++ b/hosts/surtr/email/internal-policy-server/internal_policy_server/__main__.py
@@ -0,0 +1,106 @@
1from systemd.daemon import listen_fds
2from sdnotify import SystemdNotifier
3from socketserver import StreamRequestHandler, ThreadingMixIn
4from systemd_socketserver import SystemdSocketServer
5import sys
6from threading import Thread
7from psycopg_pool import ConnectionPool
8from psycopg.rows import namedtuple_row
9
10import logging
11
12
13class PolicyHandler(StreamRequestHandler):
14 def handle(self):
15 logger.debug('Handling new connection...')
16
17 self.args = dict()
18
19 line = None
20 while line := self.rfile.readline().removesuffix(b'\n'):
21 if b'=' not in line:
22 break
23
24 key, val = line.split(sep=b'=', maxsplit=1)
25 self.args[key.decode()] = val.decode()
26
27 logger.info('Connection parameters: %s', self.args)
28
29 allowed = False
30 user = None
31 if self.args['sasl_username']:
32 user = self.args['sasl_username']
33 if self.args['ccert_subject']:
34 user = self.args['ccert_subject']
35
36 with self.server.db_pool.connection() as conn:
37 local, domain = self.args['recipient'].split(sep='@', maxsplit=1)
38 extension = None
39 if '+' in local:
40 local, extension = local.split(sep='+', maxsplit=1)
41
42 logger.debug('Parsed recipient address: %s', {'local': local, 'extension': extension, 'domain': domain})
43
44 with conn.cursor() as cur:
45 cur.row_factory = namedtuple_row
46 cur.execute('SELECT id, internal FROM "mailbox_mapping" WHERE ("local" = %(local)s OR "local" IS NULL) AND ("extension" = %(extension)s OR "extension" IS NULL) AND "domain" = %(domain)s', params = {'local': local, 'extension': extension if extension is not None else '', 'domain': domain}, prepare = True)
47 if (row := cur.fetchone()) is not None:
48 if not row.internal:
49 logger.debug('Recipient mailbox is not internal')
50 allowed = True
51 elif user:
52 cur.execute('SELECT EXISTS(SELECT true FROM "mailbox_mapping_access" INNER JOIN "mailbox" ON "mailbox".id = "mailbox_mapping_access"."mailbox" WHERE mailbox_mapping = %(mailbox_mapping)s AND "mailbox"."mailbox" = %(user)s) as "exists"', params = { 'mailbox_mapping': row.id, 'user': user }, prepare = True)
53 if (row := cur.fetchone()) is not None:
54 allowed = row.exists
55 else:
56 logger.debug('Recipient is not local')
57 allowed = True
58
59 action = '550 5.7.0 Recipient mailbox mapping not authorized for current user'
60 if allowed:
61 action = 'DUNNO'
62
63 logger.info('Reached verdict: %s', {'allowed': allowed, 'action': action})
64 self.wfile.write(f'action={action}\n\n'.encode())
65
66class ThreadedSystemdSocketServer(ThreadingMixIn, SystemdSocketServer):
67 def __init__(self, fd, RequestHandlerClass):
68 super().__init__(fd, RequestHandlerClass)
69
70 self.db_pool = ConnectionPool(min_size=1)
71 self.db_pool.wait()
72
73def main():
74 global logger
75 logger = logging.getLogger(__name__)
76 console_handler = logging.StreamHandler()
77 console_handler.setFormatter( logging.Formatter('[%(levelname)s](%(name)s): %(message)s') )
78 if sys.stderr.isatty():
79 console_handler.setFormatter( logging.Formatter('%(asctime)s [%(levelname)s](%(name)s): %(message)s') )
80 logger.addHandler(console_handler)
81 logger.setLevel(logging.DEBUG)
82
83 # log uncaught exceptions
84 def log_exceptions(type, value, tb):
85 global logger
86
87 logger.error(value)
88 sys.__excepthook__(type, value, tb) # calls default excepthook
89
90 sys.excepthook = log_exceptions
91
92 fds = listen_fds()
93 servers = [ThreadedSystemdSocketServer(fd, PolicyHandler) for fd in fds]
94
95 if servers:
96 for server in servers:
97 Thread(name=f'Server for fd{server.fileno()}', target=server.serve_forever).start()
98 else:
99 return 2
100
101 SystemdNotifier().notify('READY=1')
102
103 return 0
104
105if __name__ == '__main__':
106 sys.exit(main())