Coverage for b4_backup/main/connection.py: 100%

129 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-18 22:40 +0000

1from __future__ import annotations 

2 

3import contextlib 

4import logging 

5import re 

6import shlex 

7import subprocess 

8from abc import ABCMeta, abstractmethod 

9from dataclasses import asdict, dataclass 

10from pathlib import PurePath 

11 

12import paramiko 

13 

14from b4_backup import exceptions 

15 

16log = logging.getLogger("b4_backup.connection") 

17 

18 

19@dataclass 

20class URL: 

21 """ 

22 Contains an URL. 

23 

24 Args: 

25 protocol: protocol used. eg. ssh 

26 user: Username 

27 password: Password 

28 host: Hostname 

29 port: Port 

30 location: Protocol specific location 

31 """ 

32 

33 protocol: str | None = None 

34 user: str = "root" 

35 password: str | None = None 

36 host: str | None = None 

37 port: int = 0 

38 location: PurePath = PurePath("/") 

39 

40 _url_pattern = re.compile( 

41 r"^(?:(?P<protocol>[a-zA-Z0-9_\-]+):\/\/)(?:(?P<user>[a-zA-Z0-9_\-]+)(?::(?P<password>[a-zA-Z0-9_\-]+))?@)?(?P<host>[a-zA-Z0-9_.\-]+)(?::(?P<port>[0-9]+))?(?P<location>\/[\/a-zA-Z0-9_.\-]*)?$" 

42 ) 

43 _local_dir_pattern = re.compile(r"^(?P<location>[\/a-zA-Z0-9_.\-]+)$") 

44 _protocol_mapping = {"ssh": 22, None: 0} 

45 

46 @classmethod 

47 def from_url(cls, source: str) -> URL: 

48 """ 

49 Create an instance by providing an URL string. 

50 

51 Args: 

52 source: URL string 

53 

54 Returns: 

55 ParsedURL instance 

56 """ 

57 result = cls._url_pattern.match(source) 

58 

59 if not result: 

60 result = cls._local_dir_pattern.match(source) 

61 

62 if not result: 

63 raise exceptions.InvalidConnectionUrlError( 

64 f"The connection url {source} got an invalid format." 

65 ) 

66 

67 result_dict = asdict(URL()) 

68 result_dict.update(result.groupdict()) 

69 

70 if result_dict["protocol"] is not None: 

71 result_dict["protocol"] = result_dict["protocol"].lower() 

72 

73 return URL( 

74 protocol=result_dict["protocol"], 

75 user=result_dict["user"] or URL.user, 

76 password=result_dict["password"], 

77 host=result_dict["host"], 

78 port=int(result_dict["port"] or cls._protocol_mapping.get(result_dict["protocol"], 0)), 

79 location=PurePath(result_dict["location"] or "/"), 

80 ) 

81 

82 

83class Connection(metaclass=ABCMeta): 

84 """An abstract connection wrapper to execute commands on machines.""" 

85 

86 def __init__(self, location: PurePath) -> None: 

87 """ 

88 Args: 

89 location: Target directory or file. 

90 """ 

91 self.location = location 

92 self.keep_open = False 

93 

94 self.connected: bool = False 

95 

96 @classmethod 

97 def from_url(cls, url: str | None) -> Connection | contextlib.nullcontext: 

98 """ 

99 Parse the URL and return a fitting connection instance. 

100 

101 Args: 

102 url: URL string to parse 

103 

104 Returns: 

105 Connection instance 

106 """ 

107 if url is None: 

108 return contextlib.nullcontext() 

109 

110 parsed_url = URL.from_url(url) 

111 

112 if parsed_url.protocol is None: 

113 return LocalConnection(parsed_url.location) 

114 

115 if parsed_url.protocol == "ssh": 

116 assert parsed_url.host is not None 

117 

118 return SSHConnection( 

119 host=parsed_url.host, 

120 port=parsed_url.port, 

121 user=parsed_url.user, 

122 password=parsed_url.password, 

123 location=parsed_url.location, 

124 ) 

125 

126 raise exceptions.UnknownProtocolError 

127 

128 @abstractmethod 

129 def run_process(self, command: list[str]) -> str: 

130 """ 

131 Run a process without interaction and return the result. 

132 

133 Args: 

134 command: List of parameters 

135 Returns: 

136 stdout of process. 

137 """ 

138 

139 @abstractmethod 

140 def open(self) -> Connection: 

141 """ 

142 Open the connection to the target host. 

143 

144 Returns: 

145 Itself 

146 """ 

147 

148 @abstractmethod 

149 def close(self) -> None: 

150 """Close the connection.""" 

151 

152 @property 

153 @abstractmethod 

154 def exec_prefix(self) -> str: 

155 """ 

156 Returns: 

157 Prefix to run commands on the target using local commands. 

158 """ 

159 

160 def __enter__(self) -> Connection: 

161 """Entrypoint in a "with" statement.""" 

162 return self.open() 

