Skip to content

Commit

Permalink
refactor(oauth): change to approach of registering client in login view
Browse files Browse the repository at this point in the history
the register method seems to be idempotent, so it's fine to call it like
this. the approach of registering during view import made importing the
module complicated (e.g. when importing from the module for our tests)
  • Loading branch information
angela-tran committed Jun 10, 2024
1 parent 715b3b4 commit fea7087
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 70 deletions.
28 changes: 11 additions & 17 deletions benefits/oauth/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@

from authlib.integrations.django_client import OAuth

from benefits.core.models import AuthProvider


logger = logging.getLogger(__name__)

oauth = OAuth()
Expand Down Expand Up @@ -42,23 +39,20 @@ def _authorize_params(scheme):
return params


def register_providers(oauth_registry):
def register_provider(oauth_registry, provider):
"""
Register OAuth clients into the given registry, using configuration from AuthProvider models.
Register OAuth clients into the given registry, using configuration from AuthProvider model.
Adapted from https://stackoverflow.com/a/64174413.
"""
logger.info("Registering OAuth clients")

providers = AuthProvider.objects.all()
logger.debug(f"Registering OAuth client: {provider.client_name}")

for provider in providers:
logger.debug(f"Registering OAuth client: {provider.client_name}")
client = oauth_registry.register(
provider.client_name,
client_id=provider.client_id,
server_metadata_url=_server_metadata_url(provider.authority),
client_kwargs=_client_kwargs(provider.scope),
authorize_params=_authorize_params(provider.scheme),
)

oauth_registry.register(
provider.client_name,
client_id=provider.client_id,
server_metadata_url=_server_metadata_url(provider.authority),
client_kwargs=_client_kwargs(provider.scope),
authorize_params=_authorize_params(provider.scheme),
)
return client
8 changes: 3 additions & 5 deletions benefits/oauth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from benefits.core import session
from . import analytics, redirects
from .client import oauth, register_providers
from .client import oauth, register_provider
from .middleware import VerifierUsesAuthVerificationSessionRequired


Expand All @@ -20,14 +20,12 @@
ROUTE_POST_LOGOUT = "oauth:post_logout"


register_providers(oauth)


@decorator_from_middleware(VerifierUsesAuthVerificationSessionRequired)
def login(request):
"""View implementing OIDC authorize_redirect."""
verifier = session.verifier(request)
oauth_client = oauth.create_client(verifier.auth_provider.client_name)

oauth_client = register_provider(oauth, verifier.auth_provider)

if not oauth_client:
raise Exception(f"oauth_client not registered: {verifier.auth_provider.client_name}")
Expand Down
23 changes: 0 additions & 23 deletions tests/pytest/oauth/test_app.py

This file was deleted.

38 changes: 13 additions & 25 deletions tests/pytest/oauth/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from benefits.core.models import AuthProvider
from benefits.oauth.client import _client_kwargs, _server_metadata_url, _authorize_params, register_providers
from benefits.oauth.client import _client_kwargs, _server_metadata_url, _authorize_params, register_provider


def test_client_kwargs():
Expand Down Expand Up @@ -39,33 +39,21 @@ def test_authorize_params_no_scheme():


@pytest.mark.django_db
def test_register_providers(mocker, mocked_oauth_registry):
mock_providers = []

for i in range(3):
p = mocker.Mock(spec=AuthProvider)
p.client_name = f"client_name_{i}"
p.client_id = f"client_id_{i}"
mock_providers.append(p)

mocked_client_provider = mocker.patch("benefits.oauth.client.AuthProvider")
mocked_client_provider.objects.all.return_value = mock_providers
def test_register_provider(mocker, mocked_oauth_registry):
mocked_client_provider = mocker.Mock(spec=AuthProvider)
mocked_client_provider.client_name = "client_name_1"
mocked_client_provider.client_id = "client_id_1"

mocker.patch("benefits.oauth.client._client_kwargs", return_value={"client": "kwargs"})
mocker.patch("benefits.oauth.client._server_metadata_url", return_value="https://metadata.url")
mocker.patch("benefits.oauth.client._authorize_params", return_value={"scheme": "test_scheme"})

register_providers(mocked_oauth_registry)

mocked_client_provider.objects.all.assert_called_once()

for provider in mock_providers:
i = mock_providers.index(provider)
register_provider(mocked_oauth_registry, mocked_client_provider)

mocked_oauth_registry.register.assert_any_call(
f"client_name_{i}",
client_id=f"client_id_{i}",
server_metadata_url="https://metadata.url",
client_kwargs={"client": "kwargs"},
authorize_params={"scheme": "test_scheme"},
)
mocked_oauth_registry.register.assert_any_call(
"client_name_1",
client_id="client_id_1",
server_metadata_url="https://metadata.url",
client_kwargs={"client": "kwargs"},
authorize_params={"scheme": "test_scheme"},
)

0 comments on commit fea7087

Please sign in to comment.