Skip to content

Commit

Permalink
Agent: Refactor to Remove IO from Agent Class
Browse files Browse the repository at this point in the history
Without IO the Agent class can be tested with unit tests.
  • Loading branch information
holesch committed Apr 5, 2024
1 parent 8c236cf commit e646c42
Showing 1 changed file with 113 additions and 80 deletions.
193 changes: 113 additions & 80 deletions not_my_board/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import contextlib
import functools
import ipaddress
import logging
import os
Expand All @@ -20,59 +21,112 @@


async def agent(hub_url):
async with Agent(hub_url) as a:
await a.serve_forever()
io = _AgentIO(hub_url)
async with Agent(hub_url, io) as agent_:
await agent_.serve_forever()


class Agent(util.ContextStack):
class _AgentIO:
def __init__(self, hub_url):
self._hub_url = hub_url
self._reserved_places = {}
self._pending = set()
url = urllib.parse.urlsplit(hub_url)
ws_scheme = "ws" if url.scheme == "http" else "wss"
self._ws_uri = f"{ws_scheme}://{url.netloc}/ws-agent"
self._hub_host = url.netloc.split(":")[0]

async def _context_stack(self, stack):
@contextlib.asynccontextmanager
async def hub_rpc(self):
auth = "Bearer dummy-token-1"
self._hub = await stack.enter_async_context(
jsonrpc.WebsocketChannel(self._ws_uri, start=False, auth=auth)
)

stack.push_async_callback(self._cleanup)
url = f"{self._hub_url}/ws-agent"
async with jsonrpc.WebsocketChannel(url, auth=auth) as rpc:
yield rpc

@contextlib.asynccontextmanager
async def unix_server(self, api_obj):
socket_path = pathlib.Path("/run") / "not-my-board-agent.sock"
self._unix_server = await stack.enter_async_context(
util.UnixServer(self._handle_client, socket_path)
)

os.chmod(socket_path, 0o660)
try:
shutil.chown(socket_path, group="not-my-board")
except Exception as e:
logger.warning(
'Failed to change group on agent socket "%s": %s', socket_path, e
)

async def _cleanup(self):
for _, place in self._reserved_places.items():
if place.is_attached:
await place.detach()
connection_handler = functools.partial(self._handle_unix_client, api_obj)
async with util.UnixServer(connection_handler, socket_path) as unix_server:
os.chmod(socket_path, 0o660)
try:
shutil.chown(socket_path, group="not-my-board")
except Exception as e:
logger.warning(
'Failed to change group on agent socket "%s": %s', socket_path, e
)

@jsonrpc.hidden
async def serve_forever(self):
await util.run_concurrently(
self._unix_server.serve_forever(), self._hub.communicate_forever()
)
yield unix_server

async def _handle_client(self, reader, writer):
@staticmethod
async def _handle_unix_client(api_obj, reader, writer):
async def send(data):
writer.write(data + b"\n")
await writer.drain()

socket_channel = jsonrpc.Channel(send, reader, self)
await socket_channel.communicate_forever()
channel = jsonrpc.Channel(send, reader, api_obj)
await channel.communicate_forever()

async def get_places(self):
response = await http.get_json(f"{self._hub_url}/api/v1/places")
return [models.Place(**p) for p in response["places"]]

@staticmethod
async def usbip_refresh_status():
await usbip.refresh_vhci_status()

@staticmethod
def usbip_is_attached(vhci_port):
return usbip.is_attached(vhci_port)

@staticmethod
async def usbip_attach(proxy, target, port_num, usbid):
tunnel = http.open_tunnel(*proxy, *target)
async with tunnel as (reader, writer, trailing_data):
if trailing_data:
raise ProtocolError("USB/IP implementation cannot handle trailing data")
return await usbip.attach(reader, writer, usbid, port_num)

@staticmethod
def usbip_detach(vhci_port):
usbip.detach(vhci_port)

async def port_forward(self, ready_event, proxy, target, local_port):
localhost = "127.0.0.1"
connection_handler = functools.partial(
self._handle_port_forward_client, proxy, target
)
async with util.Server(connection_handler, localhost, local_port) as server:
ready_event.set()
await server.serve_forever()

@staticmethod
async def _handle_port_forward_client(proxy, target, client_r, client_w):
async with http.open_tunnel(*proxy, *target) as (
remote_r,
remote_w,
trailing_data,
):
client_w.write(trailing_data)
await client_w.drain()
await util.relay_streams(client_r, client_w, remote_r, remote_w)


class Agent(util.ContextStack):
def __init__(self, hub_url, io):
url = urllib.parse.urlsplit(hub_url)
self._hub_host = url.netloc.split(":")[0]
self._io = io
self._reserved_places = {}
self._pending = set()

async def _context_stack(self, stack):
self._hub = await stack.enter_async_context(self._io.hub_rpc())
stack.push_async_callback(self._cleanup)
self._unix_server = await stack.enter_async_context(self._io.unix_server(self))

async def serve_forever(self):
await self._unix_server.serve_forever()

async def _cleanup(self):
for _, place in self._reserved_places.items():
if place.is_attached:
await place.detach()

async def reserve(self, import_description):
import_description = models.ImportDesc(**import_description)
Expand All @@ -86,8 +140,7 @@ async def reserve(self, import_description):

