Skip to content

Commit 9aceb59

Browse files
Merge pull request #28 from SUNET/lundberg_jwsd_fix
Fix jwsd implementation
2 parents 82583e6 + 8c02b48 commit 9aceb59

File tree

8 files changed

+119
-38
lines changed

8 files changed

+119
-38
lines changed

src/auth_server/api.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
from typing import Dict, Type, cast
33

44
from fastapi import FastAPI
5+
from fastapi.exceptions import RequestValidationError
56
from fastapi.middleware.cors import CORSMiddleware
67
from loguru import logger
8+
from starlette.requests import Request
9+
from starlette.responses import JSONResponse
710
from starlette.staticfiles import StaticFiles
11+
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
812

913
from auth_server.config import AuthServerConfig, ConfigurationError, FlowName, load_config
1014
from auth_server.context import ContextRequestRoute
@@ -78,4 +82,16 @@ def init_auth_server_api() -> AuthServer:
7882
app.mount(
7983
"/static", StaticFiles(packages=["auth_server"]), name="static"
8084
) # defaults to the "statics" directory (the ending s is not a mistake) because starlette says so
85+
86+
config = load_config()
87+
if config.debug or config.testing:
88+
# log more info about 422 errors to ease fault tracing
89+
@app.exception_handler(RequestValidationError)
90+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
91+
92+
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
93+
logger.exception(f"{exc}")
94+
content = {"status_code": 10422, "message": exc_str, "data": None}
95+
return JSONResponse(content=content, status_code=HTTP_422_UNPROCESSABLE_ENTITY)
96+
8197
return app

src/auth_server/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class TLSFEDMetadata(BaseModel):
5656
class AuthServerConfig(BaseSettings):
5757
app_name: str = Field(default="auth-server")
5858
environment: Environment = Field(default=Environment.PROD)
59+
debug: bool = False
5960
testing: bool = False
6061
log_level: str = Field(default="INFO")
6162
log_color: bool = True

src/auth_server/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Context(BaseModel):
1515
client_cert: Optional[str] = None
1616
jws_obj: Optional[jws.JWS] = None
1717
detached_jws: Optional[str] = None
18+
detached_jws_body: Optional[str] = None
1819
model_config = ConfigDict(arbitrary_types_allowed=True)
1920

2021
def to_dict(self):

src/auth_server/flows.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ async def continue_transaction(self, continue_request: ContinueRequest) -> Optio
136136
self.state.proof_ok = await self.check_proof(
137137
gnap_key=self.state.grant_request.client.key, gnap_request=continue_request
138138
)
139+
if not self.state.proof_ok:
140+
logger.error("could not validate proof of key possession in continue response, aborting")
141+
raise StopTransactionException(status_code=401, detail="could not validate proof of key possession")
139142

