Skip to content

Commit

Permalink
Use retrieven_current_databases in django_db marked tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
shaleh committed Aug 16, 2024
1 parent 8072cc7 commit 86b2519
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 41 deletions.
5 changes: 3 additions & 2 deletions tests/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from . import presets
from .common_testing import OAuth2ProviderTestCase as TestCase
from .common_testing import retrieve_current_databases
from .utils import get_basic_auth_header, spy_on


Expand Down Expand Up @@ -1319,7 +1320,7 @@ def test_pre_auth_default_scopes(self):
self.assertEqual(form["client_id"].value(), self.application.client_id)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_application, client, oidc_key):
client.force_login(test_user)
Expand Down Expand Up @@ -1368,7 +1369,7 @@ def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_app
assert claims["nonce"] == "random_nonce_string"


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
def test_claims_passed_to_code_generation(
oauth2_settings, test_user, hybrid_application, client, mocker, oidc_key
Expand Down
13 changes: 7 additions & 6 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from . import presets
from .common_testing import OAuth2ProviderTestCase as TestCase
from .common_testing import retrieve_current_databases


CLEARTEXT_SECRET = "1234567890abcdefghijklmnopqrstuvwxyz"
Expand Down Expand Up @@ -466,7 +467,7 @@ def test_clear_expired_tokens_with_tokens(self):
assert remaining_gt_count == initial_gt_count // 2, "half the remaining grants should still exist."


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
def test_id_token_methods(oidc_tokens, rf):
id_token = IDToken.objects.get()
Expand Down Expand Up @@ -501,7 +502,7 @@ def test_id_token_methods(oidc_tokens, rf):
assert IDToken.objects.filter(jti=id_token.jti).count() == 0


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf):
id_token = IDToken.objects.get()
Expand Down Expand Up @@ -540,7 +541,7 @@ def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf):
assert not IDToken.objects.filter(jti=id_token.jti).exists()


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
def test_application_key(oauth2_settings, application):
# RS256 key
Expand All @@ -565,7 +566,7 @@ def test_application_key(oauth2_settings, application):
assert "This application does not support signed tokens" == str(exc.value)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
def test_application_clean(oauth2_settings, application):
# RS256, RSA key is configured
Expand Down Expand Up @@ -605,15 +606,15 @@ def test_application_clean(oauth2_settings, application):
application.clean()


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_DEFAULT)
def test_application_origin_allowed_default_https(oauth2_settings, cors_application):
"""Test that http schemes are not allowed because ALLOWED_SCHEMES allows only https"""
assert cors_application.origin_allowed("https://example.com")
assert not cors_application.origin_allowed("http://example.com")


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_HTTP)
def test_application_origin_allowed_http(oauth2_settings, cors_application):
"""Test that http schemes are allowed because http was added to ALLOWED_SCHEMES"""
Expand Down
7 changes: 4 additions & 3 deletions tests/test_oauth2_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from . import presets
from .common_testing import OAuth2ProviderTestCase as TestCase
from .common_testing import OAuth2ProviderTransactionTestCase as TransactionTestCase
from .common_testing import retrieve_current_databases
from .utils import get_basic_auth_header


Expand Down Expand Up @@ -546,7 +547,7 @@ def test_get_jwt_bearer_token(oauth2_settings, mocker):
assert mock_get_id_token.call_args[1] == {}


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
def test_validate_id_token_expired_jwt(oauth2_settings, mocker, oidc_tokens):
mocker.patch("oauth2_provider.oauth2_validators.jwt.JWT", side_effect=jwt.JWTExpired)
Expand All @@ -562,7 +563,7 @@ def test_validate_id_token_no_token(oauth2_settings, mocker):
assert status is False


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens):
oidc_tokens.application.delete()
Expand All @@ -571,7 +572,7 @@ def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens):
assert status is False


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
def test_validate_id_token_bad_token_no_aud(oauth2_settings, mocker, oidc_key):
token = jwt.JWT(header=json.dumps({"alg": "RS256"}), claims=json.dumps({"bad": "token"}))
Expand Down
61 changes: 31 additions & 30 deletions tests/test_oidc_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from . import presets
from .common_testing import OAuth2ProviderTestCase as TestCase
from .common_testing import retrieve_current_databases


