Skip to content

Refactor Hub to drop the Place class #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 74 additions & 105 deletions not_my_board/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@

import asyncio
import contextlib
import contextvars
import itertools
import logging
import random
import traceback

import asgineer

import not_my_board._jsonrpc as jsonrpc
import not_my_board._models as models
import not_my_board._util as util

logger = logging.getLogger(__name__)
client_ip_var = contextvars.ContextVar("client_ip")
reservation_context_var = contextvars.ContextVar("reservation_context")
valid_tokens = ("dummy-token-1", "dummy-token-2")


Expand Down Expand Up @@ -65,158 +70,122 @@ async def _authorize_ws(ws):


class Hub:
_places = {}
_exporters = {}
_available = set()
_wait_queue = []
_reservations = {}

def __init__(self):
self._id_generator = itertools.count(start=1)

async def get_places(self):
return {"places": [p.desc for p in Place.all()]}
return {"places": [p.dict() for p in self._places.values()]}

async def agent_communicate(self, client_ip, rpc):
async with Place.reservation_context(client_ip) as ctx:
api = AgentAPI(ctx)
rpc.set_api_object(api)
client_ip_var.set(client_ip)
async with self._register_agent():
rpc.set_api_object(self)
await rpc.serve_forever()

async def exporter_communicate(self, client_ip, rpc):
client_ip_var.set(client_ip)
async with util.background_task(rpc.io_loop()) as io_loop:
place = await rpc.get_place()
with Place.register(place, rpc, client_ip):
export_desc = await rpc.get_place()
with self._register_place(export_desc, rpc, client_ip):
await io_loop


_hub = Hub()


class AgentAPI:
def __init__(self, reservation_context):
self._reservation_context = reservation_context

async def reserve(self, candidate_ids):
place = await Place.reserve(candidate_ids, self._reservation_context)
return place.desc["id"]

async def return_reservation(self, place_id):
await Place.return_by_id(place_id, self._reservation_context)


class Place:
_all_places = {}
_next_id = 1
_available = set()
_wait_queue = []
_reservations = {}

@classmethod
def all(cls):
return cls._all_places.values()

@classmethod
def _new_id(cls):
id_ = cls._next_id
cls._next_id += 1
return id_

@classmethod
@contextlib.contextmanager
def register(cls, desc, exporter, client_ip):
self = cls()
self._desc = desc
self._exporter = exporter

self._id = cls._new_id()
self._desc["id"] = self._id
self._desc["host"] = client_ip
def _register_place(self, export_desc, rpc, client_ip):
id_ = next(self._id_generator)
place = models.Place(id=id_, host=client_ip, **export_desc)

try:
logger.info("New place registered: %d", self._id)
cls._all_places[self._id] = self
cls._available.add(self._id)
logger.info("New place registered: %d", id_)
self._places[id_] = place
self._exporters[id_] = rpc
self._available.add(id_)
yield self
finally:
logger.info("Place disappeared: %d", self._id)
del cls._all_places[self._id]
cls._available.discard(self._id)
for candidates, _, future in cls._wait_queue:
candidates.discard(self._id)
logger.info("Place disappeared: %d", id_)
del self._places[id_]
del self._exporters[id_]
self._available.discard(id_)
for candidates, _, future in self._wait_queue:
candidates.discard(id_)
if not candidates and not future.done():
future.set_exception(Exception("All candidate places are gone"))

@property
def desc(self):
return self._desc

@classmethod
@contextlib.asynccontextmanager
async def reservation_context(cls, client_ip):
ctx = _ReservationContext(client_ip)
async def _register_agent(self):
ctx = object()
reservation_context_var.set(ctx)

try:
cls._reservations[ctx] = set()
yield ctx
self._reservations[ctx] = set()
yield
finally:
for place in cls._reservations[ctx].copy():
await cls.return_by_id(place, ctx)
del cls._reservations[ctx]
coros = [self.return_reservation(id_) for id_ in self._reservations[ctx]]
results = await asyncio.gather(*coros, return_exceptions=True)
del self._reservations[ctx]
for result in results:
if isinstance(result, Exception):
logger.warning("Error while deregistering agent: %s", result)

@classmethod
async def reserve(cls, candidate_ids, ctx):
existing_candidates = {id_ for id_ in candidate_ids if id_ in cls._all_places}
async def reserve(self, candidate_ids):
ctx = reservation_context_var.get()
existing_candidates = {id_ for id_ in candidate_ids if id_ in self._places}
if not existing_candidates:
raise RuntimeError("None of the candidates exist anymore")

available_candidates = existing_candidates & cls._available
available_candidates = existing_candidates & self._available
if available_candidates:
# TODO do something smart to get the best candidate
reserved_id = random.choice(list(available_candidates))

cls._available.remove(reserved_id)
cls._reservations[ctx].add(reserved_id)
self._available.remove(reserved_id)
self._reservations[ctx].add(reserved_id)
logger.info("Place reserved: %d", reserved_id)
place = cls._all_places[reserved_id]
else:
logger.debug(
"No places available, adding request to queue: %s",
str(existing_candidates),
)
future = asyncio.get_running_loop().create_future()
entry = (existing_candidates, ctx, future)
cls._wait_queue.append(entry)
self._wait_queue.append(entry)
try:
place = await future
reserved_id = await future
finally:
cls._wait_queue.remove(entry)
self._wait_queue.remove(entry)

