Skip to content

Commit aeeb43d

Browse files
committed
auth: Refactor get_id_token()
Use token source classes to hide the details from users. There are two sources: from file and from command.
1 parent 81ba2b4 commit aeeb43d

File tree

7 files changed

+76
-74
lines changed

7 files changed

+76
-74
lines changed

not_my_board/_agent.py

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from dataclasses import dataclass, field
1414
from typing import List, Tuple
1515

16-
import not_my_board._auth as auth
17-
import not_my_board._http as http
1816
import not_my_board._jsonrpc as jsonrpc
1917
import not_my_board._models as models
2018
import not_my_board._usbip as usbip
@@ -25,18 +23,10 @@
2523
Address = Tuple[str, int]
2624

2725

28-
async def agent(hub_url, ca_files, token_store_path, token_cmd):
29-
io = _AgentIO(hub_url, http.Client(ca_files), token_store_path, token_cmd)
30-
async with Agent(hub_url, io) as agent_:
31-
await agent_.serve_forever()
32-
33-
34-
class _AgentIO:
35-
def __init__(self, hub_url, http_client, token_store_path, token_cmd):
26+
class AgentIO:
27+
def __init__(self, hub_url, http_client):
3628
self._hub_url = hub_url
3729
self._http = http_client
38-
self._token_store_path = token_store_path
39-
self._token_cmd = token_cmd
4030

4131
@contextlib.asynccontextmanager
4232
async def hub_rpc(self):
@@ -111,33 +101,19 @@ async def _handle_port_forward_client(self, proxy, target, client_r, client_w):
111101
await client_w.drain()
112102
await util.relay_streams(client_r, client_w, remote_r, remote_w)
113103

114-
async def get_id_token(self):
115-
if self._token_cmd:
116-
logger.debug("Executing token command: %s", self._token_cmd)
117-
proc = await asyncio.create_subprocess_shell(
118-
self._token_cmd, stdout=asyncio.subprocess.PIPE
119-
)
120-
stdout, _ = await proc.communicate()
121-
if proc.returncode:
122-
raise RuntimeError(f"{self._token_cmd!r} exited with {proc.returncode}")
123-
return stdout.decode("utf-8").rstrip()
124-
125-
return await auth.get_id_token(
126-
self._token_store_path, self._hub_url, self._http
127-
)
128-
129104

130105
class Agent(util.ContextStack):
131-
def __init__(self, hub_url, io):
106+
def __init__(self, hub_url, io, token_src):
132107
url = urllib.parse.urlsplit(hub_url)
133108
self._hub_host = url.netloc.split(":")[0]
134109
self._io = io
135110
self._locks = weakref.WeakValueDictionary()
136111
self._reservations = {}
112+
self._token_src = token_src
137113

138114
async def _context_stack(self, stack):
139115
self._hub = await stack.enter_async_context(self._io.hub_rpc())
140-
self._hub.set_api_object(_WebsocketInterface(self._io))
116+
self._hub.set_api_object(self._token_src)
141117
stack.push_async_callback(self._cleanup)
142118
self._unix_server = await stack.enter_async_context(self._io.unix_server(self))
143119

@@ -246,14 +222,6 @@ async def status(self):
246222
]
247223

248224

249-
class _WebsocketInterface:
250-
def __init__(self, io):
251-
self._io = io
252-
253-
async def get_id_token(self):
254-
return await self._io.get_id_token()
255-
256-
257225
def _filter_places(import_description, places):
258226
candidates = {}
259227

not_my_board/_auth/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from ._login import LoginFlow, get_id_token
1+
from ._login import IdTokenFromCmd, IdTokenFromFile, LoginFlow
22
from ._openid import Validator

not_my_board/_auth/_login.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,37 @@ async def oidc_callback_registered(self):
7171
self._ready_event.set()
7272

7373

