Skip to content

Commit

Permalink
[WIP]: Add OIDC Auth
Browse files Browse the repository at this point in the history
  • Loading branch information
holesch committed Mar 31, 2024
1 parent 0c8ba6e commit 51ca1d9
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 6 deletions.
7 changes: 7 additions & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ py.install_sources(
subdir: 'not_my_board/_jsonrpc',
)

py.install_sources(
'not_my_board/_auth/__init__.py',
'not_my_board/_auth/_openid.py',
'not_my_board/_auth/_login.py',
subdir: 'not_my_board/_auth',
)

py.install_sources(
'not_my_board/cli/__init__.py',
subdir: 'not_my_board/cli',
Expand Down
1 change: 1 addition & 0 deletions not_my_board/_auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._login import LoginFlow, get_id_token
91 changes: 91 additions & 0 deletions not_my_board/_auth/_login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import asyncio
import json
import os
import pathlib

import not_my_board._jsonrpc as jsonrpc
import not_my_board._util as util

from ._openid import Client, IdentityProvider, ensure_fresh

state_home = pathlib.Path(
os.environ.get("XDG_STATE_HOME", pathlib.Path.home() / ".local/state")
)
token_store = state_home / "not-my-board/auth_tokens.json"


class LoginFlow(util.ContextStack):
def __init__(self, hub_url):
self._hub_url = hub_url

async def _context_stack(self, stack):
# todo from settings
self._client_id = "6e2750e5-4f1e-42d8-bdcf-7c794c154e01"
self._issuer = "https://login.microsoftonline.com/common/v2.0"
redirect_uri = f"{self._hub_url}/oidc-callback"

identity_provider = await IdentityProvider.from_url(self._issuer)
self._client = Client(self._client_id, identity_provider, redirect_uri)

ready_event = asyncio.Event()
notification_api = _HubNotifications(ready_event)

channel_url = f"{self._hub_url}/ws-login"
hub = jsonrpc.WebsocketChannel(channel_url, api_obj=notification_api)
self._hub = await stack.enter_async_context(hub)

coro = self._hub.get_authentication_response(self._client.state)
self._auth_response_task = await stack.enter_async_context(
util.background_task(coro)
)

await ready_event.wait()

async def finish(self):
auth_response = await self._auth_response_task
tokens = await self._client.request_tokens(auth_response)

# drop unused enries, like "access_token"
to_store = {k: tokens[k] for k in ("refresh_token", "id_token")}

if not token_store.exists():
token_store.parent.mkdir(parents=True, exist_ok=True)
token_store.touch(mode=0o600)

with token_store.open("r+") as f:
async with util.flock(f):
f.seek(0)
f.truncate()
f.write(json.dumps(to_store))

@property
def login_url(self):
return self._client.login_url


class _HubNotifications:
def __init__(self, ready_event):
self._ready_event = ready_event

async def oidc_callback_registered(self):
self._ready_event.set()


async def get_id_token():
# todo from hub
client_id = "6e2750e5-4f1e-42d8-bdcf-7c794c154e01"
issuer = "https://login.microsoftonline.com/common/v2.0"

identity_provider = await IdentityProvider.from_url(issuer)

with token_store.open("r+") as f:
async with util.flock(f):
tokens = json.loads(f.read())
fresh_tokens = ensure_fresh(tokens, identity_provider, client_id)
if fresh_tokens != tokens:
to_store = {k: fresh_tokens[k] for k in ("refresh_token", "id_token")}
f.seek(0)
f.truncate()
f.write(json.dumps(to_store))

return fresh_tokens["id_token"]
151 changes: 151 additions & 0 deletions not_my_board/_auth/_openid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#!/usr/bin/env python3

import base64
import dataclasses
import hashlib
import secrets
import urllib.parse

import jwt

import not_my_board._http as http


@dataclasses.dataclass
class IdentityProvider:
authorization_endpoint: str
token_endpoint: str
issuer: str
jwks_uri: str

@classmethod
async def from_url(cls, issuer_url):
config_url = urllib.parse.urljoin(
f"{issuer_url}/", ".well-known/openid-configuration"
)
config = await http.get_json(config_url)

init_args = {
field.name: config[field.name] for field in dataclasses.fields(cls)
}
return cls(**init_args)


class Client:
def __init__(self, client_id, identity_provider, redirect_uri):
self._client_id = client_id
self._identity_provider = identity_provider
self._redirect_uri = redirect_uri
self._state = secrets.token_urlsafe()
self._nonce = secrets.token_urlsafe()
self._code_verifier = secrets.token_urlsafe()

hashed = hashlib.sha256(self._code_verifier.encode()).digest()
code_challange = base64.urlsafe_b64encode(hashed).rstrip(b"=").decode("ascii")

auth_params = {
"scope": "openid profile offline_access",
"response_type": "code",
"client_id": client_id,
"redirect_uri": redirect_uri,
"state": self._state,
"nonce": self._nonce,
"prompt": "consent",
"code_challenge": code_challange,
"code_challenge_method": "S256",
}

url_parts = list(
urllib.parse.urlparse(self._identity_provider.authorization_endpoint)
)
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update(auth_params)

url_parts[4] = urllib.parse.urlencode(query)

self._login_url = urllib.parse.urlunparse(url_parts)

@property
def state(self):
return self._state

@property
def login_url(self):
return self._login_url

async def request_tokens(self, auth_response):
if "error" in auth_response:
if "error_description" in auth_response:
msg = f'{auth_response["error_description"]} ({auth_response["error"]})'
else:
msg = auth_response["error"]

raise ProtocolError(f"Authentication error: {msg}")

url = self._identity_provider.token_endpoint
params = {
"grant_type": "authorization_code",
"code": auth_response["code"],
"redirect_uri": self._redirect_uri,
"client_id": self._client_id,
"code_verifier": self._code_verifier,
}
response = await http.post_form(url, params)

if response["token_type"].lower() != "bearer":
raise ProtocolError(
f'Expected token type "Bearer", got "{response["token_type"]}"'
)

id_token = verify(response["id_token"], self._client_id)
if id_token["nonce"] != self._nonce:
raise ProtocolError(
"Nonce in the ID token doesn't match the one in the authorization request"
)

return response


async def verify(token, client_id):
unverified_token = jwt.api_jwt.decode_complete(
token, options={"verify_signature": False}
)
kid = unverified_token["header"]["kid"]
issuer = unverified_token["payload"]["iss"]

identity_provider = await IdentityProvider.from_url(issuer)
jwk_set_raw = await http.get_json(identity_provider.jwks_uri)
jwk_set = jwt.PyJWKSet.from_dict(jwk_set_raw)

for key in jwk_set.keys:
if key.public_key_use in ["sig", None] and key.key_id == kid:
signing_key = key
break
else:
raise ProtocolError(f'Unable to find a signing key that matches "{kid}"')

return jwt.decode(
token,
key=signing_key.key,
algorithms="RS256",
audience=client_id,
)


async def ensure_fresh(tokens, identity_provider, client_id):
try:
verify(tokens["id_token"], client_id)
return tokens
except Exception:
params = {
"grant_type": "refresh_token",
"refresh_token": tokens["refresh_token"],
"client_id": client_id,
}
new_tokens = await http.post_form(identity_provider.token_endpoint, params)
verify(new_tokens["id_token"], client_id)
return new_tokens


class ProtocolError(Exception):
pass
51 changes: 45 additions & 6 deletions not_my_board/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,24 @@ def hub():

@asgineer.to_asgi
async def asgi_app(request):
response = (404, {}, "Page not found")

if isinstance(request, asgineer.WebsocketRequest):
if request.path == "/ws-agent":
return await _handle_agent(request)
await _handle_agent(request)
elif request.path == "/ws-exporter":
return await _handle_exporter(request)
await request.close()
return
await _handle_exporter(request)
elif request.path == "/ws-login":
await _handle_login(request)
else:
await request.close()
response = None
elif isinstance(request, asgineer.HttpRequest):
if request.path == "/api/v1/places":
return await _hub.get_places()
return 404, {}, "Page not found"
response = await _hub.get_places()
elif request.path == "/oidc-callback":
response = await _hub.oidc_callback(request.querydict)
return response


async def _handle_agent(ws):
Expand All @@ -54,6 +61,13 @@ async def _handle_exporter(ws):
await _hub.exporter_communicate(client_ip, exporter)


async def _handle_login(ws):
await ws.accept()
client_ip = ws.scope["client"][0]
channel = jsonrpc.Channel(ws.send, ws.receive_iter())
await _hub.login_communicate(client_ip, channel)


async def _authorize_ws(ws):
try:
auth = ws.headers["authorization"]
Expand All @@ -76,6 +90,7 @@ class Hub:
_available = set()
_wait_queue = []
_reservations = {}
_pending_callbacks = {}

def __init__(self):
self._id_generator = itertools.count(start=1)
Expand All @@ -99,6 +114,12 @@ async def exporter_communicate(self, client_ip, rpc):
with self._register_place(export_desc, rpc, client_ip):
await com_task

@jsonrpc.hidden
async def login_communicate(self, client_ip, rpc):
client_ip_var.set(client_ip)
rpc.set_api_object(self)
await rpc.communicate_forever()

@contextlib.contextmanager
def _register_place(self, export_desc, rpc, client_ip):
id_ = next(self._id_generator)
Expand Down Expand Up @@ -188,6 +209,24 @@ async def return_reservation(self, place_id):
else:
logger.info("Place returned, but it doesn't exist: %d", place_id)

async def get_authentication_response(self, state):
future = asyncio.get_running_loop().create_future()
self._pending_callbacks[state] = future
try:
channel = jsonrpc.get_current_channel()
await channel.oidc_callback_registered(_notification=True)
return await future
finally:
del self._pending_callbacks[state]

@jsonrpc.hidden
async def oidc_callback(self, query):
future = self._pending_callbacks[query["state"]]
if not future.done():
future.set_result(query)

return "Continue in not-my-board CLI"


_hub = Hub()

Expand Down
12 changes: 12 additions & 0 deletions not_my_board/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pathlib
import sys

import not_my_board._auth as auth
import not_my_board._client as client
import not_my_board._util as util
from not_my_board._agent import agent
Expand Down Expand Up @@ -104,6 +105,10 @@ def add_verbose_arg(subparser):
add_verbose_arg(subparser)
subparser.add_argument("devpath", help="devpath attribute of uevent")

subparser = add_subcommand("login", help="Log in to a hub")
add_verbose_arg(subparser)
subparser.add_argument("hub_url", help="http(s) URL of the hub")

args = parser.parse_args()

# Don't use escape sequences, if stdout is not a tty
Expand Down Expand Up @@ -196,6 +201,13 @@ async def _uevent_command(args):
await client.uevent(args.devpath)


async def _login_command(args):
async with auth.LoginFlow(args.hub_url) as login:
print(login.login_url)
await login.finish()
print("Login was successful")


class Format:
RESET = "\033[0m"
BOLD = "\033[1m"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"asgineer",
"h11",
"pydantic ~= 1.10",
"pyjwt[crypto]",
"tomli; python_version < '3.11'",
"typing_extensions; python_version < '3.9'",
"uvicorn",
Expand Down

0 comments on commit 51ca1d9

Please sign in to comment.