Coverage for b4_backup/main/connection.py: 100%
129 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 22:40 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-18 22:40 +0000
1from __future__ import annotations
3import contextlib
4import logging
5import re
6import shlex
7import subprocess
8from abc import ABCMeta, abstractmethod
9from dataclasses import asdict, dataclass
10from pathlib import PurePath
12import paramiko
14from b4_backup import exceptions
16log = logging.getLogger("b4_backup.connection")
19@dataclass
20class URL:
21 """
22 Contains an URL.
24 Args:
25 protocol: protocol used. eg. ssh
26 user: Username
27 password: Password
28 host: Hostname
29 port: Port
30 location: Protocol specific location
31 """
33 protocol: str | None = None
34 user: str = "root"
35 password: str | None = None
36 host: str | None = None
37 port: int = 0
38 location: PurePath = PurePath("/")
40 _url_pattern = re.compile(
41 r"^(?:(?P<protocol>[a-zA-Z0-9_\-]+):\/\/)(?:(?P<user>[a-zA-Z0-9_\-]+)(?::(?P<password>[a-zA-Z0-9_\-]+))?@)?(?P<host>[a-zA-Z0-9_.\-]+)(?::(?P<port>[0-9]+))?(?P<location>\/[\/a-zA-Z0-9_.\-]*)?$"
42 )
43 _local_dir_pattern = re.compile(r"^(?P<location>[\/a-zA-Z0-9_.\-]+)$")
44 _protocol_mapping = {"ssh": 22, None: 0}
46 @classmethod
47 def from_url(cls, source: str) -> URL:
48 """
49 Create an instance by providing an URL string.
51 Args:
52 source: URL string
54 Returns:
55 ParsedURL instance
56 """
57 result = cls._url_pattern.match(source)
59 if not result:
60 result = cls._local_dir_pattern.match(source)
62 if not result:
63 raise exceptions.InvalidConnectionUrlError(
64 f"The connection url {source} got an invalid format."
65 )
67 result_dict = asdict(URL())
68 result_dict.update(result.groupdict())
70 if result_dict["protocol"] is not None:
71 result_dict["protocol"] = result_dict["protocol"].lower()
73 return URL(
74 protocol=result_dict["protocol"],
75 user=result_dict["user"] or URL.user,
76 password=result_dict["password"],
77 host=result_dict["host"],
78 port=int(result_dict["port"] or cls._protocol_mapping.get(result_dict["protocol"], 0)),
79 location=PurePath(result_dict["location"] or "/"),
80 )
83class Connection(metaclass=ABCMeta):
84 """An abstract connection wrapper to execute commands on machines."""
86 def __init__(self, location: PurePath) -> None:
87 """
88 Args:
89 location: Target directory or file.
90 """
91 self.location = location
92 self.keep_open = False
94 self.connected: bool = False
96 @classmethod
97 def from_url(cls, url: str | None) -> Connection | contextlib.nullcontext:
98 """
99 Parse the URL and return a fitting connection instance.
101 Args:
102 url: URL string to parse
104 Returns:
105 Connection instance
106 """
107 if url is None:
108 return contextlib.nullcontext()
110 parsed_url = URL.from_url(url)
112 if parsed_url.protocol is None:
113 return LocalConnection(parsed_url.location)
115 if parsed_url.protocol == "ssh":
116 assert parsed_url.host is not None
118 return SSHConnection(
119 host=parsed_url.host,
120 port=parsed_url.port,
121 user=parsed_url.user,
122 password=parsed_url.password,
123 location=parsed_url.location,
124 )
126 raise exceptions.UnknownProtocolError
128 @abstractmethod
129 def run_process(self, command: list[str]) -> str:
130 """
131 Run a process without interaction and return the result.
133 Args:
134 command: List of parameters
135 Returns:
136 stdout of process.
137 """
139 @abstractmethod
140 def open(self) -> Connection:
141 """
142 Open the connection to the target host.
144 Returns:
145 Itself
146 """
148 @abstractmethod
149 def close(self) -> None:
150 """Close the connection."""
152 @property
153 @abstractmethod
154 def exec_prefix(self) -> str:
155 """
156 Returns:
157 Prefix to run commands on the target using local commands.
158 """
160 def __enter__(self) -> Connection:
161 """Entrypoint in a "with" statement."""
162 return self.open()
164 def __exit__(self, *args, **kwargs) -> None:
165 """Endpoint in a "with" statement."""
166 if not self.keep_open:
167 self.close()
170class LocalConnection(Connection):
171 """A connection wrapper to execute commands on the local machine."""
173 def __init__(self, location: PurePath) -> None:
174 """
175 Args:
176 location: Target directory or file.
177 """
178 super().__init__(location)
180 self.location: PurePath = location
182 def run_process(self, command: list[str]) -> str:
183 """
184 Run a process without interaction and return the result.
186 Args:
187 command: List of parameters
188 Returns:
189 stdout of process.
190 """
191 log.debug("Start local process:\n%s", command)
192 with subprocess.Popen( # noqa: S603
193 command,
194 stdout=subprocess.PIPE,
195 stderr=subprocess.PIPE,
196 stdin=subprocess.PIPE,
197 ) as process:
198 stdout, stderr = process.communicate()
199 stdout = stdout.decode()
200 stderr = stderr.decode()
202 if process.returncode:
203 raise exceptions.FailedProcessError(command, stdout, stderr)
205 return stdout
207 def open(self) -> Connection:
208 """
209 Open the connection to the target host.
211 Returns:
212 Itself
213 """
214 log.info("Opening local connection to %s", self.location)
215 self.connected = True
217 return self
219 def close(self) -> None:
220 """Close the connection."""
221 assert self.connected, "Connection already closed"
223 self.connected = False
225 @property
226 def exec_prefix(self) -> str:
227 """
228 Returns:
229 Prefix to run commands on the target using local commands.
230 """
231 return ""
234class SSHConnection(Connection):
235 """A connection wrapper to execute commands on remote machines via SSH."""
237 ssh_client_pool: dict[tuple[str, int, str], paramiko.SSHClient] = {}
239 def __init__(
240 self,
241 host: str,
242 location: PurePath,
243 port: int = 22,
244 user: str = "root",
245 password: str | None = None,
246 ) -> None:
247 """
248 Args:
249 host: Hostname
250 location: Target directory or file
251 port: Port
252 user: Username
253 password: Optional password. SSH key recommended.
254 """
255 super().__init__(location)
257 self.host = host
258 self.port = port
259 self.user = user
260 self.password = password
261 self._ssh_client: paramiko.SSHClient | None
263 def run_process(self, command: list[str]) -> str:
264 """
265 Run a process without interaction and return the result.
267 Args:
268 command: List of parameters
269 Returns:
270 stdout of process.
271 """
272 assert self._ssh_client, "Not connected"
274 log.debug("Start SSH process:\n%s", command)
276 _stdin, stdout, stderr = self._ssh_client.exec_command(shlex.join(command))
277 stdout_str = stdout.read().decode()
278 stderr_str = stderr.read().decode()
280 if stdout.channel.recv_exit_status():
281 raise exceptions.FailedProcessError(command, stdout_str, stderr_str)
283 return stdout_str
285 def open(self) -> SSHConnection:
286 """
287 Open the connection to the target host.
289 Returns:
290 Itself
291 """
292 ssh_client = SSHConnection.ssh_client_pool.get((self.host, self.port, self.user), None)
293 if not ssh_client:
294 ssh_client = paramiko.SSHClient()
295 ssh_client.load_system_host_keys()
296 ssh_client.set_missing_host_key_policy(paramiko.RejectPolicy())
298 log.info("Opening ssh connection to %s@%s:%s", self.user, self.host, self.port)
299 ssh_client.connect(
300 self.host,
301 username=self.user,
302 password=self.password,
303 port=self.port,
304 )
305 SSHConnection.ssh_client_pool[(self.host, self.port, self.user)] = ssh_client
307 self.connected = True
308 self._ssh_client = ssh_client
310 return self
312 def close(self) -> None:
313 """Close the connection."""
314 assert self.connected, "Connection already closed"
315 assert self._ssh_client
317 log.info("Closing ssh connection to %s %s", self.host, self.location)
318 self._ssh_client.close()
319 del SSHConnection.ssh_client_pool[(self.host, self.port, self.user)]
320 self.connected = False
321 self._ssh_client = None
323 @property
324 def exec_prefix(self) -> str:
325 """
326 Returns:
327 Prefix to run commands on the target using local commands.
328 """
329 return f"ssh -p {self.port} {self.user}@{self.host} "