From 3de576c916ea122d21d2883abdb1fb959c98d3f8 Mon Sep 17 00:00:00 2001 From: Jamie Cockburn Date: Tue, 28 Feb 2023 17:37:20 +0000 Subject: [PATCH 01/10] Added middleware to refresh access tokens --- django_auth_adfs/backend.py | 36 +++++++++++- django_auth_adfs/config.py | 1 + django_auth_adfs/middleware.py | 16 ++++++ tests/test_authentication.py | 102 ++++++++++++++++----------------- 4 files changed, 101 insertions(+), 54 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 4574cb42..1c6d91e4 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -1,11 +1,13 @@ import logging +from datetime import datetime, timedelta import jwt -from django.contrib.auth import get_user_model +from django.contrib.auth import get_user_model, logout from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import Group from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist, PermissionDenied) +from requests import HTTPError from django_auth_adfs import signals from django_auth_adfs.config import provider_config, settings @@ -398,10 +400,38 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): provider_config.load_config() adfs_response = self.exchange_auth_code(authorization_code, request) - access_token = adfs_response["access_token"] - user = self.process_access_token(access_token, adfs_response) + user = self._process_adfs_response(request, adfs_response) return user + def _process_adfs_response(self, request, adfs_response): + user = self.process_access_token(adfs_response['access_token'], adfs_response) + request.session['adfs_access_token'] = adfs_response['access_token'] + expiry = datetime.now() + timedelta(seconds=adfs_response['expires_in']) + request.session['adfs_token_expiry'] = expiry.isoformat() + if 'refresh_token' in adfs_response: + request.session['adfs_refresh_token'] = adfs_response['refresh_token'] + request.session.save() + return user + + def process_request(self, request): + now = datetime.now() + settings.REFRESH_THRESHOLD + expiry = datetime.fromisoformat(request.session['adfs_token_expiry']) + if now > expiry: + try: + self._refresh_access_token(request, request.session['adfs_refresh_token']) + except (PermissionDenied, HTTPError): + logout(request) + + def _refresh_access_token(self, request, refresh_token): + provider_config.load_config() + response = provider_config.session.post( + provider_config.token_endpoint, + data=f'grant_type=refresh_token&refresh_token={refresh_token}' + ) + response.raise_for_status() + adfs_response = response.json() + self._process_adfs_response(request, adfs_response) + class AdfsAccessTokenBackend(AdfsBaseBackend): """ diff --git a/django_auth_adfs/config.py b/django_auth_adfs/config.py index 12c36dc9..9c7ebb74 100644 --- a/django_auth_adfs/config.py +++ b/django_auth_adfs/config.py @@ -72,6 +72,7 @@ def __init__(self): self.USERNAME_CLAIM = "winaccountname" self.GUEST_USERNAME_CLAIM = None self.JWT_LEEWAY = 0 + self.REFRESH_THRESHOLD = timedelta(minutes=5) self.CUSTOM_FAILED_RESPONSE_VIEW = lambda request, error_message, status: render( request, 'django_auth_adfs/login_failed.html', {'error_message': error_message}, status=status ) diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 649a2390..0b4c50eb 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -4,9 +4,11 @@ from re import compile from django.conf import settings as django_settings +from django.contrib import auth from django.contrib.auth.views import redirect_to_login from django.urls import reverse +from django_auth_adfs.backend import AdfsAuthCodeBackend from django_auth_adfs.exceptions import MFARequired from django_auth_adfs.config import settings @@ -49,3 +51,17 @@ def __call__(self, request): return redirect_to_login('django_auth_adfs:login-force-mfa') return self.get_response(request) + + +def adfs_refresh_middleware(get_response): + def middleware(request): + try: + backend_str = request.session[auth.BACKEND_SESSION_KEY] + except KeyError: + pass + else: + backend = auth.load_backend(backend_str) + if isinstance(backend, AdfsAuthCodeBackend): + backend.process_request(request) + return get_response() + return middleware diff --git a/tests/test_authentication.py b/tests/test_authentication.py index c16691fc..6d822a8f 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,5 +1,7 @@ import base64 +from django.urls import reverse + from django_auth_adfs.exceptions import MFARequired try: @@ -16,7 +18,6 @@ from mock import Mock, patch from django_auth_adfs import signals -from django_auth_adfs.backend import AdfsAuthCodeBackend from django_auth_adfs.config import ProviderConfig, Settings from .models import Profile @@ -34,14 +35,13 @@ def setUp(self): @mock_adfs("2012") def test_post_authenticate_signal_send(self): - backend = AdfsAuthCodeBackend() - backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) self.assertEqual(self.signal_handler.call_count, 1) @mock_adfs("2012") def test_with_auth_code_2012(self): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -52,8 +52,8 @@ def test_with_auth_code_2012(self): @mock_adfs("2016") def test_with_auth_code_2016(self): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -64,9 +64,15 @@ def test_with_auth_code_2016(self): @mock_adfs("2016", mfa_error=True) def test_mfa_error_backends(self): - with self.assertRaises(MFARequired): - backend = AdfsAuthCodeBackend() - backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + self.assertEqual(response.status_code, 302) + self.assertEqual( + response['Location'], + "https://adfs.example.com/adfs/oauth2/authorize/?response_type=code&" + "client_id=your-configured-client-id&resource=your-adfs-RPT-name&" + "redirect_uri=http%3A%2F%2Ftestserver%2Foauth2%2Fcallback&state=Lw%3D%3D&scope=openid&" + "amr_values=ngcmfa" + ) @mock_adfs("azure") def test_with_auth_code_azure(self): @@ -77,8 +83,8 @@ def test_with_auth_code_azure(self): with patch("django_auth_adfs.config.django_settings", settings): with patch("django_auth_adfs.config.settings", Settings()): with patch("django_auth_adfs.backend.provider_config", ProviderConfig()): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -100,9 +106,8 @@ def test_with_auth_code_azure_guest_block(self): with patch('django_auth_adfs.backend.settings', Settings()): with patch("django_auth_adfs.config.settings", Settings()): with patch("django_auth_adfs.backend.provider_config", ProviderConfig()): - with self.assertRaises(PermissionDenied, msg=''): - backend = AdfsAuthCodeBackend() - _ = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + self.assertEqual(response.status_code, 401) @mock_adfs("azure", guest=True) def test_with_auth_code_azure_guest_no_block(self): @@ -117,8 +122,8 @@ def test_with_auth_code_azure_guest_no_block(self): with patch('django_auth_adfs.backend.settings', Settings()): with patch("django_auth_adfs.config.settings", Settings()): with patch("django_auth_adfs.backend.provider_config", ProviderConfig()): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -139,8 +144,8 @@ def test_version_two_endpoint_calls_correct_url(self): with patch('django_auth_adfs.backend.settings', Settings()): with patch("django_auth_adfs.config.settings", Settings()): with patch("django_auth_adfs.backend.provider_config", ProviderConfig()): - backend = AdfsAuthCodeBackend() - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -151,14 +156,15 @@ def test_version_two_endpoint_calls_correct_url(self): @mock_adfs("2016") def test_empty(self): - backend = AdfsAuthCodeBackend() - self.assertIsNone(backend.authenticate(self.request)) + response = self.client.get(reverse('django_auth_adfs:callback')) + user = response.wsgi_request.user + self.assertTrue(user.is_anonymous) @mock_adfs("2016") def test_group_claim(self): - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.backend.settings.GROUPS_CLAIM", "nonexisting"): - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -167,9 +173,9 @@ def test_group_claim(self): @mock_adfs("2016") def test_no_group_claim(self): - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.backend.settings.GROUPS_CLAIM", None): - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -181,9 +187,9 @@ def test_group_claim_with_mirror_groups(self): # Remove one group Group.objects.filter(name="group1").delete() - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", True): - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -197,9 +203,9 @@ def test_group_claim_without_mirror_groups(self): # Remove one group Group.objects.filter(name="group1").delete() - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.backend.settings.MIRROR_GROUPS", False): - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -210,9 +216,9 @@ def test_group_claim_without_mirror_groups(self): @mock_adfs("2016", empty_keys=True) def test_empty_keys(self): - backend = AdfsAuthCodeBackend() with patch("django_auth_adfs.config.provider_config.signing_keys", []): - self.assertRaises(PermissionDenied, backend.authenticate, self.request, authorization_code='testcode') + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertEqual(response.status_code, 401) @mock_adfs("2016") def test_group_removal(self): @@ -227,9 +233,8 @@ def test_group_removal(self): self.assertEqual(user.groups.all()[0].name, "group3") self.assertEqual(len(user.groups.all()), 1) - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -253,9 +258,8 @@ def test_group_removal_overlap(self): self.assertEqual(user.groups.all()[1].name, "group3") self.assertEqual(len(user.groups.all()), 2) - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -272,9 +276,8 @@ def test_group_to_flag_mapping(self): } with patch("django_auth_adfs.backend.settings.GROUP_TO_FLAG_MAPPING", group_to_flag_mapping): with patch("django_auth_adfs.backend.settings.BOOLEAN_CLAIM_MAPPING", {}): - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -289,9 +292,8 @@ def test_boolean_claim_mapping(self): "is_superuser": "user_is_superuser", } with patch("django_auth_adfs.backend.settings.BOOLEAN_CLAIM_MAPPING", boolean_claim_mapping): - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -312,9 +314,8 @@ def test_extended_model_claim_mapping_missing_instance(self): }, } with patch("django_auth_adfs.backend.settings.CLAIM_MAPPING", claim_mapping): - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -340,9 +341,8 @@ def create_profile(sender, instance, created, **kwargs): }, } with patch("django_auth_adfs.backend.settings.CLAIM_MAPPING", claim_mapping): - backend = AdfsAuthCodeBackend() - - user = backend.authenticate(self.request, authorization_code="dummycode") + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + user = response.wsgi_request.user self.assertIsInstance(user, User) self.assertEqual(user.first_name, "John") self.assertEqual(user.last_name, "Doe") @@ -493,5 +493,5 @@ def test_nonexisting_user(self): settings.AUTH_ADFS["CREATE_NEW_USERS"] = False with patch("django_auth_adfs.config.django_settings", settings),\ patch("django_auth_adfs.backend.settings", Settings()): - backend = AdfsAuthCodeBackend() - self.assertRaises(PermissionDenied, backend.authenticate, self.request, authorization_code='testcode') + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertEqual(response.status_code, 401) From 9abbcbdd14f2f6fd9abd474247565f8dff468784 Mon Sep 17 00:00:00 2001 From: Jamie Cockburn Date: Tue, 28 Feb 2023 18:55:37 +0000 Subject: [PATCH 02/10] Added middleware to refresh access tokens --- django_auth_adfs/backend.py | 15 ++++++++----- django_auth_adfs/middleware.py | 2 +- tests/settings.py | 1 + tests/test_authentication.py | 40 ++++++++++++++++++++++++++++++---- tests/urls.py | 5 ++++- tests/utils.py | 37 ++++++++++++++++++++++++------- tests/views.py | 9 ++++++++ 7 files changed, 90 insertions(+), 19 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 1c6d91e4..8f75c948 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -396,6 +396,11 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): logger.debug("Authentication backend was called but no authorization code was received") return + # If there's no request object, we pass control to the next authentication backend + if request is None: + logger.debug("Authentication backend was called without request") + return + # If loaded data is too old, reload it again provider_config.load_config() @@ -405,20 +410,20 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): def _process_adfs_response(self, request, adfs_response): user = self.process_access_token(adfs_response['access_token'], adfs_response) - request.session['adfs_access_token'] = adfs_response['access_token'] + request.session['_adfs_access_token'] = adfs_response['access_token'] expiry = datetime.now() + timedelta(seconds=adfs_response['expires_in']) - request.session['adfs_token_expiry'] = expiry.isoformat() + request.session['_adfs_token_expiry'] = expiry.isoformat() if 'refresh_token' in adfs_response: - request.session['adfs_refresh_token'] = adfs_response['refresh_token'] + request.session['_adfs_refresh_token'] = adfs_response['refresh_token'] request.session.save() return user def process_request(self, request): now = datetime.now() + settings.REFRESH_THRESHOLD - expiry = datetime.fromisoformat(request.session['adfs_token_expiry']) + expiry = datetime.fromisoformat(request.session['_adfs_token_expiry']) if now > expiry: try: - self._refresh_access_token(request, request.session['adfs_refresh_token']) + self._refresh_access_token(request, request.session['_adfs_refresh_token']) except (PermissionDenied, HTTPError): logout(request) diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 0b4c50eb..0163ea78 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -63,5 +63,5 @@ def middleware(request): backend = auth.load_backend(backend_str) if isinstance(backend, AdfsAuthCodeBackend): backend.process_request(request) - return get_response() + return get_response(request) return middleware diff --git a/tests/settings.py b/tests/settings.py index 81d397c7..121507e0 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -35,6 +35,7 @@ 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware', + 'django_auth_adfs.middleware.adfs_refresh_middleware', 'django_auth_adfs.middleware.LoginRequiredMiddleware', ) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 6d822a8f..edd1b3bc 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,5 +1,7 @@ import base64 +from datetime import datetime, timedelta + from django.urls import reverse from django_auth_adfs.exceptions import MFARequired @@ -12,9 +14,9 @@ from copy import deepcopy from django.contrib.auth.models import Group, User -from django.core.exceptions import ObjectDoesNotExist, PermissionDenied +from django.core.exceptions import ObjectDoesNotExist from django.db.models.signals import post_save -from django.test import RequestFactory, TestCase +from django.test import TestCase from mock import Mock, patch from django_auth_adfs import signals @@ -29,13 +31,12 @@ def setUp(self): Group.objects.create(name='group1') Group.objects.create(name='group2') Group.objects.create(name='group3') - self.request = RequestFactory().get('/oauth2/callback') self.signal_handler = Mock() signals.post_authenticate.connect(self.signal_handler) @mock_adfs("2012") def test_post_authenticate_signal_send(self): - response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) + self.client.get(reverse('django_auth_adfs:callback'), data={'code': "dummycode"}) self.assertEqual(self.signal_handler.call_count, 1) @mock_adfs("2012") @@ -495,3 +496,34 @@ def test_nonexisting_user(self): patch("django_auth_adfs.backend.settings", Settings()): response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) self.assertEqual(response.status_code, 401) + + @mock_adfs("2016") + def test_access_token_unexpired(self): + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertFalse(response.wsgi_request.user.is_anonymous) + response = self.client.get(reverse('test')) + self.assertEqual(response.status_code, 200) + + @mock_adfs("2016") + def test_access_token_expired(self): + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertFalse(response.wsgi_request.user.is_anonymous) + fromisoformat = datetime.fromisoformat + with patch('django_auth_adfs.backend.datetime') as dt: + dt.fromisoformat = fromisoformat + dt.now.return_value = datetime.now() + timedelta(hours=1) + response = self.client.get(reverse('test')) + self.assertEqual(response.status_code, 200) + + @mock_adfs("2016", refresh_token_expired=True) + def test_refresh_token_expired(self): + response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) + self.assertFalse(response.wsgi_request.user.is_anonymous) + fromisoformat = datetime.fromisoformat + with patch('django_auth_adfs.backend.datetime') as dt: + dt.fromisoformat = fromisoformat + dt.now.return_value = datetime.now() + timedelta(hours=1) + response = self.client.get(reverse('test')) + self.assertEqual(response.status_code, 302) + self.assertEqual(response['Location'], f"{reverse('django_auth_adfs:login')}?next=/") + self.assertTrue(response.wsgi_request.user.is_anonymous) diff --git a/tests/urls.py b/tests/urls.py index e3a608df..9ad8a6e7 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,6 +1,9 @@ -from django.urls import include, re_path +from django.urls import include, re_path, path + +from tests.views import TestView urlpatterns = [ + path('', TestView.as_view(), name='test'), re_path(r'^oauth2/', include('django_auth_adfs.urls')), re_path(r'^oauth2/', include('django_auth_adfs.drf_urls')), ] diff --git a/tests/utils.py b/tests/utils.py index f6040d27..bda61c6d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,7 @@ import time from datetime import datetime, tzinfo, timedelta from functools import partial +from urllib.parse import parse_qs import jwt import responses @@ -98,9 +99,14 @@ def build_access_token_azure_groups_in_claim_source(request): return do_build_access_token(request, issuer, groups_in_claim_names=True) +def build_access_token_adfs_expired(request): + issuer = "http://adfs.example.com/adfs/services/trust" + return do_build_access_token(request, issuer, refresh_token_expired=True) + + def do_build_mfa_error(request): response = {'error_description': 'AADSTS50076'} - return 400, [], json.dumps(response) + return 400, {}, json.dumps(response) def do_build_graph_response(request): @@ -111,7 +117,11 @@ def do_build_graph_response_no_group_perm(request): return do_build_ms_graph_groups(request, missing_group_names=True) -def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None, groups_in_claim_names=False): +def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None, groups_in_claim_names=False, + refresh_token_expired=False): + data = parse_qs(request.body) + if data.get('grant_type') == ['refresh_token'] and data.get('refresh_token') == ['expired_refresh_token']: + return 401, {}, None issued_at = int(time.time()) expires = issued_at + 3600 auth_time = datetime.utcnow() @@ -159,16 +169,20 @@ def do_build_access_token(request, issuer, schema=None, no_upn=False, idp=None, } } token = jwt.encode(claims, signing_key_b, algorithm="RS256") + if refresh_token_expired: + refresh_token = 'expired_refresh_token' + else: + refresh_token = 'random_refresh_token' response = { 'resource': 'django_website.adfs.relying_party_id', 'token_type': 'bearer', 'refresh_token_expires_in': 28799, - 'refresh_token': 'random_refresh_token', + 'refresh_token': refresh_token, 'expires_in': 3600, 'id_token': 'not_used', 'access_token': token.decode() if isinstance(token, bytes) else token # PyJWT>=2 returns a str instead of bytes } - return 200, [], json.dumps(response) + return 200, {}, json.dumps(response) def do_build_obo_access_token(request): @@ -228,7 +242,7 @@ def do_build_obo_access_token(request): 'refresh_token': 'not_used', 'access_token': token.decode() if isinstance(token, bytes) else token # PyJWT>=2 returns a str instead of bytes } - return 200, [], json.dumps(response) + return 200, {}, json.dumps(response) def do_build_ms_graph_groups(request, missing_group_names=False): @@ -308,7 +322,7 @@ def do_build_ms_graph_groups(request, missing_group_names=False): if missing_group_names: for group in response["value"]: group["displayName"] = None - return 200, [], json.dumps(response) + return 200, {}, json.dumps(response) def build_openid_keys(request, empty_keys=False): @@ -337,7 +351,7 @@ def build_openid_keys(request, empty_keys=False): }, ] } - return 200, [], json.dumps(keys) + return 200, {}, json.dumps(keys) def build_adfs_meta(request): @@ -345,7 +359,7 @@ def build_adfs_meta(request): data = "".join(f.readlines()) data = data.replace("REPLACE_WITH_CERT_A", base64.b64encode(signing_cert_a).decode()) data = data.replace("REPLACE_WITH_CERT_B", base64.b64encode(signing_cert_b).decode()) - return 200, [], data + return 200, {}, data def mock_adfs( @@ -356,6 +370,7 @@ def mock_adfs( version=None, requires_obo=False, missing_graph_group_perm=False, + refresh_token_expired=False, ): if adfs_version not in ["2012", "2016", "azure"]: raise NotImplementedError("This version of ADFS is not implemented") @@ -465,6 +480,12 @@ def wrapper(*original_args, **original_kwargs): callback=do_build_mfa_error, content_type='application/json', ) + elif refresh_token_expired: + rsps.add_callback( + rsps.POST, token_endpoint, + callback=build_access_token_adfs_expired, + content_type='application/json', + ) else: rsps.add_callback( rsps.POST, token_endpoint, diff --git a/tests/views.py b/tests/views.py index b16e4025..7bb0bedd 100644 --- a/tests/views.py +++ b/tests/views.py @@ -1,2 +1,11 @@ +from django.http import HttpResponse +from django.views import View + + def test_failed_response(request, error_message, status): pass + + +class TestView(View): + def get(self, request): + return HttpResponse('okay') From b43f25906c9f2242f7e7ccf55c3b9e15db7ca2e2 Mon Sep 17 00:00:00 2001 From: Dominik Vogt Date: Tue, 2 Jul 2024 20:06:07 +0200 Subject: [PATCH 03/10] Moved refresh access token check into middleware --- django_auth_adfs/backend.py | 14 ++------------ django_auth_adfs/middleware.py | 16 +++++++++++++++- tests/test_authentication.py | 2 +- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 8f75c948..95640173 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -2,12 +2,11 @@ from datetime import datetime, timedelta import jwt -from django.contrib.auth import get_user_model, logout +from django.contrib.auth import get_user_model from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import Group from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist, PermissionDenied) -from requests import HTTPError from django_auth_adfs import signals from django_auth_adfs.config import provider_config, settings @@ -418,16 +417,7 @@ def _process_adfs_response(self, request, adfs_response): request.session.save() return user - def process_request(self, request): - now = datetime.now() + settings.REFRESH_THRESHOLD - expiry = datetime.fromisoformat(request.session['_adfs_token_expiry']) - if now > expiry: - try: - self._refresh_access_token(request, request.session['_adfs_refresh_token']) - except (PermissionDenied, HTTPError): - logout(request) - - def _refresh_access_token(self, request, refresh_token): + def refresh_access_token(self, request, refresh_token): provider_config.load_config() response = provider_config.session.post( provider_config.token_endpoint, diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 0163ea78..d39d9f0a 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -1,12 +1,17 @@ """ Based on https://djangosnippets.org/snippets/1179/ """ +import logging +from datetime import datetime from re import compile from django.conf import settings as django_settings from django.contrib import auth +from django.contrib.auth import logout from django.contrib.auth.views import redirect_to_login +from django.core.exceptions import PermissionDenied from django.urls import reverse +from requests import HTTPError from django_auth_adfs.backend import AdfsAuthCodeBackend from django_auth_adfs.exceptions import MFARequired @@ -21,6 +26,8 @@ if hasattr(settings, 'LOGIN_EXEMPT_URLS'): LOGIN_EXEMPT_URLS += [compile(expr) for expr in settings.LOGIN_EXEMPT_URLS] +logger = logging.getLogger("django_auth_adfs") + class LoginRequiredMiddleware: """ @@ -62,6 +69,13 @@ def middleware(request): else: backend = auth.load_backend(backend_str) if isinstance(backend, AdfsAuthCodeBackend): - backend.process_request(request) + now = datetime.now() + settings.REFRESH_THRESHOLD + expiry = datetime.fromisoformat(request.session['_adfs_token_expiry']) + if now > expiry: + try: + backend.refresh_access_token(request, request.session['_adfs_refresh_token']) + except (PermissionDenied, HTTPError) as error: + logger.debug("Error refreshing access token: %s", error) + logout(request) return get_response(request) return middleware diff --git a/tests/test_authentication.py b/tests/test_authentication.py index edd1b3bc..dfc788bd 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -520,7 +520,7 @@ def test_refresh_token_expired(self): response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) self.assertFalse(response.wsgi_request.user.is_anonymous) fromisoformat = datetime.fromisoformat - with patch('django_auth_adfs.backend.datetime') as dt: + with patch('django_auth_adfs.middleware.datetime') as dt: dt.fromisoformat = fromisoformat dt.now.return_value = datetime.now() + timedelta(hours=1) response = self.client.get(reverse('test')) From 2eb52e92aafddccf364bb26b29c93e54b7f5d248 Mon Sep 17 00:00:00 2001 From: sdev95 <> Date: Tue, 3 Jun 2025 23:41:14 +0200 Subject: [PATCH 04/10] - AdfsMiddleware checks SESSION_ENGINE - Adfs refresh/access_tokens only stored when middlware activated --- django_auth_adfs/backend.py | 23 ++++++++++++++--------- django_auth_adfs/middleware.py | 33 +++++++++++++++++++++++++++------ 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index c417b513..18976638 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -2,6 +2,7 @@ from datetime import datetime, timedelta import jwt +from django.conf import settings as django_settings from django.contrib.auth import get_user_model from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import Group @@ -425,18 +426,22 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): provider_config.load_config() adfs_response = self.exchange_auth_code(authorization_code, request) - user = self._process_adfs_response(request, adfs_response) + access_token = adfs_response["access_token"] + user = self.process_access_token(access_token, adfs_response) + if ("django_auth_adfs.middleware.AdfsRefreshMiddleware" in django_settings.MIDDLEWARE): + self._store_adfs_tokens_in_session(request, adfs_response) return user - def _process_adfs_response(self, request, adfs_response): - user = self.process_access_token(adfs_response['access_token'], adfs_response) - request.session['_adfs_access_token'] = adfs_response['access_token'] - expiry = datetime.now() + timedelta(seconds=adfs_response['expires_in']) - request.session['_adfs_token_expiry'] = expiry.isoformat() - if 'refresh_token' in adfs_response: - request.session['_adfs_refresh_token'] = adfs_response['refresh_token'] + def _store_adfs_tokens_in_session(self, request, adfs_response): + assert "refresh_token" in adfs_response, ( + "AdfsRefreshMiddleware requires a refresh token to function correctly. " + "Make sure your ADFS server is configured to return a refresh token." + ) + request.session["_adfs_access_token"] = adfs_response["access_token"] + expiry = datetime.now() + timedelta(seconds=int(adfs_response["expires_in"])) + request.session["_adfs_token_expiry"] = expiry.isoformat() + request.session["_adfs_refresh_token"] = adfs_response["refresh_token"] request.session.save() - return user def refresh_access_token(self, request, refresh_token): provider_config.load_config() diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index d39d9f0a..a1f348c9 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -60,8 +60,28 @@ def __call__(self, request): return self.get_response(request) -def adfs_refresh_middleware(get_response): - def middleware(request): +class AdfsRefreshMiddleware: + """ + Middleware that refreshes the access token for the user if it is close to + expiring. This is done by checking the session for the '_adfs_token_expiry' + key and comparing it with the current time plus a threshold defined in + settings.REFRESH_THRESHOLD. + """ + + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if hasattr(django_settings, "SESSION_ENGINE"): + assert ( + django_settings.SESSION_ENGINE + != "django.contrib.sessions.backends.signed_cookies" + ), ( + "You are trying to use ADFS Refresh middleware with signed cookie-based sessions. " + "For security reasons, we do not recommend this configuration. " + "Please change SESSION_ENGINE to a different backend, such as 'django.contrib.sessions.backends.db' " + ) + try: backend_str = request.session[auth.BACKEND_SESSION_KEY] except KeyError: @@ -70,12 +90,13 @@ def middleware(request): backend = auth.load_backend(backend_str) if isinstance(backend, AdfsAuthCodeBackend): now = datetime.now() + settings.REFRESH_THRESHOLD - expiry = datetime.fromisoformat(request.session['_adfs_token_expiry']) + expiry = datetime.fromisoformat(request.session["_adfs_token_expiry"]) if now > expiry: try: - backend.refresh_access_token(request, request.session['_adfs_refresh_token']) + backend.refresh_access_token( + request, request.session["_adfs_refresh_token"] + ) except (PermissionDenied, HTTPError) as error: logger.debug("Error refreshing access token: %s", error) logout(request) - return get_response(request) - return middleware + return self.get_response(request) From 66c2303f549caf8c9b51f65ca568a59346e2c5d5 Mon Sep 17 00:00:00 2001 From: sdev95 <> Date: Wed, 4 Jun 2025 11:18:45 +0200 Subject: [PATCH 05/10] Correct middleware in tests --- tests/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/settings.py b/tests/settings.py index 121507e0..5f6346e2 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -35,7 +35,7 @@ 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware', - 'django_auth_adfs.middleware.adfs_refresh_middleware', + 'django_auth_adfs.middleware.AdfsRefreshMiddleware', 'django_auth_adfs.middleware.LoginRequiredMiddleware', ) From eed531ee2cfef2571fce63a391fe743bf1a02cc3 Mon Sep 17 00:00:00 2001 From: sdev95 <> Date: Wed, 4 Jun 2025 15:16:06 +0200 Subject: [PATCH 06/10] Added AdfsAuthCodeRefreshBackend --- django_auth_adfs/backend.py | 63 ++++++++++++++++++++++++++-------- django_auth_adfs/middleware.py | 20 +++++------ 2 files changed, 56 insertions(+), 27 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 18976638..5b30adbe 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -3,11 +3,12 @@ import jwt from django.conf import settings as django_settings -from django.contrib.auth import get_user_model +from django.contrib.auth import get_user_model, logout from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import Group from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist, PermissionDenied) +from requests import HTTPError from django_auth_adfs import signals from django_auth_adfs.config import provider_config, settings @@ -411,6 +412,26 @@ class AdfsAuthCodeBackend(AdfsBaseBackend): Microsoft ADFS server with an authorization code. """ + def authenticate(self, request=None, authorization_code=None, **kwargs): + # If there's no token or code, we pass control to the next authentication backend + if authorization_code is None or authorization_code == '': + logger.debug("Authentication backend was called but no authorization code was received") + return + + # If loaded data is too old, reload it again + provider_config.load_config() + + adfs_response = self.exchange_auth_code(authorization_code, request) + access_token = adfs_response["access_token"] + user = self.process_access_token(access_token, adfs_response) + return user + + +class AdfsAuthCodeRefreshBackend(AdfsBaseBackend): + """ + Authentication backend that supports storing and refreshing ADFS tokens in the session. + Use this backend in conjunction with AdfsRefreshMiddleware. + """ def authenticate(self, request=None, authorization_code=None, **kwargs): # If there's no token or code, we pass control to the next authentication backend if authorization_code is None or authorization_code == '': @@ -428,22 +449,24 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): adfs_response = self.exchange_auth_code(authorization_code, request) access_token = adfs_response["access_token"] user = self.process_access_token(access_token, adfs_response) - if ("django_auth_adfs.middleware.AdfsRefreshMiddleware" in django_settings.MIDDLEWARE): - self._store_adfs_tokens_in_session(request, adfs_response) + self._store_adfs_tokens_in_session(request, adfs_response) return user - def _store_adfs_tokens_in_session(self, request, adfs_response): - assert "refresh_token" in adfs_response, ( - "AdfsRefreshMiddleware requires a refresh token to function correctly. " - "Make sure your ADFS server is configured to return a refresh token." - ) - request.session["_adfs_access_token"] = adfs_response["access_token"] - expiry = datetime.now() + timedelta(seconds=int(adfs_response["expires_in"])) - request.session["_adfs_token_expiry"] = expiry.isoformat() - request.session["_adfs_refresh_token"] = adfs_response["refresh_token"] - request.session.save() + + def ensure_valid_access_token(self, request): + now = datetime.now() + settings.REFRESH_THRESHOLD + expiry = datetime.fromisoformat(request.session["_adfs_token_expiry"]) + if now > expiry: + try: + adfs_refresh_response = self._refresh_access_token( + request.session["_adfs_refresh_token"] + ) + self._store_adfs_tokens_in_session(request, adfs_refresh_response) + except (PermissionDenied, HTTPError) as error: + logger.debug("Error refreshing access token: %s", error) + logout(request) - def refresh_access_token(self, request, refresh_token): + def _refresh_access_token(self, refresh_token): provider_config.load_config() response = provider_config.session.post( provider_config.token_endpoint, @@ -451,8 +474,18 @@ def refresh_access_token(self, request, refresh_token): ) response.raise_for_status() adfs_response = response.json() - self._process_adfs_response(request, adfs_response) + return adfs_response + def _store_adfs_tokens_in_session(self, request, adfs_response): + assert "refresh_token" in adfs_response, ( + "AdfsAuthCodeRefreshBackend requires a refresh token to function correctly. " + "Make sure your ADFS server is configured to return a refresh token." + ) + request.session["_adfs_access_token"] = adfs_response["access_token"] + expiry = datetime.now() + timedelta(seconds=int(adfs_response["expires_in"])) + request.session["_adfs_token_expiry"] = expiry.isoformat() + request.session["_adfs_refresh_token"] = adfs_response["refresh_token"] + request.session.save() class AdfsAccessTokenBackend(AdfsBaseBackend): """ diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index a1f348c9..79f7cca9 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -13,7 +13,7 @@ from django.urls import reverse from requests import HTTPError -from django_auth_adfs.backend import AdfsAuthCodeBackend +from django_auth_adfs.backend import AdfsAuthCodeRefreshBackend from django_auth_adfs.exceptions import MFARequired from django_auth_adfs.config import settings @@ -88,15 +88,11 @@ def __call__(self, request): pass else: backend = auth.load_backend(backend_str) - if isinstance(backend, AdfsAuthCodeBackend): - now = datetime.now() + settings.REFRESH_THRESHOLD - expiry = datetime.fromisoformat(request.session["_adfs_token_expiry"]) - if now > expiry: - try: - backend.refresh_access_token( - request, request.session["_adfs_refresh_token"] - ) - except (PermissionDenied, HTTPError) as error: - logger.debug("Error refreshing access token: %s", error) - logout(request) + if isinstance(backend, AdfsAuthCodeRefreshBackend): + backend.check_and_refresh_access_token(request) + else: + assert ( + "ADFS Refresh middleware is only applicable to AdfsAuthCodeRefreshBackend, " + "but found %s", backend_str + ) return self.get_response(request) From 104ed3345d24a1c3ad2d50929b5f3c65a5258104 Mon Sep 17 00:00:00 2001 From: sdev95 <> Date: Wed, 4 Jun 2025 15:22:30 +0200 Subject: [PATCH 07/10] Removed middleware from test settings --- tests/settings.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/settings.py b/tests/settings.py index 5f6346e2..81d397c7 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -35,7 +35,6 @@ 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware', - 'django_auth_adfs.middleware.AdfsRefreshMiddleware', 'django_auth_adfs.middleware.LoginRequiredMiddleware', ) From 41a87e232d0169b37f9c730f5c9337509f3adecf Mon Sep 17 00:00:00 2001 From: sdev95 <> Date: Wed, 4 Jun 2025 15:29:03 +0200 Subject: [PATCH 08/10] Solve lint issues --- django_auth_adfs/backend.py | 5 ++--- django_auth_adfs/middleware.py | 11 ++--------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 5b30adbe..5b6ad689 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -2,7 +2,6 @@ from datetime import datetime, timedelta import jwt -from django.conf import settings as django_settings from django.contrib.auth import get_user_model, logout from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import Group @@ -451,8 +450,7 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): user = self.process_access_token(access_token, adfs_response) self._store_adfs_tokens_in_session(request, adfs_response) return user - - + def ensure_valid_access_token(self, request): now = datetime.now() + settings.REFRESH_THRESHOLD expiry = datetime.fromisoformat(request.session["_adfs_token_expiry"]) @@ -487,6 +485,7 @@ def _store_adfs_tokens_in_session(self, request, adfs_response): request.session["_adfs_refresh_token"] = adfs_response["refresh_token"] request.session.save() + class AdfsAccessTokenBackend(AdfsBaseBackend): """ Authentication backend to allow authenticating users against a diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 79f7cca9..9ba9193c 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -2,16 +2,12 @@ Based on https://djangosnippets.org/snippets/1179/ """ import logging -from datetime import datetime from re import compile from django.conf import settings as django_settings from django.contrib import auth -from django.contrib.auth import logout from django.contrib.auth.views import redirect_to_login -from django.core.exceptions import PermissionDenied from django.urls import reverse -from requests import HTTPError from django_auth_adfs.backend import AdfsAuthCodeRefreshBackend from django_auth_adfs.exceptions import MFARequired @@ -90,9 +86,6 @@ def __call__(self, request): backend = auth.load_backend(backend_str) if isinstance(backend, AdfsAuthCodeRefreshBackend): backend.check_and_refresh_access_token(request) - else: - assert ( - "ADFS Refresh middleware is only applicable to AdfsAuthCodeRefreshBackend, " - "but found %s", backend_str - ) + else: + assert "ADFS Refresh middleware is only applicable to AdfsAuthCodeRefreshBackend" return self.get_response(request) From a7bcd37815bb981b4e50c9b40a4bc4452b72df16 Mon Sep 17 00:00:00 2001 From: sdev95 <> Date: Wed, 4 Jun 2025 20:36:54 +0200 Subject: [PATCH 09/10] Rename AdfsAccessTokenRefreshBackend --- django_auth_adfs/backend.py | 9 +++++---- django_auth_adfs/middleware.py | 8 +++----- tests/test_authentication.py | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 5b6ad689..9726ccd8 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -426,7 +426,7 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): return user -class AdfsAuthCodeRefreshBackend(AdfsBaseBackend): +class AdfsAccessTokenRefreshBackend(AdfsBaseBackend): """ Authentication backend that supports storing and refreshing ADFS tokens in the session. Use this backend in conjunction with AdfsRefreshMiddleware. @@ -450,7 +450,7 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): user = self.process_access_token(access_token, adfs_response) self._store_adfs_tokens_in_session(request, adfs_response) return user - + def ensure_valid_access_token(self, request): now = datetime.now() + settings.REFRESH_THRESHOLD expiry = datetime.fromisoformat(request.session["_adfs_token_expiry"]) @@ -468,7 +468,8 @@ def _refresh_access_token(self, refresh_token): provider_config.load_config() response = provider_config.session.post( provider_config.token_endpoint, - data=f'grant_type=refresh_token&refresh_token={refresh_token}' + data=f'client_id={settings.CLIENT_ID}&client_secret={settings.CLIENT_SECRET}&grant_type=refresh_token' + + f'&refresh_token={refresh_token}' ) response.raise_for_status() adfs_response = response.json() @@ -476,7 +477,7 @@ def _refresh_access_token(self, refresh_token): def _store_adfs_tokens_in_session(self, request, adfs_response): assert "refresh_token" in adfs_response, ( - "AdfsAuthCodeRefreshBackend requires a refresh token to function correctly. " + "AdfsAccessTokenRefreshBackend requires a refresh token to function correctly. " "Make sure your ADFS server is configured to return a refresh token." ) request.session["_adfs_access_token"] = adfs_response["access_token"] diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 9ba9193c..1baace55 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -9,7 +9,7 @@ from django.contrib.auth.views import redirect_to_login from django.urls import reverse -from django_auth_adfs.backend import AdfsAuthCodeRefreshBackend +from django_auth_adfs.backend import AdfsAccessTokenRefreshBackend from django_auth_adfs.exceptions import MFARequired from django_auth_adfs.config import settings @@ -84,8 +84,6 @@ def __call__(self, request): pass else: backend = auth.load_backend(backend_str) - if isinstance(backend, AdfsAuthCodeRefreshBackend): - backend.check_and_refresh_access_token(request) - else: - assert "ADFS Refresh middleware is only applicable to AdfsAuthCodeRefreshBackend" + if isinstance(backend, AdfsAccessTokenRefreshBackend): + backend.ensure_valid_access_token(request) return self.get_response(request) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 9ecba8c7..15e4ab5a 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -548,7 +548,7 @@ def test_refresh_token_expired(self): response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) self.assertFalse(response.wsgi_request.user.is_anonymous) fromisoformat = datetime.fromisoformat - with patch('django_auth_adfs.middleware.datetime') as dt: + with patch('django_auth_adfs.backend.datetime') as dt: dt.fromisoformat = fromisoformat dt.now.return_value = datetime.now() + timedelta(hours=1) response = self.client.get(reverse('test')) From b7f03fa7bb93e28e7bfc2bc6a5fd632d7f1ae4f2 Mon Sep 17 00:00:00 2001 From: sdev95 <> Date: Wed, 4 Jun 2025 22:02:05 +0200 Subject: [PATCH 10/10] AuthCodeRefresh and tests --- django_auth_adfs/backend.py | 19 +++++++------------ django_auth_adfs/middleware.py | 15 ++++++++++++--- tests/test_authentication.py | 17 +++++++++++++++-- 3 files changed, 34 insertions(+), 17 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index 9726ccd8..20abd24e 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -2,12 +2,11 @@ from datetime import datetime, timedelta import jwt -from django.contrib.auth import get_user_model, logout +from django.contrib.auth import get_user_model from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import Group from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist, PermissionDenied) -from requests import HTTPError from django_auth_adfs import signals from django_auth_adfs.config import provider_config, settings @@ -426,7 +425,7 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): return user -class AdfsAccessTokenRefreshBackend(AdfsBaseBackend): +class AdfsAuthCodeRefreshBackend(AdfsBaseBackend): """ Authentication backend that supports storing and refreshing ADFS tokens in the session. Use this backend in conjunction with AdfsRefreshMiddleware. @@ -455,14 +454,10 @@ def ensure_valid_access_token(self, request): now = datetime.now() + settings.REFRESH_THRESHOLD expiry = datetime.fromisoformat(request.session["_adfs_token_expiry"]) if now > expiry: - try: - adfs_refresh_response = self._refresh_access_token( - request.session["_adfs_refresh_token"] - ) - self._store_adfs_tokens_in_session(request, adfs_refresh_response) - except (PermissionDenied, HTTPError) as error: - logger.debug("Error refreshing access token: %s", error) - logout(request) + adfs_refresh_response = self._refresh_access_token( + request.session["_adfs_refresh_token"] + ) + self._store_adfs_tokens_in_session(request, adfs_refresh_response) def _refresh_access_token(self, refresh_token): provider_config.load_config() @@ -477,7 +472,7 @@ def _refresh_access_token(self, refresh_token): def _store_adfs_tokens_in_session(self, request, adfs_response): assert "refresh_token" in adfs_response, ( - "AdfsAccessTokenRefreshBackend requires a refresh token to function correctly. " + "AdfsAuthCodeRefreshBackend requires a refresh token to function correctly. " "Make sure your ADFS server is configured to return a refresh token." ) request.session["_adfs_access_token"] = adfs_response["access_token"] diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 1baace55..2c506fd5 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -3,13 +3,17 @@ """ import logging from re import compile +from requests import HTTPError from django.conf import settings as django_settings from django.contrib import auth from django.contrib.auth.views import redirect_to_login +from django.contrib.auth import logout +from django.core.exceptions import (PermissionDenied) + from django.urls import reverse -from django_auth_adfs.backend import AdfsAccessTokenRefreshBackend +from django_auth_adfs.backend import AdfsAuthCodeRefreshBackend from django_auth_adfs.exceptions import MFARequired from django_auth_adfs.config import settings @@ -84,6 +88,11 @@ def __call__(self, request): pass else: backend = auth.load_backend(backend_str) - if isinstance(backend, AdfsAccessTokenRefreshBackend): - backend.ensure_valid_access_token(request) + if isinstance(backend, AdfsAuthCodeRefreshBackend): + try: + backend.ensure_valid_access_token(request) + except (PermissionDenied, HTTPError) as error: + logger.debug("Error refreshing access token: %s", error) + logout(request) + return self.get_response(request) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 15e4ab5a..37fa615f 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -16,7 +16,7 @@ from django.contrib.auth.models import Group, User from django.core.exceptions import ObjectDoesNotExist from django.db.models.signals import post_save -from django.test import TestCase +from django.test import TestCase, override_settings from mock import Mock, patch from django_auth_adfs import signals @@ -543,6 +543,19 @@ def test_access_token_expired(self): response = self.client.get(reverse('test')) self.assertEqual(response.status_code, 200) + @override_settings(AUTHENTICATION_BACKENDS=['django_auth_adfs.backend.AdfsAuthCodeRefreshBackend']) + @override_settings( + MIDDLEWARE=[ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.middleware.common.CommonMiddleware', + 'django.middleware.csrf.CsrfViewMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + 'django.middleware.clickjacking.XFrameOptionsMiddleware', + "django_auth_adfs.middleware.AdfsRefreshMiddleware", + "django_auth_adfs.middleware.LoginRequiredMiddleware", + ] + ) @mock_adfs("2016", refresh_token_expired=True) def test_refresh_token_expired(self): response = self.client.get(reverse('django_auth_adfs:callback'), data={'code': "testcode"}) @@ -550,7 +563,7 @@ def test_refresh_token_expired(self): fromisoformat = datetime.fromisoformat with patch('django_auth_adfs.backend.datetime') as dt: dt.fromisoformat = fromisoformat - dt.now.return_value = datetime.now() + timedelta(hours=1) + dt.now.return_value = datetime.now() + timedelta(hours=2) response = self.client.get(reverse('test')) self.assertEqual(response.status_code, 302) self.assertEqual(response['Location'], f"{reverse('django_auth_adfs:login')}?next=/")