@pytest.mark.usefixtures("oauth2_settings")
Expand Down Expand Up @@ -221,7 +222,7 @@ def mock_request_for(user):
return request


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_validate_logout_request(oidc_tokens, public_application, rp_settings):
oidc_tokens = oidc_tokens
application = oidc_tokens.application
Expand Down Expand Up @@ -299,7 +300,7 @@ def test_validate_logout_request(oidc_tokens, public_application, rp_settings):
)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.parametrize("ALWAYS_PROMPT", [True, False])
def test_must_prompt(oidc_tokens, other_user, rp_settings, ALWAYS_PROMPT):
rp_settings.OIDC_RP_INITIATED_LOGOUT_ALWAYS_PROMPT = ALWAYS_PROMPT
Expand All @@ -320,14 +321,14 @@ def is_logged_in(client):
return get_user(client).is_authenticated


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_rp_initiated_logout_get(logged_in_client, rp_settings):
rsp = logged_in_client.get(reverse("oauth2_provider:rp-initiated-logout"), data={})
assert rsp.status_code == 200
assert is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_rp_initiated_logout_get_id_token(logged_in_client, oidc_tokens, rp_settings):
rsp = logged_in_client.get(
reverse("oauth2_provider:rp-initiated-logout"), data={"id_token_hint": oidc_tokens.id_token}
Expand All @@ -337,7 +338,7 @@ def test_rp_initiated_logout_get_id_token(logged_in_client, oidc_tokens, rp_sett
assert not is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_rp_initiated_logout_get_revoked_id_token(logged_in_client, oidc_tokens, rp_settings):
validator = oauth2_settings.OAUTH2_VALIDATOR_CLASS()
validator._load_id_token(oidc_tokens.id_token).revoke()
Expand All @@ -348,7 +349,7 @@ def test_rp_initiated_logout_get_revoked_id_token(logged_in_client, oidc_tokens,
assert is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_rp_initiated_logout_get_id_token_redirect(logged_in_client, oidc_tokens, rp_settings):
rsp = logged_in_client.get(
reverse("oauth2_provider:rp-initiated-logout"),
Expand All @@ -359,7 +360,7 @@ def test_rp_initiated_logout_get_id_token_redirect(logged_in_client, oidc_tokens
assert not is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_rp_initiated_logout_get_id_token_redirect_with_state(logged_in_client, oidc_tokens, rp_settings):
rsp = logged_in_client.get(
reverse("oauth2_provider:rp-initiated-logout"),
Expand All @@ -374,7 +375,7 @@ def test_rp_initiated_logout_get_id_token_redirect_with_state(logged_in_client,
assert not is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_rp_initiated_logout_get_id_token_missmatch_client_id(
logged_in_client, oidc_tokens, public_application, rp_settings
):
Expand All @@ -386,7 +387,7 @@ def test_rp_initiated_logout_get_id_token_missmatch_client_id(
assert is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_rp_initiated_logout_public_client_redirect_client_id(
logged_in_client, oidc_non_confidential_tokens, public_application, rp_settings
):
Expand All @@ -402,7 +403,7 @@ def test_rp_initiated_logout_public_client_redirect_client_id(
assert not is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_rp_initiated_logout_public_client_strict_redirect_client_id(
logged_in_client, oidc_non_confidential_tokens, public_application, oauth2_settings
):
Expand All @@ -419,7 +420,7 @@ def test_rp_initiated_logout_public_client_strict_redirect_client_id(
assert is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_rp_initiated_logout_get_client_id(logged_in_client, oidc_tokens, rp_settings):
rsp = logged_in_client.get(
reverse("oauth2_provider:rp-initiated-logout"), data={"client_id": oidc_tokens.application.client_id}
Expand All @@ -428,7 +429,7 @@ def test_rp_initiated_logout_get_client_id(logged_in_client, oidc_tokens, rp_set
assert is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_rp_initiated_logout_post(logged_in_client, oidc_tokens, rp_settings):
form_data = {
"client_id": oidc_tokens.application.client_id,
Expand All @@ -438,7 +439,7 @@ def test_rp_initiated_logout_post(logged_in_client, oidc_tokens, rp_settings):
assert is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_rp_initiated_logout_post_allowed(logged_in_client, oidc_tokens, rp_settings):
form_data = {"client_id": oidc_tokens.application.client_id, "allow": True}
rsp = logged_in_client.post(reverse("oauth2_provider:rp-initiated-logout"), form_data)
Expand All @@ -447,7 +448,7 @@ def test_rp_initiated_logout_post_allowed(logged_in_client, oidc_tokens, rp_sett
assert not is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_rp_initiated_logout_post_no_session(client, oidc_tokens, rp_settings):
form_data = {"client_id": oidc_tokens.application.client_id, "allow": True}
rsp = client.post(reverse("oauth2_provider:rp-initiated-logout"), form_data)
Expand All @@ -456,7 +457,7 @@ def test_rp_initiated_logout_post_no_session(client, oidc_tokens, rp_settings):
assert not is_logged_in(client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT)
def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application, expired_id_token):
# Accepting expired (but otherwise valid and signed by us) tokens is enabled. Logout should go through.
Expand All @@ -471,7 +472,7 @@ def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application
assert not is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED)
def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application, expired_id_token):
# Expired tokens should not be accepted by default.
Expand All @@ -486,30 +487,30 @@ def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application,
assert is_logged_in(logged_in_client)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT)
def test_load_id_token_accept_expired(expired_id_token):
id_token, _ = _load_id_token(expired_id_token)
assert isinstance(id_token, get_id_token_model())


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT)
def test_load_id_token_wrong_aud(id_token_wrong_aud):
id_token, claims = _load_id_token(id_token_wrong_aud)
assert id_token is None
assert claims is None


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED)
def test_load_id_token_deny_expired(expired_id_token):
id_token, claims = _load_id_token(expired_id_token)
assert id_token is None
assert claims is None


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT)
def test_validate_claims_wrong_iss(id_token_wrong_iss):
id_token, claims = _load_id_token(id_token_wrong_iss)
Expand All @@ -518,15 +519,15 @@ def test_validate_claims_wrong_iss(id_token_wrong_iss):
assert not _validate_claims(mock_request(), claims)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT)
def test_validate_claims(oidc_tokens):
id_token, claims = _load_id_token(oidc_tokens.id_token)
assert claims is not None
assert _validate_claims(mock_request_for(oidc_tokens.user), claims)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.parametrize("method", ["get", "post"])
def test_userinfo_endpoint(oidc_tokens, client, method):
auth_header = "Bearer %s" % oidc_tokens.access_token
Expand All @@ -539,7 +540,7 @@ def test_userinfo_endpoint(oidc_tokens, client, method):
assert data["sub"] == str(oidc_tokens.user.pk)


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_userinfo_endpoint_bad_token(oidc_tokens, client):
# No access token
rsp = client.get(reverse("oauth2_provider:user-info"))
Expand All @@ -552,7 +553,7 @@ def test_userinfo_endpoint_bad_token(oidc_tokens, client):
assert rsp.status_code == 401


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_token_deletion_on_logout(oidc_tokens, logged_in_client, rp_settings):
AccessToken = get_access_token_model()
IDToken = get_id_token_model()
Expand All @@ -575,7 +576,7 @@ def test_token_deletion_on_logout(oidc_tokens, logged_in_client, rp_settings):
assert all([token.revoked <= timezone.now() for token in RefreshToken.objects.all()])


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_token_deletion_on_logout_expired_session(oidc_tokens, client, rp_settings):
AccessToken = get_access_token_model()
IDToken = get_id_token_model()
Expand Down Expand Up @@ -616,7 +617,7 @@ def test_token_deletion_on_logout_expired_session(oidc_tokens, client, rp_settin
assert all(token.revoked <= timezone.now() for token in RefreshToken.objects.all())


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_KEEP_TOKENS)
def test_token_deletion_on_logout_disabled(oidc_tokens, logged_in_client, rp_settings):
rp_settings.OIDC_RP_INITIATED_LOGOUT_DELETE_TOKENS = False
Expand Down Expand Up @@ -652,7 +653,7 @@ def claim_user_email(request):
return EXAMPLE_EMAIL


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_userinfo_endpoint_custom_claims_callable(oidc_tokens, client, oauth2_settings):
class CustomValidator(OAuth2Validator):
oidc_claim_scope = None
Expand Down Expand Up @@ -680,7 +681,7 @@ def get_additional_claims(self):
assert data["email"] == EXAMPLE_EMAIL


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_userinfo_endpoint_custom_claims_email_scope_callable(
oidc_email_scope_tokens, client, oauth2_settings
):
Expand All @@ -707,7 +708,7 @@ def get_additional_claims(self):
assert data["email"] == EXAMPLE_EMAIL


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_userinfo_endpoint_custom_claims_plain(oidc_tokens, client, oauth2_settings):
class CustomValidator(OAuth2Validator):
oidc_claim_scope = None
Expand Down Expand Up @@ -735,7 +736,7 @@ def get_additional_claims(self, request):
assert data["email"] == EXAMPLE_EMAIL


@pytest.mark.django_db
@pytest.mark.django_db(databases=retrieve_current_databases())
def test_userinfo_endpoint_custom_claims_email_scopeplain(oidc_email_scope_tokens, client, oauth2_settings):
class CustomValidator(OAuth2Validator):
def get_additional_claims(self, request):
Expand Down

0 comments on commit 86b2519

Please sign in to comment.