#!@python@/bin/python

import json
import os
import subprocess
import re
import sys
from sys import stderr
from humanize import naturalsize

from tempfile import TemporaryDirectory

from datetime import (datetime, timedelta)
from dateutil.tz import (tzlocal, tzutc)
import dateutil.parser
import argparse

from tqdm import tqdm

from xdg import xdg_runtime_dir
import pathlib

import unshare
from pyprctl import cap_permitted, cap_inheritable, cap_effective, cap_ambient, Cap
from pwd import getpwnam

import signal
from time import sleep

from halo import Halo

from collections import deque


parser = argparse.ArgumentParser()
parser.add_argument('source', metavar='REPO_OR_ARCHIVE')
parser.add_argument('target', metavar='REPO_OR_ARCHIVE')
args = parser.parse_args()

halo_args = {
    'stream': stderr,
    'enabled': stderr.isatty(),
    'spinner': 'arc'
}

borg_pwd = getpwnam('borg')

def as_borg(caps=set(), cwd=None):
    if caps:
        cap_permitted.add(*caps)
        cap_inheritable.add(*caps)
        cap_effective.add(*caps)
        cap_ambient.add(*caps)

    os.setgid(borg_pwd.pw_gid)
    os.setuid(borg_pwd.pw_uid)

    if cwd is not None:
        os.chdir(cwd)

def read_repo(path):
    with Halo(text=f'Listing {path}', **halo_args) as sp:
        res = None
        with subprocess.Popen(['borg', 'list', '--info', '--lock-wait=600', '--json', path], stdout=subprocess.PIPE, preexec_fn=lambda: as_borg()) as proc:
            res = json.load(proc.stdout)['archives']
        if sp.enabled:
            sp.succeed(f'{len(res)} archives in {path}')
        else:
            print(f'{len(res)} archives in {path}', file=stderr)
    return res

class ToSync:
    to_sync = deque()
    
    def __iter__(self):
        return self

    def __next__(self):
        if self.to_sync:
            return self.to_sync.popleft()

        while True:
            try:
                src = read_repo(args.source)
                dst = read_repo(args.target)
            except (subprocess.CalledProcessError, json.decoder.JSONDecodeError) as err:
                print(err, file=stderr)
                continue

            self.to_sync.extend([entry for entry in src if entry['name'] not in {dst_entry['name'] for dst_entry in dst} and not entry['name'].endswith('.checkpoint')])

            if self.to_sync:
                return self.to_sync.popleft()

            raise StopIteration

