-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #48 from AndrewSergienko/2.x/auth
Basic auth0 authentication
- Loading branch information
Showing
40 changed files
with
609 additions
and
178 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,3 +52,5 @@ dmypy.json | |
|
||
.idea | ||
/pytest.ini | ||
/docker-compose.yml | ||
/keycloak/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from aiohttp import ClientSession | ||
from sqlalchemy import Table, select | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
|
||
from costy.application.common.auth_gateway import AuthLoger | ||
from costy.domain.exceptions.access import AuthenticationError | ||
from costy.domain.models.user import UserId | ||
from costy.infrastructure.config import AuthSettings | ||
|
||
|
||
class AuthGateway(AuthLoger): | ||
def __init__( | ||
self, | ||
db_session: AsyncSession, | ||
web_session: ClientSession, | ||
table: Table, | ||
settings: AuthSettings | ||
) -> None: | ||
self.db_session = db_session | ||
self.web_session = web_session | ||
self.table = table | ||
self.settings = settings | ||
|
||
async def authenticate(self, email: str, password: str) -> str: | ||
url = self.settings.authorize_url | ||
data = { | ||
"username": email, | ||
"password": password, | ||
"client_id": self.settings.client_id, | ||
"client_secret": self.settings.client_secret, | ||
"audience": self.settings.audience, | ||
"grant_type": self.settings.grant_type | ||
} | ||
async with self.web_session.post(url, data=data) as response: | ||
response_data = await response.json() | ||
if response.status == 200: | ||
token: str | None = response_data.get("access_token") | ||
if token: | ||
return token | ||
raise AuthenticationError(response_data) | ||
|
||
async def get_user_id_by_sub(self, sub: str) -> UserId: | ||
query = select(self.table).where(self.table.c.auth_id == sub) | ||
result = await self.db_session.execute(query) | ||
try: | ||
return UserId(next(result.mappings())["id"]) | ||
except StopIteration: | ||
raise AuthenticationError("Invalid auth sub. User is not exists.") |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from datetime import datetime, timedelta | ||
from typing import Any, Literal | ||
|
||
from aiohttp import ClientSession | ||
from jose import exceptions as jwt_exc | ||
from jose import jwt | ||
|
||
from costy.application.common.auth_gateway import AuthLoger | ||
from costy.application.common.id_provider import IdProvider | ||
from costy.domain.exceptions.access import AuthenticationError | ||
from costy.domain.models.user import UserId | ||
|
||
Algorithm = Literal[ | ||
"HS256", "HS384", "HS512", | ||
"RS256", "RS384", "RS512", | ||
] | ||
|
||
|
||
class JwtTokenProcessor: | ||
def __init__( | ||
self, | ||
algorithm: Algorithm, | ||
audience: str, | ||
issuer: str, | ||
): | ||
self.algorithm = algorithm | ||
self.audience = audience | ||
self.issuer = issuer | ||
|
||
def _fetch_rsa_key(self, jwks: dict[Any, Any], unverified_header: dict[str, str]) -> dict[str, str]: | ||
rsa_key = {} | ||
for key in jwks["keys"]: | ||
if key["kid"] == unverified_header["kid"]: | ||
rsa_key = { | ||
"kty": key["kty"], | ||
"kid": key["kid"], | ||
"use": key["use"], | ||
"n": key["n"], | ||
"e": key["e"] | ||
} | ||
return rsa_key | ||
|
||
def validate_token(self, token: str, jwks: dict[Any, Any]) -> str: | ||
invalid_header_error = AuthenticationError( | ||
{"detail": "Invalid header. Use an RS256 signed JWT Access Token"} | ||
) | ||
try: | ||
unverified_header = jwt.get_unverified_header(token) | ||
except jwt_exc.JWTError: | ||
raise invalid_header_error | ||
if unverified_header["alg"] == "HS256": | ||
raise invalid_header_error | ||
rsa_key = self._fetch_rsa_key(jwks, unverified_header) | ||
try: | ||
payload: dict[str, str] = jwt.decode( | ||
token, | ||
rsa_key, | ||
algorithms=[self.algorithm], | ||
audience=self.audience, | ||
issuer=self.issuer | ||
) | ||
return payload["sub"] | ||
except jwt_exc.ExpiredSignatureError: | ||
raise AuthenticationError({"detail": "token is expired"}) | ||
except jwt_exc.JWTClaimsError: | ||
raise AuthenticationError( | ||
{"detail": "incorrect claims (check audience and issuer)"} | ||
) | ||
except Exception: | ||
raise AuthenticationError( | ||
{"detail": "Unable to parse authentication token."} | ||
) | ||
|
||
|
||
class KeySetProvider: | ||
def __init__(self, uri: str, session: ClientSession, expired: timedelta): | ||
self.session = session | ||
self.jwks: dict[str, str] = {} | ||
self.expired = expired | ||
self.last_updated: datetime | None = None | ||
self.uri = uri | ||
|
||
async def get_key_set(self) -> dict[Any, Any]: | ||
if not self.jwks: | ||
await self._request_new_key_set() | ||
if self.last_updated and datetime.now() - self.last_updated > self.expired: | ||
# TODO: add use Cache-Control | ||
await self._request_new_key_set() | ||
return self.jwks | ||
|
||
async def _request_new_key_set(self) -> None: | ||
async with self.session.get(self.uri) as response: | ||
self.jwks = await response.json() | ||
self.last_updated = datetime.now() | ||
|
||
|
||
class TokenIdProvider(IdProvider): | ||
def __init__( | ||
self, | ||
token_processor: JwtTokenProcessor, | ||
key_set_provider: KeySetProvider, | ||
token: str | None = None | ||
): | ||
self.token_processor = token_processor | ||
self.key_set_provider = key_set_provider | ||
self.token = token | ||
self.auth_gateway: AuthLoger | None = None | ||
|
||
async def get_current_user_id(self) -> UserId: | ||
if self.token and self.auth_gateway: | ||
jwks = await self.key_set_provider.get_key_set() | ||
sub = self.token_processor.validate_token(self.token, jwks) | ||
user_id = await self.auth_gateway.get_user_id_by_sub(sub) | ||
return user_id | ||
raise AuthenticationError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,27 @@ | ||
from sqlalchemy import select | ||
from adaptix import Retort | ||
from sqlalchemy import Table, select | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
|
||
from costy.application.common.user_gateway import UserReader, UserSaver | ||
from costy.domain.models.user import User, UserId | ||
|
||
|
||
class UserGateway(UserSaver, UserReader): | ||
def __init__(self, session: AsyncSession): | ||
def __init__(self, session: AsyncSession, table: Table, retort: Retort): | ||
self.session = session | ||
self.table = table | ||
self.retort = retort | ||
|
||
async def save_user(self, user: User) -> None: | ||
self.session.add(user) | ||
await self.session.flush(objects=[user]) | ||
|
||
async def get_user_by_id(self, user_id: UserId) -> User | None: | ||
query = select(User).where(User.id == user_id) # type: ignore | ||
result: User | None = await self.session.scalar(query) | ||
return result | ||
|
||
async def get_user_by_email(self, email: str) -> User | None: | ||
query = select(User).where(User.email == email) # type: ignore | ||
result: User | None = await self.session.scalar(query) | ||
return result | ||
query = select(self.table).where(self.table.c.id == user_id) | ||
result = await self.session.scalar(query) | ||
try: | ||
data = next(result.mapping()) | ||
user: User = self.retort.load(data, User) | ||
return user | ||
except StopIteration: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.