74-
async def get_id_token(token_store_path, hub_url, http_client):
75-
token_store = _TokenStore(token_store_path)
76-
async with token_store:
77-
id_token, refresh_token = token_store.get_tokens(hub_url)
78-
id_token, refresh_token = await ensure_fresh(
79-
id_token, refresh_token, http_client
80-
)
81-
token_store.save_tokens(hub_url, id_token, refresh_token)
74+
class IdTokenFromFile:
75+
def __init__(self, hub_url, http_client, token_store_path):
76+
self._hub_url = hub_url
77+
self._http = http_client
78+
self._token_store = _TokenStore(token_store_path)
79+
80+
async def get_id_token(self):
81+
async with self._token_store:
82+
id_token, refresh_token = self._token_store.get_tokens(self._hub_url)
83+
id_token, refresh_token = await ensure_fresh(
84+
id_token, refresh_token, self._http
85+
)
86+
self._token_store.save_tokens(self._hub_url, id_token, refresh_token)
87+
88+
return id_token
8289

83-
return id_token
90+
91+
class IdTokenFromCmd:
92+
def __init__(self, hub_url, http_client, cmd):
93+
self._hub_url = hub_url
94+
self._http = http_client
95+
self._cmd = cmd
96+
97+
async def get_id_token(self):
98+
proc = await asyncio.create_subprocess_shell(
99+
self._cmd, stdout=asyncio.subprocess.PIPE
100+
)
101+
stdout, _ = await proc.communicate()
102+
if proc.returncode:
103+
raise RuntimeError(f"{self._cmd!r} exited with {proc.returncode}")
104+
return stdout.decode("utf-8").rstrip()
84105

85106

86107
class _TokenStore(util.ContextStack):

not_my_board/_export.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import h11
1010

11-
import not_my_board._auth as auth
1211
import not_my_board._http as http
1312
import not_my_board._jsonrpc as jsonrpc
1413
import not_my_board._models as models
@@ -18,21 +17,14 @@
1817
logger = logging.getLogger(__name__)
1918

2019

21-
async def export(hub_url, place, ca_files, token_store_path):
22-
http_client = http.Client(ca_files)
23-
async with Exporter(hub_url, place, http_client, token_store_path) as exporter:
24-
await exporter.register_place()
25-
await exporter.serve_forever()
26-
27-
2820
class Exporter(util.ContextStack):
29-
def __init__(self, hub_url, export_desc_path, http_client, token_store_path):
21+
def __init__(self, hub_url, export_desc_path, http_client, token_src):
3022
self._hub_url = hub_url
3123
self._ip_to_tasks_map = {}
3224
export_desc_content = export_desc_path.read_text()
3325
self._place = models.ExportDesc(**util.toml_loads(export_desc_content))
3426
self._http = http_client
35-
self._token_store_path = token_store_path
27+
self._token_src = token_src
3628

