#!@python@/bin/python import csv import subprocess import io from distutils.util import strtobool from datetime import datetime, timezone, timedelta from dateutil.tz import gettz, tzlocal import pytimeparse import argparse import re import sys import logging import shlex from collections import defaultdict, OrderedDict, deque, namedtuple import configparser from xdg import BaseDirectory from functools import cache from math import floor @cache def _now(): return datetime.now(timezone.utc) def _snap_name(item, time=_now()): suffix = re.sub(r'\+00:00$', r'Z', time.isoformat(timespec='seconds')) return f'{item}@auto_{suffix}' def _log_cmd(*args): fmt_args = ' '.join(map(shlex.quote, args)) logger.debug(f'Running command: {fmt_args}') def _get_items(): items = {} args = ['zfs', 'get', '-H', '-p', '-o', 'name,value', '-t', 'filesystem,volume', '-s', 'local,default,inherited,temporary,received', 'li.yggdrasil:auto-snapshot'] _log_cmd(*args) with subprocess.Popen(args, stdout=subprocess.PIPE) as proc: text_stdout = io.TextIOWrapper(proc.stdout) reader = csv.reader(text_stdout, delimiter='\t', quoting=csv.QUOTE_NONE) Row = namedtuple('Row', ['name', 'setting']) for row in map(Row._make, reader): items[row.name] = bool(strtobool(row.setting)) return items def prune(config, dry_run, keep_newest): items = defaultdict(list) Snap = namedtuple('Snap', ['name', 'creation']) args = ['zfs', 'get', '-H', '-p', '-o', 'name,value', '-t', 'snapshot', 'creation'] _log_cmd(*args) with subprocess.Popen(args, stdout=subprocess.PIPE) as proc: text_stdout = io.TextIOWrapper(proc.stdout) reader = csv.reader(text_stdout, delimiter='\t', quoting=csv.QUOTE_NONE) Row = namedtuple('Row', ['name', 'timestamp']) for row in map(Row._make, reader): creation = datetime.fromtimestamp(int(row.timestamp), timezone.utc) base_name, _, _ = row.name.rpartition('@') expected_name = _snap_name(base_name, time=creation) if expected_name != row.name: # logger.debug(f'Skipping ‘{row.name}’ since it does not conform to naming scheme') continue items[base_name].append(Snap(name=row.name, creation=creation)) keep = set() kept_count = defaultdict(lambda: defaultdict(lambda: 0)) KeptBecause = namedtuple('KeptBecause', ['rule', 'ix', 'base', 'period']) kept_because = OrderedDict(deque) def keep_because(base, snap, rule, period=None): nonlocal KeptBecause, keep, kept_count, kept_because kept_count[rule][base] += 1 if snap not in kept_because: kept_because[snap] = deque() kept_because[snap].append(KeptBecause(rule=rule, ix=kept_count[rule][base], base=base, period=period)) keep.add(snap) within = config.gettimedelta('KEEP', 'within') if within > timedelta(seconds=0): for base, snaps in items.items(): time_ref = max(snaps, key=lambda snap: snap.creation, default=None) if not time_ref: logger.warn(f'Nothing to keep for ‘{base}’') continue logger.info(f'Using ‘{time_ref.name}’ as time reference for ‘{base}’') within_cutoff = time_ref.creation - within for snap in snaps: if snap.creation >= within_cutoff: keep_because(base, snap.name, 'within') else: logger.warn('Skipping rule ‘within’ since retention period is zero') prune_timezone = config.gettimezone('KEEP', 'timezone', fallback=tzlocal) PRUNING_PATTERNS = OrderedDict([ ("secondly", lambda t: t.strftime('%Y-%m-%d %H:%M:%S')), ("minutely", lambda t: t.strftime('%Y-%m-%d %H:%M')), ("5m", lambda t: (t.strftime('%Y-%m-%d %H'), floor(t.minute / 5) * 5)), ("hourly", lambda t: t.strftime('%Y-%m-%d %H')), ("daily", lambda t: t.strftime('%Y-%m-%d')), ("weekly", lambda t: t.strftime('%G-%V')), ("monthly", lambda t: t.strftime('%Y-%m')), ("yearly", lambda t: t.strftime('%Y')), ]) for rule, pattern in PRUNING_PATTERNS.items(): desired_count = config.getint('KEEP', rule, fallback=0) for base, snaps in items.items(): periods = OrderedDict() for snap in sorted(snaps, key=lambda snap: snap.creation, reverse=keep_newest): period = pattern(snap.creation) if period not in periods: periods[period] = deque() periods[period].append(snap) to_keep = desired_count ordered_periods = periods.items() if keep_newest else reversed(periods.items()) for period, period_snaps in ordered_periods: if to_keep == 0: break for snap in period_snaps: keep_because(base, snap.name, rule, period=period) to_keep -= 1 break if to_keep > 0: logger.debug(f'Missing {to_keep} to fulfill {rule}={desired_count} for ‘{base}’') for snap, reasons in kept_because: reasons_str = ', '.join(map(str, reasons)) logger.info(f'Keeping ‘{snap}’ because: {reasons_str}') all_snaps = {snap.name for _, snaps in items.items() for snap in snaps} to_destroy = all_snaps - keep if not to_destroy: logger.info('Nothing to prune') for snap in sorted(to_destroy): args = ['zfs', 'destroy'] if dry_run: args += ['-n'] args += [snap] _log_cmd(*args) subprocess.run(args, check=True) if dry_run: logger.info(f'Would have pruned ‘{snap}’') else: logger.info(f'Pruned ‘{snap}’') def rename(snapshots, destroy=False): args = ['zfs', 'get', '-H', '-p', '-o', 'name,value', 'creation', *snapshots] _log_cmd(*args) renamed_to = set() with subprocess.Popen(args, stdout=subprocess.PIPE) as proc: text_stdout = io.TextIOWrapper(proc.stdout) reader = csv.reader(text_stdout, delimiter='\t', quoting=csv.QUOTE_NONE) Row = namedtuple('Row', ['name', 'timestamp']) for row in map(Row._make, reader): creation = datetime.fromtimestamp(int(row.timestamp), timezone.utc) base_name, _, _ = row.name.rpartition('@') new_name = _snap_name(base_name, time=creation) if new_name == row.name: logger.debug(f'Not renaming ‘{row.name}’ since name is already correct') continue if new_name in renamed_to: if destroy: logger.warning(f'Destroying ‘{row.name}’ since ‘{new_name}’ was already renamed to') args = ['zfs', 'destroy', row.name] _log_cmd(*args) subprocess.run(args, check=True) else: logger.info(f'Skipping ‘{row.name}’ since ‘{new_name}’ was already renamed to') continue logger.info(f'Renaming ‘{row.name}’ to ‘{new_name}’') args = ['zfs', 'rename', row.name, new_name] _log_cmd(*args) subprocess.run(args, check=True) renamed_to.add(new_name) def autosnap(): items = _get_items() recursive, single = set(), set() for item_name, is_included in items.items(): if not is_included: continue children = {sub_name for sub_name in items if sub_name.startswith(f'{item_name}/')} is_recursive = all([items[sub_name] for sub_name in children]) if is_recursive and children: recursive.add(item_name) else: single.add(item_name) for item_name in recursive | single: is_covered = any([item_name.startswith(f'{super_name}/') for super_name in recursive]) if is_covered: try: recursive.remove(item_name) except KeyError: pass try: single.remove(item_name) except KeyError: pass all_snap_names = set() def do_snapshot(*snap_items, recursive=False): nonlocal items, all_snap_names snap_names = {_snap_name(item) for item in snap_items} if recursive: for snap_item in snap_items: all_snap_names |= {_snap_name(item) for item in items if item.startswith(snap_item)} else: all_snap_names |= snap_names args = ['zfs', 'snapshot'] if recursive: args += ['-r'] args += snap_names _log_cmd(*args) subprocess.run(args, check=True) if single: do_snapshot(*single) if recursive: do_snapshot(*recursive, recursive=True) if not single and not recursive: logger.warning('No snapshots to create') for snap in all_snap_names: logger.info(f'Created ‘{snap}’') if all_snap_names: rename(snapshots=all_snap_names) def main(): global logger logger = logging.getLogger(__name__) console_handler = logging.StreamHandler() console_handler.setFormatter( logging.Formatter('[%(levelname)s](%(name)s): %(message)s') ) if sys.stderr.isatty(): console_handler.setFormatter( logging.Formatter('%(asctime)s [%(levelname)s](%(name)s): %(message)s') ) logger.addHandler(console_handler) # log uncaught exceptions def log_exceptions(type, value, tb): global logger logger.error(value) sys.__excepthook__(type, value, tb) # calls default excepthook sys.excepthook = log_exceptions parser = argparse.ArgumentParser(prog='zfssnap') parser.add_argument('--verbose', '-v', action='count', default=0) subparsers = parser.add_subparsers() parser.set_defaults(cmd=autosnap) rename_parser = subparsers.add_parser('rename') rename_parser.add_argument('snapshots', nargs='+') rename_parser.add_argument('--destroy', action='store_true', default=False) rename_parser.set_defaults(cmd=rename) prune_parser = subparsers.add_parser('prune') prune_parser.add_argument('--config', '-c', dest='config_files', nargs='*', default=list()) prune_parser.add_argument('--dry-run', '-n', action='store_true', default=False) prune_parser.add_argument('--keep-newest', action='store_true', default=False) prune_parser.set_defaults(cmd=prune) args = parser.parse_args() if args.verbose <= 0: logger.setLevel(logging.WARNING) elif args.verbose <= 1: logger.setLevel(logging.INFO) else: logger.setLevel(logging.DEBUG) cmdArgs = {} for copy in {'snapshots', 'dry_run', 'destroy', 'keep_newest'}: if copy in vars(args): cmdArgs[copy] = vars(args)[copy] if 'config_files' in vars(args): def convert_timedelta(secs_str): secs=pytimeparse.parse(secs_str) if secs is None: raise ValueError(f'Could not parse timedelta expression ‘{secs_str}’') return timedelta(seconds=secs) config = configparser.ConfigParser(converters={ 'timedelta': convert_timedelta, 'timezone': gettz }) search_files = args.config_files if args.config_files else [*BaseDirectory.load_config_paths('zfssnap.ini')] read_files = config.read(search_files) def format_config_files(files): if not files: return 'no files' return ', '.join(map(lambda file: f'‘{file}’', files)) if not read_files: raise Exception(f'Found no config files. Tried: {format_config_files(search_files)}') logger.debug(f'Read following config files: {format_config_files(read_files)}') cmdArgs['config'] = config args.cmd(**cmdArgs) sys.exit(main())