#!@python@/bin/python

import json
import os
import subprocess
import re
import sys
import io
from sys import stderr
from humanize import naturalsize

from tempfile import TemporaryDirectory

from datetime import (datetime, timedelta)
from dateutil.tz import (tzlocal, tzutc)
import dateutil.parser
import argparse

from tqdm import tqdm

from xdg import xdg_runtime_dir
import pathlib

import unshare
from pyprctl import CapState, Cap, cap_ambient_raise, cap_ambient_is_set, set_keepcaps
from pwd import getpwnam

import logging

import signal
import time
import math

from halo import Halo

from collections import deque

import select
import fcntl

from multiprocessing import Process, Manager
from contextlib import closing


halo_args = {
    'stream': stderr,
    'enabled': stderr.isatty(),
    'spinner': 'arc'
}

borg_pwd = getpwnam('borg')

def as_borg(caps=set()):
    global logger

    try:
        if caps:
            c_state = CapState.get_current()
            c_state.permitted.add(*caps)
            c_state.set_current()

            # logger.debug("before setgid/setuid: cap_permitted=%s", CapState.get_current().permitted)

            set_keepcaps(True)

        os.setgid(borg_pwd.pw_gid)
        os.setuid(borg_pwd.pw_uid)

        if caps:
            # logger.debug("after setgid/setuid: cap_permitted=%s", CapState.get_current().permitted)

            c_state = CapState.get_current()
            c_state.permitted = caps.copy()
            c_state.inheritable.add(*caps)
            c_state.set_current()

            # logger.debug("cap_permitted=%s", CapState.get_current().permitted)
            # logger.debug("cap_inheritable=%s", CapState.get_current().inheritable)

            for cap in caps:
                cap_ambient_raise(cap)
                # logger.debug("cap_ambient[%s]=%s", cap, cap_ambient_is_set(cap))
    except Exception:
        logger.error(format_exc())
        raise

def borg_json(*args, **kwargs):
    global logger

    with subprocess.Popen(*args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, **kwargs) as proc:
        stdout_buffer = io.BytesIO()

        proc_logger = logger.getChild('borg')
        stdout_logger = proc_logger.getChild('stdout')
        stderr_logger = proc_logger.getChild('stderr')

        fcntl.fcntl(proc.stdout.fileno(), fcntl.F_SETFL, fcntl.fcntl(proc.stdout.fileno(), fcntl.F_GETFL) | os.O_NONBLOCK)
        fcntl.fcntl(proc.stderr.fileno(), fcntl.F_SETFL, fcntl.fcntl(proc.stderr.fileno(), fcntl.F_GETFL) | os.O_NONBLOCK)

        poll = select.poll()
        poll.register(proc.stdout, select.POLLIN | select.POLLHUP)
        poll.register(proc.stderr, select.POLLIN | select.POLLHUP)
        pollc = 2
        events = poll.poll()
        stderr_linebuf = bytearray()

        while pollc > 0 and len(events) > 0:
            for rfd, event in events:
                if event & select.POLLIN:
                    if rfd == proc.stdout.fileno():
                        try:
                            buf = os.read(proc.stdout.fileno(), 8192)
                            # stdout_logger.debug(buf)
                            stdout_buffer.write(buf)
                        except BlockingIOError:
                            pass
                    if rfd == proc.stderr.fileno():
                        try:
                            stderr_linebuf.extend(os.read(proc.stderr.fileno(), 8192))
                        except BlockingIOError:
                            pass

                        while stderr_linebuf:
                            line, sep, stderr_linebuf = stderr_linebuf.partition(b'\n')
                            if not sep:
                                stderr_linebuf = line
                                break

                            stderr_logger.info(line.decode())
                if event == select.POLLHUP:
                    poll.unregister(rfd)
                    pollc -= 1

            if pollc > 0:
                events = poll.poll()

        for handler in proc_logger.handlers:
            handler.flush()

        ret = proc.wait()
        if ret != 0:
            raise Exception(f'borg subprocess exited with returncode {ret}')

        stdout_buffer.seek(0)
        return json.load(stdout_buffer)

def read_repo(path):
    global logger

    with Halo(text=f'Listing {path}', **halo_args) as sp:
        if not sp.enabled:
            logger.debug('Listing %s...', path)
        res = borg_json(['borg', 'list', '--info', '--lock-wait=600', '--json', path], preexec_fn=lambda: as_borg())['archives']
        if sp.enabled:
            sp.succeed(f'{len(res)} archives in {path}')
        else:
            logger.info('%d archives in ‘%s’', len(res), path)
    return res

