Coverage for b4_backup/cli/utils.py: 100%

117 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-18 22:40 +0000

1import json 

2import logging 

3import os 

4import shlex 

5from collections.abc import Generator 

6from contextlib import contextmanager 

7from enum import Enum 

8from pathlib import Path, PurePath 

9from typing import Any 

10 

11import click 

12import rich 

13import typer 

14from rich.table import Table 

15 

16from b4_backup import utils 

17from b4_backup.cli.init import app, init 

18from b4_backup.config_schema import DEFAULT, BaseConfig 

19from b4_backup.exceptions import BaseBtrfsBackupError 

20from b4_backup.main.dataclass import Snapshot 

21 

22log = logging.getLogger("b4_backup.cli") 

23 

24 

25def validate_target(ctx: typer.Context, values: list[str]) -> list[str]: 

26 """A handler to validate target types.""" 

27 config: BaseConfig = ctx.obj 

28 

29 options = set(config.backup_targets) - {DEFAULT} 

30 for value in values: 

31 if value is not None and not any(PurePath(x).is_relative_to(value) for x in options): 

32 raise typer.BadParameter(f"Unknown target. Available targets are: {', '.join(options)}") 

33 

34 return values 

35 

36 

37def _parse_arg(param: click.Argument | click.Option, args: list[str]) -> Any | list[Any]: 

38 args = list(args) 

39 parsed_arg = [] 

40 

41 if not any(opt in args for opt in param.opts): 

42 return param.default 

43 

44 while any(opt in args for opt in param.opts): 

45 for opt in param.opts: 

46 if opt not in args: 

47 continue 

48 

49 idx = args.index(opt) 

50 value = args[idx + 1] 

51 

52 # Hacky conversion, because I just don't get what's going on in these click types 

53 if isinstance(param.type, click.types.Path): 

54 value = Path(value) 

55 

56 parsed_arg.append(value) 

57 del args[idx] 

58 del args[idx] 

59 

60 if param.multiple: 

61 return parsed_arg 

62 

63 return parsed_arg[-1] 

64 

65 

66def parse_callback_args(app: typer.Typer, args: list[str]) -> dict[str, Any]: 

67 """ 

68 Extract and parse args from the callback function. 

69 

70 This function is a workaround to this issue: 

71 https://github.com/tiangolo/typer/issues/259 

72 tl;dr: Callback is not called before autocomplete functions, so we need to do it manually 

73 

74 Args: 

75 app: Typer CLI instance 

76 args: Raw cli args 

77 

78 Returns: 

79 Parsed parameters from callback 

80 """ 

81 assert app.registered_callback is not None 

82 params = typer.main.get_params_convertors_ctx_param_name_from_function( 

83 app.registered_callback.callback 

84 )[0] 

85 

86 parsed_args = {} 

87 for param in params: 

88 parsed_args[param.name] = _parse_arg(param, args) 

89 

90 return parsed_args 

91 

92 

93def complete_target(ctx: typer.Context, incomplete: str) -> Generator[str, None, None]: 

94 """A handler to provide autocomplete for target types.""" 

95 args = shlex.split(os.getenv("_TYPER_COMPLETE_ARGS", "")) 

96 parsed_args = parse_callback_args(app, args) 

97 init(ctx, **parsed_args) 

98 config: BaseConfig = ctx.obj 

99 

100 options = set() 

101 for target in set(config.backup_targets) - {DEFAULT}: 

102 options.add(target) 

103 options |= {str(x) for x in PurePath(target).parents} 

104 

105 options = sorted(options) 

106 taken_targets = ctx.params.get("target") or [] 

107 for target in options: 

108 if str(target).startswith(incomplete) and target not in taken_targets: 

109 yield target 

110 

111 

112class ErrorHandler: 

113 """Handles errors during execution.""" 

114 

115 errors: list[Exception] 

116 

117 def __init__(self) -> None: # noqa: D107 

118 self.errors = [] 

