Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: increase the default throttle rate for enterprise users (POC) #4394

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions course_discovery/apps/api/tests/jwt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@
from django.conf import settings


def generate_jwt_payload(user):
def generate_jwt_payload(user, payload=None):
"""Generate a valid JWT payload given a user."""
now = int(time())
ttl = 5
return {
jwt_payload = {
'iss': settings.JWT_AUTH['JWT_ISSUER'],
'aud': settings.JWT_AUTH['JWT_AUDIENCE'],
'username': user.username,
'email': user.email,
'iat': now,
'exp': now + ttl
}
if payload:
jwt_payload.update(payload)
return jwt_payload


def generate_jwt_token(payload):
Expand All @@ -29,8 +32,8 @@ def generate_jwt_header(token):
return f'JWT {token}'


def generate_jwt_header_for_user(user):
payload = generate_jwt_payload(user)
def generate_jwt_header_for_user(user, payload=None):
payload = generate_jwt_payload(user, payload)
token = generate_jwt_token(payload)

return generate_jwt_header(token)
43 changes: 36 additions & 7 deletions course_discovery/apps/core/tests/test_throttles.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from unittest.mock import patch

import ddt
from django.urls import reverse
from rest_framework.test import APITestCase

from course_discovery.apps.api.tests.jwt_utils import generate_jwt_header_for_user
from course_discovery.apps.api.tests.mixins import SiteMixin
from course_discovery.apps.core.models import UserThrottleRate
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.throttles import OverridableUserRateThrottle, throttling_cache
from course_discovery.apps.publisher.tests.factories import GroupFactory


@ddt.ddt
class RateLimitingExceededTest(SiteMixin, APITestCase):
"""
Testing rate limiting of API calls.
Expand All @@ -26,7 +29,15 @@ def tearDown(self):
super().tearDown()
throttling_cache().clear()

def _make_requests(self, count=None):
def _build_jwt_headers(self, user, payload=None):
"""
Helper function for creating headers for the JWT authentication.
"""
token = generate_jwt_header_for_user(user, payload)
headers = {'HTTP_AUTHORIZATION': token}
return headers

def _make_requests(self, count=None, **headers):
""" Make multiple requests until the throttle's limit is exceeded.

Returns
Expand All @@ -37,19 +48,19 @@ def _make_requests(self, count=None):
with patch('rest_framework.views.APIView.throttle_classes', (OverridableUserRateThrottle,)):
with patch.object(OverridableUserRateThrottle, 'THROTTLE_RATES', user_rates):
for __ in range(count - 1):
response = self.client.get(self.url)
response = self.client.get(self.url, **headers)
assert response.status_code == 200
response = self.client.get(self.url)
response = self.client.get(self.url, **headers)
return response

def assert_rate_limit_successfully_exceeded(self):
def assert_rate_limit_successfully_exceeded(self, **headers):
""" Asserts that the throttle's rate limit can be exceeded without encountering an error. """
response = self._make_requests()
response = self._make_requests(**headers)
assert response.status_code == 200

def assert_rate_limited(self, count=None):
def assert_rate_limited(self, count=None, **headers):
""" Asserts that the throttle's rate limit was exceeded and we were denied. """
response = self._make_requests(count)
response = self._make_requests(count, **headers)
assert response.status_code == 429

def test_rate_limiting(self):
Expand Down Expand Up @@ -84,3 +95,21 @@ def test_staff_with_user_throttle_rate(self):
self.user.save()
UserThrottleRate.objects.create(user=self.user, rate='10/hour')
self.assert_rate_limited(11)

@ddt.data(
([], True),
(['enterprise_learner:*'], False),
(['enterprise_admin:*'], False),
(['enterprise_openedx_operator:*'], False),
)
@ddt.unpack
def test_enterprise_user_throttling_with_jwt_authentication(self, jwt_roles, is_rate_limited):
""" Verify enterprise users are throttled at a higher rate. """
payload = {
'roles': jwt_roles,
}
headers = self._build_jwt_headers(self.user, payload)
if is_rate_limited:
self.assert_rate_limited(**headers)
else:
self.assert_rate_limit_successfully_exceeded(**headers)
21 changes: 20 additions & 1 deletion course_discovery/apps/core/throttles.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Custom API throttles."""
from django.core.cache import InvalidCacheBackendError, caches
from edx_rest_framework_extensions.auth.jwt.decoder import configured_jwt_decode_handler
from rest_framework.throttling import UserRateThrottle

from course_discovery.apps.core.models import UserThrottleRate
Expand All @@ -16,6 +17,19 @@ def throttling_cache():
return caches['default']


def is_enterprise_user(request):
"""
Determine whether a JWT-authenticated user is an enterprise user based on the `roles` in
the decoded JWT token associated with the request (e.g., `enterprise_learner`).
"""
jwt_token = request.auth
if not jwt_token:
return False
decoded_jwt = configured_jwt_decode_handler(jwt_token)
roles = decoded_jwt.get('roles', [])
return any('enterprise' in role for role in roles)


class OverridableUserRateThrottle(UserRateThrottle):
"""Rate throttling of requests, overridable on a per-user basis."""
cache = throttling_cache()
Expand All @@ -28,10 +42,15 @@ def allow_request(self, request, view):
# Override this throttle's rate if applicable
user_throttle = UserThrottleRate.objects.get(user=user)
self.rate = user_throttle.rate
self.num_requests, self.duration = self.parse_rate(self.rate)
except UserThrottleRate.DoesNotExist:
# If we don't have a custom user override, skip throttling if they are a privileged user
if user.is_superuser or user.is_staff or is_publisher_user(user):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[inform/context] We already remove the throttling based on whether the authenticated user is a "publisher user" (aka is a member of a Django Group). This POC is largely extending this to account for an is_enterprise_user check as well.

return True

# If the user is not a privileged user, increase throttling rate if they are an enterprise user
if is_enterprise_user(request):
self.rate = '600/hour'

self.num_requests, self.duration = self.parse_rate(self.rate)

return super().allow_request(request, view)
Loading