diff --git a/seacatauth/client/service.py b/seacatauth/client/service.py index 18349cdb..129f662b 100644 --- a/seacatauth/client/service.py +++ b/seacatauth/client/service.py @@ -1,4 +1,5 @@ import base64 +import binascii import datetime import logging import re @@ -489,63 +490,107 @@ async def validate_client_authorize_options( return True - async def authenticate_client_request(self, request, expected_client_id: str) -> str: + async def authenticate_client_request( + self, + request, + expected_client_id: typing.Optional[str] = None + ) -> typing.Optional[str]: """ Verify client ID and secret. """ - client_dict = await self.get(expected_client_id) - token_endpoint_auth_method = client_dict.get("token_endpoint_auth_method", "client_secret_basic") - if token_endpoint_auth_method == "none": - # Public client - no authentication required - return expected_client_id - - # Check secret expiration - client_secret_expires_at = client_dict.get("client_secret_expires_at", None) - if client_secret_expires_at and client_secret_expires_at < datetime.datetime.now(datetime.timezone.utc): - raise exceptions.ClientAuthenticationError("Expired client secret.", client_id=expected_client_id) + if expected_client_id: + # Client ID is known - Use the pre-configured authentication method + client_dict = await self.get(expected_client_id) + expected_auth_method = client_dict.get("token_endpoint_auth_method", "client_secret_basic") + if expected_auth_method == "none": + return expected_client_id + if expected_auth_method == "client_secret_basic": + client_id, client_secret = self._get_credentials_from_authorization_header(request) + elif expected_auth_method == "client_secret_post": + client_id, client_secret = await self._get_credentials_from_post_data(request) + else: + raise NotImplementedError("Unsupported client authentication method: {}".format(expected_auth_method)) - client_secret_hash = client_dict.get("__client_secret", None) - if token_endpoint_auth_method == "client_secret_basic": - try: - auth_header = request.headers.get("Authorization") - _, basic_auth = auth_header.split(" ") - client_id, client_secret = base64.urlsafe_b64decode( - basic_auth.encode("ascii")).decode("ascii").split(":") - except Exception as e: + if not client_id: raise exceptions.ClientAuthenticationError( - "Falied to get client credentials from Authorization header: {}.".format(e), - client_id=expected_client_id + "Failed to get client credentials from request.", + client_id=expected_client_id, ) - if client_id != expected_client_id: + elif client_id != expected_client_id: raise exceptions.ClientAuthenticationError( "Client IDs do not match (expected {!r}).".format(expected_client_id), client_id=client_id, ) - if generic.argon2_verify(client_secret_hash, client_secret): - return client_id - else: - raise exceptions.ClientAuthenticationError("Incorrect client secret.", client_id=client_id) - elif token_endpoint_auth_method == "client_secret_post": - post_data = await request.post() - client_id = post_data.get("client_id") - client_secret = post_data.get("client_secret") - if client_id != expected_client_id: + else: + # Client ID is not known in advance - Try to extract it from the request + client_id, client_secret = self._get_credentials_from_authorization_header(request) + if client_id and client_secret: + auth_method = "client_secret_basic" + else: + client_id, client_secret = await self._get_credentials_from_post_data(request) + if client_id and client_secret: + auth_method = "client_secret_post" + else: + # Public client - Authentication not required + # auth_method = "none" + return None + + assert client_id + client_dict = await self.get(client_id) + expected_auth_method = client_dict.get("token_endpoint_auth_method", "client_secret_basic") + if auth_method != expected_auth_method: raise exceptions.ClientAuthenticationError( - "Client IDs do not match (expected {!r}).".format(expected_client_id), + "Unexpected authentication method (expected {!r}, {!r}).".format( + expected_auth_method, auth_method), client_id=client_id, ) - if generic.argon2_verify(client_secret_hash, client_secret): + elif auth_method == "none": + # Public client - no secret verification required return client_id - else: - raise exceptions.ClientAuthenticationError("Incorrect client secret.", client_id=client_id) - elif token_endpoint_auth_method == "client_secret_jwt": - raise ValueError("Unsupported token_endpoint_auth_method value: {}".format(token_endpoint_auth_method)) - elif token_endpoint_auth_method == "private_key_jwt": - raise ValueError("Unsupported token_endpoint_auth_method value: {}".format(token_endpoint_auth_method)) - else: - raise ValueError("Unsupported token_endpoint_auth_method value: {}".format(token_endpoint_auth_method)) + # Check secret expiration + client_secret_expires_at = client_dict.get("client_secret_expires_at", None) + if client_secret_expires_at and client_secret_expires_at < datetime.datetime.now(datetime.timezone.utc): + raise exceptions.ClientAuthenticationError("Expired client secret.", client_id=expected_client_id) + + # Verify client secret + client_secret_hash = client_dict.get("__client_secret", None) + if not generic.argon2_verify(client_secret_hash, client_secret): + raise exceptions.ClientAuthenticationError("Incorrect client secret.", client_id=client_id) + + return client_id + + def _get_credentials_from_authorization_header( + self, request + ) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]: + auth_header = request.headers.get("Authorization") + if not auth_header: + return None, None + try: + token_type, auth_token = auth_header.split(" ") + except ValueError: + return None, None + if token_type != "Basic": + return None, None + try: + auth_token_decoded = base64.urlsafe_b64decode(auth_token.encode("ascii")).decode("ascii") + except (binascii.Error, UnicodeDecodeError): + return None, None + try: + client_id, client_secret = auth_token_decoded.split(":") + except ValueError: + return None, None + return client_id, client_secret + + + async def _get_credentials_from_post_data( + self, request + ) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]: + post_data = await request.post() + if not ("client_id" in post_data and "client_secret" in post_data): + return None, None + return post_data["client_id"], post_data["client_secret"] def _check_grant_types(self, grant_types, response_types):