# TODO refactor Place class
# pylint: disable=protected-access
try:
await place._exporter.set_allowed_ips([ctx.client_ip])
except Exception:
await cls.return_by_id(place._id, ctx)
raise

return place

@classmethod
async def return_by_id(cls, place_id, ctx):
cls._reservations[ctx].remove(place_id)
if place_id in cls._all_places:
for candidates, new_ctx, future in cls._wait_queue:
client_ip = client_ip_var.get()
async with util.on_error(self.return_reservation, reserved_id):
rpc = self._exporters[reserved_id]
await rpc.set_allowed_ips([client_ip])

return reserved_id

async def return_reservation(self, place_id):
ctx = reservation_context_var.get()
self._reservations[ctx].remove(place_id)
if place_id in self._places:
for candidates, new_ctx, future in self._wait_queue:
if place_id in candidates and not future.done():
cls._reservations[new_ctx].add(place_id)
self._reservations[new_ctx].add(place_id)
logger.info("Place returned and reserved again: %d", place_id)
future.set_result(cls._all_places[place_id])
future.set_result(place_id)
break
else:
logger.info("Place returned: %d", place_id)
cls._available.add(place_id)
# pylint: disable=protected-access
await cls._all_places[place_id]._exporter.set_allowed_ips([])
self._available.add(place_id)
rpc = self._exporters[place_id]
await rpc.set_allowed_ips([])
else:
logger.info("Place returned, but it doesn't exist: %d", place_id)


class _ReservationContext:
def __init__(self, client_ip):
self._client_ip = client_ip

@property
def client_ip(self):
return self._client_ip
_hub = Hub()


class ProtocolError(Exception):
Expand Down
5 changes: 4 additions & 1 deletion not_my_board/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,7 @@ class ExportDesc(pydantic.BaseModel):

class Place(ExportDesc):
id: pydantic.PositiveInt
host: pydantic.IPvAnyAddress
# host: pydantic.IPvAnyAddress
# can't serialize IP address with json.dumps()
# TODO: maybe drop pydantic as a dependency
host: str
1 change: 1 addition & 0 deletions not_my_board/_util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
background_task,
cancel_tasks,
connect,
on_error,
relay_streams,
run,
run_concurrently,
Expand Down
11 changes: 11 additions & 0 deletions not_my_board/_util/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ async def cancel_tasks(tasks):
pass


@contextlib.asynccontextmanager
async def on_error(callback, /, *args, **kwargs):
"""Calls a cleanup callback, if an exception is raised within the
context manager.
"""
async with contextlib.AsyncExitStack() as stack:
stack.push_async_callback(callback, *args, **kwargs)
yield
stack.pop_all()


@contextlib.asynccontextmanager
async def connect(*args, **kwargs):
"""Wraps `asyncio.open_connection()` in a context manager
Expand Down
42 changes: 20 additions & 22 deletions tests/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import not_my_board._hub as hubmodule
import not_my_board._jsonrpc as jsonrpc
import not_my_board._util as util

DEFAULT_EXPORTER_IP = "3.1.1.1"
Expand Down Expand Up @@ -34,7 +35,9 @@ async def get_place(self):
"port": 1234,
"parts": [
{
"compatible": "test-board",
"compatible": [
"test-board",
],
"tcp": {
"test-if": {
"host": "localhost",
Expand Down Expand Up @@ -75,35 +78,30 @@ async def test_register_exporter(hub):
assert len(places["places"]) == 0


class FakeAgent:
def __init__(self, register_event):
self._register_event = register_event
def fake_rpc_pair():
proxy_to_server = asyncio.Queue()
server_to_proxy = asyncio.Queue()

def set_api_object(self, api_obj):
self._api_obj = api_obj
self._register_event.set()
async def receive_iter(queue):
while True:
data = await queue.get()
yield data
queue.task_done()

async def serve_forever(self):
# wait forever
await asyncio.Event().wait()

def __getattr__(self, method_name):
if method_name.startswith("_"):
raise AttributeError(f"invalid attribute '{method_name}'")
return getattr(self._api_obj, method_name)
server = jsonrpc.Server(server_to_proxy.put, receive_iter(proxy_to_server))
proxy = jsonrpc.Proxy(proxy_to_server.put, receive_iter(server_to_proxy))
return server, proxy


# pylint: disable=redefined-outer-name
@contextlib.asynccontextmanager
async def register_agent(hub):
agent_ip = DEFAULT_AGENT_IP
register_event = asyncio.Event()
fake_agent = FakeAgent(register_event)
coro = hub.agent_communicate(agent_ip, fake_agent)
server, proxy = fake_rpc_pair()
coro = hub.agent_communicate(agent_ip, server)
async with util.background_task(coro):
async with asyncio.timeout(2):
await register_event.wait()
yield fake_agent
async with util.background_task(proxy.io_loop()):
yield proxy


async def test_reserve_place(hub):
Expand All @@ -119,7 +117,7 @@ async def test_reserve_place(hub):
async def test_reserve_non_existent(hub):
async with register_agent(hub) as agent:
candidate_ids = [42]
with pytest.raises(RuntimeError) as execinfo:
with pytest.raises(jsonrpc.RemoteError) as execinfo:
await agent.reserve(candidate_ids)
assert "None of the candidates exist anymore" in str(execinfo.value)

Expand Down