Coverage for b4_backup/cli/utils.py: 100%
117 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 22:40 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 22:40 +0000
1import json
2import logging
3import os
4import shlex
5from collections.abc import Generator
6from contextlib import contextmanager
7from enum import Enum
8from pathlib import Path, PurePath
9from typing import Any
11import click
12import rich
13import typer
14from rich.table import Table
16from b4_backup import utils
17from b4_backup.cli.init import app, init
18from b4_backup.config_schema import DEFAULT, BaseConfig
19from b4_backup.exceptions import BaseBtrfsBackupError
20from b4_backup.main.dataclass import Snapshot
22log = logging.getLogger("b4_backup.cli")
25def validate_target(ctx: typer.Context, values: list[str]) -> list[str]:
26 """A handler to validate target types."""
27 config: BaseConfig = ctx.obj
29 options = set(config.backup_targets) - {DEFAULT}
30 for value in values:
31 if value is not None and not any(PurePath(x).is_relative_to(value) for x in options):
32 raise typer.BadParameter(f"Unknown target. Available targets are: {', '.join(options)}")
34 return values
37def _parse_arg(param: click.Argument | click.Option, args: list[str]) -> Any | list[Any]:
38 args = list(args)
39 parsed_arg = []
41 if not any(opt in args for opt in param.opts):
42 return param.default
44 while any(opt in args for opt in param.opts):
45 for opt in param.opts:
46 if opt not in args:
47 continue
49 idx = args.index(opt)
50 value = args[idx + 1]
52 # Hacky conversion, because I just don't get what's going on in these click types
53 if isinstance(param.type, click.types.Path):
54 value = Path(value)
56 parsed_arg.append(value)
57 del args[idx]
58 del args[idx]
60 if param.multiple:
61 return parsed_arg
63 return parsed_arg[-1]
66def parse_callback_args(app: typer.Typer, args: list[str]) -> dict[str, Any]:
67 """
68 Extract and parse args from the callback function.
70 This function is a workaround to this issue:
71 https://github.com/tiangolo/typer/issues/259
72 tl;dr: Callback is not called before autocomplete functions, so we need to do it manually
74 Args:
75 app: Typer CLI instance
76 args: Raw cli args
78 Returns:
79 Parsed parameters from callback
80 """
81 assert app.registered_callback is not None
82 params = typer.main.get_params_convertors_ctx_param_name_from_function(
83 app.registered_callback.callback
84 )[0]
86 parsed_args = {}
87 for param in params:
88 parsed_args[param.name] = _parse_arg(param, args)
90 return parsed_args
93def complete_target(ctx: typer.Context, incomplete: str) -> Generator[str, None, None]:
94 """A handler to provide autocomplete for target types."""
95 args = shlex.split(os.getenv("_TYPER_COMPLETE_ARGS", ""))
96 parsed_args = parse_callback_args(app, args)
97 init(ctx, **parsed_args)
98 config: BaseConfig = ctx.obj
100 options = set()
101 for target in set(config.backup_targets) - {DEFAULT}:
102 options.add(target)
103 options |= {str(x) for x in PurePath(target).parents}
105 options = sorted(options)
106 taken_targets = ctx.params.get("target") or []
107 for target in options:
108 if str(target).startswith(incomplete) and target not in taken_targets:
109 yield target
112class ErrorHandler:
113 """Handles errors during execution."""
115 errors: list[Exception]
117 def __init__(self) -> None: # noqa: D107
118 self.errors = []
120 def add(self, exc: Exception) -> None:
121 """
122 Add exception to list.
124 Args:
125 exc: Exception to add
126 """
127 log.exception(exc)
128 self.errors.append(exc)
130 def finalize(self) -> None:
131 """Raise errors if any exist."""
132 if self.errors:
133 raise ExceptionGroup("Errors during loop execution", self.errors)
136@contextmanager
137def error_handler():
138 """A wrapper around the CLI error handler."""
139 try:
140 err_handler = ErrorHandler()
141 yield err_handler
142 err_handler.finalize()
144 except BaseBtrfsBackupError as exc:
145 log.debug("An error occured (%s)", type(exc).__name__, exc_info=exc)
146 rich.print(f"[red]An error occured ({type(exc).__name__})")
147 rich.print(exc)
148 raise typer.Exit(1) from exc
149 except Exception as exc:
150 log.exception("An unknown error occured (%s)", type(exc).__name__)
151 rich.print(f"[red]An unknown error occured ({type(exc).__name__})")
152 rich.print(exc)
153 raise typer.Exit(1) from exc
156class OutputFormat(str, Enum):
157 """An enumeration of supported output formats."""
159 RICH = "rich"
160 JSON = "json"
161 RAW = "raw"
163 @classmethod
164 def output(
165 cls,
166 snapshots: dict[str, Snapshot],
167 title: str,
168 output_format: "OutputFormat",
169 ) -> None:
170 """
171 Output the snapshots in the specified format.
173 Args:
174 snapshots: The snapshots to output
175 title: The title of the output
176 output_format: The format to output the snapshots in
177 """
178 if output_format == OutputFormat.RICH:
179 cls.output_rich(snapshots, title)
180 elif output_format == OutputFormat.JSON:
181 cls.output_json(snapshots, title)
182 else:
183 cls.output_raw(snapshots, title)
185 @classmethod
186 def output_rich(cls, snapshots: dict[str, Snapshot], title: str) -> None:
187 """Output the snapshots in a rich format."""
188 table = Table(title=title)
190 table.add_column("Name", style="cyan", no_wrap=True)
191 table.add_column("Subvolumes", style="magenta")
193 for snapshot_name in sorted(snapshots, reverse=True):
194 table.add_row(
195 snapshot_name,
196 "\n".join(
197 [str(PurePath("/") / x) for x in snapshots[snapshot_name].subvolumes_unescaped]
198 ),
199 )
201 utils.CONSOLE.print(table)
203 @classmethod
204 def output_json(cls, snapshots: dict[str, Snapshot], title: str) -> None:
205 """Output the snapshots in a JSON format."""
206 utils.CONSOLE.print(
207 json.dumps(
208 {
209 "host": title.lower(),
210 "snapshots": {
211 snapshot_name: [
212 str(PurePath("/") / x) for x in snapshot.subvolumes_unescaped
213 ]
214 for snapshot_name, snapshot in snapshots.items()
215 },
216 },
217 sort_keys=True,
218 indent=2,
219 )
220 )
222 @classmethod
223 def output_raw(cls, snapshots: dict[str, Snapshot], title: str) -> None:
224 """Output the snapshots in a raw format."""
225 utils.CONSOLE.print(
226 "\n".join(
227 [
228 f"{title.lower()} {snapshot_name} {str(PurePath(' / ') / subvolume)}"
229 for snapshot_name, snapshot in snapshots.items()
230 for subvolume in snapshot.subvolumes_unescaped
231 ]
232 )
233 )