119 

120 def add(self, exc: Exception) -> None: 

121 """ 

122 Add exception to list. 

123 

124 Args: 

125 exc: Exception to add 

126 """ 

127 log.exception(exc) 

128 self.errors.append(exc) 

129 

130 def finalize(self) -> None: 

131 """Raise errors if any exist.""" 

132 if self.errors: 

133 raise ExceptionGroup("Errors during loop execution", self.errors) 

134 

135 

136@contextmanager 

137def error_handler(): 

138 """A wrapper around the CLI error handler.""" 

139 try: 

140 err_handler = ErrorHandler() 

141 yield err_handler 

142 err_handler.finalize() 

143 

144 except BaseBtrfsBackupError as exc: 

145 log.debug("An error occured (%s)", type(exc).__name__, exc_info=exc) 

146 rich.print(f"[red]An error occured ({type(exc).__name__})") 

147 rich.print(exc) 

148 raise typer.Exit(1) from exc 

149 except Exception as exc: 

150 log.exception("An unknown error occured (%s)", type(exc).__name__) 

151 rich.print(f"[red]An unknown error occured ({type(exc).__name__})") 

152 rich.print(exc) 

153 raise typer.Exit(1) from exc 

154 

155 

156class OutputFormat(str, Enum): 

157 """An enumeration of supported output formats.""" 

158 

159 RICH = "rich" 

160 JSON = "json" 

161 RAW = "raw" 

162 

163 @classmethod 

164 def output( 

165 cls, 

166 snapshots: dict[str, Snapshot], 

167 title: str, 

168 output_format: "OutputFormat", 

169 ) -> None: 

170 """ 

171 Output the snapshots in the specified format. 

172 

173 Args: 

174 snapshots: The snapshots to output 

175 title: The title of the output 

176 output_format: The format to output the snapshots in 

177 """ 

178 if output_format == OutputFormat.RICH: 

179 cls.output_rich(snapshots, title) 

180 elif output_format == OutputFormat.JSON: 

181 cls.output_json(snapshots, title) 

182 else: 

183 cls.output_raw(snapshots, title) 

184 

185 @classmethod 

186 def output_rich(cls, snapshots: dict[str, Snapshot], title: str) -> None: 

187 """Output the snapshots in a rich format.""" 

188 table = Table(title=title) 

189 

190 table.add_column("Name", style="cyan", no_wrap=True) 

191 table.add_column("Subvolumes", style="magenta") 

192 

193 for snapshot_name in sorted(snapshots, reverse=True): 

194 table.add_row( 

195 snapshot_name, 

196 "\n".join( 

197 [str(PurePath("/") / x) for x in snapshots[snapshot_name].subvolumes_unescaped] 

198 ), 

199 ) 

200 

201 utils.CONSOLE.print(table) 

202 

203 @classmethod 

204 def output_json(cls, snapshots: dict[str, Snapshot], title: str) -> None: 

205 """Output the snapshots in a JSON format.""" 

206 utils.CONSOLE.print( 

207 json.dumps( 

208 { 

209 "host": title.lower(), 

210 "snapshots": { 

211 snapshot_name: [ 

212 str(PurePath("/") / x) for x in snapshot.subvolumes_unescaped 

213 ] 

214 for snapshot_name, snapshot in snapshots.items() 

215 }, 

216 }, 

217 sort_keys=True, 

218 indent=2, 

219 ) 

220 ) 

221 

222 @classmethod 

223 def output_raw(cls, snapshots: dict[str, Snapshot], title: str) -> None: 

224 """Output the snapshots in a raw format.""" 

225 utils.CONSOLE.print( 

226 "\n".join( 

227 [ 

228 f"{title.lower()} {snapshot_name} {str(PurePath(' / ') / subvolume)}" 

229 for snapshot_name, snapshot in snapshots.items() 

230 for subvolume in snapshot.subvolumes_unescaped 

231 ] 

232 ) 

233 )