summaryrefslogtreecommitdiff
path: root/overlays/nftables-prometheus-exporter/nftables-prometheus-exporter.py
blob: 484228c8af93a073761704c6b33e8675c5de9464 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!@python@/bin/python

import json

from os import environ
import sys

from http.server import BaseHTTPRequestHandler, HTTPServer

from urllib.parse import urlparse

from textwrap import dedent

import subprocess


def _format_prom_attrs(**attrs):
    if not attrs:
        return ''

    return '{' + ','.join(map(lambda k: f'{k}="{attrs[k]}"', attrs)) + '}'

def _format_prom_metrics(metricName, metricType, metrics, metricHelp=''):
    metricStr = dedent(f'''
      # HELP {metricName} {metricHelp}
      # TYPE {metricName} {metricType}
    ''').lstrip()
    for (attrs, val) in metrics:
        attrs_str = _format_prom_attrs(**attrs)
        metricStr += dedent(f'''
            {metricName}{attrs_str} {val}
        ''').lstrip()
    return metricStr


class NFTMetrics:
    _instance = None

    @classmethod
    def instance(cls):
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance.attrs = None
        return cls._instance
    

    def __init__(self):
        raise RuntimeError('Call instance() instead')

    def update(self):
        attrs = dict()
        queries = dict()

        for query_name in ['ruleset', 'counters', 'maps', 'meters', 'sets']:
            process = subprocess.run(
                ('nft', '--json', 'list', query_name),
                stdout = subprocess.PIPE, check = True, text = True
            )
            data = json.loads(process.stdout)
            version = data['nftables'][0]['metainfo']['json_schema_version']
            if version != 1:
                raise RuntimeError(f'nftables json schema v{version} is not supported')
            queries[query_name] = data['nftables'][1:]

        
        def extract_query(query_name, type_name):
            return [
                item[type_name]
                for item in queries[query_name]
                if type_name in item
            ]

        attrs['rules_count'] = len(extract_query('ruleset', 'rule'))
        attrs['chain_count'] = len(extract_query('ruleset', 'chain'))
        attrs['counters'] = extract_query('counters', 'counter')
        attrs['maps'] = extract_query('maps', 'map')
        attrs['meters'] = extract_query('meters', 'meter')
        attrs['sets'] = extract_query('sets', 'set')

        self.attrs = attrs

    def json_text(self):
        return json.dumps(self.attrs)

    def prometheus(self):
        metrics = ''

        metrics += _format_prom_metrics('nftables_rules_count', 'gauge', [({}, self.attrs['rules_count'])], 'Number of nftables rules')
        metrics += _format_prom_metrics('nftables_chains_count', 'gauge', [({}, self.attrs['chain_count'])], 'Number of nftables chains')

        counter_bytes = []
        counter_packets = []
        for counter in self.attrs['counters']:
            labels = { k: v for k, v in counter.items() if k not in set(['bytes', 'packets']) }
            counter_bytes += [(labels, counter['bytes'])]
            counter_packets += [(labels, counter['packets'])]
        metrics += _format_prom_metrics('nftables_counter_bytes', 'counter', counter_bytes)
        metrics += _format_prom_metrics('nftables_counter_packets_count', 'counter', counter_packets)

        map_counts = []
        for meter in self.attrs['maps']:
            labels = {  k: v for k, v in counter.items() if k not in set(['elem']) }
            map_counts += [(labels, len(meter['elem']))]
        metrics += _format_prom_metrics('nftables_map_elem_count', 'gauge', map_counts)

        meter_counts = []
        for meter in self.attrs['meters']:
            labels = {  k: v for k, v in counter.items() if k not in set(['elem']) }
            meter_counts += [(labels, len(meter['elem']))]
        metrics += _format_prom_metrics('nftables_meter_elem_count', 'gauge', meter_counts)
            
        set_counts = []
        for meter in self.attrs['sets']:
            labels = {  k: v for k, v in counter.items() if k not in set(['elem']) }
            set_counts += [(labels, len(meter['elem']))]
        metrics += _format_prom_metrics('nftables_set_elem_count', 'gauge', set_counts)

        return metrics.encode('utf-8')

class NFTMetricsServer(BaseHTTPRequestHandler):
    def log_message(self, format, *args):
        pass
    
    def do_GET(self):
        nft_metrics = NFTMetrics.instance()
        nft_metrics.update()

        url = urlparse(self.path)

        match url.path:
            case '/metrics.json':
                self.send_response(200)
                self.send_header("Content-type", "application/json")
                self.end_headers()

                self.wfile.write(nft_metrics.json_text().encode('utf-8'))
            case '/metrics':
                self.send_response(200)
                self.send_header("Content-type", "text/plain")
                self.end_headers()
                
                self.wfile.write(nft_metrics.prometheus())
            case _:
                self.send_response(404)
                self.end_headers()


def main():
    webServer = HTTPServer((str(environ.get('NFT_HOSTNAME')), int(environ.get('NFT_PORT'))), NFTMetricsServer)

    webServer.serve_forever()

if __name__ == "__main__":
    sys.exit(main())