Skip to content

Commit

Permalink
Merge pull request #48 from AndrewSergienko/2.x/auth
Browse files Browse the repository at this point in the history
Basic auth0 authentication
  • Loading branch information
andiserg authored Feb 21, 2024
2 parents 2c59fa8 + e70a11e commit 7c3e1cd
Show file tree
Hide file tree
Showing 40 changed files with 609 additions and 178 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,5 @@ dmypy.json

.idea
/pytest.ini
/docker-compose.yml
/keycloak/
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ repos:
- id: flake8
files: src
exclude: "migrations/"
args: [--max-line-length, "120"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.8.0'
hooks:
- id: mypy
additional_dependencies: []
args: [--install-types, --non-interactive, --ignore-missing-imports]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.7
hooks:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ dependencies = [
'psycopg[binary]',
'alembic',
'adaptix',
'aiohttp',
'python-jose'
]

[project.optional-dependencies]
Expand Down
48 changes: 48 additions & 0 deletions src/costy/adapters/auth/auth_gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from aiohttp import ClientSession
from sqlalchemy import Table, select
from sqlalchemy.ext.asyncio import AsyncSession

from costy.application.common.auth_gateway import AuthLoger
from costy.domain.exceptions.access import AuthenticationError
from costy.domain.models.user import UserId
from costy.infrastructure.config import AuthSettings


class AuthGateway(AuthLoger):
def __init__(
self,
db_session: AsyncSession,
web_session: ClientSession,
table: Table,
settings: AuthSettings
) -> None:
self.db_session = db_session
self.web_session = web_session
self.table = table
self.settings = settings

async def authenticate(self, email: str, password: str) -> str:
url = self.settings.authorize_url
data = {
"username": email,
"password": password,
"client_id": self.settings.client_id,
"client_secret": self.settings.client_secret,
"audience": self.settings.audience,
"grant_type": self.settings.grant_type
}
async with self.web_session.post(url, data=data) as response:
response_data = await response.json()
if response.status == 200:
token: str | None = response_data.get("access_token")
if token:
return token
raise AuthenticationError(response_data)

async def get_user_id_by_sub(self, sub: str) -> UserId:
query = select(self.table).where(self.table.c.auth_id == sub)
result = await self.db_session.execute(query)
try:
return UserId(next(result.mappings())["id"])
except StopIteration:
raise AuthenticationError("Invalid auth sub. User is not exists.")
10 changes: 0 additions & 10 deletions src/costy/adapters/auth/id_provider.py

This file was deleted.

115 changes: 115 additions & 0 deletions src/costy/adapters/auth/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from datetime import datetime, timedelta
from typing import Any, Literal

from aiohttp import ClientSession
from jose import exceptions as jwt_exc
from jose import jwt

from costy.application.common.auth_gateway import AuthLoger
from costy.application.common.id_provider import IdProvider
from costy.domain.exceptions.access import AuthenticationError
from costy.domain.models.user import UserId

Algorithm = Literal[
"HS256", "HS384", "HS512",
"RS256", "RS384", "RS512",
]


class JwtTokenProcessor:
def __init__(
self,
algorithm: Algorithm,
audience: str,
issuer: str,
):
self.algorithm = algorithm
self.audience = audience
self.issuer = issuer

def _fetch_rsa_key(self, jwks: dict[Any, Any], unverified_header: dict[str, str]) -> dict[str, str]:
rsa_key = {}
for key in jwks["keys"]:
if key["kid"] == unverified_header["kid"]:
rsa_key = {
"kty": key["kty"],
"kid": key["kid"],
"use": key["use"],
"n": key["n"],
"e": key["e"]
}
return rsa_key

def validate_token(self, token: str, jwks: dict[Any, Any]) -> str:
invalid_header_error = AuthenticationError(
{"detail": "Invalid header. Use an RS256 signed JWT Access Token"}
)
try:
unverified_header = jwt.get_unverified_header(token)
except jwt_exc.JWTError:
raise invalid_header_error
if unverified_header["alg"] == "HS256":
raise invalid_header_error
rsa_key = self._fetch_rsa_key(jwks, unverified_header)
try:
payload: dict[str, str] = jwt.decode(
token,
rsa_key,
algorithms=[self.algorithm],
audience=self.audience,
issuer=self.issuer
)
return payload["sub"]
except jwt_exc.ExpiredSignatureError:
raise AuthenticationError({"detail": "token is expired"})
except jwt_exc.JWTClaimsError:
raise AuthenticationError(
{"detail": "incorrect claims (check audience and issuer)"}
)
except Exception:
raise AuthenticationError(
{"detail": "Unable to parse authentication token."}
)


class KeySetProvider:
def __init__(self, uri: str, session: ClientSession, expired: timedelta):
self.session = session
self.jwks: dict[str, str] = {}
self.expired = expired
self.last_updated: datetime | None = None
self.uri = uri

async def get_key_set(self) -> dict[Any, Any]:
if not self.jwks:
await self._request_new_key_set()
if self.last_updated and datetime.now() - self.last_updated > self.expired:
# TODO: add use Cache-Control
await self._request_new_key_set()
return self.jwks

async def _request_new_key_set(self) -> None:
async with self.session.get(self.uri) as response:
self.jwks = await response.json()
self.last_updated = datetime.now()


class TokenIdProvider(IdProvider):
def __init__(
self,
token_processor: JwtTokenProcessor,
key_set_provider: KeySetProvider,
token: str | None = None
):
self.token_processor = token_processor
self.key_set_provider = key_set_provider
self.token = token
self.auth_gateway: AuthLoger | None = None

async def get_current_user_id(self) -> UserId:
if self.token and self.auth_gateway:
jwks = await self.key_set_provider.get_key_set()
sub = self.token_processor.validate_token(self.token, jwks)
user_id = await self.auth_gateway.get_user_id_by_sub(sub)
return user_id
raise AuthenticationError()
17 changes: 9 additions & 8 deletions src/costy/adapters/db/category_gateway.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from adaptix import Retort
from sqlalchemy import delete, or_, select
from sqlalchemy import Table, delete, or_, select
from sqlalchemy.ext.asyncio import AsyncSession

from costy.application.category.dto import CategoryDTO
Expand All @@ -16,13 +16,14 @@
class CategoryGateway(
CategoryReader, CategorySaver, CategoryDeleter, CategoriesReader
):
def __init__(self, session: AsyncSession, retort: Retort):
def __init__(self, session: AsyncSession, table: Table, retort: Retort):
self.session = session
self.table = table
self.retort = retort

async def get_category(self, category_id: CategoryId) -> Category | None:
query = select(Category).where(
Category.id == category_id # type: ignore
query = select(self.table).where(
self.table.c.id == category_id
)
result: Category | None = await self.session.scalar(query)
return result
Expand All @@ -33,16 +34,16 @@ async def save_category(self, category: Category) -> None:

async def delete_category(self, category_id: CategoryId) -> None:
query = delete(Category).where(
Category.id == category_id # type: ignore
self.table.c.id == category_id
)
await self.session.execute(query)

async def find_categories(self, user_id: UserId) -> list[CategoryDTO]:
filter_expr = or_(
Category.user_id == user_id, # type: ignore
Category.user_id == None # type: ignore # noqa: E711
self.table.c.user_id == user_id,
self.table.c.user_id == None # noqa: E711
)
query = select(Category).where(filter_expr)
query = select(self.table).where(filter_expr)
categories = list(await self.session.scalars(query))
dumped = self.retort.dump(categories, list[Category])
return self.retort.load(dumped, list[CategoryDTO])
23 changes: 13 additions & 10 deletions src/costy/adapters/db/operation_gateway.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sqlalchemy import delete, select
from adaptix import Retort
from sqlalchemy import Table, delete, select
from sqlalchemy.ext.asyncio import AsyncSession

from costy.application.common.operation_gateway import (
Expand All @@ -14,14 +15,16 @@
class OperationGateway(
OperationReader, OperationSaver, OperationDeleter, OperationsReader
):
def __init__(self, session: AsyncSession):
def __init__(self, session: AsyncSession, table: Table, retort: Retort):
self.session = session
self.table = table
self.retort = retort

async def get_operation(
self, operation_id: OperationId
) -> Operation | None:
query = select(Operation).where(
Operation.id == operation_id # type: ignore
query = select(self.table).where(
self.table.c.id == operation_id
)
result: Operation | None = await self.session.scalar(query)
return result
Expand All @@ -31,20 +34,20 @@ async def save_operation(self, operation: Operation) -> None:
await self.session.flush(objects=[operation])

async def delete_operation(self, operation_id: OperationId) -> None:
query = delete(Operation).where(
Operation.id == operation_id # type: ignore
query = delete(self.table).where(
self.table.c.id == operation_id
)
await self.session.execute(query)

async def find_operations_by_user(
self, user_id: UserId, from_time: int | None, to_time: int | None
) -> list[Operation]:
query = (
select(Operation)
.where(Operation.user_id == user_id) # type: ignore
select(self.table)
.where(self.table.c.user_id == user_id)
)
if from_time:
query = query.where(Operation.time >= from_time) # type: ignore
query = query.where(self.table.c.time >= from_time)
if to_time:
query = query.where(Operation.time <= to_time) # type: ignore
query = query.where(self.table.c.time <= to_time)
return list(await self.session.scalars(query))
23 changes: 13 additions & 10 deletions src/costy/adapters/db/user_gateway.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from sqlalchemy import select
from adaptix import Retort
from sqlalchemy import Table, select
from sqlalchemy.ext.asyncio import AsyncSession

from costy.application.common.user_gateway import UserReader, UserSaver
from costy.domain.models.user import User, UserId


class UserGateway(UserSaver, UserReader):
def __init__(self, session: AsyncSession):
def __init__(self, session: AsyncSession, table: Table, retort: Retort):
self.session = session
self.table = table
self.retort = retort

async def save_user(self, user: User) -> None:
self.session.add(user)
await self.session.flush(objects=[user])

async def get_user_by_id(self, user_id: UserId) -> User | None:
query = select(User).where(User.id == user_id) # type: ignore
result: User | None = await self.session.scalar(query)
return result

async def get_user_by_email(self, email: str) -> User | None:
query = select(User).where(User.email == email) # type: ignore
result: User | None = await self.session.scalar(query)
return result
query = select(self.table).where(self.table.c.id == user_id)
result = await self.session.scalar(query)
try:
data = next(result.mapping())
user: User = self.retort.load(data, User)
return user
except StopIteration:
return None
16 changes: 7 additions & 9 deletions src/costy/application/authenticate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from dataclasses import dataclass

from costy.application.common.auth_gateway import AuthLoger
from costy.application.common.interactor import Interactor
from costy.application.common.uow import UoW
from costy.application.common.user_gateway import UserReader
from costy.domain.models.user import UserId


@dataclass
Expand All @@ -12,12 +11,11 @@ class LoginInputDTO:
password: str


class Authenticate(Interactor[LoginInputDTO, UserId | None]):
def __init__(self, user_gateway: UserReader, uow: UoW):
self.user_gateway = user_gateway
class Authenticate(Interactor[LoginInputDTO, str | None]):
def __init__(self, auth_gateway: AuthLoger, uow: UoW):
self.auth_gateway = auth_gateway
self.uow = uow

async def __call__(self, data: LoginInputDTO) -> UserId | None:
user = await self.user_gateway.get_user_by_email(data.email)
# TODO: compare hashed passwords
return user.id if user else None
async def __call__(self, data: LoginInputDTO) -> str | None:
token = await self.auth_gateway.authenticate(data.email, data.password)
return token
Loading

0 comments on commit 7c3e1cd

Please sign in to comment.