self._pending.add(name)
try:
response = await http.get_json(f"{self._hub_url}/api/v1/places")
places = [models.Place(**p) for p in response["places"]]
places = await self._io.get_places()

candidates = self._filter_places(import_description, places)
candidate_ids = list(candidates)
Expand Down Expand Up @@ -140,7 +193,7 @@ async def list(self):
]

async def status(self):
await usbip.refresh_vhci_status()
await self._io.usbip_refresh_status()
return [
{"place": name, **status}
for name, place in self._reserved_places.items()
Expand All @@ -160,7 +213,7 @@ def _filter_places(self, import_description, places):
if matching:
real_host = self._real_host(place.host)
reserved_places[place.id] = ReservedPlace(
import_description, place, real_host, matching
import_description, place, real_host, matching, self._io
)

return reserved_places
Expand Down Expand Up @@ -208,7 +261,7 @@ def _part_to_set(part):


class ReservedPlace:
def __init__(self, import_description, place, real_host, matching):
def __init__(self, import_description, place, real_host, matching, io):
self._import_description = import_description
self._place = place
self._tunnels = []
Expand All @@ -223,6 +276,7 @@ def __init__(self, import_description, place, real_host, matching):
for usb_name, usb_import_description in imported_part.usb.items():
self._tunnels.append(
UsbTunnel(
io,
part_name=name,
iface_name=usb_name,
proxy=proxy,
Expand All @@ -234,6 +288,7 @@ def __init__(self, import_description, place, real_host, matching):
for tcp_name, tcp_import_description in imported_part.tcp.items():
self._tunnels.append(
TcpTunnel(
io,
part_name=name,
iface_name=tcp_name,
proxy=proxy,
Expand Down Expand Up @@ -292,7 +347,8 @@ class UsbTunnel(util.ContextStack):
_target = "usb.not-my-board.localhost", 3240
_ready_timeout = 5

def __init__(self, part_name, iface_name, proxy, usbid, port_num):
def __init__(self, io, part_name, iface_name, proxy, usbid, port_num):
self._io = io
self._part_name = part_name
self._iface_name = iface_name
self._name = f"{part_name}.{iface_name}"
Expand All @@ -318,7 +374,10 @@ async def _tunnel_task(self, ready_event):
try:
while True:
try:
await self._attach()
self._vhci_port = await self._io.usbip_attach(
self._proxy, self._target, self._port_num, self._usbid
)
logger.debug("%s: USB device attached", self._name)
ready_event.set()
retry_timeout = 1
except Exception:
Expand All @@ -327,19 +386,9 @@ async def _tunnel_task(self, ready_event):
retry_timeout = min(2 * retry_timeout, 30)
finally:
if self._vhci_port is not None:
usbip.detach(self._vhci_port)
self._io.usbip_detach(self._vhci_port)
logger.debug("%s: USB device detached", self._name)

async def _attach(self):
tunnel = http.open_tunnel(*self._proxy, *self._target)
async with tunnel as (reader, writer, trailing_data):
if trailing_data:
raise ProtocolError("USB/IP implementation cannot handle trailing data")
self._vhci_port = await usbip.attach(
reader, writer, self._usbid, self._port_num
)
logger.debug("%s: USB device attached", self._name)

@property
def part_name(self):
return self._part_name
Expand All @@ -355,12 +404,15 @@ def type_name(self):
@property
def attached(self):
return (
usbip.is_attached(self._vhci_port) if self._vhci_port is not None else False
self._io.usbip_is_attached(self._vhci_port)
if self._vhci_port is not None
else False
)


class TcpTunnel(util.ContextStack):
def __init__(self, part_name, iface_name, proxy, remote, local_port):
def __init__(self, io, part_name, iface_name, proxy, remote, local_port):
self._io = io
self._part_name = part_name
self._iface_name = iface_name
self._name = f"{part_name}.{iface_name}"
Expand All @@ -371,36 +423,17 @@ def __init__(self, part_name, iface_name, proxy, remote, local_port):

async def _context_stack(self, stack):
ready_event = asyncio.Event()
await stack.enter_async_context(
util.background_task(self._tunnel_task(ready_event))
coro = self._io.port_forward(
ready_event, self._proxy, self._remote, self._local_port
)
await stack.enter_async_context(util.background_task(coro))
await ready_event.wait()
self._is_attached = True

async def __aexit__(self, exc_type, exc, tb):
super().__aexit__(exc_type, exc, tb)
self._is_attached = False

async def _tunnel_task(self, ready_event):
localhost = "127.0.0.1"
async with util.Server(
self._handle_client, localhost, self._local_port
) as server:
ready_event.set()
await server.serve_forever()

async def _handle_client(self, client_r, client_w):
logger.debug("%s: Opening tunnel", self._name)
async with http.open_tunnel(*self._proxy, *self._remote) as (
remote_r,
remote_w,
trailing_data,
):
logger.debug("%s: Tunnel created, relaying data", self._name)
client_w.write(trailing_data)
await client_w.drain()
await util.relay_streams(client_r, client_w, remote_r, remote_w)

@property
def part_name(self):
return self._part_name
Expand Down

0 comments on commit e646c42

Please sign in to comment.