#!@python@/bin/python

from os import environ
import sys

from http.server import BaseHTTPRequestHandler, HTTPServer

import subprocess
import json

from urllib.parse import urlparse

from textwrap import dedent


def _format_prom_attrs(**attrs):
    if not attrs:
        return ''

    return '{' + ','.join(map(lambda k: f'{k}="{attrs[k]}"', attrs)) + '}'

def _format_prom_metrics(metricName, metricType, metrics, metricHelp=''):
    metricStr = dedent(f'''
      # HELP {metricName} {metricHelp}
      # TYPE {metricName} {metricType}
    ''').lstrip()
    for (attrs, val) in metrics:
        attrs_str = _format_prom_attrs(**attrs)
        metricStr += dedent(f'''
            {metricName}{attrs_str} {val}
        ''').lstrip()
    return metricStr


class CAKEMetrics:
    _instance = None

    @classmethod
    def instance(cls):
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance.attrs = None
        return cls._instance

    def __init__(self):
        raise RuntimeError('Call instance() instead')

    def update(self):
        attrs = dict()

        tc_output = None
        with subprocess.Popen(['tc', '-s', '-j', 'qdisc', 'show'], stdout=subprocess.PIPE) as proc:
            tc_output = json.load(proc.stdout)

        for qdisc in tc_output:
            if 'kind' not in qdisc or qdisc['kind'] != 'cake':
                continue

            tin_names = []
            if len(qdisc['tins']) == 4:
                tin_names = ['bulk', 'best-effort', 'video', 'voice']
            tins = {}
            for tin_name, tin_data in zip(tin_names, qdisc['tins']):
                tins[tin_name] = {
                    'bytes': tin_data['sent_bytes'],
                    'packets': tin_data['sent_packets']
                }

            attrs[qdisc['dev']] = {
                'bytes': qdisc['bytes'],
                'packets': qdisc['packets'],
                'drops': qdisc['drops'],
                'overlimits': qdisc['overlimits'],
                'requeues': qdisc['requeues'],
                'tins': tins
            }

        self.attrs = attrs

    def json_text(self):
        return json.dumps(self.attrs)

    def prometheus(self):
        metrics = ''

        metrics += _format_prom_metrics('cake_bytes', 'counter', [({'dev': dev}, self.attrs[dev]['bytes']) for dev in self.attrs])
        for packet_counter in ['packets', 'overlimits', 'requeues', 'drops']:
            metrics += _format_prom_metrics(f'cake_{packet_counter}', 'counter', [({'dev': dev}, self.attrs[dev][packet_counter]) for dev in self.attrs])

        metrics += _format_prom_metrics('cake_tin_bytes', 'counter', [({'dev': dev, 'tin': tin}, self.attrs[dev]['tins'][tin]['bytes']) for dev in self.attrs for tin in self.attrs[dev]['tins']])
        metrics += _format_prom_metrics('cake_tin_packets', 'counter', [({'dev': dev, 'tin': tin}, self.attrs[dev]['tins'][tin]['packets']) for dev in self.attrs for tin in self.attrs[dev]['tins']])

        return metrics.encode('utf-8')


class CAKEMetricsServer(BaseHTTPRequestHandler):
    def log_message(self, format, *args):
        pass

    def do_GET(self):
        cake_metrics = CAKEMetrics.instance()
        cake_metrics.update()

        url = urlparse(self.path)

        match url.path:
            case '/metrics.json':
                self.send_response(200)
                self.send_header("Content-type", "application/json")
                self.end_headers()

                self.wfile.write(cake_metrics.json_text().encode('utf-8'))
            case '/metrics':
                self.send_response(200)
                self.send_header("Content-type", "text/plain")
                self.end_headers()

                self.wfile.write(cake_metrics.prometheus())
            case _:
                self.send_response(404)
                self.end_headers()


def main():
    webServer = HTTPServer((str(environ.get('CAKE_HOSTNAME')), int(environ.get('CAKE_PORT'))), CAKEMetricsServer)

    webServer.serve_forever()

if __name__ == "__main__":
    sys.exit(main())