class ToSync:
    to_sync = deque()

    def __init__(self, source, target):
        self.source = source
        self.target = target

    def __iter__(self):
        return self

    def __next__(self):
        global logger

        if self.to_sync:
            return self.to_sync.popleft()

        while True:
            try:
                src = read_repo(self.source)
                dst = read_repo(self.target)
            except (subprocess.CalledProcessError, json.decoder.JSONDecodeError) as err:
                logger.error(err)
                continue

            self.to_sync.extend([entry for entry in src if entry['name'] not in {dst_entry['name'] for dst_entry in dst} and not entry['name'].endswith('.checkpoint')])

            if self.to_sync:
                return self.to_sync.popleft()

            raise StopIteration

def copy_archive(src_repo_path, dst_repo_path, entry):
    global logger

    def do_copy(tmpdir_q):
        global logger

        nonlocal src_repo_path, dst_repo_path, entry

        tmpdir = tmpdir_q.get()

        cache_suffix = None
        with Halo(text=f'Determine archive parameters', **halo_args) as sp:
            if not sp.enabled:
                logger.debug('Determining archive parameters...')
            match = re.compile('^(.*)-[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.(checkpoint|recreate)(\.[0-9]+)?)?').fullmatch(entry['name'])
            if match:
                repo_id = borg_json(['borg', 'info', '--info', '--lock-wait=600', '--json', src_repo_path], preexec_fn=lambda: as_borg())['repository']['id']

                if repo_id:
                    cache_suffix = f'{repo_id}_{match.group(1)}'
            if sp.enabled:
                sp.succeed(f'Will process {entry["name"]} ({dateutil.parser.isoparse(entry["start"])}, cache_suffix={cache_suffix})')
            else:
                logger.info('Will process ‘%s’ (%s, cache_suffix=%s)', entry['name'], dateutil.parser.isoparse(entry['start']), cache_suffix)

        logger.debug('Setting up environment...')
        unshare.unshare(unshare.CLONE_NEWNS)
        subprocess.run(['mount', '--make-rprivate', '/'], check=True)
        chroot = pathlib.Path(tmpdir) / 'chroot'
        upper = pathlib.Path(tmpdir) / 'upper'
        work = pathlib.Path(tmpdir) / 'work'
        for path in [chroot,upper,work]:
            path.mkdir()
        subprocess.run(['mount', '-t', 'overlay', 'overlay', '-o', f'lowerdir=/,upperdir={upper},workdir={work}', chroot], check=True)
        bindMounts = ['nix', 'run', 'run/secrets.d', 'run/wrappers', 'proc', 'dev', 'sys', pathlib.Path(os.path.expanduser('~')).relative_to('/')]
        if os.environ.get('BORG_BASE_DIR'):
            bindMounts.append(pathlib.Path(os.environ['BORG_BASE_DIR']).relative_to('/'))
        if not ":" in src_repo_path:
            bindMounts.append(pathlib.Path(src_repo_path).relative_to('/'))
        if 'SSH_AUTH_SOCK' in os.environ:
            bindMounts.append(pathlib.Path(os.environ['SSH_AUTH_SOCK']).parent.relative_to('/'))
        for bindMount in bindMounts:
            (chroot / bindMount).mkdir(parents=True,exist_ok=True)
            subprocess.run(['mount', '--bind', pathlib.Path('/') / bindMount, chroot / bindMount], check=True)
        os.chroot(chroot)
        os.chdir('/')
        try:
            os.unlink('/etc/fuse.conf')
        except FileNotFoundError:
            pass
        pathlib.Path('/etc/fuse.conf').parent.mkdir(parents=True,exist_ok=True)
        with open('/etc/fuse.conf', 'w') as fuse_conf:
            fuse_conf.write('user_allow_other\nmount_max = 1000\n')
        dir = pathlib.Path('/borg')
        dir.mkdir(parents=True,exist_ok=True,mode=0o0750)
        os.chown(dir, borg_pwd.pw_uid, borg_pwd.pw_gid)

        total_size = None
        total_files = None
        if stderr.isatty():
            with Halo(text=f'Determine size', **halo_args) as sp:
                stats = borg_json(['borg', 'info', '--info', '--json', '--lock-wait=600', f'{src_repo_path}::{entry["name"]}'], preexec_fn=lambda: as_borg())['archives'][0]['stats']
                total_size = stats['original_size']
                total_files = stats['nfiles']
                if sp.enabled:
                    sp.succeed(f'{total_files} files, {naturalsize(total_size, binary=True)}')
                else:
                    logger.info('%d files, %s', total_files, naturalsize(total_size, binary=True))
        with subprocess.Popen(['borg', 'mount', '-o', 'allow_other,ignore_permissions', '--foreground', '--progress', '--lock-wait=600', f'{src_repo_path}::{entry["name"]}', dir], preexec_fn=lambda: as_borg()) as mount_proc:
            with Halo(text='Waiting for mount', **halo_args) as sp:
                if not sp.enabled:
                    logger.debug('Waiting for mount...')
                wait_start = datetime.now()
                while True:
                    if os.path.ismount(dir):
                        break
                    elif datetime.now() - wait_start > timedelta(minutes=15):
                        ret.check_returncode()
                    time.sleep(0.1)
                if sp.enabled:
                    sp.succeed('Mounted')
                else:
                    logger.info('Mounted %s', f'{src_repo_path}::{entry["name"]}')

            while True:
                with tqdm(total=total_size, unit_scale=True, unit_divisor=1024, unit='B', smoothing=0.01, disable=None, dynamic_ncols=True, maxinterval=0.5, miniters=1) as progress:
                    seen = 0
                    env = os.environ.copy()
                    create_args = ['borg',
                                   'create',
                                   '--lock-wait=600',
                                   '--one-file-system',
                                   '--compression=auto,zstd,10',
                                   '--chunker-params=10,23,16,4095',
                                   '--files-cache=ctime,size',
                                   '--show-rc',
                                   '--upload-buffer=100',
                                   '--log-json',
                                   '--progress',
                                   '--list',
                                   '--filter=AMEi-x?',
                                   '--stats'
                                   ]
                    archive_time = datetime.strptime(entry["time"], "%Y-%m-%dT%H:%M:%S.%f").replace(tzinfo=tzlocal()).astimezone(tzutc())
                    create_args += [f'--timestamp={archive_time.strftime("%Y-%m-%dT%H:%M:%S")}']
                    if cache_suffix:
                        env['BORG_FILES_CACHE_SUFFIX'] = cache_suffix
                    else:
                        create_args += ['--files-cache=disabled']
                    create_args += [f'{dst_repo_path}::{entry["name"]}', '.']

                    with subprocess.Popen(create_args, stdin=subprocess.DEVNULL, stderr=subprocess.PIPE, stdout=subprocess.PIPE, env=env, preexec_fn=lambda: as_borg(caps={Cap.DAC_READ_SEARCH}), cwd=dir) as proc:
                        last_list = None
                        last_list_time = time.monotonic_ns()
                        logger.info('Creating...')

                        proc_logger = logger.getChild('borg')
                        stdout_logger = proc_logger.getChild('stdout')
                        stderr_logger = proc_logger.getChild('stderr')

                        fcntl.fcntl(proc.stdout.fileno(), fcntl.F_SETFL, fcntl.fcntl(proc.stdout.fileno(), fcntl.F_GETFL) | os.O_NONBLOCK)
                        fcntl.fcntl(proc.stderr.fileno(), fcntl.F_SETFL, fcntl.fcntl(proc.stderr.fileno(), fcntl.F_GETFL) | os.O_NONBLOCK)

                        poll = select.poll()
                        poll.register(proc.stdout, select.POLLIN | select.POLLHUP)
                        poll.register(proc.stderr, select.POLLIN | select.POLLHUP)
                        pollc = 2
                        events = poll.poll()
                        stdout_linebuf = bytearray()
                        stderr_linebuf = bytearray()

                        while pollc > 0 and len(events) > 0:
                            # logger.debug('%d events', len(events))
                            for rfd, event in events:
                                # logger.debug('event %s', event)
                                if event & select.POLLIN:
                                    if rfd == proc.stdout.fileno():
                                        try:
                                            # logger.debug('reading stdout...')
                                            stdout_linebuf.extend(os.read(proc.stdout.fileno(), 8192))
                                            # logger.debug('read stdout, len(stdout_linebuf)=%d', len(stdout_linebuf))
                                        except BlockingIOError:
                                            pass

                                        while stdout_linebuf:
                                            # logger.debug('stdout line...')
                                            line, sep, stdout_linebuf = stdout_linebuf.partition(b'\n')
                                            if not sep:
                                                stdout_linebuf = line
                                                break

                                            stdout_logger.info(line.decode())
                                        # logger.debug('handled stdout lines, %d leftover', len(stdout_linebuf))
                                    if rfd == proc.stderr.fileno():
                                        try:
                                            # logger.debug('reading stderr...')
                                            stderr_linebuf.extend(os.read(proc.stderr.fileno(), 8192))
                                            # logger.debug('read stderr, len(stderr_linebuf)=%d', len(stderr_linebuf))
                                        except BlockingIOError:
                                            pass

                                        while stderr_linebuf:
                                            # logger.debug('stderr line...')
                                            line, sep, stderr_linebuf = stderr_linebuf.partition(b'\n')
                                            if not sep:
                                                stderr_linebuf = line
                                                break

                                            try:
                                                json_line = json.loads(line)
                                            except json.decoder.JSONDecodeError:
                                                if progress.disable:
                                                    stderr_logger.error(line.decode())
                                                else:
                                                    tqdm.write(line.decode())
                                                continue

                                            # logger.debug('stderr line decoded: %s', json_line['type'] if 'type' in json_line else None)

                                            t = ''
                                            if 'time' in json_line and not progress.disable:
                                                ts = datetime.fromtimestamp(json_line['time']).replace(tzinfo=tzlocal())
                                                t = f'{ts.isoformat(timespec="minutes")} '
                                            if json_line['type'] == 'archive_progress' and not progress.disable:
                                                now = time.monotonic_ns()
                                                if last_list_time is None or now - last_list_time >= 3e9:
                                                    last_list_time = now
                                                    if 'path' in json_line and json_line['path']:
                                                        progress.set_description(f'… {json_line["path"]}', refresh=False)
                                                    else:
                                                        progress.set_description(None, refresh=False)
                                                elif last_list is not None:
                                                    progress.set_description(last_list, refresh=False)
                                                nfiles=json_line["nfiles"]
                                                if total_files is not None:
                                                    nfiles=f'{json_line["nfiles"]}/{total_files}'
                                                progress.set_postfix(compressed=naturalsize(json_line['compressed_size'], binary=True), deduplicated=naturalsize(json_line['deduplicated_size'], binary=True), nfiles=nfiles, refresh=False)
                                                progress.update(json_line["original_size"] - seen)
                                                seen = json_line["original_size"]
                                            elif json_line['type'] == 'archive_progress':
                                                now = time.monotonic_ns()
                                                if last_list_time is None or now - last_list_time >= 3e9:
                                                    last_list_time = now
                                                    if 'path' in json_line and json_line['path']:
                                                        stderr_logger.debug('… %s (%s)', json_line["path"], naturalsize(json_line["original_size"]))
                                                    else:
                                                        stderr_logger.debug('… (%s)', naturalsize(json_line["original_size"]))
                                            elif json_line['type'] == 'file_status':
                                                # tqdm.write(t + f'{json_line["status"]} {json_line["path"]}')
                                                last_list = f'{json_line["status"]} {json_line["path"]}'
                                                last_list_time = time.monotonic_ns()
                                                progress.set_description(last_list, refresh=False)
                                                if progress.disable:
                                                    stderr_logger.info(last_list)
                                            elif (json_line['type'] == 'log_message' or json_line['type'] == 'progress_message' or json_line['type'] == 'progress_percent') and ('message' in json_line or 'msgid' in json_line):
                                                if 'message' in json_line:
                                                    if progress.disable:
                                                        stderr_logger.info(t + json_line['message'])
                                                    else:
                                                        tqdm.write(t + json_line['message'])
                                                elif 'msgid' in json_line:
                                                    if progress.disable:
                                                        stderr_logger.info(t + json_line['msgid'])
                                                    else:
                                                        tqdm.write(t + json_line['msgid'])
                                            else:
                                                if progress.disable:
                                                    stderr_logger.info(t + line.decode())
                                                else:
                                                    tqdm.write(t + line.decode())
                                        # logger.debug('handled stderr lines, %d leftover', len(stderr_linebuf))
                                if event == select.POLLHUP:
                                    poll.unregister(rfd)
                                    pollc -= 1

                            if pollc > 0:
                                # logger.debug('polling %d fds...', pollc)
                                events = poll.poll()
                                # logger.debug('done polling')

                        # logger.debug('borg create closed stdout/stderr')
                        if stdout_linebuf:
                            logger.error('unterminated line leftover in stdout: %s', stdout_linebuf)
                        if stderr_linebuf:
                            logger.error('unterminated line leftover in stdout: %s', stderr_linebuf)
                        progress.set_description(None)
                        ret = proc.wait()
                        # logger.debug('borg create terminated; ret=%d', ret)
                        if ret != 0:
                            dst = None
                            try:
                                dst = read_repo(dst_repo_path)
                            except (subprocess.CalledProcessError, json.decoder.JSONDecodeError) as err:
                                logger.error(err)
                                continue
                            else:
                                if any(map(lambda other: entry['name'] == other['name'], dst)):
                                    logger.info('destination exists, terminating')
                                    break

                            logger.warn('destination does not exist, retrying')
                            continue
                        else:
                            # logger.debug('terminating')
                            break
            mount_proc.terminate()

    with Manager() as manager:
        tmpdir_q = manager.Queue(1)

        with closing(Process(target=do_copy, args=(tmpdir_q,), name='do_copy')) as p:
            p.start()

            with TemporaryDirectory(prefix=f'borg-mount_{entry["name"]}_', dir=os.environ.get('RUNTIME_DIRECTORY')) as tmpdir:
                tmpdir_q.put(tmpdir)
                p.join()
                return p.exitcode

