Skip to content

Commit

Permalink
Merge pull request #96 from shawnz/generic-user-in-oidc-request-v2
Browse files Browse the repository at this point in the history
Make User generic in OIDC Request class
  • Loading branch information
aliev authored Aug 20, 2024
2 parents 7a8ce10 + 01b4755 commit 73a03c3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
30 changes: 24 additions & 6 deletions aioauth/oidc/core/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@

from typing import Any, Optional, TypeVar

from aioauth.requests import BaseRequest, Query as BaseQuery, Post
from aioauth.requests import (
BaseRequest as BaseOAuth2Request,
Query as OAuth2Query,
Post,
TPost,
TUser,
)


@dataclass
class Query(BaseQuery):
class Query(OAuth2Query):
# Space delimited, case sensitive list of ASCII string values that
# specifies whether the Authorization Server prompts the End-User for
# reauthentication and consent. The defined values are: none, login,
Expand All @@ -15,17 +21,29 @@ class Query(BaseQuery):
prompt: Optional[str] = None


TQuery = TypeVar("TQuery", bound=Query)


@dataclass
class Request(BaseRequest[Query, Post, Any]):
class BaseRequest(BaseOAuth2Request[TQuery, TPost, TUser]):
"""
Object that contains a client's complete request with extensions as defined
by OpenID Core.
https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
"""

query: TQuery
post: TPost
user: Optional[TUser] = None


TRequest = TypeVar("TRequest", bound=BaseRequest)


@dataclass
class Request(BaseRequest[Query, Post, Any]):
"""Object that contains a client's complete request."""

query: Query = field(default_factory=Query)
post: Post = field(default_factory=Post)
user: Optional[Any] = None


TRequest = TypeVar("TRequest", bound=Request)
6 changes: 3 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from aioauth.collections import HTTPHeaderDict
from aioauth.constances import default_headers
from aioauth.requests import Post, Query, Request
from aioauth.requests import BaseRequest, Post, Query
from aioauth.responses import ErrorResponse, Response

EMPTY_KEYS = {
Expand Down Expand Up @@ -298,7 +298,7 @@ def get_keys(query: Union[Query, Post]) -> Dict[str, Any]:


async def check_query_values(
request: Request, responses, query_dict: Dict, endpoint_func, value
request: BaseRequest, responses, query_dict: Dict, endpoint_func, value
):
keys = set(query_dict.keys()) & set(responses.keys())

Expand All @@ -322,7 +322,7 @@ async def check_query_values(


async def check_request_validators(
request: Request,
request: BaseRequest,
endpoint_func: Callable,
):
query_dict = {}
Expand Down

0 comments on commit 73a03c3

Please sign in to comment.