163 

164 def __exit__(self, *args, **kwargs) -> None: 

165 """Endpoint in a "with" statement.""" 

166 if not self.keep_open: 

167 self.close() 

168 

169 

170class LocalConnection(Connection): 

171 """A connection wrapper to execute commands on the local machine.""" 

172 

173 def __init__(self, location: PurePath) -> None: 

174 """ 

175 Args: 

176 location: Target directory or file. 

177 """ 

178 super().__init__(location) 

179 

180 self.location: PurePath = location 

181 

182 def run_process(self, command: list[str]) -> str: 

183 """ 

184 Run a process without interaction and return the result. 

185 

186 Args: 

187 command: List of parameters 

188 Returns: 

189 stdout of process. 

190 """ 

191 log.debug("Start local process:\n%s", command) 

192 with subprocess.Popen( # noqa: S603 

193 command, 

194 stdout=subprocess.PIPE, 

195 stderr=subprocess.PIPE, 

196 stdin=subprocess.PIPE, 

197 ) as process: 

198 stdout, stderr = process.communicate() 

199 stdout = stdout.decode() 

200 stderr = stderr.decode() 

201 

202 if process.returncode: 

203 raise exceptions.FailedProcessError(command, stdout, stderr) 

204 

205 return stdout 

206 

207 def open(self) -> Connection: 

208 """ 

209 Open the connection to the target host. 

210 

211 Returns: 

212 Itself 

213 """ 

214 log.info("Opening local connection to %s", self.location) 

215 self.connected = True 

216 

217 return self 

218 

219 def close(self) -> None: 

220 """Close the connection.""" 

221 assert self.connected, "Connection already closed" 

222 

223 self.connected = False 

224 

225 @property 

226 def exec_prefix(self) -> str: 

227 """ 

228 Returns: 

229 Prefix to run commands on the target using local commands. 

230 """ 

231 return "" 

232 

233 

234class SSHConnection(Connection): 

235 """A connection wrapper to execute commands on remote machines via SSH.""" 

236 

237 ssh_client_pool: dict[tuple[str, int, str], paramiko.SSHClient] = {} 

238 

239 def __init__( 

240 self, 

241 host: str, 

242 location: PurePath, 

243 port: int = 22, 

244 user: str = "root", 

245 password: str | None = None, 

246 ) -> None: 

247 """ 

248 Args: 

249 host: Hostname 

250 location: Target directory or file 

251 port: Port 

252 user: Username 

253 password: Optional password. SSH key recommended. 

254 """ 

255 super().__init__(location) 

256 

257 self.host = host 

258 self.port = port 

259 self.user = user 

260 self.password = password 

261 self._ssh_client: paramiko.SSHClient | None 

262 

263 def run_process(self, command: list[str]) -> str: 

264 """ 

265 Run a process without interaction and return the result. 

266 

267 Args: 

268 command: List of parameters 

269 Returns: 

270 stdout of process. 

271 """ 

272 assert self._ssh_client, "Not connected" 

273 

274 log.debug("Start SSH process:\n%s", command) 

275 

276 _stdin, stdout, stderr = self._ssh_client.exec_command(shlex.join(command)) 

277 stdout_str = stdout.read().decode() 

278 stderr_str = stderr.read().decode() 

279 

280 if stdout.channel.recv_exit_status(): 

281 raise exceptions.FailedProcessError(command, stdout_str, stderr_str) 

282 

283 return stdout_str 

284 

285 def open(self) -> SSHConnection: 

286 """ 

287 Open the connection to the target host. 

288 

289 Returns: 

290 Itself 

291 """ 

292 ssh_client = SSHConnection.ssh_client_pool.get((self.host, self.port, self.user), None) 

293 if not ssh_client: 

294 ssh_client = paramiko.SSHClient() 

295 ssh_client.load_system_host_keys() 

296 ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy()) 

297 

298 log.info("Opening ssh connection to %s@%s:%s", self.user, self.host, self.port) 

299 ssh_client.connect( 

300 self.host, 

301 username=self.user, 

302 password=self.password, 

303 port=self.port, 

304 ) 

305 SSHConnection.ssh_client_pool[(self.host, self.port, self.user)] = ssh_client 

306 

307 self.connected = True 

308 self._ssh_client = ssh_client 

309 

310 return self 

311 

312 def close(self) -> None: 

313 """Close the connection.""" 

314 assert self.connected, "Connection already closed" 

315 assert self._ssh_client 

316 

317 log.info("Closing ssh connection to %s %s", self.host, self.location) 

318 self._ssh_client.close() 

319 del SSHConnection.ssh_client_pool[(self.host, self.port, self.user)] 

320 self.connected = False 

321 self._ssh_client = None 

322 

323 @property 

324 def exec_prefix(self) -> str: 

325 """ 

326 Returns: 

327 Prefix to run commands on the target using local commands. 

328 """ 

329 return f"ssh -p {self.port} {self.user}@{self.host} "