diff --git a/oauth2_provider/migrations/0011_refreshtoken_token_family.py b/oauth2_provider/migrations/0011_refreshtoken_token_family.py new file mode 100644 index 000000000..082f444de --- /dev/null +++ b/oauth2_provider/migrations/0011_refreshtoken_token_family.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2 on 2024-08-09 16:40 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('oauth2_provider', '0010_application_allowed_origins'), + ] + + operations = [ + migrations.AddField( + model_name='refreshtoken', + name='token_family', + field=models.UUIDField(blank=True, editable=False, null=True), + ), + ] diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 661bd7dfc..9895528de 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -490,6 +490,7 @@ class AbstractRefreshToken(models.Model): null=True, related_name="refresh_token", ) + token_family = models.UUIDField(null=True, blank=True, editable=False) created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 47d65e851..d1cb8b9b6 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -15,7 +15,6 @@ from django.contrib.auth.hashers import check_password, identify_hasher from django.core.exceptions import ObjectDoesNotExist from django.db import transaction -from django.db.models import Q from django.http import HttpRequest from django.utils import dateformat, timezone from django.utils.crypto import constant_time_compare @@ -644,7 +643,9 @@ def save_bearer_token(self, token, request, *args, **kwargs): source_refresh_token=refresh_token_instance, ) - self._create_refresh_token(request, refresh_token_code, access_token) + self._create_refresh_token( + request, refresh_token_code, access_token, refresh_token_instance + ) else: # make sure that the token data we're returning matches # the existing token @@ -688,9 +689,17 @@ def _create_authorization_code(self, request, code, expires=None): claims=json.dumps(request.claims or {}), ) - def _create_refresh_token(self, request, refresh_token_code, access_token): + def _create_refresh_token(self, request, refresh_token_code, access_token, previous_refresh_token): + if previous_refresh_token: + token_family = previous_refresh_token.token_family + else: + token_family = uuid.uuid4() return RefreshToken.objects.create( - user=request.user, token=refresh_token_code, application=request.client, access_token=access_token + user=request.user, + token=refresh_token_code, + application=request.client, + access_token=access_token, + token_family=token_family, ) def revoke_token(self, token, token_type_hint, request, *args, **kwargs): @@ -752,22 +761,25 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs Also attach User instance to the request object """ - null_or_recent = Q(revoked__isnull=True) | Q( - revoked__gt=timezone.now() - timedelta(seconds=oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS) - ) - rt = ( - RefreshToken.objects.filter(null_or_recent, token=refresh_token) - .select_related("access_token") - .first() - ) + rt = RefreshToken.objects.filter(token=refresh_token).select_related("access_token").first() if not rt: return False + if rt.revoked is not None and rt.revoked <= timezone.now() - timedelta( + seconds=oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS + ): + if oauth2_settings.REFRESH_TOKEN_REUSE_PROTECTION and rt.token_family: + rt_token_family = RefreshToken.objects.filter(token_family=rt.token_family) + for related_rt in rt_token_family.all(): + related_rt.revoke() + return False + request.user = rt.user request.refresh_token = rt.token # Temporary store RefreshToken instance to be reused by get_original_scopes and save_bearer_token. request.refresh_token_instance = rt + return rt.application == client @transaction.atomic diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index 950ab5643..329a1b354 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -54,6 +54,7 @@ "ID_TOKEN_EXPIRE_SECONDS": 36000, "REFRESH_TOKEN_EXPIRE_SECONDS": None, "REFRESH_TOKEN_GRACE_PERIOD_SECONDS": 0, + "REFRESH_TOKEN_REUSE_PROTECTION": False, "ROTATE_REFRESH_TOKEN": True, "ERROR_RESPONSE_WITH_SCOPES": False, "APPLICATION_MODEL": APPLICATION_MODEL, diff --git a/tests/migrations/0006_basetestapplication_token_family.py b/tests/migrations/0006_basetestapplication_token_family.py new file mode 100644 index 000000000..1a4dd19e1 --- /dev/null +++ b/tests/migrations/0006_basetestapplication_token_family.py @@ -0,0 +1,43 @@ +# Generated by Django 5.2 on 2024-08-09 16:40 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('tests', '0005_basetestapplication_allowed_origins_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='samplerefreshtoken', + name='token_family', + field=models.UUIDField(blank=True, editable=False, null=True), + ), + migrations.AlterField( + model_name='basetestapplication', + name='allowed_origins', + field=models.TextField(blank=True, default='', help_text='Allowed origins list to enable CORS, space separated'), + ), + migrations.AlterField( + model_name='basetestapplication', + name='post_logout_redirect_uris', + field=models.TextField(blank=True, default='', help_text='Allowed Post Logout URIs list, space separated'), + ), + migrations.AlterField( + model_name='sampleaccesstoken', + name='token', + field=models.CharField(db_index=True, max_length=255, unique=True), + ), + migrations.AlterField( + model_name='sampleapplication', + name='allowed_origins', + field=models.TextField(blank=True, default='', help_text='Allowed origins list to enable CORS, space separated'), + ), + migrations.AlterField( + model_name='sampleapplication', + name='post_logout_redirect_uris', + field=models.TextField(blank=True, default='', help_text='Allowed Post Logout URIs list, space separated'), + ), + ] diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index fad55c124..ae6e7e76e 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -1013,6 +1013,7 @@ def test_refresh_repeating_requests_revokes_old_token(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } + # First response works as usual response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) new_tokens = json.loads(response.content.decode("utf-8")) @@ -1021,7 +1022,7 @@ def test_refresh_repeating_requests_revokes_old_token(self): response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) - # Previously returned tokens are now invalid + # Previously returned tokens are now invalid as well new_token_request_data = { "grant_type": "refresh_token", "refresh_token": new_tokens["refresh_token"],