Skip to content

Commit

Permalink
[DOP-21574] Add keycloack provider (#123)
Browse files Browse the repository at this point in the history
* [DOP-21574] Add keycloack provider

* [DOP-21574] small provider refactoring

* [DOP-21574] add tests for dummy provider

* [DOP-21574] add tests for dummy provider

* [DOP-21574] Update docs

* [DOP-21574] fix docs

* [DOP-21574] fix test names
  • Loading branch information
TiGrib authored Dec 2, 2024
1 parent 952dfe1 commit a458768
Show file tree
Hide file tree
Showing 30 changed files with 1,118 additions and 71 deletions.
13 changes: 13 additions & 0 deletions .env.docker
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ DATA_RENTGEN__KAFKA__COMPRESSION=zstd
# See Frontend -> UI
DATA_RENTGEN__UI__API_BROWSER_URL=http://localhost:8000

# Session
DATA_RENTGEN__SERVER__SESSION__SECRET_KEY=session_secret_key

# Keycloak Auth
DATA_RENTGEN__AUTH__KEYCLOAK__SERVER_URL=http://keycloak:8080
DATA_RENTGEN__AUTH__KEYCLOAK__REALM_NAME=manually_created
DATA_RENTGEN__AUTH__KEYCLOAK__CLIENT_ID=manually_created
DATA_RENTGEN__AUTH__KEYCLOAK__CLIENT_SECRET=generated_by_keycloak
DATA_RENTGEN__AUTH__KEYCLOAK__REDIRECT_URI=http://localhost:8000/auth/callback
DATA_RENTGEN__AUTH__KEYCLOAK__SCOPE=email
DATA_RENTGEN__AUTH__KEYCLOAK__VERIFY_SSL=False
DATA_RENTGEN__AUTH__PROVIDER=data_rentgen.server.providers.auth.keycloak_provider.KeycloakAuthProvider

# Dummy Auth
DATA_RENTGEN__AUTH__PROVIDER=data_rentgen.server.providers.auth.dummy_provider.DummyAuthProvider
DATA_RENTGEN__AUTH__ACCESS_TOKEN__SECRET_KEY=secret
13 changes: 13 additions & 0 deletions .env.local
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ export DATA_RENTGEN__SERVER__DEBUG=true

export DATA_RENTGEN__UI__API_BROWSER_URL=http://localhost:8000

# Session
export DATA_RENTGEN__SERVER__SESSION__SECRET_KEY=session_secret_key

# Keycloak Auth
export DATA_RENTGEN__AUTH__KEYCLOAK__SERVER_URL=http://keycloak:8080
export DATA_RENTGEN__AUTH__KEYCLOAK__REALM_NAME=manually_created
export DATA_RENTGEN__AUTH__KEYCLOAK__CLIENT_ID=manually_created
export DATA_RENTGEN__AUTH__KEYCLOAK__CLIENT_SECRET=generated_by_keycloak
export DATA_RENTGEN__AUTH__KEYCLOAK__REDIRECT_URI=http://localhost:8000/auth/callback
export DATA_RENTGEN__AUTH__KEYCLOAK__SCOPE=email
export DATA_RENTGEN__AUTH__KEYCLOAK__VERIFY_SSL=False
export DATA_RENTGEN__AUTH__PROVIDER=data_rentgen.server.providers.auth.keycloak_provider.KeycloakAuthProvider

# Dummy Auth
export DATA_RENTGEN__AUTH__PROVIDER=data_rentgen.server.providers.auth.dummy_provider.DummyAuthProvider
export DATA_RENTGEN__AUTH__ACCESS_TOKEN__SECRET_KEY=secret
2 changes: 1 addition & 1 deletion data_rentgen/db/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ async def _paginate_by_query(
items_query = items_query.options(*options)

total_query = select(func.count()).select_from(query.subquery())

items_result: ScalarResult[Model] = await self._session.scalars(items_query)

total_count: int = await self._session.scalar(total_query) # type: ignore[assignment]
return PaginationDTO[model_type]( # type: ignore[valid-type]
items=list(items_result.all()),
Expand Down
4 changes: 2 additions & 2 deletions data_rentgen/db/scripts/create_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def get_parser() -> ArgumentParser:
parser.add_argument(
"--end",
type=isoparse,
default=datetime.now().replace(day=1, hour=0, minute=0, second=0, microsecond=0) + relativedelta(months=1),
default=datetime.now().replace(day=1, hour=0, minute=0, second=0, microsecond=0) + relativedelta(months=2),
nargs="?",
help="End date for partitions, default is the first day of next month.",
help="End date for partitions, default is the last day of next month.",
)
parser.add_argument(
"--granularity",
Expand Down
2 changes: 2 additions & 0 deletions data_rentgen/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from data_rentgen.exceptions.auth import ActionNotAllowedError, AuthorizationError
from data_rentgen.exceptions.base import ApplicationError
from data_rentgen.exceptions.entity import EntityNotFoundError
from data_rentgen.exceptions.redirect import RedirectError

__all__ = [
"AuthorizationError",
"ActionNotAllowedError",
"ApplicationError",
"EntityNotFoundError",
"RedirectError",
]
21 changes: 21 additions & 0 deletions data_rentgen/exceptions/redirect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: 2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0
from typing import Any

from data_rentgen.exceptions.base import ApplicationError


class RedirectError(ApplicationError):
"""Error which contains redirect url for authorization."""

def __init__(self, message: str, details: Any = None) -> None:
self._message = message
self._details = details

@property
def message(self) -> str:
return self._message

@property
def details(self) -> Any:
return self._details
8 changes: 7 additions & 1 deletion data_rentgen/server/api/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from asgi_correlation_id import correlation_id
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import RedirectResponse
from pydantic import ValidationError

from data_rentgen.exceptions import ApplicationError, AuthorizationError
from data_rentgen.exceptions import ApplicationError, AuthorizationError, RedirectError
from data_rentgen.server.errors.base import APIErrorSchema, BaseErrorSchema
from data_rentgen.server.errors.registration import get_response_for_exception
from data_rentgen.server.settings.server import ServerSettings
Expand Down Expand Up @@ -91,6 +92,10 @@ def application_exception_handler(request: Request, exc: ApplicationError) -> Re
)


def redirect_exception_handler(_: Request, exc: RedirectError) -> Response:
return RedirectResponse(url=exc.message)


def exception_json_response(
status: int,
content: BaseErrorSchema,
Expand All @@ -107,6 +112,7 @@ def exception_json_response(


def apply_exception_handlers(app: FastAPI) -> None:
app.add_exception_handler(RedirectError, redirect_exception_handler) # type: ignore[arg-type]
app.add_exception_handler(ApplicationError, application_exception_handler) # type: ignore[arg-type]
app.add_exception_handler(AuthorizationError, application_exception_handler) # type: ignore[arg-type]
app.add_exception_handler(
Expand Down
2 changes: 2 additions & 0 deletions data_rentgen/server/middlewares/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from data_rentgen.server.middlewares.openapi import apply_openapi_middleware
from data_rentgen.server.middlewares.request_id import apply_request_id_middleware
from data_rentgen.server.middlewares.session import apply_session_middleware
from data_rentgen.server.middlewares.static_files import apply_static_files
from data_rentgen.server.settings import ServerSettings

Expand All @@ -28,5 +29,6 @@ def apply_middlewares(
apply_application_version_middleware(application, settings.application_version)
apply_openapi_middleware(application, settings.openapi)
apply_static_files(application, settings.static_files)
apply_session_middleware(application, settings.session)

return application
16 changes: 16 additions & 0 deletions data_rentgen/server/middlewares/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: 2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0
from fastapi import FastAPI
from starlette.middleware.sessions import SessionMiddleware

from data_rentgen.server.settings.session import SessionSettings


def apply_session_middleware(app: FastAPI, settings: SessionSettings) -> FastAPI:
"""Add SessionMiddleware middleware to the application."""

app.add_middleware(
SessionMiddleware,
**settings.dict(),
)
return app
2 changes: 2 additions & 0 deletions data_rentgen/server/providers/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
# SPDX-License-Identifier: Apache-2.0
from data_rentgen.server.providers.auth.base_provider import AuthProvider
from data_rentgen.server.providers.auth.dummy_provider import DummyAuthProvider
from data_rentgen.server.providers.auth.keycloak_provider import KeycloakAuthProvider

__all__ = [
"AuthProvider",
"DummyAuthProvider",
"KeycloakAuthProvider",
]
14 changes: 7 additions & 7 deletions data_rentgen/server/providers/auth/dummy_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from data_rentgen.db.models import User
from data_rentgen.dependencies import Stub
from data_rentgen.dto import UserDTO
from data_rentgen.exceptions import EntityNotFoundError
from data_rentgen.exceptions.auth import AuthorizationError
from data_rentgen.server.providers.auth.base_provider import AuthProvider
from data_rentgen.server.settings.auth.dummy import DummyAuthProviderSettings
Expand All @@ -36,12 +37,15 @@ def setup(cls, app: FastAPI) -> FastAPI:
app.dependency_overrides[DummyAuthProviderSettings] = lambda: settings
return app

async def get_current_user(self, access_token: str, *args, **kwargs) -> User | None:
async def get_current_user(self, access_token: str, *args, **kwargs) -> User:
if not access_token:
raise AuthorizationError("Missing auth credentials")

user_id = self._get_user_id_from_token(access_token)
return await self._uow.user.read_by_id(user_id)
user = await self._uow.user.read_by_id(user_id)
if user is None:
raise EntityNotFoundError("User", "user_id", user_id) # type: ignore[call-arg]
return user

async def get_token_password_grant(
self,
Expand All @@ -57,11 +61,7 @@ async def get_token_password_grant(

logger.info("Get/create user %r in database", login)
async with self._uow:
user = await self._uow.user._get(login) # noqa: WPS437
if not user:
user = await self._uow.user._create( # noqa: WPS437
UserDTO(name=login),
)
user = await self._uow.user.get_or_create(UserDTO(name=login))

logger.info("User with id %r found", user.id)
logger.info("Generate access token for user id %r", user.id)
Expand Down
127 changes: 127 additions & 0 deletions data_rentgen/server/providers/auth/keycloak_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-FileCopyrightText: 2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Annotated, Any

from fastapi import Depends, FastAPI, Request
from keycloak import KeycloakOpenID

from data_rentgen.db.models import User
from data_rentgen.dependencies import Stub
from data_rentgen.dto import UserDTO
from data_rentgen.exceptions.auth import AuthorizationError
from data_rentgen.exceptions.redirect import RedirectError
from data_rentgen.server.providers.auth.base_provider import AuthProvider
from data_rentgen.server.settings.auth.keycloak import KeycloakAuthProviderSettings
from data_rentgen.services import UnitOfWork

logger = logging.getLogger(__name__)


class KeycloakAuthProvider(AuthProvider):
def __init__(
self,
settings: Annotated[KeycloakAuthProviderSettings, Depends(Stub(KeycloakAuthProviderSettings))],
unit_of_work: Annotated[UnitOfWork, Depends()],
) -> None:
self.settings = settings
self._uow = unit_of_work
self.keycloak_openid = KeycloakOpenID(
server_url=self.settings.keycloak.server_url,
client_id=self.settings.keycloak.client_id,
realm_name=self.settings.keycloak.realm_name,
client_secret_key=self.settings.keycloak.client_secret.get_secret_value(),
verify=self.settings.keycloak.verify_ssl,
)

@classmethod
def setup(cls, app: FastAPI) -> FastAPI:
settings = KeycloakAuthProviderSettings.model_validate(app.state.settings.auth.dict(exclude={"provider"}))
logger.info("Using %s provider with settings:\n%s", cls.__name__, settings)
app.dependency_overrides[AuthProvider] = cls
app.dependency_overrides[KeycloakAuthProviderSettings] = lambda: settings
return app

async def get_token_password_grant(
self,
grant_type: str | None = None,
login: str | None = None,
password: str | None = None,
scopes: list[str] | None = None,
client_id: str | None = None,
client_secret: str | None = None,
) -> dict[str, Any]:
raise NotImplementedError("Password grant is not supported by KeycloakAuthProvider.")

async def get_token_authorization_code_grant(
self,
code: str,
redirect_uri: str,
scopes: list[str] | None = None,
client_id: str | None = None,
client_secret: str | None = None,
) -> dict[str, Any]:
try:
redirect_uri = redirect_uri or self.settings.keycloak.redirect_uri
return self.keycloak_openid.token(
grant_type="authorization_code",
code=code,
redirect_uri=redirect_uri,
)
except Exception as e:
raise AuthorizationError("Failed to get token") from e

async def get_current_user(self, access_token: str, *args, **kwargs) -> User:
request: Request = kwargs["request"]
refresh_token = request.session.get("refresh_token")

if not access_token:
logger.debug("No access token found in session.")
self.redirect_to_auth()

# if user is disabled or blocked in Keycloak after the token is issued, he will
# remain authorized until the token expires (not more than 15 minutes in MTS SSO)
token_info = self.decode_token(access_token)

if token_info is None and refresh_token:
logger.debug("Access token invalid. Attempting to refresh.")
access_token, refresh_token = self.refresh_access_token(refresh_token)
request.session["access_token"] = access_token
request.session["refresh_token"] = refresh_token

token_info = self.decode_token(access_token)

if token_info is None:
# If there is no token_info after refresh user get redirect
self.redirect_to_auth()

# these names are hardcoded in keycloak:
# https://github.com/keycloak/keycloak/blob/3ca3a4ad349b4d457f6829eaf2ae05f1e01408be/core/src/main/java/org/keycloak/representations/IDToken.java
user_id = token_info.get("sub") # type: ignore[union-attr]
login = token_info.get("preferred_username") # type: ignore[union-attr]
if not user_id:
raise AuthorizationError("Invalid token payload")
return await self._uow.user.get_or_create(UserDTO(name=login)) # type: ignore[arg-type]

def decode_token(self, access_token: str) -> dict[str, Any] | None:
try:
return self.keycloak_openid.decode_token(token=access_token)
except Exception as err:
logger.info("Access token is invalid or expired: %s", err)
return None

def refresh_access_token(self, refresh_token: str) -> tuple[str, str]: # type: ignore[return]
try:
new_tokens = self.keycloak_openid.refresh_token(refresh_token)
logger.debug("Access token refreshed")
return new_tokens.get("access_token"), new_tokens.get("refresh_token")
except Exception as err:
logger.debug("Failed to refresh access token: %s", err)
self.redirect_to_auth()

def redirect_to_auth(self) -> None:
auth_url = self.keycloak_openid.auth_url(
redirect_uri=self.settings.keycloak.redirect_uri,
scope=self.settings.keycloak.scope,
)
raise RedirectError(message=auth_url, details="Authorize on provided url")
8 changes: 3 additions & 5 deletions data_rentgen/server/services/get_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from data_rentgen.db.models import User
from data_rentgen.dependencies import Stub
from data_rentgen.exceptions import EntityNotFoundError
from data_rentgen.server.providers.auth import AuthProvider

oauth_schema = OAuth2PasswordBearer(tokenUrl="v1/auth/token", auto_error=False)
Expand All @@ -20,13 +19,12 @@ async def wrapper(
auth_provider: Annotated[AuthProvider, Depends(Stub(AuthProvider))],
access_token: Annotated[str | None, Depends(oauth_schema)],
) -> User:
# keycloak provider patches session and store access_token in cookie,
# dummy auth stores access_token in "Authorization" header
user = await auth_provider.get_current_user(
access_token = request.session.get("access_token", "") or access_token
return await auth_provider.get_current_user( # type: ignore[return-value]
access_token=access_token,
request=request,
)
if user is None:
raise EntityNotFoundError("User not found") # type: ignore[call-arg]
return user

return wrapper
Loading

0 comments on commit a458768

Please sign in to comment.