From 4b7e67f184eb967b84d3d45155ad087f30e80781 Mon Sep 17 00:00:00 2001 From: Adam Stankiewicz Date: Wed, 24 Jul 2024 12:48:03 -0400 Subject: [PATCH 1/3] fix: increase the default throttle rate for enterprise users --- course_discovery/apps/api/tests/jwt_utils.py | 11 +++-- .../apps/core/tests/test_throttles.py | 43 ++++++++++++++++--- course_discovery/apps/core/throttles.py | 21 ++++++++- 3 files changed, 63 insertions(+), 12 deletions(-) diff --git a/course_discovery/apps/api/tests/jwt_utils.py b/course_discovery/apps/api/tests/jwt_utils.py index 4382c5288c..60f26bb189 100644 --- a/course_discovery/apps/api/tests/jwt_utils.py +++ b/course_discovery/apps/api/tests/jwt_utils.py @@ -5,11 +5,11 @@ from django.conf import settings -def generate_jwt_payload(user): +def generate_jwt_payload(user, payload=None): """Generate a valid JWT payload given a user.""" now = int(time()) ttl = 5 - return { + jwt_payload = { 'iss': settings.JWT_AUTH['JWT_ISSUER'], 'aud': settings.JWT_AUTH['JWT_AUDIENCE'], 'username': user.username, @@ -17,6 +17,9 @@ def generate_jwt_payload(user): 'iat': now, 'exp': now + ttl } + if payload: + jwt_payload.update(payload) + return jwt_payload def generate_jwt_token(payload): @@ -29,8 +32,8 @@ def generate_jwt_header(token): return f'JWT {token}' -def generate_jwt_header_for_user(user): - payload = generate_jwt_payload(user) +def generate_jwt_header_for_user(user, payload=None): + payload = generate_jwt_payload(user, payload) token = generate_jwt_token(payload) return generate_jwt_header(token) diff --git a/course_discovery/apps/core/tests/test_throttles.py b/course_discovery/apps/core/tests/test_throttles.py index f76c703399..7c3b662c9d 100644 --- a/course_discovery/apps/core/tests/test_throttles.py +++ b/course_discovery/apps/core/tests/test_throttles.py @@ -1,8 +1,10 @@ from unittest.mock import patch +import ddt from django.urls import reverse from rest_framework.test import APITestCase +from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user from course_discovery.apps.api.tests.mixins import SiteMixin from course_discovery.apps.core.models import UserThrottleRate from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory @@ -10,6 +12,7 @@ from course_discovery.apps.publisher.tests.factories import GroupFactory +@ddt.ddt class RateLimitingExceededTest(SiteMixin, APITestCase): """ Testing rate limiting of API calls. @@ -26,7 +29,15 @@ def tearDown(self): super().tearDown() throttling_cache().clear() - def _make_requests(self, count=None): + def _build_jwt_headers(self, user, payload=None): + """ + Helper function for creating headers for the JWT authentication. + """ + token = generate_jwt_header_for_user(user, payload) + headers = {'HTTP_AUTHORIZATION': token} + return headers + + def _make_requests(self, count=None, **headers): """ Make multiple requests until the throttle's limit is exceeded. Returns @@ -37,19 +48,19 @@ def _make_requests(self, count=None): with patch('rest_framework.views.APIView.throttle_classes', (OverridableUserRateThrottle,)): with patch.object(OverridableUserRateThrottle, 'THROTTLE_RATES', user_rates): for __ in range(count - 1): - response = self.client.get(self.url) + response = self.client.get(self.url, **headers) assert response.status_code == 200 - response = self.client.get(self.url) + response = self.client.get(self.url, **headers) return response - def assert_rate_limit_successfully_exceeded(self): + def assert_rate_limit_successfully_exceeded(self, **headers): """ Asserts that the throttle's rate limit can be exceeded without encountering an error. """ - response = self._make_requests() + response = self._make_requests(**headers) assert response.status_code == 200 - def assert_rate_limited(self, count=None): + def assert_rate_limited(self, count=None, **headers): """ Asserts that the throttle's rate limit was exceeded and we were denied. """ - response = self._make_requests(count) + response = self._make_requests(count, **headers) assert response.status_code == 429 def test_rate_limiting(self): @@ -84,3 +95,21 @@ def test_staff_with_user_throttle_rate(self): self.user.save() UserThrottleRate.objects.create(user=self.user, rate='10/hour') self.assert_rate_limited(11) + + @ddt.data( + ([], True), + (['enterprise_learner:*'], False), + (['enterprise_admin:*'], False), + (['enterprise_openedx_operator:*'], False), + ) + @ddt.unpack + def test_enterprise_user_throttling_with_jwt_authentication(self, jwt_roles, is_rate_limited): + """ Verify enterprise users are throttled at a higher rate. """ + payload = { + 'roles': jwt_roles, + } + headers = self._build_jwt_headers(self.user, payload) + if is_rate_limited: + self.assert_rate_limited(**headers) + else: + self.assert_rate_limit_successfully_exceeded(**headers) diff --git a/course_discovery/apps/core/throttles.py b/course_discovery/apps/core/throttles.py index c638bb3739..0f3f9a9685 100644 --- a/course_discovery/apps/core/throttles.py +++ b/course_discovery/apps/core/throttles.py @@ -1,5 +1,6 @@ """Custom API throttles.""" from django.core.cache import InvalidCacheBackendError, caches +from edx_rest_framework_extensions.auth.jwt.decoder import configured_jwt_decode_handler from rest_framework.throttling import UserRateThrottle from course_discovery.apps.core.models import UserThrottleRate @@ -16,6 +17,19 @@ def throttling_cache(): return caches['default'] +def is_enterprise_user(request): + """ + Determine whether a JWT-authenticated user is an enterprise user based on the `roles` in + the decoded JWT token associated with the request (e.g., `enterprise_learner`). + """ + jwt_token = request.auth + if not jwt_token: + return False + decoded_jwt = configured_jwt_decode_handler(jwt_token) + roles = decoded_jwt.get('roles', []) + return any('enterprise' in role for role in roles) + + class OverridableUserRateThrottle(UserRateThrottle): """Rate throttling of requests, overridable on a per-user basis.""" cache = throttling_cache() @@ -28,10 +42,15 @@ def allow_request(self, request, view): # Override this throttle's rate if applicable user_throttle = UserThrottleRate.objects.get(user=user) self.rate = user_throttle.rate - self.num_requests, self.duration = self.parse_rate(self.rate) except UserThrottleRate.DoesNotExist: # If we don't have a custom user override, skip throttling if they are a privileged user if user.is_superuser or user.is_staff or is_publisher_user(user): return True + # If the user is not a privileged user, increase throttling rate if they are an enterprise user + if is_enterprise_user(request): + self.rate = '300/hour' + + self.num_requests, self.duration = self.parse_rate(self.rate) + return super().allow_request(request, view) From fa15a59bb5c9c48fbe2329022afe8e8bc221fcaa Mon Sep 17 00:00:00 2001 From: Adam Stankiewicz Date: Wed, 24 Jul 2024 17:48:52 -0400 Subject: [PATCH 2/3] chore: force re-run of CI From 39117c2da3ff432391bb35180627b3cb50de3b57 Mon Sep 17 00:00:00 2001 From: Adam Stankiewicz Date: Wed, 24 Jul 2024 17:49:50 -0400 Subject: [PATCH 3/3] chore: increase throttle rate for enterprise users --- course_discovery/apps/core/throttles.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/course_discovery/apps/core/throttles.py b/course_discovery/apps/core/throttles.py index 0f3f9a9685..424e36f5fd 100644 --- a/course_discovery/apps/core/throttles.py +++ b/course_discovery/apps/core/throttles.py @@ -49,7 +49,7 @@ def allow_request(self, request, view): # If the user is not a privileged user, increase throttling rate if they are an enterprise user if is_enterprise_user(request): - self.rate = '300/hour' + self.rate = '600/hour' self.num_requests, self.duration = self.parse_rate(self.rate)