140143
# run the remaining steps in the flow
141144
steps = await self.steps()
@@ -170,15 +173,14 @@ async def check_proof(self, gnap_key: Key, gnap_request: Optional[Union[GrantReq
170173
return await check_jwsd_proof(
171174
request=self.request,
172175
gnap_key=gnap_key,
173-
gnap_request=gnap_request,
174-
key_reference=self.state.key_reference,
175176
access_token=self.state.continue_access_token,
176177
)
177178
else:
178179
raise NextFlowException(status_code=400, detail="no supported proof method")
179180

180181
async def create_claims(self) -> Claims:
181182
if self.state.auth_source is None:
183+
logger.error("no auth_source set, aborting")
182184
raise NextFlowException(status_code=400, detail="no auth source set")
183185

184186
claims = Claims(
@@ -384,7 +386,8 @@ async def handle_subject_response(self) -> Optional[GrantResponse]:
384386

385387
async def create_auth_token(self) -> Optional[GrantResponse]:
386388
if not self.state.proof_ok:
387-
return None
389+
logger.error("could not validate proof of key possession, running next flow")
390+
raise NextFlowException(status_code=401, detail="could not validate proof of key possession")
388391

389392
# Create claims
390393
claims = await self.create_claims()

src/auth_server/middleware.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# -*- coding: utf-8 -*-
2+
from typing import Optional
3+
24
from jwcrypto import jws
35
from jwcrypto.common import JWException
46
from loguru import logger
@@ -33,22 +35,43 @@ async def get_body(request: Request) -> bytes:
3335
return body
3436

3537

38+
def get_header_index(request: Request, header_key: bytes) -> Optional[int]:
39+
for key, value in request.scope["headers"]:
40+
if key == header_key:
41+
return request.scope["headers"].index((key, value))
42+
return None
43+
44+
45+
def set_header(request: Request, header_key: str, header_value: str) -> None:
46+
b_header_key = header_key.encode("utf-8")
47+
b_header_value = header_value.encode("utf-8")
48+
content_type_index = get_header_index(request, b_header_key)
49+
if content_type_index:
50+
logger.debug(
51+
f"Replacing header {request.scope['headers'][content_type_index]} with {(b_header_key, b_header_value)}"
52+
)
53+
request.scope["headers"][content_type_index] = (b_header_key, b_header_value)
54+
else:
55+
# no header to replace, just set it
56+
request.scope["headers"].append((b_header_key, b_header_value))
57+
58+
3659
class JOSEMiddleware(BaseHTTPMiddleware, ContextRequestMixin):
3760
def __init__(self, app):
3861
super().__init__(app)
3962

4063
async def dispatch(self, request: Request, call_next):
41-
if request.headers.get("content-type") == "application/jose":
42-
# Return a more helpful error message for a common mistake
43-
return return_error_response(status_code=422, detail="content-type needs to be application/jose+json")
64+
acceptable_jose_content_types = ["application/jose", "application/jose+json"]
65+
is_jose = request.headers.get("content-type") in acceptable_jose_content_types
66+
is_detached_jws = request.headers.get("Detached-JWS") is not None
4467

45-
if request.headers.get("content-type") == "application/jose+json":
68+
if is_jose and not is_detached_jws:
4669
request = self.make_context_request(request)
47-
logger.info("got application/jose request")
70+
logger.info("got application/jose+json request")
4871
body = await get_body(request)
4972
# deserialize jws
5073
body_str = body.decode("utf-8")
51-
logger.debug(f"body: {body_str}")
74+
logger.debug(f"JWS body: {body_str}")
5275
jwstoken = jws.JWS()
5376
try:
5477
jwstoken.deserialize(body_str)
@@ -62,5 +85,18 @@ async def dispatch(self, request: Request, call_next):
6285
request.context.jws_obj = jwstoken
6386
# replace body with unverified deserialized token - verification is done when verifying proof
6487
await set_body(request, jwstoken.objects["payload"])
88+
# set content-type to application/json as the body has changed
89+
set_header(request, "content-type", "application/json")
90+
# update content-length header to match the new body
91+
set_header(request, "content-length", str(len(jwstoken.objects["payload"])))
92+
93+
if is_detached_jws:
94+
request = self.make_context_request(request)
95+
logger.info("got detached jws request")
96+
# save original body for the detached jws validation
97+
body = await get_body(request)
98+
body_str = body.decode("utf-8")
99+
logger.debug(f"JWSD body: {body_str}")
100+
request.context.detached_jws_body = body_str
65101

66102
return await call_next(request)

src/auth_server/models/gnap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ class FinishInteraction(GnapBaseModel):
179179
method: FinishInteractionMethod
180180
uri: str
181181
nonce: str
182-
hash_method: HashMethod = Field(default=HashMethod.SHA_256)
182+
hash_method: Optional[HashMethod] = None
183183

184184

185185
class Hints(GnapBaseModel):

src/auth_server/proof/jws.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from cryptography.hazmat.primitives.hashes import SHA256, SHA384, SHA512
66
from fastapi import HTTPException
77
from jwcrypto import jwk, jws
8-
from jwcrypto.common import base64url_encode
8+
from jwcrypto.common import base64url_decode, base64url_encode
99
from loguru import logger
1010
from pydantic import ValidationError
1111

@@ -99,42 +99,42 @@ async def check_jws_proof(
9999
async def check_jwsd_proof(
100100
request: ContextRequest,
101101
gnap_key: Key,
102-
gnap_request: Union[GrantRequest, ContinueRequest],
103-
key_reference: Optional[str] = None,
104102
access_token: Optional[str] = None,
105103
) -> bool:
106-
if request.context.detached_jws is None:
104+
if request.context.detached_jws is None or request.context.detached_jws_body is None:
107105
raise HTTPException(status_code=400, detail="No detached JWS found")
108106

109107
logger.debug(f"detached_jws: {request.context.detached_jws}")
108+
logger.debug(f"detached_jws_body: {request.context.detached_jws_body}")
110109

111110
# recreate jws
112111
try:
113-
header, _, signature = request.context.detached_jws.split(".")
112+
header, client_payload_hash, signature = request.context.detached_jws.split(".")
114113
except ValueError as e:
115114
logger.error(f"invalid detached jws: {e}")
116-
return False
115+
raise HTTPException(status_code=400, detail="invalid format for detached jws")
117116

118-
gnap_request_orig = gnap_request.copy(deep=True)
119-
if isinstance(gnap_request_orig, GrantRequest) and key_reference is not None:
120-
# If key was sent as reference in grant request we need to mirror that when
121-
# rebuilding the request as that was what was signed
122-
assert isinstance(gnap_request_orig.client, Client) # please mypy
123-
gnap_request_orig.client.key = key_reference
117+
payload = base64url_encode(request.context.detached_jws_body)
118+
logger.debug(f"payload: {payload}")
119+
120+
# check hash of payload
121+
payload_hash = hash_with(SHA256(), request.context.detached_jws_body.encode())
122+
if payload_hash != base64url_decode(client_payload_hash):
123+
logger.error(f"invalid payload hash: {repr(payload_hash)}")
124+
raise HTTPException(status_code=400, detail="invalid payload hash")
124125

125-
logger.debug(f"gnap_request_orig: {gnap_request_orig.json(exclude_unset=True)}")
126-
payload = base64url_encode(gnap_request_orig.json(exclude_unset=True))
127126
raw_jws = f"{header}.{payload}.{signature}"
128-
_jws = jws.JWS()
127+
logger.debug(f"raw_jws: {raw_jws}")
129128

130129
# deserialize jws
130+
_jws = jws.JWS()
131131
try:
132132
_jws.deserialize(raw_jws=raw_jws)
133133
logger.info("Detached JWS token deserialized")
134134
logger.debug(f"JWS: {_jws.objects}")
135135
except jws.InvalidJWSObject as e:
136136
logger.error(f"Failed to deserialize detached jws: {e}")
137-
return False
137+
raise HTTPException(status_code=400, detail=str(e))
138138

139139
verify_jws(jws_obj=_jws, gnap_key=gnap_key)
140140

src/auth_server/tests/test_app.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from cryptography import x509
1515
from cryptography.hazmat.primitives.hashes import SHA256
1616
from jwcrypto import jwk, jws, jwt
17+
from jwcrypto.common import base64url_encode
1718
from starlette.testclient import TestClient
1819

1920
from auth_server.api import init_auth_server_api
@@ -308,7 +309,7 @@ def test_transaction_jws(self):
308309
)
309310
data = _jws.serialize(compact=True)
310311

311-
client_header = {"Content-Type": "application/jose+json"}
312+
client_header = {"Content-Type": "application/jose"}
312313
response = self.client.post("/transaction", content=data, headers=client_header)
313314

314315
assert response.status_code == 200
@@ -321,7 +322,7 @@ def test_transaction_jws(self):
321322
assert claims["auth_source"] == AuthSource.TEST
322323

323324
def test_transaction_jwsd(self):
324-
client_key_dict = self.client_jwk.export(as_dict=True)
325+
client_key_dict = self.client_jwk.export_public(as_dict=True)
325326
client_jwk = ECJWK(**client_key_dict)
326327
req = GrantRequest(
327328
client=Client(key=Key(proof=Proof(method=ProofMethod.JWSD), jwk=client_jwk)),
@@ -335,7 +336,15 @@ def test_transaction_jwsd(self):
335336
"uri": "http://testserver/transaction",
336337
"created": int(utc_now().timestamp()),
337338
}
338-
_jws = jws.JWS(payload=req.json(exclude_unset=True))
339+
340+
payload = req.model_dump_json(exclude_unset=True)
341+
342+
# create a hash of payload to send in payload place
343+
payload_digest = hash_with(SHA256(), payload.encode())
344+
payload_hash = base64url_encode(payload_digest)
345+
346+
# create detached jws
347+
_jws = jws.JWS(payload=payload)
339348
_jws.add_signature(
340349
key=self.client_jwk,
341350
protected=json.dumps(jws_header),
@@ -344,9 +353,11 @@ def test_transaction_jwsd(self):
344353

345354
# Remove payload from serialized jws
346355
header, _, signature = data.split(".")
347-
client_header = {"Detached-JWS": f"{header}..{signature}"}
356+
client_header = {"Detached-JWS": f"{header}.{payload_hash}.{signature}"}
348357

349-
response = self.client.post("/transaction", json=req.dict(exclude_unset=True), headers=client_header)
358+
response = self.client.post(
359+
"/transaction", content=req.model_dump_json(exclude_unset=True), headers=client_header
360+
)
350361

351362
assert response.status_code == 200
352363
assert "access_token" in response.json()
@@ -1148,7 +1159,7 @@ def test_transaction_jwsd_continue(self):
11481159
self.config["auth_flows"] = json.dumps(["InteractionFlow"])
11491160
self._update_app_config(config=self.config)
11501161

1151-
client_key_dict = self.client_jwk.export(as_dict=True)
1162+
client_key_dict = self.client_jwk.export_public(as_dict=True)
11521163
client_jwk = ECJWK(**client_key_dict)
11531164

11541165
req = GrantRequest(
@@ -1164,7 +1175,14 @@ def test_transaction_jwsd_continue(self):
11641175
"uri": "http://testserver/transaction",
11651176
"created": int(utc_now().timestamp()),
11661177
}
1167-
_jws = jws.JWS(payload=req.json(exclude_unset=True))
1178+
1179+
payload = req.model_dump_json(exclude_unset=True)
1180+
1181+
# create a hash of payload to send in payload place
1182+
payload_digest = hash_with(SHA256(), payload.encode())
1183+
payload_hash = base64url_encode(payload_digest)
1184+
1185+
_jws = jws.JWS(payload=payload)
11681186
_jws.add_signature(
11691187
key=self.client_jwk,
11701188
protected=json.dumps(jws_header),
@@ -1173,9 +1191,11 @@ def test_transaction_jwsd_continue(self):
11731191

11741192
# Remove payload from serialized jws
11751193
header, _, signature = data.split(".")
1176-
client_header = {"Detached-JWS": f"{header}..{signature}"}
1194+
client_header = {"Detached-JWS": f"{header}.{payload_hash}.{signature}"}
11771195

1178-
response = self.client.post("/transaction", json=req.dict(exclude_unset=True), headers=client_header)
1196+
response = self.client.post(
1197+
"/transaction", content=req.model_dump_json(exclude_unset=True), headers=client_header
1198+
)
11791199
assert response.status_code == 200
11801200

11811201
# continue response with no continue reference in uri
@@ -1207,7 +1227,11 @@ def test_transaction_jwsd_continue(self):
12071227
# calculate ath header value
12081228
access_token_hash = hash_with(SHA256(), continue_response["access_token"]["value"].encode())
12091229
jws_header["ath"] = base64.urlsafe_b64encode(access_token_hash).decode("ascii").rstrip("=")
1210-
_jws = jws.JWS(payload="{}")
1230+
# create hash of empty payload to send in payload place
1231+
payload = "{}"
1232+
payload_digest = hash_with(SHA256(), payload.encode())
1233+
payload_hash = base64url_encode(payload_digest)
1234+
_jws = jws.JWS(payload=payload)
12111235
_jws.add_signature(
12121236
key=self.client_jwk,
12131237
protected=json.dumps(jws_header),
@@ -1216,11 +1240,11 @@ def test_transaction_jwsd_continue(self):
12161240

12171241
# Remove payload from serialized jws
12181242
continue_header, _, continue_signature = continue_data.split(".")
1219-
client_header = {"Detached-JWS": f"{continue_header}..{continue_signature}"}
1243+
client_header = {"Detached-JWS": f"{continue_header}.{payload_hash}.{continue_signature}"}
12201244

12211245
authorization_header = f'GNAP {continue_response["access_token"]["value"]}'
12221246
client_header["Authorization"] = authorization_header
1223-
response = self.client.post(continue_response["uri"], json=dict(), headers=client_header)
1247+
response = self.client.post(continue_response["uri"], content=payload, headers=client_header)
12241248

12251249
assert response.status_code == 200
12261250
assert "access_token" in response.json()

0 commit comments

Comments
 (0)