3729
tcp_targets = {
3830
f"{tcp.host}:{tcp.port}".encode()
@@ -127,9 +119,7 @@ async def _tunnel(self, client_r, client_w, target, trailing_data):
127119
await util.relay_streams(client_r, client_w, remote_r, remote_w)
128120

129121
async def get_id_token(self):
130-
return await auth.get_id_token(
131-
self._token_store_path, self._hub_url, self._http
132-
)
122+
return await self._token_src.get_id_token()
133123

134124

135125
def format_date_time(dt=None):

not_my_board/cli/__init__.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import pathlib
77
import sys
88

9+
import not_my_board._agent as agent
910
import not_my_board._auth as auth
1011
import not_my_board._client as client
12+
import not_my_board._export as export
1113
import not_my_board._http as http
1214
import not_my_board._util as util
13-
from not_my_board._agent import agent
14-
from not_my_board._export import export
1515
from not_my_board._hub import run_hub
1616

1717
try:
@@ -160,11 +160,26 @@ def _hub_command(_):
160160

161161

162162
async def _export_command(args):
163-
await export(args.hub_url, args.export_description, args.cacert, TOKEN_STORE_PATH)
163+
http_client = http.Client(args.cacert)
164+
token_src = auth.IdTokenFromFile(args.hub_url, http_client, TOKEN_STORE_PATH)
165+
async with export.Exporter(
166+
args.hub_url, args.export_description, http_client, token_src
167+
) as exporter:
168+
await exporter.register_place()
169+
await exporter.serve_forever()
164170

165171

166172
async def _agent_command(args):
167-
await agent(args.hub_url, args.cacert, TOKEN_STORE_PATH, args.token_cmd)
173+
http_client = http.Client(args.cacert)
174+
io = agent.AgentIO(args.hub_url, http_client)
175+
176+
if args.token_cmd:
177+
token_src = auth.IdTokenFromCmd(args.hub_url, http_client, args.token_cmd)
178+
else:
179+
token_src = auth.IdTokenFromFile(args.hub_url, http_client, TOKEN_STORE_PATH)
180+
181+
async with agent.Agent(args.hub_url, io, token_src) as agent_:
182+
await agent_.serve_forever()
168183

169184

170185
async def _reserve_command(args):

tests/test_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ async def port_forward(self, ready_event, proxy, target, local_port):
212212
@pytest.fixture()
213213
async def agent_io():
214214
io = FakeAgentIO()
215-
async with agentmodule.Agent(HUB_URL, io) as agent:
215+
async with agentmodule.Agent(HUB_URL, io, None) as agent:
216216
async with util.background_task(agent.serve_forever()):
217217
yield io
218218

tests/test_auth.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,11 @@ def token_store_path():
210210

211211

212212
class FakeExporter:
213-
def __init__(self, rpc, token_store_path_, http, id_token=None):
213+
def __init__(self, rpc, token_src, http):
214214
rpc.set_api_object(self)
215215
self._rpc = rpc
216216
self._http = http
217-
self._token_store_path = token_store_path_
218-
self._id_token = id_token
217+
self._token_src = token_src
219218

220219
async def communicate_forever(self):
221220
await self._rpc.communicate_forever()
@@ -228,9 +227,15 @@ async def register_place(self):
228227
await self._rpc.register_place(place)
229228

230229
async def get_id_token(self):
231-
if self._id_token is not None:
232-
return self._id_token
233-
return await auth.get_id_token(self._token_store_path, HUB_URL, self._http)
230+
return await self._token_src.get_id_token()
231+
232+
233+
class FakeTokenSource:
234+
def __init__(self, id_token):
235+
self._id_token = id_token
236+
237+
async def get_id_token(self):
238+
return self._id_token
234239

235240

236241
async def test_login_success(hub, http_client, token_store_path):
@@ -247,7 +252,8 @@ async def test_login_success(hub, http_client, token_store_path):
247252
assert claims["aud"] == CLIENT_ID
248253

249254
rpc1, rpc2 = fake_rpc_pair()
250-
fake_exporter = FakeExporter(rpc1, token_store_path, http_client)
255+
token_src = auth.IdTokenFromFile(HUB_URL, http_client, token_store_path)
256+
fake_exporter = FakeExporter(rpc1, token_src, http_client)
251257
hub_coro = hub.communicate("3.1.1.1", rpc2)
252258
exporter_coro = fake_exporter.communicate_forever()
253259
async with util.background_task(hub_coro):
@@ -299,7 +305,8 @@ async def test_permissions(claims, is_allowed):
299305
tokens = http_client_.issue_tokens(full_claims)
300306

301307
rpc1, rpc2 = fake_rpc_pair()
302-
fake_exporter = FakeExporter(rpc1, "", http_client_, tokens["id"])
308+
token_src = FakeTokenSource(tokens["id"])
309+
fake_exporter = FakeExporter(rpc1, token_src, http_client_)
303310
hub_coro = hub_.communicate("3.1.1.1", rpc2)
304311
exporter_coro = fake_exporter.communicate_forever()
305312
async with util.background_task(hub_coro):
@@ -327,7 +334,8 @@ async def test_permission_lost(hub, http_client, token_store_path, fake_time):
327334
token_store_path.write_text(json.dumps(token_store_content))
328335

329336
rpc1, rpc2 = fake_rpc_pair()
330-
fake_exporter = FakeExporter(rpc1, token_store_path, http_client)
337+
token_src = auth.IdTokenFromFile(HUB_URL, http_client, token_store_path)
338+
fake_exporter = FakeExporter(rpc1, token_src, http_client)
331339
hub_coro = hub.communicate("3.1.1.1", rpc2)
332340
exporter_coro = fake_exporter.communicate_forever()
333341

0 commit comments

Comments
 (0)