def copy_archive(src_repo_path, dst_repo_path, entry):
    cache_suffix = None
    with Halo(text=f'Determine archive parameters', **halo_args) as sp:
        match = re.compile('^(.*)-[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.(checkpoint|recreate)(\.[0-9]+)?)?').fullmatch(entry['name'])
        if match:
            repo_id = None
            with subprocess.Popen(['borg', 'info', '--info', '--lock-wait=600', '--json', src_repo_path], stdout=subprocess.PIPE, preexec_fn=lambda: as_borg()) as proc:
                repo_id = json.load(proc.stdout)['repository']['id']
            if repo_id:
                cache_suffix = f'{repo_id}_{match.group(1)}'
        if sp.enabled:
            sp.succeed(f'Will process {entry["name"]} ({dateutil.parser.isoparse(entry["start"])}, cache_suffix={cache_suffix})')
        else:
            print(f'Will process {entry["name"]} ({dateutil.parser.isoparse(entry["start"])}, cache_suffix={cache_suffix})', file=stderr)
    with TemporaryDirectory(prefix=f'borg-mount_{entry["name"]}_', dir=os.environ.get('RUNTIME_DIRECTORY')) as tmpdir:
        child = os.fork()
        if child == 0:
            # print('unshare/chroot', file=stderr)
            unshare.unshare(unshare.CLONE_NEWNS)
            subprocess.run(['mount', '--make-rprivate', '/'], check=True)
            chroot = pathlib.Path(tmpdir) / 'chroot'
            upper = pathlib.Path(tmpdir) / 'upper'
            work = pathlib.Path(tmpdir) / 'work'
            for path in [chroot,upper,work]:
                path.mkdir()
            subprocess.run(['mount', '-t', 'overlay', 'overlay', '-o', f'lowerdir=/,upperdir={upper},workdir={work}', chroot], check=True)
            bindMounts = ['nix', 'run', 'run/secrets.d', 'run/wrappers', 'proc', 'dev', 'sys', pathlib.Path(os.path.expanduser('~')).relative_to('/')]
            if os.environ.get('BORG_BASE_DIR'):
                bindMounts.append(pathlib.Path(os.environ['BORG_BASE_DIR']).relative_to('/'))
            if not ":" in src_repo_path:
                bindMounts.append(pathlib.Path(src_repo_path).relative_to('/'))
            if 'SSH_AUTH_SOCK' in os.environ:
                bindMounts.append(pathlib.Path(os.environ['SSH_AUTH_SOCK']).parent.relative_to('/'))
            for bindMount in bindMounts:
                (chroot / bindMount).mkdir(parents=True,exist_ok=True)
                # print(*['mount', '--bind', pathlib.Path('/') / bindMount, chroot / bindMount], file=stderr)
                subprocess.run(['mount', '--bind', pathlib.Path('/') / bindMount, chroot / bindMount], check=True)
            os.chroot(chroot)
            os.chdir('/')
            try:
                os.unlink('/etc/fuse.conf')
            except FileNotFoundError:
                pass
            pathlib.Path('/etc/fuse.conf').parent.mkdir(parents=True,exist_ok=True)
            with open('/etc/fuse.conf', 'w') as fuse_conf:
                fuse_conf.write('user_allow_other\nmount_max = 1000\n')
            dir = pathlib.Path('/borg')
            dir.mkdir(parents=True,exist_ok=True,mode=0o0750)
            os.chown(dir, borg_pwd.pw_uid, borg_pwd.pw_gid)

            total_size = None
            total_files = None
            if stderr.isatty():
                with Halo(text=f'Determine size', **halo_args) as sp:
                    with subprocess.Popen(['borg', 'info', '--info', '--json', '--lock-wait=600', f'{src_repo_path}::{entry["name"]}'], stdout=subprocess.PIPE, text=True, preexec_fn=lambda: as_borg()) as proc:
                        stats = json.load(proc.stdout)['archives'][0]['stats']
                        total_size = stats['original_size']
                        total_files = stats['nfiles']
                    if sp.enabled:
                        sp.succeed(f'{total_files} files, {naturalsize(total_size, binary=True)}')
                    else:
                        print(f'{total_files} files, {naturalsize(total_size, binary=True)}', file=stderr)
            # print(f'Mounting to {dir}', file=stderr)
            with subprocess.Popen(['borg', 'mount', '-o', 'allow_other,ignore_permissions', '--foreground', '--progress', '--lock-wait=600', f'{src_repo_path}::{entry["name"]}', dir], preexec_fn=lambda: as_borg()) as mount_proc:
                with Halo(text='Waiting for mount', **halo_args) as sp:
                    wait_start = datetime.now()
                    while True:
                        if os.path.ismount(dir):
                            break
                        elif datetime.now() - wait_start > timedelta(minutes=15):
                            ret.check_returncode()
                        sleep(0.1)
                    if sp.enabled:
                        sp.succeed('Mounted')
                    else:
                        print('Mounted', file=stderr)
                while True:
                    try:
                        with tqdm(total=total_size, unit_scale=True, unit_divisor=1024, unit='B', smoothing=0.01, disable=None, dynamic_ncols=True, maxinterval=0.5, miniters=1) as progress:
                            seen = 0
                            env = os.environ.copy()
                            create_args = ['borg',
                                           'create',
                                           '--lock-wait=600',
                                           '--one-file-system',
                                           '--compression=auto,zstd,10',
                                           '--chunker-params=10,23,16,4095',
                                           '--files-cache=ctime,size',
                                           '--show-rc',
                                           # '--remote-ratelimit=20480',
                                           '--log-json',
                                           '--progress',
                                           '--list',
                                           '--filter=AMEi-x?',
                                           '--stats'
                                           ]
                            archive_time = datetime.strptime(entry["time"], "%Y-%m-%dT%H:%M:%S.%f").replace(tzinfo=tzlocal()).astimezone(tzutc())
                            create_args += [f'--timestamp={archive_time.strftime("%Y-%m-%dT%H:%M:%S")}']
                            if cache_suffix:
                                env['BORG_FILES_CACHE_SUFFIX'] = cache_suffix
                            else:
                                create_args += ['--files-cache=disabled']
                            create_args += [f'{dst_repo_path}::{entry["name"]}', '.']

                            with subprocess.Popen(create_args, stdin=subprocess.DEVNULL, stderr=subprocess.PIPE, text=True, env=env, preexec_fn=lambda: as_borg(caps={Cap.DAC_READ_SEARCH}, cwd=dir)) as proc:
                                last_list = None
                                last_list_time = None
                                for line in proc.stderr:
                                    try:
                                        json_line = json.loads(line)
                                    except json.decoder.JSONDecodeError:
                                        tqdm.write(line)
                                        continue

                                    t = ''
                                    if 'time' in json_line and stderr.isatty():
                                        ts = datetime.fromtimestamp(json_line['time']).replace(tzinfo=tzlocal())
                                        t = f'{ts.isoformat(timespec="minutes")} '
                                    if json_line['type'] == 'archive_progress':
                                        if last_list_time is None or ((datetime.now() - last_list_time) // timedelta(seconds=3)) % 2 == 1:
                                            if 'path' in json_line and json_line['path']:
                                                progress.set_description(f'… {json_line["path"]}', refresh=False)
                                            else:
                                                progress.set_description(None, refresh=False)
                                        elif last_list is not None:
                                            progress.set_description(last_list, refresh=False)
                                        nfiles=json_line["nfiles"]
                                        if total_files is not None:
                                            nfiles=f'{json_line["nfiles"]}/{total_files}'
                                        progress.set_postfix(compressed=naturalsize(json_line['compressed_size'], binary=True), deduplicated=naturalsize(json_line['deduplicated_size'], binary=True), nfiles=nfiles, refresh=False)
                                        progress.update(json_line["original_size"] - seen)
                                        seen = json_line["original_size"]
                                    elif json_line['type'] == 'file_status':
                                        # tqdm.write(t + f'{json_line["status"]} {json_line["path"]}')
                                        last_list = f'{json_line["status"]} {json_line["path"]}'
                                        last_list_time = datetime.now()
                                        progress.set_description(last_list, refresh=False)
                                        if not stderr.isatty():
                                            print(last_list, file=stderr)
                                    elif (json_line['type'] == 'log_message' or json_line['type'] == 'progress_message' or json_line['type'] == 'progress_percent') and ('message' in json_line or 'msgid' in json_line):
                                        if 'message' in json_line:
                                            tqdm.write(t + json_line['message'])
                                        elif 'msgid' in json_line:
                                            tqdm.write(t + json_line['msgid'])
                                    else:
                                        tqdm.write(t + line)
                                progress.set_description(None)
                                if proc.wait() != 0:
                                    continue
                    except subprocess.CalledProcessError as err:
                        print(err, file=stderr)
                        continue
                    else:
                        break
                mount_proc.terminate()
            os._exit(0)
        else:
            while True:
                waitpid, waitret = os.wait()
                if waitret != 0:
                    sys.exit(waitret)
                if waitpid == child:
                    break

def sigterm(signum, frame):
    raise SystemExit(128 + signum)

def main():
    signal.signal(signal.SIGTERM, sigterm)
    
    if "::" in args.source:
        (src_repo_path, _, src_archive) = args.source.partition("::")
        entry = None
        for candidate_entry in read_repo(src_repo_path):
            if entry['name'] != src_archive:
                continue
            entry = candidate_entry
            break

        if entry is None:
            print("Did not find archive", file=stderr)
            os.exit(1)

        copy_archive(src_repo_path, args.target, entry)
    else:
        for entry in ToSync():
            copy_archive(args.source, args.target, entry)

if __name__ == "__main__":
    sys.exit(main())