diff --git a/not_my_board/_agent.py b/not_my_board/_agent.py index de17e5d..66e2649 100644 --- a/not_my_board/_agent.py +++ b/not_my_board/_agent.py @@ -26,16 +26,17 @@ Address = Tuple[str, int] -async def agent(hub_url, ca_files): - io = _AgentIO(hub_url, http.Client(ca_files)) +async def agent(hub_url, ca_files, token_store_path): + io = _AgentIO(hub_url, http.Client(ca_files), token_store_path) async with Agent(hub_url, io) as agent_: await agent_.serve_forever() class _AgentIO: - def __init__(self, hub_url, http_client): + def __init__(self, hub_url, http_client, token_store_path): self._hub_url = hub_url self._http = http_client + self._token_store_path = token_store_path @contextlib.asynccontextmanager async def hub_rpc(self): @@ -111,7 +112,9 @@ async def _handle_port_forward_client(self, proxy, target, client_r, client_w): await util.relay_streams(client_r, client_w, remote_r, remote_w) async def get_id_token(self): - return await auth.get_id_token(self._hub_url, self._http) + return await auth.get_id_token( + self._token_store_path, self._hub_url, self._http + ) class Agent(util.ContextStack): diff --git a/not_my_board/_auth/_login.py b/not_my_board/_auth/_login.py index d24fa5a..c3e3618 100644 --- a/not_my_board/_auth/_login.py +++ b/not_my_board/_auth/_login.py @@ -10,11 +10,11 @@ class LoginFlow(util.ContextStack): - def __init__(self, hub_url, http_client): + def __init__(self, hub_url, http_client, token_store_path): self._hub_url = hub_url self._http = http_client + self._token_store = _TokenStore(token_store_path) self._show_claims = [] - self._token_store = _TokenStore() async def _context_stack(self, stack): url = f"{self._hub_url}/api/v1/auth-info" @@ -49,8 +49,8 @@ async def finish(self): auth_response, self._http ) - async with _TokenStore() as token_store: - token_store.save_tokens(self._hub_url, id_token, refresh_token) + async with self._token_store: + self._token_store.save_tokens(self._hub_url, id_token, refresh_token) if self._show_claims: # filter claims to only show relevant ones @@ -71,8 +71,9 @@ async def oidc_callback_registered(self): self._ready_event.set() -async def get_id_token(hub_url, http_client): - async with _TokenStore() as token_store: +async def get_id_token(token_store_path, hub_url, http_client): + token_store = _TokenStore(token_store_path) + async with token_store: id_token, refresh_token = token_store.get_tokens(hub_url) id_token, refresh_token = await ensure_fresh( id_token, refresh_token, http_client @@ -83,15 +84,17 @@ async def get_id_token(hub_url, http_client): class _TokenStore(util.ContextStack): - _path = pathlib.Path("/var/lib/not-my-board/auth_tokens.json") + def __init__(self, path_str=None): + path = pathlib.Path(path_str) - def __init__(self): - if not self._path.exists(): - self._path.parent.mkdir(parents=True, exist_ok=True) - self._path.touch(mode=0o600) + if not path.exists(): + path.parent.mkdir(parents=True, exist_ok=True) + path.touch(mode=0o600) - if not os.access(self._path, os.R_OK | os.W_OK): - raise RuntimeError(f"Not allowed to access {self._path}") + if not os.access(path, os.R_OK | os.W_OK): + raise RuntimeError(f"Not allowed to access {path}") + + self._path = path async def _context_stack(self, stack): # pylint: disable-next=consider-using-with # false positive diff --git a/not_my_board/_export.py b/not_my_board/_export.py index 3ecdd54..a077f9d 100644 --- a/not_my_board/_export.py +++ b/not_my_board/_export.py @@ -18,20 +18,21 @@ logger = logging.getLogger(__name__) -async def export(hub_url, place, ca_files): +async def export(hub_url, place, ca_files, token_store_path): http_client = http.Client(ca_files) - async with Exporter(hub_url, place, http_client) as exporter: + async with Exporter(hub_url, place, http_client, token_store_path) as exporter: await exporter.register_place() await exporter.serve_forever() class Exporter(util.ContextStack): - def __init__(self, hub_url, export_desc_path, http_client): + def __init__(self, hub_url, export_desc_path, http_client, token_store_path): self._hub_url = hub_url self._ip_to_tasks_map = {} export_desc_content = export_desc_path.read_text() self._place = models.ExportDesc(**util.toml_loads(export_desc_content)) self._http = http_client + self._token_store_path = token_store_path tcp_targets = { f"{tcp.host}:{tcp.port}".encode() @@ -126,7 +127,9 @@ async def _tunnel(self, client_r, client_w, target, trailing_data): await util.relay_streams(client_r, client_w, remote_r, remote_w) async def get_id_token(self): - return await auth.get_id_token(self._hub_url, self._http) + return await auth.get_id_token( + self._token_store_path, self._hub_url, self._http + ) def format_date_time(dt=None): diff --git a/not_my_board/cli/__init__.py b/not_my_board/cli/__init__.py index 37765c5..1f1bee6 100644 --- a/not_my_board/cli/__init__.py +++ b/not_my_board/cli/__init__.py @@ -19,6 +19,9 @@ __version__ = "dev" +TOKEN_STORE_PATH = "/var/lib/not-my-board/auth_tokens.json" + + # pylint: disable=too-many-statements def main(): parser = argparse.ArgumentParser(description="Setup, manage and use a board farm") @@ -155,11 +158,11 @@ def _hub_command(_): async def _export_command(args): - await export(args.hub_url, args.export_description, args.cacert) + await export(args.hub_url, args.export_description, args.cacert, TOKEN_STORE_PATH) async def _agent_command(args): - await agent(args.hub_url, args.cacert) + await agent(args.hub_url, args.cacert, TOKEN_STORE_PATH) async def _reserve_command(args): @@ -215,7 +218,8 @@ async def _uevent_command(args): async def _login_command(args): http_client = http.Client(args.cacert) - async with auth.LoginFlow(args.hub_url, http_client) as login: + token_store_path = "/var/lib/not-my-board/auth_tokens.json" + async with auth.LoginFlow(args.hub_url, http_client, token_store_path) as login: print( f"{Format.BOLD}" "Open the following link in your browser and log in:"