#!@python@/bin/python

import json

from os import environ
import sys

from http.server import BaseHTTPRequestHandler, HTTPServer

from urllib.parse import urlparse

from textwrap import dedent

import subprocess


def _format_prom_attrs(**attrs):
    if not attrs:
        return ''

    return '{' + ','.join(map(lambda k: f'{k}="{attrs[k]}"', attrs)) + '}'

def _format_prom_metrics(metricName, metricType, metrics, metricHelp=''):
    metricStr = dedent(f'''
      # HELP {metricName} {metricHelp}
      # TYPE {metricName} {metricType}
    ''').lstrip()
    for (attrs, val) in metrics:
        attrs_str = _format_prom_attrs(**attrs)
        metricStr += dedent(f'''
            {metricName}{attrs_str} {val}
        ''').lstrip()
    return metricStr


class NFTMetrics:
    _instance = None

    @classmethod
    def instance(cls):
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance.attrs = None
        return cls._instance
    

    def __init__(self):
        raise RuntimeError('Call instance() instead')

    def update(self):
        attrs = dict()
        queries = dict()

        for query_name in ['ruleset', 'counters', 'maps', 'meters', 'sets']:
            process = subprocess.run(
                ('nft', '--json', 'list', query_name),
                stdout = subprocess.PIPE, check = True, text = True
            )
            data = json.loads(process.stdout)
            version = data['nftables'][0]['metainfo']['json_schema_version']
            if version != 1:
                raise RuntimeError(f'nftables json schema v{version} is not supported')
            queries[query_name] = data['nftables'][1:]

        
        def extract_query(query_name, type_name):
            return [
                item[type_name]
                for item in queries[query_name]
                if type_name in item
            ]

        attrs['rules_count'] = len(extract_query('ruleset', 'rule'))
        attrs['chain_count'] = len(extract_query('ruleset', 'chain'))
        attrs['counters'] = extract_query('counters', 'counter')
        attrs['maps'] = extract_query('maps', 'map')
        attrs['meters'] = extract_query('meters', 'meter')
        attrs['sets'] = extract_query('sets', 'set')

        self.attrs = attrs

    def json_text(self):
        return json.dumps(self.attrs)

    def prometheus(self):
        metrics = ''

        metrics += _format_prom_metrics('nftables_rules_count', 'gauge', [({}, self.attrs['rules_count'])], 'Number of nftables rules')
        metrics += _format_prom_metrics('nftables_chains_count', 'gauge', [({}, self.attrs['chain_count'])], 'Number of nftables chains')

        counter_bytes = []
        counter_packets = []
        for counter in self.attrs['counters']:
            labels = { k: v for k, v in counter.items() if k not in set(['bytes', 'packets']) }
            counter_bytes += [(labels, counter['bytes'])]
            counter_packets += [(labels, counter['packets'])]
        metrics += _format_prom_metrics('nftables_counter_bytes', 'counter', counter_bytes)
        metrics += _format_prom_metrics('nftables_counter_packets_count', 'counter', counter_packets)

        map_counts = []
        for meter in self.attrs['maps']:
            labels = {  k: v for k, v in counter.items() if k not in set(['elem']) }
            map_counts += [(labels, len(meter['elem']))]
        metrics += _format_prom_metrics('nftables_map_elem_count', 'gauge', map_counts)

        meter_counts = []
        for meter in self.attrs['meters']:
            labels = {  k: v for k, v in counter.items() if k not in set(['elem']) }
            meter_counts += [(labels, len(meter['elem']))]
        metrics += _format_prom_metrics('nftables_meter_elem_count', 'gauge', meter_counts)
            
        set_counts = []
        for meter in self.attrs['sets']:
            labels = {  k: v for k, v in counter.items() if k not in set(['elem']) }
            set_counts += [(labels, len(meter['elem']))]
        metrics += _format_prom_metrics('nftables_set_elem_count', 'gauge', set_counts)

        return metrics.encode('utf-8')

class NFTMetricsServer(BaseHTTPRequestHandler):
    def log_message(self, format, *args):
        pass
    
    def do_GET(self):
        nft_metrics = NFTMetrics.instance()
        nft_metrics.update()

        url = urlparse(self.path)

        match url.path:
            case '/metrics.json':
                self.send_response(200)
                self.send_header("Content-type", "application/json")
                self.end_headers()

                self.wfile.write(nft_metrics.json_text().encode('utf-8'))
            case '/metrics':
                self.send_response(200)
                self.send_header("Content-type", "text/plain")
                self.end_headers()
                
                self.wfile.write(nft_metrics.prometheus())
            case _:
                self.send_response(404)
                self.end_headers()


def main():
    webServer = HTTPServer((str(environ.get('NFT_HOSTNAME')), int(environ.get('NFT_PORT'))), NFTMetricsServer)

    webServer.serve_forever()

if __name__ == "__main__":
    sys.exit(main())