From 86b251989c0fb51804a28762db9a21d0c6f9728d Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Fri, 16 Aug 2024 07:23:32 -0700 Subject: [PATCH] Use retrieven_current_databases in django_db marked tests. --- tests/test_hybrid.py | 5 +-- tests/test_models.py | 13 +++---- tests/test_oauth2_validators.py | 7 ++-- tests/test_oidc_views.py | 61 +++++++++++++++++---------------- 4 files changed, 45 insertions(+), 41 deletions(-) diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index 87c4b0ad9..67c29a54e 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -22,6 +22,7 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase +from .common_testing import retrieve_current_databases from .utils import get_basic_auth_header, spy_on @@ -1319,7 +1320,7 @@ def test_pre_auth_default_scopes(self): self.assertEqual(form["client_id"].value(), self.application.client_id) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_application, client, oidc_key): client.force_login(test_user) @@ -1368,7 +1369,7 @@ def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_app assert claims["nonce"] == "random_nonce_string" -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_claims_passed_to_code_generation( oauth2_settings, test_user, hybrid_application, client, mocker, oidc_key diff --git a/tests/test_models.py b/tests/test_models.py index cd4b7342c..58765db69 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -20,6 +20,7 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase +from .common_testing import retrieve_current_databases CLEARTEXT_SECRET = "1234567890abcdefghijklmnopqrstuvwxyz" @@ -466,7 +467,7 @@ def test_clear_expired_tokens_with_tokens(self): assert remaining_gt_count == initial_gt_count // 2, "half the remaining grants should still exist." -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_id_token_methods(oidc_tokens, rf): id_token = IDToken.objects.get() @@ -501,7 +502,7 @@ def test_id_token_methods(oidc_tokens, rf): assert IDToken.objects.filter(jti=id_token.jti).count() == 0 -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf): id_token = IDToken.objects.get() @@ -540,7 +541,7 @@ def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf): assert not IDToken.objects.filter(jti=id_token.jti).exists() -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_application_key(oauth2_settings, application): # RS256 key @@ -565,7 +566,7 @@ def test_application_key(oauth2_settings, application): assert "This application does not support signed tokens" == str(exc.value) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_application_clean(oauth2_settings, application): # RS256, RSA key is configured @@ -605,7 +606,7 @@ def test_application_clean(oauth2_settings, application): application.clean() -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_DEFAULT) def test_application_origin_allowed_default_https(oauth2_settings, cors_application): """Test that http schemes are not allowed because ALLOWED_SCHEMES allows only https""" @@ -613,7 +614,7 @@ def test_application_origin_allowed_default_https(oauth2_settings, cors_applicat assert not cors_application.origin_allowed("http://example.com") -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_HTTP) def test_application_origin_allowed_http(oauth2_settings, cors_application): """Test that http schemes are allowed because http was added to ALLOWED_SCHEMES""" diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index bf06b73a8..468e05598 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -17,6 +17,7 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase from .common_testing import OAuth2ProviderTransactionTestCase as TransactionTestCase +from .common_testing import retrieve_current_databases from .utils import get_basic_auth_header @@ -546,7 +547,7 @@ def test_get_jwt_bearer_token(oauth2_settings, mocker): assert mock_get_id_token.call_args[1] == {} -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_validate_id_token_expired_jwt(oauth2_settings, mocker, oidc_tokens): mocker.patch("oauth2_provider.oauth2_validators.jwt.JWT", side_effect=jwt.JWTExpired) @@ -562,7 +563,7 @@ def test_validate_id_token_no_token(oauth2_settings, mocker): assert status is False -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens): oidc_tokens.application.delete() @@ -571,7 +572,7 @@ def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens): assert status is False -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_validate_id_token_bad_token_no_aud(oauth2_settings, mocker, oidc_key): token = jwt.JWT(header=json.dumps({"alg": "RS256"}), claims=json.dumps({"bad": "token"})) diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py index 8949f41e7..8bdf18360 100644 --- a/tests/test_oidc_views.py +++ b/tests/test_oidc_views.py @@ -19,6 +19,7 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase +from .common_testing import retrieve_current_databases @pytest.mark.usefixtures("oauth2_settings") @@ -221,7 +222,7 @@ def mock_request_for(user): return request -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_validate_logout_request(oidc_tokens, public_application, rp_settings): oidc_tokens = oidc_tokens application = oidc_tokens.application @@ -299,7 +300,7 @@ def test_validate_logout_request(oidc_tokens, public_application, rp_settings): ) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.parametrize("ALWAYS_PROMPT", [True, False]) def test_must_prompt(oidc_tokens, other_user, rp_settings, ALWAYS_PROMPT): rp_settings.OIDC_RP_INITIATED_LOGOUT_ALWAYS_PROMPT = ALWAYS_PROMPT @@ -320,14 +321,14 @@ def is_logged_in(client): return get_user(client).is_authenticated -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get(logged_in_client, rp_settings): rsp = logged_in_client.get(reverse("oauth2_provider:rp-initiated-logout"), data={}) assert rsp.status_code == 200 assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get_id_token(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), data={"id_token_hint": oidc_tokens.id_token} @@ -337,7 +338,7 @@ def test_rp_initiated_logout_get_id_token(logged_in_client, oidc_tokens, rp_sett assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get_revoked_id_token(logged_in_client, oidc_tokens, rp_settings): validator = oauth2_settings.OAUTH2_VALIDATOR_CLASS() validator._load_id_token(oidc_tokens.id_token).revoke() @@ -348,7 +349,7 @@ def test_rp_initiated_logout_get_revoked_id_token(logged_in_client, oidc_tokens, assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get_id_token_redirect(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), @@ -359,7 +360,7 @@ def test_rp_initiated_logout_get_id_token_redirect(logged_in_client, oidc_tokens assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get_id_token_redirect_with_state(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), @@ -374,7 +375,7 @@ def test_rp_initiated_logout_get_id_token_redirect_with_state(logged_in_client, assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get_id_token_missmatch_client_id( logged_in_client, oidc_tokens, public_application, rp_settings ): @@ -386,7 +387,7 @@ def test_rp_initiated_logout_get_id_token_missmatch_client_id( assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_public_client_redirect_client_id( logged_in_client, oidc_non_confidential_tokens, public_application, rp_settings ): @@ -402,7 +403,7 @@ def test_rp_initiated_logout_public_client_redirect_client_id( assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_public_client_strict_redirect_client_id( logged_in_client, oidc_non_confidential_tokens, public_application, oauth2_settings ): @@ -419,7 +420,7 @@ def test_rp_initiated_logout_public_client_strict_redirect_client_id( assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get_client_id(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), data={"client_id": oidc_tokens.application.client_id} @@ -428,7 +429,7 @@ def test_rp_initiated_logout_get_client_id(logged_in_client, oidc_tokens, rp_set assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_post(logged_in_client, oidc_tokens, rp_settings): form_data = { "client_id": oidc_tokens.application.client_id, @@ -438,7 +439,7 @@ def test_rp_initiated_logout_post(logged_in_client, oidc_tokens, rp_settings): assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_post_allowed(logged_in_client, oidc_tokens, rp_settings): form_data = {"client_id": oidc_tokens.application.client_id, "allow": True} rsp = logged_in_client.post(reverse("oauth2_provider:rp-initiated-logout"), form_data) @@ -447,7 +448,7 @@ def test_rp_initiated_logout_post_allowed(logged_in_client, oidc_tokens, rp_sett assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_post_no_session(client, oidc_tokens, rp_settings): form_data = {"client_id": oidc_tokens.application.client_id, "allow": True} rsp = client.post(reverse("oauth2_provider:rp-initiated-logout"), form_data) @@ -456,7 +457,7 @@ def test_rp_initiated_logout_post_no_session(client, oidc_tokens, rp_settings): assert not is_logged_in(client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application, expired_id_token): # Accepting expired (but otherwise valid and signed by us) tokens is enabled. Logout should go through. @@ -471,7 +472,7 @@ def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED) def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application, expired_id_token): # Expired tokens should not be accepted by default. @@ -486,14 +487,14 @@ def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application, assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_load_id_token_accept_expired(expired_id_token): id_token, _ = _load_id_token(expired_id_token) assert isinstance(id_token, get_id_token_model()) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_load_id_token_wrong_aud(id_token_wrong_aud): id_token, claims = _load_id_token(id_token_wrong_aud) @@ -501,7 +502,7 @@ def test_load_id_token_wrong_aud(id_token_wrong_aud): assert claims is None -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED) def test_load_id_token_deny_expired(expired_id_token): id_token, claims = _load_id_token(expired_id_token) @@ -509,7 +510,7 @@ def test_load_id_token_deny_expired(expired_id_token): assert claims is None -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_validate_claims_wrong_iss(id_token_wrong_iss): id_token, claims = _load_id_token(id_token_wrong_iss) @@ -518,7 +519,7 @@ def test_validate_claims_wrong_iss(id_token_wrong_iss): assert not _validate_claims(mock_request(), claims) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_validate_claims(oidc_tokens): id_token, claims = _load_id_token(oidc_tokens.id_token) @@ -526,7 +527,7 @@ def test_validate_claims(oidc_tokens): assert _validate_claims(mock_request_for(oidc_tokens.user), claims) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.parametrize("method", ["get", "post"]) def test_userinfo_endpoint(oidc_tokens, client, method): auth_header = "Bearer %s" % oidc_tokens.access_token @@ -539,7 +540,7 @@ def test_userinfo_endpoint(oidc_tokens, client, method): assert data["sub"] == str(oidc_tokens.user.pk) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_userinfo_endpoint_bad_token(oidc_tokens, client): # No access token rsp = client.get(reverse("oauth2_provider:user-info")) @@ -552,7 +553,7 @@ def test_userinfo_endpoint_bad_token(oidc_tokens, client): assert rsp.status_code == 401 -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_token_deletion_on_logout(oidc_tokens, logged_in_client, rp_settings): AccessToken = get_access_token_model() IDToken = get_id_token_model() @@ -575,7 +576,7 @@ def test_token_deletion_on_logout(oidc_tokens, logged_in_client, rp_settings): assert all([token.revoked <= timezone.now() for token in RefreshToken.objects.all()]) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_token_deletion_on_logout_expired_session(oidc_tokens, client, rp_settings): AccessToken = get_access_token_model() IDToken = get_id_token_model() @@ -616,7 +617,7 @@ def test_token_deletion_on_logout_expired_session(oidc_tokens, client, rp_settin assert all(token.revoked <= timezone.now() for token in RefreshToken.objects.all()) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_KEEP_TOKENS) def test_token_deletion_on_logout_disabled(oidc_tokens, logged_in_client, rp_settings): rp_settings.OIDC_RP_INITIATED_LOGOUT_DELETE_TOKENS = False @@ -652,7 +653,7 @@ def claim_user_email(request): return EXAMPLE_EMAIL -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_userinfo_endpoint_custom_claims_callable(oidc_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): oidc_claim_scope = None @@ -680,7 +681,7 @@ def get_additional_claims(self): assert data["email"] == EXAMPLE_EMAIL -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_userinfo_endpoint_custom_claims_email_scope_callable( oidc_email_scope_tokens, client, oauth2_settings ): @@ -707,7 +708,7 @@ def get_additional_claims(self): assert data["email"] == EXAMPLE_EMAIL -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_userinfo_endpoint_custom_claims_plain(oidc_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): oidc_claim_scope = None @@ -735,7 +736,7 @@ def get_additional_claims(self, request): assert data["email"] == EXAMPLE_EMAIL -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_userinfo_endpoint_custom_claims_email_scopeplain(oidc_email_scope_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): def get_additional_claims(self, request):