diff --git a/meson.build b/meson.build index 31ce30f..8cdce4f 100644 --- a/meson.build +++ b/meson.build @@ -34,6 +34,13 @@ py.install_sources( subdir: 'not_my_board/_jsonrpc', ) +py.install_sources( + 'not_my_board/_auth/__init__.py', + 'not_my_board/_auth/_openid.py', + 'not_my_board/_auth/_login.py', + subdir: 'not_my_board/_auth', +) + py.install_sources( 'not_my_board/cli/__init__.py', subdir: 'not_my_board/cli', diff --git a/not_my_board/_auth/__init__.py b/not_my_board/_auth/__init__.py new file mode 100644 index 0000000..f91b49e --- /dev/null +++ b/not_my_board/_auth/__init__.py @@ -0,0 +1 @@ +from ._login import LoginFlow, get_id_token diff --git a/not_my_board/_auth/_login.py b/not_my_board/_auth/_login.py new file mode 100644 index 0000000..1d7f14d --- /dev/null +++ b/not_my_board/_auth/_login.py @@ -0,0 +1,121 @@ +import asyncio +import json +import os +import pathlib + +import not_my_board._jsonrpc as jsonrpc +import not_my_board._util as util + +from ._openid import AuthRequest, ensure_fresh + + +class LoginFlow(util.ContextStack): + def __init__(self, hub_url, http_client): + self._hub_url = hub_url + self._http = http_client + self._show_claims = [] + self._token_store = _TokenStore() + + async def _context_stack(self, stack): + url = f"{self._hub_url}/api/v1/auth-info" + auth_info = await self._http.get_json(url) + redirect_uri = f"{self._hub_url}/oidc-callback" + + self._request = await AuthRequest.create( + auth_info["issuer"], auth_info["client_id"], redirect_uri, self._http + ) + + ready_event = asyncio.Event() + notification_api = _HubNotifications(ready_event) + + channel_url = f"{self._hub_url}/ws-login" + hub = jsonrpc.WebsocketChannel( + channel_url, self._http, api_obj=notification_api + ) + self._hub = await stack.enter_async_context(hub) + + coro = self._hub.get_authentication_response(self._request.state) + self._auth_response_task = await stack.enter_async_context( + util.background_task(coro) + ) + + await ready_event.wait() + + self._show_claims = auth_info.get("show_claims") + + async def finish(self): + auth_response = await self._auth_response_task + id_token, refresh_token, claims = await self._request.request_tokens( + auth_response, self._http + ) + + async with _TokenStore() as token_store: + token_store.save_tokens(self._hub_url, id_token, refresh_token) + + if self._show_claims: + # filter claims to only show relevant ones + return {k: v for k, v in claims.items() if k in self._show_claims} + else: + return claims + + @property + def login_url(self): + return self._request.login_url + + +class _HubNotifications: + def __init__(self, ready_event): + self._ready_event = ready_event + + async def oidc_callback_registered(self): + self._ready_event.set() + + +async def get_id_token(hub_url, http_client): + async with _TokenStore() as token_store: + id_token, refresh_token = token_store.get_tokens(hub_url) + id_token, refresh_token = await ensure_fresh( + id_token, refresh_token, http_client + ) + token_store.save_tokens(hub_url, id_token, refresh_token) + + return id_token + + +class _TokenStore(util.ContextStack): + _path = pathlib.Path("/var/lib/not-my-board/auth_tokens.json") + + def __init__(self): + if not self._path.exists(): + self._path.parent.mkdir(parents=True, exist_ok=True) + self._path.touch(mode=0o600) + + if not os.access(self._path, os.R_OK | os.W_OK): + raise RuntimeError(f"Not allowed to access {self._path}") + + async def _context_stack(self, stack): + # pylint: disable-next=consider-using-with # false positive + self._f = stack.enter_context(self._path.open("r+")) + await stack.enter_async_context(util.flock(self._f)) + content = self._f.read() + self._hub_tokens_map = json.loads(content) if content else {} + + def get_tokens(self, hub_url): + if hub_url not in self._hub_tokens_map: + raise RuntimeError("Login required") + + tokens = self._hub_tokens_map[hub_url] + return tokens["id"], tokens["refresh"] + + def save_tokens(self, hub_url, id_token, refresh_token): + new_tokens = { + "id": id_token, + "refresh": refresh_token, + } + old_tokens = self._hub_tokens_map.get(hub_url) + + if old_tokens != new_tokens: + self._hub_tokens_map[hub_url] = new_tokens + self._f.seek(0) + self._f.truncate() + self._f.write(json.dumps(self._hub_tokens_map)) diff --git a/not_my_board/_auth/_openid.py b/not_my_board/_auth/_openid.py new file mode 100644 index 0000000..e61cb51 --- /dev/null +++ b/not_my_board/_auth/_openid.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +import base64 +import dataclasses +import hashlib +import secrets +import urllib.parse + +import jwt + + +@dataclasses.dataclass +class IdentityProvider: + issuer: str + authorization_endpoint: str + token_endpoint: str + jwks_uri: str + + @classmethod + async def from_url(cls, issuer_url, http_client): + config_url = urllib.parse.urljoin( + f"{issuer_url}/", ".well-known/openid-configuration" + ) + config = await http_client.get_json(config_url) + + init_args = { + field.name: config[field.name] for field in dataclasses.fields(cls) + } + return cls(**init_args) + + +@dataclasses.dataclass +class AuthRequest: + client_id: str + redirect_uri: str + state: str + nonce: str + code_verifier: str + identity_provider: IdentityProvider + + @classmethod + async def create(cls, issuer_url, client_id, redirect_uri, http_client): + identity_provider = await IdentityProvider.from_url(issuer_url, http_client) + state = secrets.token_urlsafe() + nonce = secrets.token_urlsafe() + code_verifier = secrets.token_urlsafe() + + return cls( + client_id, redirect_uri, state, nonce, code_verifier, identity_provider + ) + + @property + def login_url(self): + hashed = hashlib.sha256(self.code_verifier.encode()).digest() + code_challange = base64.urlsafe_b64encode(hashed).rstrip(b"=").decode("ascii") + + auth_params = { + "scope": "openid profile offline_access", + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "state": self.state, + "nonce": self.nonce, + "prompt": "consent", + "code_challenge": code_challange, + "code_challenge_method": "S256", + } + + url_parts = list( + urllib.parse.urlparse(self.identity_provider.authorization_endpoint) + ) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update(auth_params) + + url_parts[4] = urllib.parse.urlencode(query) + + return urllib.parse.urlunparse(url_parts) + + async def request_tokens(self, auth_response, http_client): + if "error" in auth_response: + if "error_description" in auth_response: + msg = f'{auth_response["error_description"]} ({auth_response["error"]})' + else: + msg = auth_response["error"] + + raise RuntimeError(f"Authentication error: {msg}") + + url = self.identity_provider.token_endpoint + params = { + "grant_type": "authorization_code", + "code": auth_response["code"], + "redirect_uri": self.redirect_uri, + "client_id": self.client_id, + "code_verifier": self.code_verifier, + } + response = await http_client.post_form(url, params) + + if response["token_type"].lower() != "bearer": + raise RuntimeError( + f'Expected token type "Bearer", got "{response["token_type"]}"' + ) + + claims = await verify(response["id_token"], self.client_id, http_client) + if claims["nonce"] != self.nonce: + raise RuntimeError( + "Nonce in the ID token doesn't match the one in the authorization request" + ) + + return response["id_token"], response["refresh_token"], claims + + +async def ensure_fresh(id_token, refresh_token, http_client): + if _needs_refresh(id_token): + claims = jwt.decode(id_token, options={"verify_signature": False}) + issuer_url = claims["iss"] + client_id = claims["aud"] + identity_provider = await IdentityProvider.from_url(issuer_url, http_client) + + params = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": client_id, + } + response = await http_client.post_form(identity_provider.token_endpoint, params) + return response["id_token"], response["refresh_token"] + else: + return id_token, refresh_token + + +def _needs_refresh(id_token): + try: + jwt.decode( + id_token, + options={ + "verify_signature": False, + "require": ["iss", "sub", "aud", "exp", "iat"], + "verify_exp": True, + "verify_iat": True, + "verify_nbf": True, + }, + ) + except Exception: + return True + return False + + +async def verify(token, client_id, http_client): + unverified_token = jwt.api_jwt.decode_complete( + token, options={"verify_signature": False} + ) + kid = unverified_token["header"]["kid"] + issuer = unverified_token["payload"]["iss"] + + identity_provider = await IdentityProvider.from_url(issuer, http_client) + jwk_set_raw = await http_client.get_json(identity_provider.jwks_uri) + jwk_set = jwt.PyJWKSet.from_dict(jwk_set_raw) + + for key in jwk_set.keys: + if key.public_key_use in ["sig", None] and key.key_id == kid: + signing_key = key + break + else: + raise RuntimeError(f'Unable to find a signing key that matches "{kid}"') + + return jwt.decode( + token, + key=signing_key.key, + algorithms="RS256", + audience=client_id, + ) diff --git a/not_my_board/_hub.py b/not_my_board/_hub.py index ecb58e2..5252d3f 100644 --- a/not_my_board/_hub.py +++ b/not_my_board/_hub.py @@ -48,8 +48,8 @@ async def _handle_lifespan(scope, receive, send): else: config = {} - hub = Hub() - await hub.startup(config) + hub = Hub(config) + await hub.startup() scope["state"]["hub"] = hub except Exception as err: await send({"type": "lifespan.startup.failed", "message": str(err)}) @@ -70,18 +70,26 @@ async def _handle_lifespan(scope, receive, send): @asgineer.to_asgi async def _handle_request(request): hub = request.scope["state"]["hub"] + response = (404, {}, "Page not found") if isinstance(request, asgineer.WebsocketRequest): if request.path == "/ws-agent": return await _handle_agent(hub, request) elif request.path == "/ws-exporter": return await _handle_exporter(hub, request) - await request.close() - return + elif request.path == "/ws-login": + await _handle_login(hub, request) + else: + await request.close() + response = None elif isinstance(request, asgineer.HttpRequest): if request.path == "/api/v1/places": - return await hub.get_places() - return 404, {}, "Page not found" + response = await hub.get_places() + elif request.path == "/api/v1/auth-info": + response = hub.auth_info() + elif request.path == "/oidc-callback": + response = await hub.oidc_callback(request.querydict) + return response async def _handle_agent(hub, ws): @@ -98,6 +106,13 @@ async def _handle_exporter(hub, ws): await hub.exporter_communicate(client_ip, exporter) +async def _handle_login(hub, ws): + await ws.accept() + client_ip = ws.scope["client"][0] + channel = jsonrpc.Channel(ws.send, ws.receive_iter()) + await hub.login_communicate(client_ip, channel) + + async def _authorize_ws(ws): try: auth = ws.headers["authorization"] @@ -120,11 +135,12 @@ class Hub: _available = set() _wait_queue = [] _reservations = {} + _pending_callbacks = {} - def __init__(self): - self._id_generator = itertools.count(start=1) + def __init__(self, config=None): + if config is None: + config = {} - async def startup(self, config): if "log_level" in config: log_level_str = config["log_level"] log_level_map = { @@ -139,6 +155,13 @@ async def startup(self, config): format="%(levelname)s: %(name)s: %(message)s", level=log_level ) + self._config = config + + self._id_generator = itertools.count(start=1) + + async def startup(self): + pass + async def shutdown(self): pass @@ -161,6 +184,12 @@ async def exporter_communicate(self, client_ip, rpc): with self._register_place(export_desc, rpc, client_ip): await com_task + @jsonrpc.hidden + async def login_communicate(self, client_ip, rpc): + client_ip_var.set(client_ip) + rpc.set_api_object(self) + await rpc.communicate_forever() + @contextlib.contextmanager def _register_place(self, export_desc, rpc, client_ip): id_ = next(self._id_generator) @@ -250,6 +279,29 @@ async def return_reservation(self, place_id): else: logger.info("Place returned, but it doesn't exist: %d", place_id) + async def get_authentication_response(self, state): + future = asyncio.get_running_loop().create_future() + self._pending_callbacks[state] = future + try: + channel = jsonrpc.get_current_channel() + await channel.oidc_callback_registered(_notification=True) + return await future + finally: + del self._pending_callbacks[state] + + @jsonrpc.hidden + async def oidc_callback(self, query): + future = self._pending_callbacks[query["state"]] + if not future.done(): + future.set_result(query) + + return "Continue in not-my-board CLI" + + @jsonrpc.hidden + def auth_info(self): + # TODO check and filter config + return self._config.get("auth_info", {}) + def _unmap_ip(ip_str): """Resolve IPv4-mapped-on-IPv6 to an IPv4 address""" diff --git a/not_my_board/cli/__init__.py b/not_my_board/cli/__init__.py index 4ede0e9..37765c5 100644 --- a/not_my_board/cli/__init__.py +++ b/not_my_board/cli/__init__.py @@ -5,7 +5,9 @@ import pathlib import sys +import not_my_board._auth as auth import not_my_board._client as client +import not_my_board._http as http import not_my_board._util as util from not_my_board._agent import agent from not_my_board._export import export @@ -114,6 +116,11 @@ def add_cacert_arg(subparser): add_verbose_arg(subparser) subparser.add_argument("devpath", help="devpath attribute of uevent") + subparser = add_subcommand("login", help="Log in to a hub") + add_verbose_arg(subparser) + add_cacert_arg(subparser) + subparser.add_argument("hub_url", help="http(s) URL of the hub") + args = parser.parse_args() # Don't use escape sequences, if stdout is not a tty @@ -206,6 +213,25 @@ async def _uevent_command(args): await client.uevent(args.devpath) +async def _login_command(args): + http_client = http.Client(args.cacert) + async with auth.LoginFlow(args.hub_url, http_client) as login: + print( + f"{Format.BOLD}" + "Open the following link in your browser and log in:" + f"{Format.RESET}" + ) + print(login.login_url) + claims = await login.finish() + print( + f"{Format.GREEN}{Format.BOLD}" + "Login was successful, your token has the following claims:" + f"{Format.RESET}" + ) + for key, value in claims.items(): + print(f"{Format.BOLD}{key}: {Format.RESET}{value}") + + class Format: RESET = "\033[0m" BOLD = "\033[1m" diff --git a/pyproject.toml b/pyproject.toml index 3823f10..0639c7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "async-timeout; python_version < '3.11'", "h11", "pydantic ~= 1.10", + "pyjwt[crypto]", "tomli; python_version < '3.11'", "typing_extensions; python_version < '3.9'", "uvicorn",