def sigterm(signum, frame):
    raise SystemExit(128 + signum)

def main():
    signal.signal(signal.SIGTERM, sigterm)

    global logger
    logger = logging.getLogger(__name__)
    console_handler = logging.StreamHandler()
    console_handler.setFormatter( logging.Formatter('[%(levelname)s](%(name)s): %(message)s') )
    if sys.stderr.isatty():
        console_handler.setFormatter( logging.Formatter('%(asctime)s [%(levelname)s](%(name)s): %(message)s') )

    burst_max = 1000
    burst = burst_max
    last_use = None
    inv_rate = 1e7
    def consume_filter(record):
        nonlocal burst, burst_max, inv_rate, last_use

        delay = None
        while True:
            now = time.monotonic_ns()
            burst = min(burst_max, burst + math.floor((now - last_use) / inv_rate)) if last_use else burst_max
            last_use = now

            if burst > 0:
                burst -= 1
                if delay:
                    delay = now - delay

                return True

            if delay is None:
                delay = now
            time.sleep(inv_rate / 1e9)
    console_handler.addFilter(consume_filter)

    logging.getLogger().addHandler(console_handler)

    # log uncaught exceptions
    def log_exceptions(type, value, tb):
        global logger

        logger.error(value)
        sys.__excepthook__(type, value, tb) # calls default excepthook

    sys.excepthook = log_exceptions

    parser = argparse.ArgumentParser(prog='copy')
    parser.add_argument('--verbosity', dest='log_level', action='append', type=int)
    parser.add_argument('--verbose', '-v', dest='log_level', action='append_const', const=1)
    parser.add_argument('--quiet', '-q', dest='log_level', action='append_const', const=-1)
    parser.add_argument('source', metavar='REPO_OR_ARCHIVE')
    parser.add_argument('target', metavar='REPO_OR_ARCHIVE')
    args = parser.parse_args()


    LOG_LEVELS = [logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL]
    DEFAULT_LOG_LEVEL = logging.ERROR
    log_level = LOG_LEVELS.index(DEFAULT_LOG_LEVEL)

    for adjustment in args.log_level or ():
        log_level = min(len(LOG_LEVELS) - 1, max(log_level - adjustment, 0))
    logger.setLevel(LOG_LEVELS[log_level])


    if "::" in args.source:
        (src_repo_path, _, src_archive) = args.source.partition("::")
        entry = None
        for candidate_entry in read_repo(src_repo_path):
            if entry['name'] != src_archive:
                continue
            entry = candidate_entry
            break

        if entry is None:
            logger.critical("Did not find archive ‘%s’", src_archive)
            os.exit(1)

        copy_archive(src_repo_path, args.target, entry)
    else:
        for entry in ToSync(args.source, args.target):
            copy_archive(args.source, args.target, entry)

if __name__ == "__main__":
    sys.exit(main())