Skip to content

Commit

Permalink
[DOP-23122] Use async methods of Keycloak client
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Dec 23, 2024
1 parent 162b0a0 commit 7aa47ff
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 57 deletions.
23 changes: 10 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ onetl = {extras = ["spark", "s3", "hdfs"], version = "^0.12.0"}
faker = ">=28.4.1,<34.0.0"
coverage = "^7.6.1"
gevent = "^24.2.1"
responses = "*"
respx = "*"

[tool.poetry.group.dev.dependencies]
mypy = "^1.11.2"
Expand Down
2 changes: 1 addition & 1 deletion syncmaster/server/providers/auth/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
...

@abstractmethod
async def get_current_user(self, access_token: Any, *args, **kwargs) -> User:
async def get_current_user(self, access_token: str | None, **kwargs) -> User:
"""
This method should return currently logged in user.
Expand Down
2 changes: 1 addition & 1 deletion syncmaster/server/providers/auth/dummy_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def setup(cls, app: FastAPI) -> FastAPI:
app.dependency_overrides[DummyAuthProviderSettings] = lambda: settings
return app

async def get_current_user(self, access_token: str, *args, **kwargs) -> User:
async def get_current_user(self, access_token: str | None, **kwargs) -> User:
if not access_token:
raise AuthorizationError("Missing auth credentials")

Expand Down
30 changes: 14 additions & 16 deletions syncmaster/server/providers/auth/keycloak_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi import Depends, FastAPI, Request
from keycloak import KeycloakOpenID

from syncmaster.db.models import User
from syncmaster.exceptions import EntityNotFoundError
from syncmaster.exceptions.auth import AuthorizationError
from syncmaster.exceptions.redirect import RedirectException
Expand Down Expand Up @@ -63,7 +64,7 @@ async def get_token_authorization_code_grant(
) -> dict[str, Any]:
try:
redirect_uri = redirect_uri or self.settings.keycloak.redirect_uri
token = self.keycloak_openid.token(
token = await self.keycloak_openid.a_token(
grant_type="authorization_code",
code=code,
redirect_uri=redirect_uri,
Expand All @@ -72,10 +73,8 @@ async def get_token_authorization_code_grant(
except Exception as e:
raise AuthorizationError("Failed to get token") from e

async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
async def get_current_user(self, access_token: str | None, **kwargs) -> User:
request: Request = kwargs["request"]
refresh_token = request.session.get("refresh_token")

if not access_token:
log.debug("No access token found in session.")
self.redirect_to_auth(request.url.path)
Expand All @@ -86,8 +85,9 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
token_info = self.keycloak_openid.decode_token(token=access_token)
except Exception as e:
log.info("Access token is invalid or expired: %s", e)
token_info = None
token_info = {}

refresh_token = request.session.get("refresh_token")
if not token_info and refresh_token:
log.debug("Access token invalid. Attempting to refresh.")

Expand All @@ -99,9 +99,7 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
request.session["access_token"] = new_access_token
request.session["refresh_token"] = new_refresh_token

token_info = self.keycloak_openid.decode_token(
token=new_access_token,
)
token_info = self.keycloak_openid.decode_token(token=new_access_token)
log.debug("Access token refreshed and decoded successfully.")
except Exception as e:
log.debug("Failed to refresh access token: %s", e)
Expand All @@ -110,19 +108,19 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
# these names are hardcoded in keycloak:
# https://github.com/keycloak/keycloak/blob/3ca3a4ad349b4d457f6829eaf2ae05f1e01408be/core/src/main/java/org/keycloak/representations/IDToken.java
user_id = token_info.get("sub")
if not user_id:
raise AuthorizationError("Invalid token payload")

login = token_info.get("preferred_username")
email = token_info.get("email")
first_name = token_info.get("given_name")
middle_name = token_info.get("middle_name")
last_name = token_info.get("family_name")

if not user_id:
raise AuthorizationError("Invalid token payload")

async with self._uow:
try:
user = await self._uow.user.read_by_username(login)
except EntityNotFoundError:
try:
user = await self._uow.user.read_by_username(login)
except EntityNotFoundError:
async with self._uow:
user = await self._uow.user.create(
username=login,
email=email,
Expand All @@ -134,7 +132,7 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
return user

async def refresh_access_token(self, refresh_token: str) -> dict[str, Any]:
new_tokens = self.keycloak_openid.refresh_token(refresh_token)
new_tokens = await self.keycloak_openid.a_refresh_token(refresh_token)
return new_tokens

def redirect_to_auth(self, path: str) -> None:
Expand Down
27 changes: 27 additions & 0 deletions syncmaster/server/settings/auth/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: 2023-2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0

from pydantic import BaseModel, Field, ImportString


class AuthSettings(BaseModel):
"""Authorization-related settings.
Here you can set auth provider class.
Examples
--------
.. code-block:: bash
SYNCMASTER__AUTH__PROVIDER=syncmaster.server.providers.auth.dummy_provider.DummyAuthProvider
"""

provider: ImportString = Field( # type: ignore[assignment]
default="syncmaster.server.providers.auth.dummy_provider.DummyAuthProvider",
description="Full name of auth provider class",
validate_default=True,
)

class Config:
extra = "allow"
40 changes: 21 additions & 19 deletions tests/test_unit/test_auth/auth_fixtures/keycloak_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from base64 import b64encode

import pytest
import responses
import respx
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
from itsdangerous import TimestampSigner
from jose import jwt

from syncmaster.server.settings.auth.keycloak import KeycloakSettings


@pytest.fixture(scope="session")
def rsa_keys():
Expand Down Expand Up @@ -80,14 +82,15 @@ def _create_session_cookie(user, expire_in_msec=5000) -> str:


@pytest.fixture
@respx.mock
def mock_keycloak_well_known(settings):
server_url = settings.auth.dict()["keycloak"]["server_url"]
realm_name = settings.auth.dict()["keycloak"]["client_id"]
keycloak_settings = KeycloakSettings.model_validate(settings.auth.dict()["keycloak"])
server_url = keycloak_settings.server_url
realm_name = keycloak_settings.realm_name
well_known_url = f"{server_url}/realms/{realm_name}/.well-known/openid-configuration"

responses.add(
responses.GET,
well_known_url,
respx.get(well_known_url).respond(
status_code=200,
json={
"authorization_endpoint": f"{server_url}/realms/{realm_name}/protocol/openid-connect/auth",
"token_endpoint": f"{server_url}/realms/{realm_name}/protocol/openid-connect/token",
Expand All @@ -96,36 +99,37 @@ def mock_keycloak_well_known(settings):
"jwks_uri": f"{server_url}/realms/{realm_name}/protocol/openid-connect/certs",
"issuer": f"{server_url}/realms/{realm_name}",
},
status=200,
content_type="application/json",
)


@pytest.fixture
@respx.mock
def mock_keycloak_realm(settings, rsa_keys):
server_url = settings.auth.dict()["keycloak"]["server_url"]
realm_name = settings.auth.dict()["keycloak"]["client_id"]
keycloak_settings = KeycloakSettings.model_validate(settings.auth.dict()["keycloak"])
server_url = keycloak_settings.server_url
realm_name = keycloak_settings.realm_name
realm_url = f"{server_url}/realms/{realm_name}"
public_pem_str = get_public_key_pem(rsa_keys["public_key"])

responses.add(
responses.GET,
realm_url,
respx.get(realm_url).respond(
status_code=200,
json={
"realm": realm_name,
"public_key": public_pem_str,
"token-service": f"{server_url}/realms/{realm_name}/protocol/openid-connect/token",
"account-service": f"{server_url}/realms/{realm_name}/account",
},
status=200,
content_type="application/json",
)


@pytest.fixture
@respx.mock
def mock_keycloak_token_refresh(settings, rsa_keys):
server_url = settings.auth.dict()["keycloak"]["server_url"]
realm_name = settings.auth.dict()["keycloak"]["client_id"]
keycloak_settings = KeycloakSettings.model_validate(settings.auth.dict()["keycloak"])
server_url = keycloak_settings.server_url
realm_name = keycloak_settings.realm_name
token_url = f"{server_url}/realms/{realm_name}/protocol/openid-connect/token"

# generate new access and refresh tokens
Expand All @@ -144,15 +148,13 @@ def mock_keycloak_token_refresh(settings, rsa_keys):
new_access_token = jwt.encode(payload, private_pem, algorithm="RS256")
new_refresh_token = "mock_new_refresh_token"

responses.add(
responses.POST,
token_url,
respx.post(token_url).respond(
status_code=200,
json={
"access_token": new_access_token,
"refresh_token": new_refresh_token,
"token_type": "bearer",
"expires_in": expires_in,
},
status=200,
content_type="application/json",
)
6 changes: 0 additions & 6 deletions tests/test_unit/test_auth/test_auth_keycloak.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

import pytest
import responses
from httpx import AsyncClient

from syncmaster.server.settings import ServerAppSettings as Settings
Expand All @@ -11,7 +10,6 @@
pytestmark = [pytest.mark.asyncio, pytest.mark.server]


@responses.activate
@pytest.mark.parametrize(
"settings",
[
Expand All @@ -33,7 +31,6 @@ async def test_get_keycloak_user_unauthorized(client: AsyncClient, mock_keycloak
)


@responses.activate
@pytest.mark.parametrize(
"settings",
[
Expand Down Expand Up @@ -71,7 +68,6 @@ async def test_get_keycloak_user_authorized(
}


@responses.activate
@pytest.mark.parametrize(
"settings",
[
Expand Down Expand Up @@ -116,7 +112,6 @@ async def test_get_keycloak_user_expired_access_token(
}


@responses.activate
@pytest.mark.parametrize(
"settings",
[
Expand Down Expand Up @@ -155,7 +150,6 @@ async def test_get_keycloak_deleted_user(
}


@responses.activate
@pytest.mark.parametrize(
"settings",
[
Expand Down

0 comments on commit 7aa47ff

Please sign in to comment.