diff --git a/not_my_board/_agent.py b/not_my_board/_agent.py index 2a28081..2d40a44 100644 --- a/not_my_board/_agent.py +++ b/not_my_board/_agent.py @@ -2,6 +2,7 @@ import asyncio import contextlib +import functools import ipaddress import logging import os @@ -20,59 +21,112 @@ async def agent(hub_url): - async with Agent(hub_url) as a: - await a.serve_forever() + io = _AgentIO(hub_url) + async with Agent(hub_url, io) as agent_: + await agent_.serve_forever() -class Agent(util.ContextStack): +class _AgentIO: def __init__(self, hub_url): self._hub_url = hub_url - self._reserved_places = {} - self._pending = set() - url = urllib.parse.urlsplit(hub_url) - ws_scheme = "ws" if url.scheme == "http" else "wss" - self._ws_uri = f"{ws_scheme}://{url.netloc}/ws-agent" - self._hub_host = url.netloc.split(":")[0] - async def _context_stack(self, stack): + @contextlib.asynccontextmanager + async def hub_rpc(self): auth = "Bearer dummy-token-1" - self._hub = await stack.enter_async_context( - jsonrpc.WebsocketChannel(self._ws_uri, start=False, auth=auth) - ) - - stack.push_async_callback(self._cleanup) + url = f"{self._hub_url}/ws-agent" + async with jsonrpc.WebsocketChannel(url, auth=auth) as rpc: + yield rpc + @contextlib.asynccontextmanager + async def unix_server(self, api_obj): socket_path = pathlib.Path("/run") / "not-my-board-agent.sock" - self._unix_server = await stack.enter_async_context( - util.UnixServer(self._handle_client, socket_path) - ) - os.chmod(socket_path, 0o660) - try: - shutil.chown(socket_path, group="not-my-board") - except Exception as e: - logger.warning( - 'Failed to change group on agent socket "%s": %s', socket_path, e - ) - - async def _cleanup(self): - for _, place in self._reserved_places.items(): - if place.is_attached: - await place.detach() + connection_handler = functools.partial(self._handle_unix_client, api_obj) + async with util.UnixServer(connection_handler, socket_path) as unix_server: + os.chmod(socket_path, 0o660) + try: + shutil.chown(socket_path, group="not-my-board") + except Exception as e: + logger.warning( + 'Failed to change group on agent socket "%s": %s', socket_path, e + ) - @jsonrpc.hidden - async def serve_forever(self): - await util.run_concurrently( - self._unix_server.serve_forever(), self._hub.communicate_forever() - ) + yield unix_server - async def _handle_client(self, reader, writer): + @staticmethod + async def _handle_unix_client(api_obj, reader, writer): async def send(data): writer.write(data + b"\n") await writer.drain() - socket_channel = jsonrpc.Channel(send, reader, self) - await socket_channel.communicate_forever() + channel = jsonrpc.Channel(send, reader, api_obj) + await channel.communicate_forever() + + async def get_places(self): + response = await http.get_json(f"{self._hub_url}/api/v1/places") + return [models.Place(**p) for p in response["places"]] + + @staticmethod + async def usbip_refresh_status(): + await usbip.refresh_vhci_status() + + @staticmethod + def usbip_is_attached(vhci_port): + return usbip.is_attached(vhci_port) + + @staticmethod + async def usbip_attach(proxy, target, port_num, usbid): + tunnel = http.open_tunnel(*proxy, *target) + async with tunnel as (reader, writer, trailing_data): + if trailing_data: + raise ProtocolError("USB/IP implementation cannot handle trailing data") + return await usbip.attach(reader, writer, usbid, port_num) + + @staticmethod + def usbip_detach(vhci_port): + usbip.detach(vhci_port) + + async def port_forward(self, ready_event, proxy, target, local_port): + localhost = "127.0.0.1" + connection_handler = functools.partial( + self._handle_port_forward_client, proxy, target + ) + async with util.Server(connection_handler, localhost, local_port) as server: + ready_event.set() + await server.serve_forever() + + @staticmethod + async def _handle_port_forward_client(proxy, target, client_r, client_w): + async with http.open_tunnel(*proxy, *target) as ( + remote_r, + remote_w, + trailing_data, + ): + client_w.write(trailing_data) + await client_w.drain() + await util.relay_streams(client_r, client_w, remote_r, remote_w) + + +class Agent(util.ContextStack): + def __init__(self, hub_url, io): + url = urllib.parse.urlsplit(hub_url) + self._hub_host = url.netloc.split(":")[0] + self._io = io + self._reserved_places = {} + self._pending = set() + + async def _context_stack(self, stack): + self._hub = await stack.enter_async_context(self._io.hub_rpc()) + stack.push_async_callback(self._cleanup) + self._unix_server = await stack.enter_async_context(self._io.unix_server(self)) + + async def serve_forever(self): + await self._unix_server.serve_forever() + + async def _cleanup(self): + for _, place in self._reserved_places.items(): + if place.is_attached: + await place.detach() async def reserve(self, import_description): import_description = models.ImportDesc(**import_description) @@ -86,8 +140,7 @@ async def reserve(self, import_description): self._pending.add(name) try: - response = await http.get_json(f"{self._hub_url}/api/v1/places") - places = [models.Place(**p) for p in response["places"]] + places = await self._io.get_places() candidates = self._filter_places(import_description, places) candidate_ids = list(candidates) @@ -140,7 +193,7 @@ async def list(self): ] async def status(self): - await usbip.refresh_vhci_status() + await self._io.usbip_refresh_status() return [ {"place": name, **status} for name, place in self._reserved_places.items() @@ -160,7 +213,7 @@ def _filter_places(self, import_description, places): if matching: real_host = self._real_host(place.host) reserved_places[place.id] = ReservedPlace( - import_description, place, real_host, matching + import_description, place, real_host, matching, self._io ) return reserved_places @@ -208,7 +261,7 @@ def _part_to_set(part): class ReservedPlace: - def __init__(self, import_description, place, real_host, matching): + def __init__(self, import_description, place, real_host, matching, io): self._import_description = import_description self._place = place self._tunnels = [] @@ -223,6 +276,7 @@ def __init__(self, import_description, place, real_host, matching): for usb_name, usb_import_description in imported_part.usb.items(): self._tunnels.append( UsbTunnel( + io, part_name=name, iface_name=usb_name, proxy=proxy, @@ -234,6 +288,7 @@ def __init__(self, import_description, place, real_host, matching): for tcp_name, tcp_import_description in imported_part.tcp.items(): self._tunnels.append( TcpTunnel( + io, part_name=name, iface_name=tcp_name, proxy=proxy, @@ -292,7 +347,8 @@ class UsbTunnel(util.ContextStack): _target = "usb.not-my-board.localhost", 3240 _ready_timeout = 5 - def __init__(self, part_name, iface_name, proxy, usbid, port_num): + def __init__(self, io, part_name, iface_name, proxy, usbid, port_num): + self._io = io self._part_name = part_name self._iface_name = iface_name self._name = f"{part_name}.{iface_name}" @@ -318,7 +374,10 @@ async def _tunnel_task(self, ready_event): try: while True: try: - await self._attach() + self._vhci_port = await self._io.usbip_attach( + self._proxy, self._target, self._port_num, self._usbid + ) + logger.debug("%s: USB device attached", self._name) ready_event.set() retry_timeout = 1 except Exception: @@ -327,19 +386,9 @@ async def _tunnel_task(self, ready_event): retry_timeout = min(2 * retry_timeout, 30) finally: if self._vhci_port is not None: - usbip.detach(self._vhci_port) + self._io.usbip_detach(self._vhci_port) logger.debug("%s: USB device detached", self._name) - async def _attach(self): - tunnel = http.open_tunnel(*self._proxy, *self._target) - async with tunnel as (reader, writer, trailing_data): - if trailing_data: - raise ProtocolError("USB/IP implementation cannot handle trailing data") - self._vhci_port = await usbip.attach( - reader, writer, self._usbid, self._port_num - ) - logger.debug("%s: USB device attached", self._name) - @property def part_name(self): return self._part_name @@ -355,12 +404,15 @@ def type_name(self): @property def attached(self): return ( - usbip.is_attached(self._vhci_port) if self._vhci_port is not None else False + self._io.usbip_is_attached(self._vhci_port) + if self._vhci_port is not None + else False ) class TcpTunnel(util.ContextStack): - def __init__(self, part_name, iface_name, proxy, remote, local_port): + def __init__(self, io, part_name, iface_name, proxy, remote, local_port): + self._io = io self._part_name = part_name self._iface_name = iface_name self._name = f"{part_name}.{iface_name}" @@ -371,9 +423,10 @@ def __init__(self, part_name, iface_name, proxy, remote, local_port): async def _context_stack(self, stack): ready_event = asyncio.Event() - await stack.enter_async_context( - util.background_task(self._tunnel_task(ready_event)) + coro = self._io.port_forward( + ready_event, self._proxy, self._remote, self._local_port ) + await stack.enter_async_context(util.background_task(coro)) await ready_event.wait() self._is_attached = True @@ -381,26 +434,6 @@ async def __aexit__(self, exc_type, exc, tb): super().__aexit__(exc_type, exc, tb) self._is_attached = False - async def _tunnel_task(self, ready_event): - localhost = "127.0.0.1" - async with util.Server( - self._handle_client, localhost, self._local_port - ) as server: - ready_event.set() - await server.serve_forever() - - async def _handle_client(self, client_r, client_w): - logger.debug("%s: Opening tunnel", self._name) - async with http.open_tunnel(*self._proxy, *self._remote) as ( - remote_r, - remote_w, - trailing_data, - ): - logger.debug("%s: Tunnel created, relaying data", self._name) - client_w.write(trailing_data) - await client_w.drain() - await util.relay_streams(client_r, client_w, remote_r, remote_w) - @property def part_name(self): return self._part_name