diff --git a/doc/reference/cli.md b/doc/reference/cli.md index 29e47eb..254e913 100644 --- a/doc/reference/cli.md +++ b/doc/reference/cli.md @@ -37,6 +37,14 @@ Here's a description of all the commands and options `not-my-board` supports. **`status`** \[**`-h`**|**`--help`**\] \[**`-v`**|**`--verbose`**\] \[**`-n`**|**`--no-header`**\] : Show status of attached places and its interfaces. +**`edit`** \[**`-h`**|**`--help`**\] \[**`-v`**|**`--verbose`**\] *name* +: Edit the import description of a reserved or attached place. It opens a + temporary file with the current import description with the configured editor. + After the editor is closed, the reservation is updated. The editor used is + chosen from the `VISUAL` or the `EDITOR` environment variable, in that order. + If none is set, then `vi` is used. This doesn't modify the actual import + description used to attach the place. + **`uevent`** \[**`-h`**|**`--help`**\] \[**`-v`**|**`--verbose`**\] *devpath* : Handle Kernel uevent for USB devices. This should be called by the device manager, e.g. *udev*(7). diff --git a/not_my_board/_agent.py b/not_my_board/_agent.py index 3842f41..a2f94a5 100644 --- a/not_my_board/_agent.py +++ b/not_my_board/_agent.py @@ -12,7 +12,7 @@ import urllib.parse import weakref from dataclasses import dataclass, field -from typing import List, Tuple +from typing import Mapping, Tuple import not_my_board._jsonrpc as jsonrpc import not_my_board._models as models @@ -88,8 +88,12 @@ async def usbip_attach(self, proxy, target, port_num, usbid): return await usbip.attach(reader, writer, usbid, port_num) @staticmethod - def usbip_detach(vhci_port): + async def usbip_detach(vhci_port): usbip.detach(vhci_port) + # Unfortunately it takes ~ 0.5 seconds for the connection to close and + # for the remote device to be available again. Wait a bit, so an + # immediate attach after the detach succeeds. + await asyncio.sleep(2) async def port_forward(self, ready_event, proxy, target, local_port): localhost = "127.0.0.1" @@ -131,7 +135,9 @@ async def serve_forever(self): await self._unix_server.serve_forever() async def _cleanup(self): - coros = [t.close() for r in self._reservations.values() for t in r.tunnels] + coros = [ + t.close() for r in self._reservations.values() for t in r.tunnels.values() + ] await util.run_concurrently(*coros) @@ -145,18 +151,27 @@ async def reserve(self, name, import_description_toml): places = await self._io.get_places() - candidates = _filter_places(import_description, places) - if not candidates: + tunnel_descs_by_id = _filter_places(import_description, places) + if not tunnel_descs_by_id: raise RuntimeError("No matching place found") - candidate_ids = list(candidates) + candidate_ids = list(tunnel_descs_by_id) place_id = await self._hub.reserve(candidate_ids) - tunnels = [ - desc.tunnel_cls(desc, self._hub_host, self._io) - for desc in candidates[place_id] - ] - self._reservations[name] = _Reservation(place_id, tunnels) + for p in places: + if p.id == place_id: + place = p + break + else: + raise RuntimeError("Hub returned invalid Place ID") + + tunnels = { + desc: desc.tunnel_cls(desc, self._hub_host, self._io) + for desc in tunnel_descs_by_id[place_id] + } + self._reservations[name] = _Reservation( + import_description_toml, place, tunnels + ) async def return_reservation(self, name, force=False): async with self._reservation(name) as reservation: @@ -165,7 +180,7 @@ async def return_reservation(self, name, force=False): await self._detach_reservation(reservation) else: raise RuntimeError(f'Place "{name}" is still attached') - await self._hub.return_reservation(reservation.place_id) + await self._hub.return_reservation(reservation.place.id) del self._reservations[name] async def attach(self, name): @@ -173,7 +188,7 @@ async def attach(self, name): if reservation.is_attached: raise RuntimeError(f'Place "{name}" is already attached') - coros = [t.open() for t in reservation.tunnels] + coros = [t.open() for t in reservation.tunnels.values()] async with util.on_error(self._detach_reservation, reservation): await util.run_concurrently(*coros) @@ -187,7 +202,7 @@ async def detach(self, name): await self._detach_reservation(reservation) async def _detach_reservation(self, reservation): - coros = [t.close() for t in reservation.tunnels] + coros = [t.close() for t in reservation.tunnels.values()] await util.run_concurrently(*coros) reservation.is_attached = False @@ -227,9 +242,66 @@ async def status(self): "attached": tunnel.is_attached(), } for name, reservation in self._reservations.items() - for tunnel in reservation.tunnels + for tunnel in reservation.tunnels.values() ] + async def get_import_description(self, name): + async with self._reservation(name) as reservation: + return reservation.import_description_toml + + async def update_import_description(self, name, import_description_toml): + async with self._reservation(name) as reservation: + parsed = util.toml_loads(import_description_toml) + import_description = models.ImportDesc(name=name, **parsed) + + imported_part_sets = [ + (name, _part_to_set(imported_part)) + for name, imported_part in import_description.parts.items() + ] + + matching = _find_matching(imported_part_sets, reservation.place) + if not matching: + raise RuntimeError("New import description doesn't match with place") + + new_tunnel_descs = _create_tunnel_descriptions( + import_description, reservation.place, matching + ) + + old_tunnel_descs = reservation.tunnels.keys() + + to_remove = old_tunnel_descs - new_tunnel_descs + to_add = new_tunnel_descs - old_tunnel_descs + to_keep = old_tunnel_descs & new_tunnel_descs + + new_tunnels = { + desc: desc.tunnel_cls(desc, self._hub_host, self._io) for desc in to_add + } + for desc in to_keep: + new_tunnels[desc] = reservation.tunnels[desc] + + if reservation.is_attached: + # close removed tunnels + removed_tunnels = [ + t for desc, t in reservation.tunnels.items() if desc in to_remove + ] + coros = [t.close() for t in removed_tunnels] + await util.run_concurrently(*coros) + + async def restore_removed(): + coros = [t.open() for t in removed_tunnels] + await util.run_concurrently(*coros) + + async with util.on_error(restore_removed): + # open added tunnels + coros = [ + t.open() for desc, t in new_tunnels.items() if desc in to_add + ] + await util.run_concurrently(*coros) + + # everything ok: update reservation + reservation.import_description_toml = import_description_toml + reservation.tunnels = new_tunnels + def _filter_places(import_description, places): candidates = {} @@ -360,7 +432,7 @@ class _UsbTunnel(_Tunnel): async def close(self): await super().close() if self._vhci_port is not None: - self._io.usbip_detach(self._vhci_port) + await self._io.usbip_detach(self._vhci_port) async def _task_func(self): retry_timeout = 1 @@ -420,9 +492,10 @@ class _TcpTunnelDesc(_TunnelDesc): @dataclass class _Reservation: - place_id: int + import_description_toml: str + place: models.Place is_attached: bool = field(default=False, init=False) - tunnels: List[_Tunnel] + tunnels: Mapping[_TunnelDesc, _Tunnel] class ProtocolError(Exception): diff --git a/not_my_board/_client.py b/not_my_board/_client.py index 00e4da3..1a5644c 100644 --- a/not_my_board/_client.py +++ b/not_my_board/_client.py @@ -5,6 +5,7 @@ import logging import os import pathlib +import tempfile import not_my_board._jsonrpc as jsonrpc @@ -62,6 +63,38 @@ async def status(): return await agent.status() +async def edit(name): + async with agent_channel() as agent: + import_description_toml = await agent.get_import_description(name) + new_content = None + + try: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".toml", delete_on_close=False + ) as file: + file.write(import_description_toml) + file.close() + + editor = os.environ.get("VISUAL") or os.environ.get("EDITOR") or "vi" + proc = await asyncio.create_subprocess_exec(editor, file.name) + await proc.wait() + + new_content = pathlib.Path(file.name).read_text() + + if proc.returncode: + raise RuntimeError(f"{editor!r} exited with {proc.returncode}") + + await agent.update_import_description(name, new_content) + except Exception as e: + if new_content is not None: + message = ( + "Failed to edit, here is your changed import description for reference:\n" + + new_content.rstrip() + ) + raise RuntimeError(message) from e + raise + + async def uevent(devpath): # devpath has a leading "/", so joining with the / operator doesn't # work diff --git a/not_my_board/cli/__init__.py b/not_my_board/cli/__init__.py index d729b2f..d91ec6f 100644 --- a/not_my_board/cli/__init__.py +++ b/not_my_board/cli/__init__.py @@ -105,7 +105,7 @@ def add_cacert_arg(subparser): subparser.add_argument( "-k", "--keep", action="store_true", help="don't return reservation" ) - subparser.add_argument("name", help="name of the place to attach") + subparser.add_argument("name", help="name of the place to detach") subparser = add_subcommand("list", help="list reserved places") add_verbose_arg(subparser) @@ -125,11 +125,15 @@ def add_cacert_arg(subparser): add_verbose_arg(subparser) subparser.add_argument("devpath", help="devpath attribute of uevent") - subparser = add_subcommand("login", help="Log in to a hub") + subparser = add_subcommand("login", help="log in to a hub") add_verbose_arg(subparser) add_cacert_arg(subparser) subparser.add_argument("hub_url", help="http(s) URL of the hub") + subparser = add_subcommand("edit", help="edit a reserved place") + add_verbose_arg(subparser) + subparser.add_argument("name", help="name of the place to edit") + args = parser.parse_args() # Don't use escape sequences, if stdout is not a tty @@ -272,6 +276,10 @@ async def _login_command(args): print(f"{Format.BOLD}{key}: {Format.RESET}{value}") +async def _edit_command(args): + await client.edit(args.name) + + class Format: RESET = "\033[0m" BOLD = "\033[1m" diff --git a/tests/test_agent.py b/tests/test_agent.py index 879023a..ffa7c81 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -163,7 +163,7 @@ async def usbip_attach(self, proxy, target, port_num, usbid): self.detach_event[port_num] = asyncio.Event() return port_num - def usbip_detach(self, vhci_port): + async def usbip_detach(self, vhci_port): if vhci_port in self.attached: del self.attached[vhci_port] self.detach_event[vhci_port].set()