From 4d13e4cf6403e46b71d36267833932915f95e9b0 Mon Sep 17 00:00:00 2001 From: Andrew Srg Date: Mon, 12 Feb 2024 23:50:57 +0200 Subject: [PATCH 1/5] feature: add JWT id provider --- .gitignore | 2 + pyproject.toml | 2 + src/costy/adapters/auth/id_provider.py | 10 -- src/costy/adapters/auth/token.py | 105 ++++++++++++++++++ src/costy/domain/exceptions/__init__.py | 0 src/costy/domain/exceptions/access.py | 2 + src/costy/infrastructure/auth.py | 28 +++++ src/costy/main/web.py | 14 +++ .../api/dependencies/id_provider.py | 19 ++-- 9 files changed, 162 insertions(+), 20 deletions(-) delete mode 100644 src/costy/adapters/auth/id_provider.py create mode 100644 src/costy/adapters/auth/token.py create mode 100644 src/costy/domain/exceptions/__init__.py create mode 100644 src/costy/domain/exceptions/access.py create mode 100644 src/costy/infrastructure/auth.py diff --git a/.gitignore b/.gitignore index 0854d02..6f0482f 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,5 @@ dmypy.json .idea /pytest.ini +/docker-compose.yml +/keycloak/ diff --git a/pyproject.toml b/pyproject.toml index 55fe57e..8fefcec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,8 @@ dependencies = [ 'psycopg[binary]', 'alembic', 'adaptix', + 'aiohttp', + 'python-jose' ] [project.optional-dependencies] diff --git a/src/costy/adapters/auth/id_provider.py b/src/costy/adapters/auth/id_provider.py deleted file mode 100644 index 475e07b..0000000 --- a/src/costy/adapters/auth/id_provider.py +++ /dev/null @@ -1,10 +0,0 @@ -from costy.application.common.id_provider import IdProvider -from costy.domain.models.user import UserId - - -class SimpleIdProvider(IdProvider): - def __init__(self, user_id: UserId): - self.user_id = user_id - - async def get_current_user_id(self) -> UserId: - return self.user_id diff --git a/src/costy/adapters/auth/token.py b/src/costy/adapters/auth/token.py new file mode 100644 index 0000000..cfeee67 --- /dev/null +++ b/src/costy/adapters/auth/token.py @@ -0,0 +1,105 @@ +from datetime import datetime, timedelta +from typing import Literal + +from aiohttp import ClientSession +from jose import jwt + +from costy.application.common.id_provider import IdProvider +from costy.domain.exceptions.access import AuthenticationError +from costy.domain.models.user import UserId + +Algorithm = Literal[ + "HS256", "HS384", "HS512", + "RS256", "RS384", "RS512", +] + + +class JwtTokenProcessor: + def __init__( + self, + algorithm: Algorithm, + audience: str, + issuer: str, + ): + self.algorithm = algorithm + self.audience = audience + self.issuer = issuer + + def _fetch_rsa_key(self, jwks: dict, unverified_header: dict) -> dict: + rsa_key = {} + for key in jwks["keys"]: + if key["kid"] == unverified_header["kid"]: + rsa_key = { + "kty": key["kty"], + "kid": key["kid"], + "use": key["use"], + "n": key["n"], + "e": key["e"] + } + return rsa_key + + def validate_token(self, token: str, jwks: dict) -> UserId: + invalid_header_error = AuthenticationError( + {"detail": "Invalid header. Use an RS256 signed JWT Access Token"} + ) + try: + unverified_header = jwt.get_unverified_header(token) + except jwt.JWTError: + raise invalid_header_error + if unverified_header["alg"] == "HS256": + raise invalid_header_error + rsa_key = self._fetch_rsa_key(jwks, unverified_header) + try: + payload = jwt.decode( + token, + rsa_key, + algorithms=[self.algorithm], + audience=self.audience, + issuer=self.issuer + ) + return payload["sub"] + except jwt.ExpiredSignatureError: + raise AuthenticationError({"detail": "token is expired"}) + except jwt.JWTClaimsError: + raise AuthenticationError( + {"detail": "incorrect claims (check audience and issuer)"} + ) + except Exception: + raise AuthenticationError( + {"detail": "Unable to parse authentication token."} + ) + + +class KeySetProvider: + def __init__(self, uri: str, session: ClientSession, expired: timedelta): + self.session = session + self.jwks: dict[str, str] = {} + self.expired = expired + self.last_updated = None + self.uri = uri + + async def get_key_set(self): + if not self.jwks: + await self._request_new_key_set() + return self.jwks + + async def _request_new_key_set(self): + async with self.session.get(self.uri) as response: + self.jwks = await response.json() + self.last_updated = datetime.now() + + +class TokenIdProvider(IdProvider): + def __init__( + self, + token: str, + token_processor: JwtTokenProcessor, + key_set_provider: KeySetProvider, + ): + self.token_processor = token_processor + self.key_set_provider = key_set_provider + self.token = token + + async def get_current_user_id(self) -> UserId: + jwks = await self.key_set_provider.get_key_set() + return self.token_processor.validate_token(self.token, jwks) diff --git a/src/costy/domain/exceptions/__init__.py b/src/costy/domain/exceptions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/costy/domain/exceptions/access.py b/src/costy/domain/exceptions/access.py new file mode 100644 index 0000000..a9455d7 --- /dev/null +++ b/src/costy/domain/exceptions/access.py @@ -0,0 +1,2 @@ +class AuthenticationError(Exception): + pass diff --git a/src/costy/infrastructure/auth.py b/src/costy/infrastructure/auth.py new file mode 100644 index 0000000..d4a2d00 --- /dev/null +++ b/src/costy/infrastructure/auth.py @@ -0,0 +1,28 @@ +import typing +from datetime import timedelta + +from aiohttp import ClientSession + +from costy.adapters.auth.token import ( + Algorithm, + JwtTokenProcessor, + KeySetProvider, + TokenIdProvider, +) + + +def create_id_provider_factory( + audience: str, + algorithm: Algorithm, + issuer: str, + jwsk_uri: str, + web_session: ClientSession, + jwsk_expired: timedelta = timedelta(days=1) +) -> typing.Callable: + token_processor = JwtTokenProcessor(algorithm, audience, issuer) + jwsk_provider = KeySetProvider(jwsk_uri, web_session, jwsk_expired) + + def factory(token: str): + return TokenIdProvider(token, token_processor, jwsk_provider) + + return factory diff --git a/src/costy/main/web.py b/src/costy/main/web.py index 06ca1da..b7e800f 100644 --- a/src/costy/main/web.py +++ b/src/costy/main/web.py @@ -1,10 +1,13 @@ +import os from typing import Any, Callable, Coroutine, TypeVar from adaptix import Retort +from aiohttp import ClientSession from litestar import Litestar from litestar.di import Provide from sqlalchemy.orm import registry +from costy.infrastructure.auth import create_id_provider_factory from costy.infrastructure.config import get_db_connection_url from costy.infrastructure.db.main import get_engine, get_sessionmaker from costy.infrastructure.db.orm import create_tables, map_tables_to_models @@ -28,6 +31,15 @@ def init_app() -> Litestar: session_factory = get_sessionmaker(get_engine(get_db_connection_url())) ioc = IoC(session_factory=session_factory, retort=Retort()) + web_session = ClientSession() + id_provider_factory = create_id_provider_factory( + os.environ.get("AUTH0_AUDIENCE", ""), + "RS256", + os.environ.get("AUTH0_ISSUER", ""), + os.environ.get("AUTH0_JWKS_URI", ""), + web_session + ) + mapper_registry = registry() tables = create_tables(mapper_registry) map_tables_to_models(mapper_registry, tables) @@ -41,6 +53,8 @@ def init_app() -> Litestar: dependencies={ "ioc": Provide(singleton(ioc)), "id_provider": Provide(get_id_provider), + "id_provider_factory": Provide(id_provider_factory) }, + on_shutdown=[lambda: web_session.close()], debug=True ) diff --git a/src/costy/presentation/api/dependencies/id_provider.py b/src/costy/presentation/api/dependencies/id_provider.py index 4e21306..f6f7d57 100644 --- a/src/costy/presentation/api/dependencies/id_provider.py +++ b/src/costy/presentation/api/dependencies/id_provider.py @@ -1,17 +1,16 @@ -from typing import Any +from typing import Annotated, Callable from litestar.exceptions import HTTPException +from litestar.params import Parameter -from costy.adapters.auth.id_provider import SimpleIdProvider from costy.application.common.id_provider import IdProvider -from costy.domain.models.user import UserId -async def get_id_provider(cookies: dict[str, Any]) -> IdProvider: - # This is a simple version that will be improved - user_id: str | None = cookies.get("user_id") - if not user_id: +async def get_id_provider( + token: Annotated[str, Parameter(header="Authorization")], + id_provider_factory: Callable +) -> IdProvider: + if not token: raise HTTPException("Not authenticated", status_code=401) - if not user_id.isdigit(): - raise HTTPException("Not valid user_id", status_code=401) - return SimpleIdProvider(UserId(int(cookies["user_id"]))) + token_type, token = token.split(" ") + return id_provider_factory(token) From bac7849fe848eeb47a9d0c967277314937b52b2b Mon Sep 17 00:00:00 2001 From: Andrew Srg Date: Thu, 15 Feb 2024 17:56:29 +0200 Subject: [PATCH 2/5] feature: create external auth service integration --- .pre-commit-config.yaml | 1 + src/costy/adapters/auth/auth_gateway.py | 55 +++++++++++++++++++ src/costy/adapters/auth/token.py | 17 ++++-- src/costy/adapters/db/category_gateway.py | 15 ++--- src/costy/adapters/db/operation_gateway.py | 23 ++++---- src/costy/adapters/db/user_gateway.py | 15 +++-- src/costy/application/authenticate.py | 16 +++--- .../application/category/create_category.py | 17 +++--- .../category/read_available_categories.py | 2 +- src/costy/application/common/auth_gateway.py | 14 +++++ src/costy/application/common/user_gateway.py | 2 +- .../application/operation/create_operation.py | 4 +- src/costy/application/user/create_user.py | 2 +- src/costy/domain/models/user.py | 2 - src/costy/domain/services/user.py | 4 +- src/costy/infrastructure/auth.py | 4 +- src/costy/infrastructure/db/main.py | 6 +- src/costy/infrastructure/db/migrations/env.py | 8 +-- .../14d9cdbdf029_change_users_table.py | 37 +++++++++++++ src/costy/infrastructure/db/orm.py | 36 +++--------- src/costy/main/ioc.py | 44 ++++++++++++--- src/costy/main/web.py | 24 ++++---- src/costy/presentation/api/authenticate.py | 15 ++--- .../api/dependencies/id_provider.py | 15 +++-- src/costy/presentation/api/operation.py | 2 +- tests/domain/test_create.py | 4 +- 26 files changed, 257 insertions(+), 127 deletions(-) create mode 100644 src/costy/adapters/auth/auth_gateway.py create mode 100644 src/costy/application/common/auth_gateway.py create mode 100644 src/costy/infrastructure/db/migrations/versions/14d9cdbdf029_change_users_table.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15317ff..6a123bd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,6 +31,7 @@ repos: - id: flake8 files: src exclude: "migrations/" + args: [--max-line-length, "120"] - repo: https://github.com/pre-commit/mirrors-mypy rev: 'v1.8.0' hooks: diff --git a/src/costy/adapters/auth/auth_gateway.py b/src/costy/adapters/auth/auth_gateway.py new file mode 100644 index 0000000..af26bd0 --- /dev/null +++ b/src/costy/adapters/auth/auth_gateway.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass + +from aiohttp import ClientSession +from sqlalchemy import Select, Table +from sqlalchemy.ext.asyncio import AsyncSession + +from costy.application.common.auth_gateway import AuthLoger +from costy.domain.exceptions.access import AuthenticationError +from costy.domain.models.user import UserId + + +@dataclass +class AuthSettings: + authorize_url: str + client_id: str + client_secret: str + audience: str + grant_type: str + + +class AuthGateway(AuthLoger): + def __init__( + self, + db_session: AsyncSession, + web_session: ClientSession, + table: Table, + settings: AuthSettings + ) -> None: + self.db_session = db_session + self.web_session = web_session + self.table = table + self.settings = settings + + async def authenticate(self, email: str, password: str) -> str: + url = self.settings.authorize_url + data = { + "username": email, + "password": password, + "client_id": self.settings.client_id, + "client_secret": self.settings.client_secret, + "audience": self.settings.audience, + "grant_type": self.settings.grant_type + } + async with self.web_session.post(url, data=data) as response: + if response.status == 200: + data = await response.json() + token = data.get("access_token") + if token: + return token + raise AuthenticationError() + + async def get_user_by_sub(self, sub: str) -> UserId: + query = Select(self.table).where(self.table.c.auth_id == sub) + result = await self.db_session.execute(query) + return tuple(result)[0][0] diff --git a/src/costy/adapters/auth/token.py b/src/costy/adapters/auth/token.py index cfeee67..add1215 100644 --- a/src/costy/adapters/auth/token.py +++ b/src/costy/adapters/auth/token.py @@ -4,6 +4,7 @@ from aiohttp import ClientSession from jose import jwt +from costy.application.common.auth_gateway import AuthLoger from costy.application.common.id_provider import IdProvider from costy.domain.exceptions.access import AuthenticationError from costy.domain.models.user import UserId @@ -38,7 +39,7 @@ def _fetch_rsa_key(self, jwks: dict, unverified_header: dict) -> dict: } return rsa_key - def validate_token(self, token: str, jwks: dict) -> UserId: + def validate_token(self, token: str, jwks: dict) -> str: invalid_header_error = AuthenticationError( {"detail": "Invalid header. Use an RS256 signed JWT Access Token"} ) @@ -81,6 +82,9 @@ def __init__(self, uri: str, session: ClientSession, expired: timedelta): async def get_key_set(self): if not self.jwks: await self._request_new_key_set() + if self.last_updated and datetime.now() - self.last_updated > self.expired: + # TODO: add use Cache-Control + await self._request_new_key_set() return self.jwks async def _request_new_key_set(self): @@ -92,14 +96,19 @@ async def _request_new_key_set(self): class TokenIdProvider(IdProvider): def __init__( self, - token: str, token_processor: JwtTokenProcessor, key_set_provider: KeySetProvider, + token: str | None = None ): self.token_processor = token_processor self.key_set_provider = key_set_provider self.token = token + self.auth_gateway: AuthLoger | None = None async def get_current_user_id(self) -> UserId: - jwks = await self.key_set_provider.get_key_set() - return self.token_processor.validate_token(self.token, jwks) + if self.token and self.auth_gateway: + jwks = await self.key_set_provider.get_key_set() + sub = self.token_processor.validate_token(self.token, jwks) + user = await self.auth_gateway.get_user_by_sub(sub) + return user # type: ignore + raise AuthenticationError() diff --git a/src/costy/adapters/db/category_gateway.py b/src/costy/adapters/db/category_gateway.py index 7be5a01..1225b14 100644 --- a/src/costy/adapters/db/category_gateway.py +++ b/src/costy/adapters/db/category_gateway.py @@ -1,5 +1,5 @@ from adaptix import Retort -from sqlalchemy import delete, or_, select +from sqlalchemy import Table, delete, or_, select from sqlalchemy.ext.asyncio import AsyncSession from costy.application.category.dto import CategoryDTO @@ -16,13 +16,14 @@ class CategoryGateway( CategoryReader, CategorySaver, CategoryDeleter, CategoriesReader ): - def __init__(self, session: AsyncSession, retort: Retort): + def __init__(self, session: AsyncSession, table: Table, retort: Retort): self.session = session + self.table = table self.retort = retort async def get_category(self, category_id: CategoryId) -> Category | None: - query = select(Category).where( - Category.id == category_id # type: ignore + query = select(self.table).where( + self.table.c.id == category_id ) result: Category | None = await self.session.scalar(query) return result @@ -39,10 +40,10 @@ async def delete_category(self, category_id: CategoryId) -> None: async def find_categories(self, user_id: UserId) -> list[CategoryDTO]: filter_expr = or_( - Category.user_id == user_id, # type: ignore - Category.user_id == None # type: ignore # noqa: E711 + self.table.c.user_id == user_id, # type: ignore + self.table.c.user_id == None # type: ignore # noqa: E711 ) - query = select(Category).where(filter_expr) + query = select(self.table).where(filter_expr) categories = list(await self.session.scalars(query)) dumped = self.retort.dump(categories, list[Category]) return self.retort.load(dumped, list[CategoryDTO]) diff --git a/src/costy/adapters/db/operation_gateway.py b/src/costy/adapters/db/operation_gateway.py index af11dab..c50cd81 100644 --- a/src/costy/adapters/db/operation_gateway.py +++ b/src/costy/adapters/db/operation_gateway.py @@ -1,4 +1,5 @@ -from sqlalchemy import delete, select +from adaptix import Retort +from sqlalchemy import Table, delete, select from sqlalchemy.ext.asyncio import AsyncSession from costy.application.common.operation_gateway import ( @@ -14,14 +15,16 @@ class OperationGateway( OperationReader, OperationSaver, OperationDeleter, OperationsReader ): - def __init__(self, session: AsyncSession): + def __init__(self, session: AsyncSession, table: Table, retort: Retort): self.session = session + self.table = table + self.retort = retort async def get_operation( self, operation_id: OperationId ) -> Operation | None: - query = select(Operation).where( - Operation.id == operation_id # type: ignore + query = select(self.table).where( + self.table.c.id == operation_id ) result: Operation | None = await self.session.scalar(query) return result @@ -31,8 +34,8 @@ async def save_operation(self, operation: Operation) -> None: await self.session.flush(objects=[operation]) async def delete_operation(self, operation_id: OperationId) -> None: - query = delete(Operation).where( - Operation.id == operation_id # type: ignore + query = delete(self.table).where( + self.table.c.id == operation_id ) await self.session.execute(query) @@ -40,11 +43,11 @@ async def find_operations_by_user( self, user_id: UserId, from_time: int | None, to_time: int | None ) -> list[Operation]: query = ( - select(Operation) - .where(Operation.user_id == user_id) # type: ignore + select(self.table) + .where(self.table.c.user_id == user_id) ) if from_time: - query = query.where(Operation.time >= from_time) # type: ignore + query = query.where(self.table.c.time >= from_time) if to_time: - query = query.where(Operation.time <= to_time) # type: ignore + query = query.where(self.table.c.time <= to_time) return list(await self.session.scalars(query)) diff --git a/src/costy/adapters/db/user_gateway.py b/src/costy/adapters/db/user_gateway.py index 5481e06..c3f03cd 100644 --- a/src/costy/adapters/db/user_gateway.py +++ b/src/costy/adapters/db/user_gateway.py @@ -1,4 +1,5 @@ -from sqlalchemy import select +from adaptix import Retort +from sqlalchemy import Table, select from sqlalchemy.ext.asyncio import AsyncSession from costy.application.common.user_gateway import UserReader, UserSaver @@ -6,19 +7,21 @@ class UserGateway(UserSaver, UserReader): - def __init__(self, session: AsyncSession): + def __init__(self, session: AsyncSession, table: Table, retort: Retort): self.session = session + self.table = table + self.retort = retort async def save_user(self, user: User) -> None: self.session.add(user) await self.session.flush(objects=[user]) async def get_user_by_id(self, user_id: UserId) -> User | None: - query = select(User).where(User.id == user_id) # type: ignore + query = select(self.table).where(self.table.c.id == user_id) result: User | None = await self.session.scalar(query) return result - async def get_user_by_email(self, email: str) -> User | None: - query = select(User).where(User.email == email) # type: ignore - result: User | None = await self.session.scalar(query) + async def get_user_by_auth_id(self, auth_id: str) -> User | None: + query = select(self.table).where(self.table.c.auth_id == auth_id) + result = await self.session.scalar(query) return result diff --git a/src/costy/application/authenticate.py b/src/costy/application/authenticate.py index ef2227f..46c83b2 100644 --- a/src/costy/application/authenticate.py +++ b/src/costy/application/authenticate.py @@ -1,9 +1,8 @@ from dataclasses import dataclass +from costy.application.common.auth_gateway import AuthLoger from costy.application.common.interactor import Interactor from costy.application.common.uow import UoW -from costy.application.common.user_gateway import UserReader -from costy.domain.models.user import UserId @dataclass @@ -12,12 +11,11 @@ class LoginInputDTO: password: str -class Authenticate(Interactor[LoginInputDTO, UserId | None]): - def __init__(self, user_gateway: UserReader, uow: UoW): - self.user_gateway = user_gateway +class Authenticate(Interactor[LoginInputDTO, str | None]): + def __init__(self, auth_gateway: AuthLoger, uow: UoW): + self.auth_gateway = auth_gateway self.uow = uow - async def __call__(self, data: LoginInputDTO) -> UserId | None: - user = await self.user_gateway.get_user_by_email(data.email) - # TODO: compare hashed passwords - return user.id if user else None + async def __call__(self, data: LoginInputDTO) -> str | None: + token = await self.auth_gateway.authenticate(data.email, data.password) + return token diff --git a/src/costy/application/category/create_category.py b/src/costy/application/category/create_category.py index 47a9176..272062e 100644 --- a/src/costy/application/category/create_category.py +++ b/src/costy/application/category/create_category.py @@ -1,6 +1,7 @@ from costy.domain.models.category import CategoryId, CategoryType from costy.domain.services.category import CategoryService +from ...domain.exceptions.access import AuthenticationError from ..common.category_gateway import CategorySaver from ..common.id_provider import IdProvider from ..common.interactor import Interactor @@ -23,10 +24,12 @@ def __init__( async def __call__(self, data: NewCategoryDTO) -> CategoryId: user_id = await self.id_provider.get_current_user_id() - category = self.category_service.create( - data.name, CategoryType.PERSONAL, user_id - ) - await self.category_db_gateway.save_category(category) - category_id = category.id - await self.uow.commit() - return category_id # type: ignore + if user_id: # type: ignore + category = self.category_service.create( + data.name, CategoryType.PERSONAL, user_id # type: ignore + ) + await self.category_db_gateway.save_category(category) + category_id = category.id + await self.uow.commit() + return category_id # type: ignore + raise AuthenticationError("User not found") diff --git a/src/costy/application/category/read_available_categories.py b/src/costy/application/category/read_available_categories.py index caaf7a8..b648bc2 100644 --- a/src/costy/application/category/read_available_categories.py +++ b/src/costy/application/category/read_available_categories.py @@ -26,4 +26,4 @@ async def __call__( self, data: Optional[ReadAvailableCategoriesDTO] = None ) -> List[CategoryDTO]: user_id = await self.id_provider.get_current_user_id() - return await self.category_db_gateway.find_categories(user_id) + return await self.category_db_gateway.find_categories(user_id) # type: ignore diff --git a/src/costy/application/common/auth_gateway.py b/src/costy/application/common/auth_gateway.py new file mode 100644 index 0000000..34551bc --- /dev/null +++ b/src/costy/application/common/auth_gateway.py @@ -0,0 +1,14 @@ +from abc import abstractmethod +from typing import Protocol + +from costy.domain.models.user import UserId + + +class AuthLoger(Protocol): + @abstractmethod + async def authenticate(self, email: str, password: str) -> str | None: + raise NotImplementedError + + @abstractmethod + async def get_user_by_sub(self, sub: str) -> UserId | None: + raise NotImplementedError diff --git a/src/costy/application/common/user_gateway.py b/src/costy/application/common/user_gateway.py index 1b852ac..9d6e606 100644 --- a/src/costy/application/common/user_gateway.py +++ b/src/costy/application/common/user_gateway.py @@ -11,7 +11,7 @@ async def get_user_by_id(self, user_id: UserId) -> User | None: raise NotImplementedError @abstractmethod - async def get_user_by_email(self, email: str) -> User | None: + async def get_user_by_auth_id(self, auth_id: str) -> User | None: raise NotImplementedError diff --git a/src/costy/application/operation/create_operation.py b/src/costy/application/operation/create_operation.py index 55a9e13..ef8e357 100644 --- a/src/costy/application/operation/create_operation.py +++ b/src/costy/application/operation/create_operation.py @@ -22,11 +22,11 @@ def __init__( async def __call__(self, data: NewOperationDTO) -> OperationId: user_id = await self.id_provider.get_current_user_id() - operation = self.operation_service.create( + operation = self.operation_service.create( # type: ignore data.amount, data.description, data.time, - user_id, + user_id, # type: ignore data.category_id, ) await self.operation_db_gateway.save_operation(operation) diff --git a/src/costy/application/user/create_user.py b/src/costy/application/user/create_user.py index e812b27..f5fffeb 100644 --- a/src/costy/application/user/create_user.py +++ b/src/costy/application/user/create_user.py @@ -19,7 +19,7 @@ def __init__( self.uow = uow async def __call__(self, data: NewUserDTO) -> UserId: - user = self.user_service.create(data.email, data.password) + user = self.user_service.create() await self.user_db_gateway.save_user(user) user_id = user.id await self.uow.commit() diff --git a/src/costy/domain/models/user.py b/src/costy/domain/models/user.py index 982bcc3..ec8ed45 100644 --- a/src/costy/domain/models/user.py +++ b/src/costy/domain/models/user.py @@ -7,5 +7,3 @@ @dataclass class User: id: UserId | None - email: str - hashed_password: str diff --git a/src/costy/domain/services/user.py b/src/costy/domain/services/user.py index ac6dc02..34d406c 100644 --- a/src/costy/domain/services/user.py +++ b/src/costy/domain/services/user.py @@ -2,5 +2,5 @@ class UserService: - def create(self, email: str, hashed_password: str) -> User: - return User(id=None, email=email, hashed_password=hashed_password) + def create(self) -> User: + return User(id=None) diff --git a/src/costy/infrastructure/auth.py b/src/costy/infrastructure/auth.py index d4a2d00..e9ed458 100644 --- a/src/costy/infrastructure/auth.py +++ b/src/costy/infrastructure/auth.py @@ -22,7 +22,7 @@ def create_id_provider_factory( token_processor = JwtTokenProcessor(algorithm, audience, issuer) jwsk_provider = KeySetProvider(jwsk_uri, web_session, jwsk_expired) - def factory(token: str): - return TokenIdProvider(token, token_processor, jwsk_provider) + async def factory(): + return TokenIdProvider(token_processor, jwsk_provider) return factory diff --git a/src/costy/infrastructure/db/main.py b/src/costy/infrastructure/db/main.py index 6cd6efb..af07f65 100644 --- a/src/costy/infrastructure/db/main.py +++ b/src/costy/infrastructure/db/main.py @@ -1,10 +1,10 @@ +from sqlalchemy import MetaData from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) -from sqlalchemy.orm import registry def get_engine(url: str) -> AsyncEngine: @@ -15,5 +15,5 @@ def get_sessionmaker(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]: return async_sessionmaker(engine) -def get_registry() -> registry: - return registry() +def get_metadata() -> MetaData: + return MetaData() diff --git a/src/costy/infrastructure/db/migrations/env.py b/src/costy/infrastructure/db/migrations/env.py index 292d593..033eaad 100644 --- a/src/costy/infrastructure/db/migrations/env.py +++ b/src/costy/infrastructure/db/migrations/env.py @@ -4,7 +4,7 @@ from sqlalchemy import engine_from_config, pool from costy.infrastructure.config import get_db_connection_url -from costy.infrastructure.db.main import get_registry +from costy.infrastructure.db.main import get_metadata from costy.infrastructure.db.orm import create_tables # this is the Alembic Config object, which provides @@ -22,9 +22,9 @@ # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata -registry = get_registry() -create_tables(registry) -target_metadata = registry.metadata +metadata = get_metadata() +create_tables(metadata) +target_metadata = metadata # other values from the config, defined by the needs of env.py, # can be acquired: diff --git a/src/costy/infrastructure/db/migrations/versions/14d9cdbdf029_change_users_table.py b/src/costy/infrastructure/db/migrations/versions/14d9cdbdf029_change_users_table.py new file mode 100644 index 0000000..ff5c846 --- /dev/null +++ b/src/costy/infrastructure/db/migrations/versions/14d9cdbdf029_change_users_table.py @@ -0,0 +1,37 @@ +"""change users table + +Revision ID: 14d9cdbdf029 +Revises: f1c4a04700d3 +Create Date: 2024-02-15 15:28:12.592583 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '14d9cdbdf029' +down_revision: Union[str, None] = 'f1c4a04700d3' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('users', sa.Column('auth_id', sa.String(), nullable=False)) + op.drop_index('ix_users_email', table_name='users') + op.create_index(op.f('ix_users_auth_id'), 'users', ['auth_id'], unique=True) + op.drop_column('users', 'email') + op.drop_column('users', 'hashed_password') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('users', sa.Column('hashed_password', sa.VARCHAR(), autoincrement=False, nullable=False)) + op.add_column('users', sa.Column('email', sa.VARCHAR(), autoincrement=False, nullable=False)) + op.drop_index(op.f('ix_users_auth_id'), table_name='users') + op.create_index('ix_users_email', 'users', ['email'], unique=True) + op.drop_column('users', 'auth_id') + # ### end Alembic commands ### diff --git a/src/costy/infrastructure/db/orm.py b/src/costy/infrastructure/db/orm.py index cb7468e..8556ddd 100644 --- a/src/costy/infrastructure/db/orm.py +++ b/src/costy/infrastructure/db/orm.py @@ -1,28 +1,17 @@ -import typing -from typing import Type +from sqlalchemy import Column, ForeignKey, Integer, MetaData, String, Table -from sqlalchemy import Column, ForeignKey, Integer, String, Table -from sqlalchemy.orm import registry -from costy.domain.models.category import Category -from costy.domain.models.operation import Operation -from costy.domain.models.user import User - -Model = typing.Union[Category, Operation, User] - - -def create_tables(mapper_registry: registry) -> dict[Type[Model], Table]: +def create_tables(metadata: MetaData) -> dict[str, Table]: return { - User: Table( + "users": Table( "users", - mapper_registry.metadata, + metadata, Column("id", Integer, primary_key=True), - Column("email", String, unique=True, index=True, nullable=False), - Column("hashed_password", String, nullable=False), + Column("auth_id", String, unique=True, index=True, nullable=False) ), - Operation: Table( + "operations": Table( "operations", - mapper_registry.metadata, + metadata, Column("id", Integer, primary_key=True), Column("amount", Integer, nullable=False), Column("description", String), @@ -35,19 +24,12 @@ def create_tables(mapper_registry: registry) -> dict[Type[Model], Table]: nullable=True ), ), - Category: Table( + "categories": Table( "categories", - mapper_registry.metadata, + metadata, Column("id", Integer, primary_key=True), Column("name", String, nullable=False), Column("user_id", Integer, ForeignKey("users.id"), nullable=True), Column("kind", String, default="general"), ), } - - -def map_tables_to_models( - mapper_registry: registry, tables: dict[Type[Model], Table] -) -> None: - for model, table in tables.items(): - mapper_registry.map_imperatively(model, table) diff --git a/src/costy/main/ioc.py b/src/costy/main/ioc.py index a2795ee..95c55be 100644 --- a/src/costy/main/ioc.py +++ b/src/costy/main/ioc.py @@ -1,10 +1,14 @@ +import os from contextlib import asynccontextmanager from dataclasses import dataclass from typing import AsyncIterator from adaptix import Retort +from aiohttp import ClientSession +from sqlalchemy import Table from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from costy.adapters.auth.auth_gateway import AuthGateway, AuthSettings from costy.adapters.db.category_gateway import CategoryGateway from costy.adapters.db.operation_gateway import OperationGateway from costy.adapters.db.uow import OrmUoW @@ -28,35 +32,52 @@ @dataclass class Depends: session: AsyncSession + web_session: ClientSession uow: OrmUoW - retort: Retort + auth_gateway: AuthGateway class IoC(InteractorFactory): def __init__( self, session_factory: async_sessionmaker[AsyncSession], + web_session: ClientSession, + tables: dict[str, Table], retort: Retort ): self._session_factory = session_factory + self._web_session = web_session + self._tables = tables self._retort = retort + self._settings = AuthSettings( + authorize_url="https://dev-66quvcmw46dh86sh.us.auth0.com/oauth/token", + grant_type="password", + client_id=os.getenv("AUTH0_CLIENT_ID", ""), + client_secret=os.getenv("AUTH0_CLIENT_ID", ""), + audience="https://dev-66quvcmw46dh86sh.us.auth0.com/api/v2/" + ) @asynccontextmanager async def _init_depends(self) -> AsyncIterator[Depends]: session = self._session_factory() - yield Depends(session, OrmUoW(session), self._retort) + auth_gateway = AuthGateway(session, self._web_session, self._tables["users"], self._settings) + yield Depends(session, self._web_session, OrmUoW(session), auth_gateway) await session.close() + await self._web_session.close() @asynccontextmanager async def authenticate(self) -> AsyncIterator[Authenticate]: async with self._init_depends() as depends: - yield Authenticate(UserGateway(depends.session), depends.uow) + yield Authenticate( + depends.auth_gateway, + depends.uow + ) @asynccontextmanager async def create_user(self) -> AsyncIterator[CreateUser]: async with self._init_depends() as depends: yield CreateUser( - UserService(), UserGateway(depends.session), depends.uow + UserService(), UserGateway(depends.session, self._tables["users"], self._retort), depends.uow ) @asynccontextmanager @@ -64,9 +85,10 @@ async def create_operation( self, id_provider: IdProvider ) -> AsyncIterator[CreateOperation]: async with self._init_depends() as depends: + id_provider.auth_gateway = depends.auth_gateway # type: ignore yield CreateOperation( OperationService(), - OperationGateway(depends.session), + OperationGateway(depends.session, self._tables["operations"], self._retort), id_provider, depends.uow ) @@ -75,9 +97,10 @@ async def read_operation( self, id_provider: IdProvider ) -> AsyncIterator[ReadOperation]: async with self._init_depends() as depends: + id_provider.auth_gateway = depends.auth_gateway # type: ignore yield ReadOperation( OperationService(), - OperationGateway(depends.session), + OperationGateway(depends.session, self._tables["operations"], self._retort), id_provider, depends.uow ) @@ -87,9 +110,10 @@ async def read_list_operation( self, id_provider: IdProvider ) -> AsyncIterator[ReadListOperation]: async with self._init_depends() as depends: + id_provider.auth_gateway = depends.auth_gateway # type: ignore yield ReadListOperation( OperationService(), - OperationGateway(depends.session), + OperationGateway(depends.session, self._tables["operations"], self._retort), id_provider, depends.uow ) @@ -99,9 +123,10 @@ async def create_category( self, id_provider: IdProvider ) -> AsyncIterator[CreateCategory]: async with self._init_depends() as depends: + id_provider.auth_gateway = depends.auth_gateway # type: ignore yield CreateCategory( CategoryService(), - CategoryGateway(depends.session, depends.retort), + CategoryGateway(depends.session, self._tables["categories"], self._retort), id_provider, depends.uow ) @@ -111,9 +136,10 @@ async def read_available_categories( self, id_provider: IdProvider ) -> AsyncIterator[ReadAvailableCategories]: async with self._init_depends() as depends: + id_provider.auth_gateway = depends.auth_gateway # type: ignore yield ReadAvailableCategories( CategoryService(), - CategoryGateway(depends.session, depends.retort), + CategoryGateway(depends.session, self._tables["categories"], self._retort), id_provider, depends.uow ) diff --git a/src/costy/main/web.py b/src/costy/main/web.py index b7e800f..d6a3e9e 100644 --- a/src/costy/main/web.py +++ b/src/costy/main/web.py @@ -5,13 +5,17 @@ from aiohttp import ClientSession from litestar import Litestar from litestar.di import Provide -from sqlalchemy.orm import registry from costy.infrastructure.auth import create_id_provider_factory from costy.infrastructure.config import get_db_connection_url -from costy.infrastructure.db.main import get_engine, get_sessionmaker -from costy.infrastructure.db.orm import create_tables, map_tables_to_models +from costy.infrastructure.db.main import ( + get_engine, + get_metadata, + get_sessionmaker, +) +from costy.infrastructure.db.orm import create_tables from costy.main.ioc import IoC +from costy.presentation.api.authenticate import AuthenticationController from costy.presentation.api.category import CategoryController from costy.presentation.api.dependencies.id_provider import get_id_provider from costy.presentation.api.operation import OperationController @@ -28,10 +32,13 @@ async def func() -> T: def init_app() -> Litestar: - session_factory = get_sessionmaker(get_engine(get_db_connection_url())) - ioc = IoC(session_factory=session_factory, retort=Retort()) + base_metadata = get_metadata() + tables = create_tables(base_metadata) + session_factory = get_sessionmaker(get_engine(get_db_connection_url())) web_session = ClientSession() + ioc = IoC(session_factory=session_factory, web_session=web_session, tables=tables, retort=Retort()) + id_provider_factory = create_id_provider_factory( os.environ.get("AUTH0_AUDIENCE", ""), "RS256", @@ -40,15 +47,12 @@ def init_app() -> Litestar: web_session ) - mapper_registry = registry() - tables = create_tables(mapper_registry) - map_tables_to_models(mapper_registry, tables) - return Litestar( route_handlers=( + AuthenticationController, UserController, OperationController, - CategoryController + CategoryController, ), dependencies={ "ioc": Provide(singleton(ioc)), diff --git a/src/costy/presentation/api/authenticate.py b/src/costy/presentation/api/authenticate.py index 5352d24..4144198 100644 --- a/src/costy/presentation/api/authenticate.py +++ b/src/costy/presentation/api/authenticate.py @@ -1,5 +1,4 @@ from litestar import Controller, Response, post -from litestar.datastructures import Cookie from costy.application.authenticate import LoginInputDTO from costy.presentation.interactor_factory import InteractorFactory @@ -8,12 +7,10 @@ class AuthenticationController(Controller): path = "/auth" - @post() - async def login( - self, ioc: InteractorFactory, data: LoginInputDTO - ) -> Response[str]: + @post(status_code=200) + async def login(self, ioc: InteractorFactory, data: LoginInputDTO) -> Response: async with ioc.authenticate() as authenticate: - user_id = await authenticate(data) - return Response( - "ok", cookies=[Cookie(key="user_id", value=str(user_id))] - ) + token = await authenticate(data) + if token: + return Response({"token": token}, status_code=200) + return Response({"error": "Text"}, status_code=400) diff --git a/src/costy/presentation/api/dependencies/id_provider.py b/src/costy/presentation/api/dependencies/id_provider.py index f6f7d57..10d28c1 100644 --- a/src/costy/presentation/api/dependencies/id_provider.py +++ b/src/costy/presentation/api/dependencies/id_provider.py @@ -1,16 +1,15 @@ -from typing import Annotated, Callable - from litestar.exceptions import HTTPException -from litestar.params import Parameter +from costy.adapters.auth.token import TokenIdProvider from costy.application.common.id_provider import IdProvider async def get_id_provider( - token: Annotated[str, Parameter(header="Authorization")], - id_provider_factory: Callable + headers: dict, + id_provider_factory: TokenIdProvider ) -> IdProvider: - if not token: + if not headers.get("authorization"): raise HTTPException("Not authenticated", status_code=401) - token_type, token = token.split(" ") - return id_provider_factory(token) + token_type, token = headers.get("authorization", "").split(" ") + id_provider_factory.token = token + return id_provider_factory diff --git a/src/costy/presentation/api/operation.py b/src/costy/presentation/api/operation.py index dccaf1d..1ab01af 100644 --- a/src/costy/presentation/api/operation.py +++ b/src/costy/presentation/api/operation.py @@ -10,7 +10,7 @@ class OperationController(Controller): path = '/operations' - @get("{operation_id:int}") + @get("/{operation_id:int}") async def get_operation( self, ioc: InteractorFactory, diff --git a/tests/domain/test_create.py b/tests/domain/test_create.py index d7ff151..4136769 100644 --- a/tests/domain/test_create.py +++ b/tests/domain/test_create.py @@ -11,8 +11,8 @@ @pytest.mark.parametrize("domain_service, data, expected_model", [ ( UserService(), - ("email@test.com", "password"), - User(None, "email@test.com", "password") + (), + User(None) ), ( OperationService(), From 2b9682d0b03c1e9b16a3ea383fb6e38941f782fd Mon Sep 17 00:00:00 2001 From: Andrew Srg Date: Sun, 18 Feb 2024 23:03:50 +0200 Subject: [PATCH 3/5] refactor: mypy refactor | fix tests --- .pre-commit-config.yaml | 2 +- src/costy/adapters/auth/auth_gateway.py | 29 ++++++--------- src/costy/adapters/auth/token.py | 25 ++++++------- src/costy/adapters/db/category_gateway.py | 6 ++-- src/costy/adapters/db/user_gateway.py | 12 +++---- .../application/category/create_category.py | 4 +-- .../category/read_available_categories.py | 2 +- src/costy/application/common/auth_gateway.py | 2 +- src/costy/application/common/user_gateway.py | 4 --- .../application/operation/create_operation.py | 4 +-- src/costy/infrastructure/auth.py | 6 ++-- src/costy/infrastructure/config.py | 35 +++++++++++++++++++ src/costy/main/ioc.py | 16 +++------ src/costy/main/web.py | 18 ++++++---- src/costy/presentation/api/authenticate.py | 2 +- src/costy/presentation/api/category.py | 2 +- .../api/dependencies/id_provider.py | 8 ++--- src/costy/presentation/api/operation.py | 3 +- src/costy/presentation/api/user.py | 2 +- tests/adapters/__init__.py | 0 tests/adapters/test_auth_adapter.py | 6 ++++ .../category/test_create_category.py | 14 ++++---- .../operation/test_create_operation.py | 14 ++++---- tests/application/test_authenticate.py | 24 +++++++------ tests/application/user/test_create_user.py | 11 +++--- tests/domain/test_create.py | 2 +- 26 files changed, 141 insertions(+), 112 deletions(-) create mode 100644 tests/adapters/__init__.py create mode 100644 tests/adapters/test_auth_adapter.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6a123bd..f20db98 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: rev: 'v1.8.0' hooks: - id: mypy - additional_dependencies: [] + args: [--install-types, --non-interactive, --ignore-missing-imports] - repo: https://github.com/PyCQA/bandit rev: 1.7.7 hooks: diff --git a/src/costy/adapters/auth/auth_gateway.py b/src/costy/adapters/auth/auth_gateway.py index af26bd0..b9d145d 100644 --- a/src/costy/adapters/auth/auth_gateway.py +++ b/src/costy/adapters/auth/auth_gateway.py @@ -1,21 +1,11 @@ -from dataclasses import dataclass - from aiohttp import ClientSession -from sqlalchemy import Select, Table +from sqlalchemy import Table, select from sqlalchemy.ext.asyncio import AsyncSession from costy.application.common.auth_gateway import AuthLoger from costy.domain.exceptions.access import AuthenticationError from costy.domain.models.user import UserId - - -@dataclass -class AuthSettings: - authorize_url: str - client_id: str - client_secret: str - audience: str - grant_type: str +from costy.infrastructure.config import AuthSettings class AuthGateway(AuthLoger): @@ -42,14 +32,17 @@ async def authenticate(self, email: str, password: str) -> str: "grant_type": self.settings.grant_type } async with self.web_session.post(url, data=data) as response: + response_data = await response.json() if response.status == 200: - data = await response.json() - token = data.get("access_token") + token: str | None = response_data.get("access_token") if token: return token - raise AuthenticationError() + raise AuthenticationError(response_data) - async def get_user_by_sub(self, sub: str) -> UserId: - query = Select(self.table).where(self.table.c.auth_id == sub) + async def get_user_id_by_sub(self, sub: str) -> UserId: + query = select(self.table).where(self.table.c.auth_id == sub) result = await self.db_session.execute(query) - return tuple(result)[0][0] + try: + return UserId(next(result.mappings())["id"]) + except StopIteration: + raise AuthenticationError("Invalid auth sub. User is not exists.") diff --git a/src/costy/adapters/auth/token.py b/src/costy/adapters/auth/token.py index add1215..394677a 100644 --- a/src/costy/adapters/auth/token.py +++ b/src/costy/adapters/auth/token.py @@ -1,7 +1,8 @@ from datetime import datetime, timedelta -from typing import Literal +from typing import Any, Literal from aiohttp import ClientSession +from jose import exceptions as jwt_exc from jose import jwt from costy.application.common.auth_gateway import AuthLoger @@ -26,7 +27,7 @@ def __init__( self.audience = audience self.issuer = issuer - def _fetch_rsa_key(self, jwks: dict, unverified_header: dict) -> dict: + def _fetch_rsa_key(self, jwks: dict[Any, Any], unverified_header: dict[str, str]) -> dict[str, str]: rsa_key = {} for key in jwks["keys"]: if key["kid"] == unverified_header["kid"]: @@ -39,19 +40,19 @@ def _fetch_rsa_key(self, jwks: dict, unverified_header: dict) -> dict: } return rsa_key - def validate_token(self, token: str, jwks: dict) -> str: + def validate_token(self, token: str, jwks: dict[Any, Any]) -> str: invalid_header_error = AuthenticationError( {"detail": "Invalid header. Use an RS256 signed JWT Access Token"} ) try: unverified_header = jwt.get_unverified_header(token) - except jwt.JWTError: + except jwt_exc.JWTError: raise invalid_header_error if unverified_header["alg"] == "HS256": raise invalid_header_error rsa_key = self._fetch_rsa_key(jwks, unverified_header) try: - payload = jwt.decode( + payload: dict[str, str] = jwt.decode( token, rsa_key, algorithms=[self.algorithm], @@ -59,9 +60,9 @@ def validate_token(self, token: str, jwks: dict) -> str: issuer=self.issuer ) return payload["sub"] - except jwt.ExpiredSignatureError: + except jwt_exc.ExpiredSignatureError: raise AuthenticationError({"detail": "token is expired"}) - except jwt.JWTClaimsError: + except jwt_exc.JWTClaimsError: raise AuthenticationError( {"detail": "incorrect claims (check audience and issuer)"} ) @@ -76,10 +77,10 @@ def __init__(self, uri: str, session: ClientSession, expired: timedelta): self.session = session self.jwks: dict[str, str] = {} self.expired = expired - self.last_updated = None + self.last_updated: datetime | None = None self.uri = uri - async def get_key_set(self): + async def get_key_set(self) -> dict[Any, Any]: if not self.jwks: await self._request_new_key_set() if self.last_updated and datetime.now() - self.last_updated > self.expired: @@ -87,7 +88,7 @@ async def get_key_set(self): await self._request_new_key_set() return self.jwks - async def _request_new_key_set(self): + async def _request_new_key_set(self) -> None: async with self.session.get(self.uri) as response: self.jwks = await response.json() self.last_updated = datetime.now() @@ -109,6 +110,6 @@ async def get_current_user_id(self) -> UserId: if self.token and self.auth_gateway: jwks = await self.key_set_provider.get_key_set() sub = self.token_processor.validate_token(self.token, jwks) - user = await self.auth_gateway.get_user_by_sub(sub) - return user # type: ignore + user_id = await self.auth_gateway.get_user_id_by_sub(sub) + return user_id raise AuthenticationError() diff --git a/src/costy/adapters/db/category_gateway.py b/src/costy/adapters/db/category_gateway.py index 1225b14..351f39a 100644 --- a/src/costy/adapters/db/category_gateway.py +++ b/src/costy/adapters/db/category_gateway.py @@ -34,14 +34,14 @@ async def save_category(self, category: Category) -> None: async def delete_category(self, category_id: CategoryId) -> None: query = delete(Category).where( - Category.id == category_id # type: ignore + self.table.c.id == category_id ) await self.session.execute(query) async def find_categories(self, user_id: UserId) -> list[CategoryDTO]: filter_expr = or_( - self.table.c.user_id == user_id, # type: ignore - self.table.c.user_id == None # type: ignore # noqa: E711 + self.table.c.user_id == user_id, + self.table.c.user_id == None # noqa: E711 ) query = select(self.table).where(filter_expr) categories = list(await self.session.scalars(query)) diff --git a/src/costy/adapters/db/user_gateway.py b/src/costy/adapters/db/user_gateway.py index c3f03cd..48c4910 100644 --- a/src/costy/adapters/db/user_gateway.py +++ b/src/costy/adapters/db/user_gateway.py @@ -18,10 +18,10 @@ async def save_user(self, user: User) -> None: async def get_user_by_id(self, user_id: UserId) -> User | None: query = select(self.table).where(self.table.c.id == user_id) - result: User | None = await self.session.scalar(query) - return result - - async def get_user_by_auth_id(self, auth_id: str) -> User | None: - query = select(self.table).where(self.table.c.auth_id == auth_id) result = await self.session.scalar(query) - return result + try: + data = next(result.mapping()) + user: User = self.retort.load(data, User) + return user + except StopIteration: + return None diff --git a/src/costy/application/category/create_category.py b/src/costy/application/category/create_category.py index 272062e..c218377 100644 --- a/src/costy/application/category/create_category.py +++ b/src/costy/application/category/create_category.py @@ -24,9 +24,9 @@ def __init__( async def __call__(self, data: NewCategoryDTO) -> CategoryId: user_id = await self.id_provider.get_current_user_id() - if user_id: # type: ignore + if user_id: category = self.category_service.create( - data.name, CategoryType.PERSONAL, user_id # type: ignore + data.name, CategoryType.PERSONAL, user_id ) await self.category_db_gateway.save_category(category) category_id = category.id diff --git a/src/costy/application/category/read_available_categories.py b/src/costy/application/category/read_available_categories.py index b648bc2..caaf7a8 100644 --- a/src/costy/application/category/read_available_categories.py +++ b/src/costy/application/category/read_available_categories.py @@ -26,4 +26,4 @@ async def __call__( self, data: Optional[ReadAvailableCategoriesDTO] = None ) -> List[CategoryDTO]: user_id = await self.id_provider.get_current_user_id() - return await self.category_db_gateway.find_categories(user_id) # type: ignore + return await self.category_db_gateway.find_categories(user_id) diff --git a/src/costy/application/common/auth_gateway.py b/src/costy/application/common/auth_gateway.py index 34551bc..1aedbe6 100644 --- a/src/costy/application/common/auth_gateway.py +++ b/src/costy/application/common/auth_gateway.py @@ -10,5 +10,5 @@ async def authenticate(self, email: str, password: str) -> str | None: raise NotImplementedError @abstractmethod - async def get_user_by_sub(self, sub: str) -> UserId | None: + async def get_user_id_by_sub(self, sub: str) -> UserId: raise NotImplementedError diff --git a/src/costy/application/common/user_gateway.py b/src/costy/application/common/user_gateway.py index 9d6e606..2c71cd6 100644 --- a/src/costy/application/common/user_gateway.py +++ b/src/costy/application/common/user_gateway.py @@ -10,10 +10,6 @@ class UserReader(Protocol): async def get_user_by_id(self, user_id: UserId) -> User | None: raise NotImplementedError - @abstractmethod - async def get_user_by_auth_id(self, auth_id: str) -> User | None: - raise NotImplementedError - @runtime_checkable class UserSaver(Protocol): diff --git a/src/costy/application/operation/create_operation.py b/src/costy/application/operation/create_operation.py index ef8e357..55a9e13 100644 --- a/src/costy/application/operation/create_operation.py +++ b/src/costy/application/operation/create_operation.py @@ -22,11 +22,11 @@ def __init__( async def __call__(self, data: NewOperationDTO) -> OperationId: user_id = await self.id_provider.get_current_user_id() - operation = self.operation_service.create( # type: ignore + operation = self.operation_service.create( data.amount, data.description, data.time, - user_id, # type: ignore + user_id, data.category_id, ) await self.operation_db_gateway.save_operation(operation) diff --git a/src/costy/infrastructure/auth.py b/src/costy/infrastructure/auth.py index e9ed458..6017a61 100644 --- a/src/costy/infrastructure/auth.py +++ b/src/costy/infrastructure/auth.py @@ -1,5 +1,5 @@ -import typing from datetime import timedelta +from typing import Any, Callable, Coroutine from aiohttp import ClientSession @@ -18,11 +18,11 @@ def create_id_provider_factory( jwsk_uri: str, web_session: ClientSession, jwsk_expired: timedelta = timedelta(days=1) -) -> typing.Callable: +) -> Callable[[], Coroutine[Any, Any, TokenIdProvider]]: token_processor = JwtTokenProcessor(algorithm, audience, issuer) jwsk_provider = KeySetProvider(jwsk_uri, web_session, jwsk_expired) - async def factory(): + async def factory() -> TokenIdProvider: return TokenIdProvider(token_processor, jwsk_provider) return factory diff --git a/src/costy/infrastructure/config.py b/src/costy/infrastructure/config.py index 47a368f..da49f77 100644 --- a/src/costy/infrastructure/config.py +++ b/src/costy/infrastructure/config.py @@ -1,4 +1,20 @@ import os +from dataclasses import dataclass + + +class SettingError(Exception): + pass + + +@dataclass +class AuthSettings: + authorize_url: str + client_id: str + client_secret: str + audience: str + grant_type: str + issuer: str + jwks_uri: str def get_db_connection_url() -> str: @@ -11,3 +27,22 @@ def get_db_connection_url() -> str: if not all([user, password, host, port, db_name]): raise Exception("Database credentials not exists") return f"postgresql+psycopg://{user}:{password}@{host}:{port}/{db_name}" + + +def get_auth_settings() -> AuthSettings: + return AuthSettings( + authorize_url=_get_env_var("AUTH0_AUTHORIZE_URL"), + grant_type="password", + client_id=_get_env_var("AUTH0_CLIENT_ID"), + client_secret=_get_env_var("AUTH0_CLIENT_SECRET"), + audience=_get_env_var("AUTH0_AUDIENCE"), + issuer=_get_env_var("AUTH0_ISSUER"), + jwks_uri=_get_env_var("AUTH0_JWKS_URI") + ) + + +def _get_env_var(name: str) -> str: + try: + return os.environ[name] + except KeyError: + raise SettingError(f'Environment variable "{name}" not exists') diff --git a/src/costy/main/ioc.py b/src/costy/main/ioc.py index 95c55be..2fe7aa7 100644 --- a/src/costy/main/ioc.py +++ b/src/costy/main/ioc.py @@ -1,4 +1,3 @@ -import os from contextlib import asynccontextmanager from dataclasses import dataclass from typing import AsyncIterator @@ -8,7 +7,7 @@ from sqlalchemy import Table from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from costy.adapters.auth.auth_gateway import AuthGateway, AuthSettings +from costy.adapters.auth.auth_gateway import AuthGateway from costy.adapters.db.category_gateway import CategoryGateway from costy.adapters.db.operation_gateway import OperationGateway from costy.adapters.db.uow import OrmUoW @@ -26,6 +25,7 @@ from costy.domain.services.category import CategoryService from costy.domain.services.operation import OperationService from costy.domain.services.user import UserService +from costy.infrastructure.config import AuthSettings from costy.presentation.interactor_factory import InteractorFactory @@ -43,19 +43,14 @@ def __init__( session_factory: async_sessionmaker[AsyncSession], web_session: ClientSession, tables: dict[str, Table], - retort: Retort + retort: Retort, + auth_settings: AuthSettings ): self._session_factory = session_factory self._web_session = web_session self._tables = tables self._retort = retort - self._settings = AuthSettings( - authorize_url="https://dev-66quvcmw46dh86sh.us.auth0.com/oauth/token", - grant_type="password", - client_id=os.getenv("AUTH0_CLIENT_ID", ""), - client_secret=os.getenv("AUTH0_CLIENT_ID", ""), - audience="https://dev-66quvcmw46dh86sh.us.auth0.com/api/v2/" - ) + self._settings = auth_settings @asynccontextmanager async def _init_depends(self) -> AsyncIterator[Depends]: @@ -63,7 +58,6 @@ async def _init_depends(self) -> AsyncIterator[Depends]: auth_gateway = AuthGateway(session, self._web_session, self._tables["users"], self._settings) yield Depends(session, self._web_session, OrmUoW(session), auth_gateway) await session.close() - await self._web_session.close() @asynccontextmanager async def authenticate(self) -> AsyncIterator[Authenticate]: diff --git a/src/costy/main/web.py b/src/costy/main/web.py index d6a3e9e..b128089 100644 --- a/src/costy/main/web.py +++ b/src/costy/main/web.py @@ -1,4 +1,3 @@ -import os from typing import Any, Callable, Coroutine, TypeVar from adaptix import Retort @@ -7,7 +6,10 @@ from litestar.di import Provide from costy.infrastructure.auth import create_id_provider_factory -from costy.infrastructure.config import get_db_connection_url +from costy.infrastructure.config import ( + get_auth_settings, + get_db_connection_url, +) from costy.infrastructure.db.main import ( get_engine, get_metadata, @@ -37,13 +39,15 @@ def init_app() -> Litestar: session_factory = get_sessionmaker(get_engine(get_db_connection_url())) web_session = ClientSession() - ioc = IoC(session_factory=session_factory, web_session=web_session, tables=tables, retort=Retort()) + + auth_settings = get_auth_settings() + ioc = IoC(session_factory, web_session, tables, Retort(), auth_settings) id_provider_factory = create_id_provider_factory( - os.environ.get("AUTH0_AUDIENCE", ""), + auth_settings.audience, "RS256", - os.environ.get("AUTH0_ISSUER", ""), - os.environ.get("AUTH0_JWKS_URI", ""), + auth_settings.issuer, + auth_settings.jwks_uri, web_session ) @@ -57,7 +61,7 @@ def init_app() -> Litestar: dependencies={ "ioc": Provide(singleton(ioc)), "id_provider": Provide(get_id_provider), - "id_provider_factory": Provide(id_provider_factory) + "id_provider_pure": Provide(id_provider_factory) }, on_shutdown=[lambda: web_session.close()], debug=True diff --git a/src/costy/presentation/api/authenticate.py b/src/costy/presentation/api/authenticate.py index 4144198..8e3b8c2 100644 --- a/src/costy/presentation/api/authenticate.py +++ b/src/costy/presentation/api/authenticate.py @@ -8,7 +8,7 @@ class AuthenticationController(Controller): path = "/auth" @post(status_code=200) - async def login(self, ioc: InteractorFactory, data: LoginInputDTO) -> Response: + async def login(self, ioc: InteractorFactory, data: LoginInputDTO) -> Response[dict[str, str]]: async with ioc.authenticate() as authenticate: token = await authenticate(data) if token: diff --git a/src/costy/presentation/api/category.py b/src/costy/presentation/api/category.py index 5325f51..fe77025 100644 --- a/src/costy/presentation/api/category.py +++ b/src/costy/presentation/api/category.py @@ -1,6 +1,6 @@ from litestar import Controller, get -from costy.application.category.read_available_categories import CategoryDTO +from costy.application.category.dto import CategoryDTO from costy.application.common.id_provider import IdProvider from costy.presentation.interactor_factory import InteractorFactory diff --git a/src/costy/presentation/api/dependencies/id_provider.py b/src/costy/presentation/api/dependencies/id_provider.py index 10d28c1..99900cf 100644 --- a/src/costy/presentation/api/dependencies/id_provider.py +++ b/src/costy/presentation/api/dependencies/id_provider.py @@ -5,11 +5,11 @@ async def get_id_provider( - headers: dict, - id_provider_factory: TokenIdProvider + headers: dict[str, str], + id_provider_pure: TokenIdProvider ) -> IdProvider: if not headers.get("authorization"): raise HTTPException("Not authenticated", status_code=401) token_type, token = headers.get("authorization", "").split(" ") - id_provider_factory.token = token - return id_provider_factory + id_provider_pure.token = token + return id_provider_pure diff --git a/src/costy/presentation/api/operation.py b/src/costy/presentation/api/operation.py index 1ab01af..5c62b9d 100644 --- a/src/costy/presentation/api/operation.py +++ b/src/costy/presentation/api/operation.py @@ -1,8 +1,7 @@ from litestar import Controller, get, post from costy.application.common.id_provider import IdProvider -from costy.application.operation.create_operation import NewOperationDTO -from costy.application.operation.read_list_operation import ListOperationDTO +from costy.application.operation.dto import ListOperationDTO, NewOperationDTO from costy.domain.models.operation import Operation, OperationId from costy.presentation.interactor_factory import InteractorFactory diff --git a/src/costy/presentation/api/user.py b/src/costy/presentation/api/user.py index 79a002c..a5ba2b1 100644 --- a/src/costy/presentation/api/user.py +++ b/src/costy/presentation/api/user.py @@ -1,6 +1,6 @@ from litestar import Controller, post -from costy.application.user.create_user import NewUserDTO +from costy.application.user.dto import NewUserDTO from costy.domain.models.user import UserId from costy.presentation.interactor_factory import InteractorFactory diff --git a/tests/adapters/__init__.py b/tests/adapters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/adapters/test_auth_adapter.py b/tests/adapters/test_auth_adapter.py new file mode 100644 index 0000000..9377eb7 --- /dev/null +++ b/tests/adapters/test_auth_adapter.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.mark.asyncio +def test_auth_adapter(): + pass diff --git a/tests/application/category/test_create_category.py b/tests/application/category/test_create_category.py index ff65235..6bb71da 100644 --- a/tests/application/category/test_create_category.py +++ b/tests/application/category/test_create_category.py @@ -3,13 +3,13 @@ import pytest from pytest import fixture -from costy.application.category.create_category import ( - CreateCategory, - NewCategoryDTO, -) +from costy.application.category.create_category import CreateCategory +from costy.application.category.dto import NewCategoryDTO from costy.application.common.category_gateway import CategorySaver +from costy.application.common.id_provider import IdProvider from costy.application.common.uow import UoW -from costy.domain.models.category import Category, CategoryType +from costy.domain.models.category import Category, CategoryId, CategoryType +from costy.domain.models.user import UserId @fixture @@ -18,7 +18,7 @@ def category_info() -> NewCategoryDTO: @fixture -def interactor(id_provider, category_id, user_id, category_info) -> CreateCategory: +def interactor(id_provider: IdProvider, category_id: CategoryId, user_id: UserId, category_info: NewCategoryDTO) -> CreateCategory: category_service = Mock() category_service.create.return_value = Category( id=None, @@ -37,5 +37,5 @@ async def save_category_mock(category: Category) -> None: @pytest.mark.asyncio -async def test_create_operation(interactor: CreateCategory, category_info: NewCategoryDTO, category_id) -> None: +async def test_create_operation(interactor: CreateCategory, category_info: NewCategoryDTO, category_id: CategoryId) -> None: assert await interactor(category_info) == category_id diff --git a/tests/application/operation/test_create_operation.py b/tests/application/operation/test_create_operation.py index aa33506..bb57a2c 100644 --- a/tests/application/operation/test_create_operation.py +++ b/tests/application/operation/test_create_operation.py @@ -3,14 +3,14 @@ import pytest from pytest import fixture +from costy.application.common.id_provider import IdProvider from costy.application.common.operation_gateway import OperationSaver from costy.application.common.uow import UoW -from costy.application.operation.create_operation import ( - CreateOperation, - NewOperationDTO, -) +from costy.application.operation.create_operation import CreateOperation +from costy.application.operation.dto import NewOperationDTO from costy.domain.models.category import CategoryId -from costy.domain.models.operation import Operation +from costy.domain.models.operation import Operation, OperationId +from costy.domain.models.user import UserId @fixture @@ -24,7 +24,7 @@ def operation_info() -> NewOperationDTO: @fixture -def interactor(id_provider, operation_id, user_id, operation_info) -> CreateOperation: +def interactor(id_provider: IdProvider, operation_id: OperationId, user_id: UserId, operation_info: NewOperationDTO) -> CreateOperation: operation_service = Mock() operation_service.create.return_value = Operation( id=None, @@ -45,5 +45,5 @@ async def save_operation_mock(operation: Operation) -> None: @pytest.mark.asyncio -async def test_create_operation(interactor: CreateOperation, operation_info: NewOperationDTO, operation_id) -> None: +async def test_create_operation(interactor: CreateOperation, operation_info: NewOperationDTO, operation_id: OperationId) -> None: assert await interactor(operation_info) == operation_id diff --git a/tests/application/test_authenticate.py b/tests/application/test_authenticate.py index 3ee43c6..37b0120 100644 --- a/tests/application/test_authenticate.py +++ b/tests/application/test_authenticate.py @@ -4,9 +4,9 @@ from pytest import fixture from costy.application.authenticate import Authenticate, LoginInputDTO +from costy.application.common.auth_gateway import AuthLoger from costy.application.common.uow import UoW -from costy.application.common.user_gateway import UserReader -from costy.domain.models.user import User, UserId +from costy.domain.models.user import UserId @fixture @@ -15,16 +15,18 @@ def login_info() -> LoginInputDTO: @fixture -def interactor(user_id, login_info): - user_gateway = Mock(spec=UserReader) - user_gateway.get_user_by_email.return_value = User( - id=user_id, email=login_info.email, - hashed_password=login_info.password, - ) +def token() -> str: + return "token" + + +@fixture +def interactor(user_id: UserId, login_info: LoginInputDTO) -> Authenticate: + auth_gateway = Mock(spec=AuthLoger) + auth_gateway.authenticate.return_value = "token" uow = Mock(spec=UoW) - return Authenticate(user_gateway, uow) + return Authenticate(auth_gateway, uow) @pytest.mark.asyncio -async def test_authenticate(interactor: Authenticate, user_id: UserId, login_info: LoginInputDTO): - assert await interactor(login_info) == user_id +async def test_authenticate(interactor: Authenticate, user_id: UserId, login_info: LoginInputDTO, token: str) -> None: + assert await interactor(login_info) == token diff --git a/tests/application/user/test_create_user.py b/tests/application/user/test_create_user.py index 3f63aa5..d63ac4b 100644 --- a/tests/application/user/test_create_user.py +++ b/tests/application/user/test_create_user.py @@ -5,7 +5,8 @@ from costy.application.common.uow import UoW from costy.application.common.user_gateway import UserSaver -from costy.application.user.create_user import CreateUser, NewUserDTO +from costy.application.user.create_user import CreateUser +from costy.application.user.dto import NewUserDTO from costy.domain.models.user import User, UserId from costy.domain.services.user import UserService @@ -16,15 +17,13 @@ def user_info() -> NewUserDTO: @fixture -def interactor(user_id, user_info): +def interactor(user_id: UserId, user_info: NewUserDTO) -> CreateUser: user_service = Mock(spec=UserService) user_service.create.return_value = User( id=None, - email=user_info.email, - hashed_password=user_info.password ) - async def mock_save_user(user: User): + async def mock_save_user(user: User) -> None: user.id = user_id user_gateway = Mock(spec=UserSaver) @@ -34,5 +33,5 @@ async def mock_save_user(user: User): @pytest.mark.asyncio -async def test_create_user(interactor: CreateUser, user_info: NewUserDTO, user_id: UserId): +async def test_create_user(interactor: CreateUser, user_info: NewUserDTO, user_id: UserId) -> None: assert await interactor(user_info) == user_id diff --git a/tests/domain/test_create.py b/tests/domain/test_create.py index 4136769..6cdc891 100644 --- a/tests/domain/test_create.py +++ b/tests/domain/test_create.py @@ -25,5 +25,5 @@ Category(None, "test", CategoryType.GENERAL.value, UserId(9999)) ), ]) -def test_create_domain_service(domain_service, data, expected_model): +def test_create_domain_service(domain_service, data, expected_model): # type: ignore assert domain_service.create(*data) == expected_model From 519e85753f2d47b97c98dc1d19fc474321b72601 Mon Sep 17 00:00:00 2001 From: Andrew Srg Date: Mon, 19 Feb 2024 21:50:24 +0200 Subject: [PATCH 4/5] feature: add auth tests --- src/costy/infrastructure/db/main.py | 8 +++- tests/adapters/test_auth_adapter.py | 33 +++++++++++++- tests/conftest.py | 4 +- tests/infrastructure.py | 67 +++++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 4 deletions(-) create mode 100644 tests/infrastructure.py diff --git a/src/costy/infrastructure/db/main.py b/src/costy/infrastructure/db/main.py index af07f65..89473da 100644 --- a/src/costy/infrastructure/db/main.py +++ b/src/costy/infrastructure/db/main.py @@ -1,3 +1,4 @@ +import pytest from sqlalchemy import MetaData from sqlalchemy.ext.asyncio import ( AsyncEngine, @@ -6,9 +7,14 @@ create_async_engine, ) +from costy.infrastructure.config import SettingError + def get_engine(url: str) -> AsyncEngine: - return create_async_engine(url, future=True) + try: + return create_async_engine(url, future=True) + except SettingError: + pytest.skip("Auth settings env var are not exists.") def get_sessionmaker(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]: diff --git a/tests/adapters/test_auth_adapter.py b/tests/adapters/test_auth_adapter.py index 9377eb7..58f2a46 100644 --- a/tests/adapters/test_auth_adapter.py +++ b/tests/adapters/test_auth_adapter.py @@ -1,6 +1,35 @@ +import os + import pytest +from pytest_asyncio import fixture + +from costy.adapters.auth.auth_gateway import AuthGateway +from costy.application.common.auth_gateway import AuthLoger +from costy.infrastructure.config import AuthSettings, get_auth_settings + + +@fixture +async def auth_settings() -> AuthSettings: + return get_auth_settings() + + +@fixture +async def auth_adapter(db_session, web_session, db_tables, auth_settings: AuthSettings) -> AuthLoger: + return AuthGateway(db_session, web_session, db_tables["users"], auth_settings) + + +@fixture +async def credentials() -> dict[str, str]: # type: ignore + try: + return { + "username": os.environ["TEST_AUTH_USER"], + "password": os.environ["TEST_AUTH_PASSWORD"] + } + except KeyError: + pytest.skip("No test user credentials.") @pytest.mark.asyncio -def test_auth_adapter(): - pass +async def test_auth_adapter(auth_adapter: AuthLoger, credentials: dict[str, str]): + result = await auth_adapter.authenticate(credentials["username"], credentials["password"]) + assert isinstance(result, str) diff --git a/tests/conftest.py b/tests/conftest.py index 754af96..629a6f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,14 @@ from unittest.mock import Mock -from pytest import fixture +from pytest_asyncio import fixture from costy.application.common.id_provider import IdProvider from costy.domain.models.category import CategoryId from costy.domain.models.operation import OperationId from costy.domain.models.user import UserId +pytest_plugins = ["tests.infrastructure"] + @fixture def user_id() -> UserId: diff --git a/tests/infrastructure.py b/tests/infrastructure.py new file mode 100644 index 0000000..b3b88b3 --- /dev/null +++ b/tests/infrastructure.py @@ -0,0 +1,67 @@ +import os +from typing import AsyncGenerator + +import pytest +from aiohttp import ClientSession +from pytest_asyncio import fixture +from sqlalchemy import Table +from sqlalchemy.exc import OperationalError +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) + +from costy.infrastructure.db.main import get_metadata +from costy.infrastructure.db.orm import create_tables + + +@fixture(scope='session') +async def db_url() -> str: # type: ignore + try: + return os.environ['TEST_DB_URL'] + except KeyError: + pytest.skip("TEST_DB_URL env variable not set") + + +@fixture(scope='session') +async def db_engine(db_url: str) -> AsyncEngine: + return create_async_engine(db_url, future=True) + + +@fixture(scope='session') +async def db_sessionmaker(db_engine: AsyncEngine) -> async_sessionmaker[AsyncSession]: + return async_sessionmaker(db_engine) + + +@fixture +async def db_session(db_sessionmaker: async_sessionmaker[AsyncSession]) -> AsyncSession: + session = db_sessionmaker() + yield session + # clean up database + await session.rollback() + + +@fixture(scope='session') +async def db_tables(db_engine: AsyncEngine) -> AsyncGenerator[None, dict[str, Table]] | None: + metadata = get_metadata() + tables = create_tables(metadata) + + try: + async with db_engine.begin() as conn: + await conn.run_sync(metadata.drop_all) + await conn.run_sync(metadata.create_all) + except OperationalError: + pytest.skip("Connection to database is faield.") + + yield tables + + async with db_engine.begin() as conn: + await conn.run_sync(metadata.drop_all) + + +@fixture() +async def web_session() -> ClientSession: + async with ClientSession() as session: + yield session From e70a11e21e4494f1ae1060fd9b0be163b0b72158 Mon Sep 17 00:00:00 2001 From: Andrew Srg Date: Mon, 19 Feb 2024 22:32:25 +0200 Subject: [PATCH 5/5] feature: add auth tests --- tests/adapters/test_auth_adapter.py | 32 ++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/tests/adapters/test_auth_adapter.py b/tests/adapters/test_auth_adapter.py index 58f2a46..dc0e889 100644 --- a/tests/adapters/test_auth_adapter.py +++ b/tests/adapters/test_auth_adapter.py @@ -2,12 +2,23 @@ import pytest from pytest_asyncio import fixture +from sqlalchemy import Table, insert +from sqlalchemy.ext.asyncio import AsyncSession from costy.adapters.auth.auth_gateway import AuthGateway from costy.application.common.auth_gateway import AuthLoger +from costy.domain.models.user import UserId from costy.infrastructure.config import AuthSettings, get_auth_settings +@fixture +async def auth_sub() -> str: # type: ignore + try: + return os.environ["TEST_AUTH_USER_SUB"] + except KeyError: + pytest.skip("No test user sub environment variable.") + + @fixture async def auth_settings() -> AuthSettings: return get_auth_settings() @@ -29,7 +40,22 @@ async def credentials() -> dict[str, str]: # type: ignore pytest.skip("No test user credentials.") +@fixture +async def created_user(db_session: AsyncSession, db_tables: dict[str, Table], auth_sub: str) -> UserId: + stmt = insert(db_tables["users"]).values(auth_id=auth_sub) + result = await db_session.execute(stmt) + await db_session.flush() + created_user_id = result.inserted_primary_key[0] + return UserId(created_user_id) + + +@pytest.mark.asyncio +async def test_authenticate(auth_adapter: AuthLoger, credentials: dict[str, str]): + token = await auth_adapter.authenticate(credentials["username"], credentials["password"]) + assert isinstance(token, str) + + @pytest.mark.asyncio -async def test_auth_adapter(auth_adapter: AuthLoger, credentials: dict[str, str]): - result = await auth_adapter.authenticate(credentials["username"], credentials["password"]) - assert isinstance(result, str) +async def test_get_user_id_by_sup(auth_adapter: AuthLoger, auth_sub: str, created_user: UserId): + user_id = await auth_adapter.get_user_id_by_sub(auth_sub) + assert user_id == created_user