-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add tests for SwaggerAutoSchema * Add throttling
- Loading branch information
1 parent
f7c8b63
commit 2439b15
Showing
8 changed files
with
198 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import winter.http | ||
import winter | ||
|
||
@winter.http.throttling('5/s') | ||
@winter.route_get('with-throttling-on-controller/') | ||
class ControllerWithThrottlingOnController: | ||
|
||
@winter.route_get() | ||
def simple_method(self): | ||
return 1 | ||
|
||
@winter.http.throttling(None) | ||
@winter.route_post() | ||
def method_without_throttling(self): | ||
return None | ||
|
||
|
||
@winter.route_get('with-throttling-on-method/') | ||
class ControllerWithThrottlingOnMethod: | ||
|
||
@winter.route_get() | ||
@winter.http.throttling('5/s') | ||
def simple_method(self): | ||
return 1 | ||
|
||
@winter.route_get('without-throttling/') | ||
def method_without_throttling(self): | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import pytest | ||
from rest_framework.test import APIClient | ||
|
||
from .entities import AuthorizedUser | ||
|
||
|
||
def test_throttling(): | ||
client = APIClient() | ||
user = AuthorizedUser() | ||
client.force_authenticate(user) | ||
|
||
for i in range(1, 10): | ||
response = client.get('/with-throttling-on-controller/') | ||
if i > 5: | ||
assert response.status_code == 429, i | ||
else: | ||
assert response.status_code == 200, i | ||
|
||
|
||
def test_throttling_on_method(): | ||
client = APIClient() | ||
user = AuthorizedUser() | ||
client.force_authenticate(user) | ||
|
||
for i in range(1, 10): | ||
response = client.get('/with-throttling-on-method/') | ||
if i > 5: | ||
assert response.status_code == 429, i | ||
else: | ||
assert response.status_code == 200, i | ||
|
||
|
||
@pytest.mark.parametrize(('url', 'method'), ( | ||
('/with-throttling-on-method/without-throttling/', 'get'), | ||
('/with-throttling-on-controller/', 'post') | ||
)) | ||
def test_throttling_without_throttling(url, method): | ||
client = APIClient() | ||
user = AuthorizedUser() | ||
client.force_authenticate(user) | ||
|
||
for i in range(1, 10): | ||
client_method = getattr(client, method) | ||
response = client_method(url) | ||
assert response.status_code == 200, i |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .media_type import InvalidMediaTypeException | ||
from .media_type import MediaType | ||
from .urls import register_url_regexp | ||
from .throttling import throttling |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import time | ||
import typing | ||
|
||
import dataclasses | ||
from django.core.cache import cache as default_cache | ||
from rest_framework.throttling import BaseThrottle | ||
|
||
from ..core import Component | ||
from ..core import annotate | ||
|
||
if typing.TYPE_CHECKING: # pragma: no cover | ||
from ..routing import Route | ||
|
||
|
||
@dataclasses.dataclass | ||
class ThrottlingAnnotation: | ||
rate: typing.Optional[str] | ||
scope: typing.Optional[str] | ||
|
||
|
||
@dataclasses.dataclass | ||
class Throttling: | ||
num_requests: int | ||
duration: int | ||
scope: str | ||
|
||
|
||
def throttling(rate: typing.Optional[str], scope: typing.Optional[str] = None): | ||
return annotate(ThrottlingAnnotation(rate, scope), single=True) | ||
|
||
|
||
class BaseRateThrottle(BaseThrottle): | ||
throttling_by_http_method: typing.Dict[str, Throttling] = {} | ||
cache = default_cache | ||
cache_format = 'throttle_{scope}_{ident}' | ||
|
||
def allow_request(self, request, view) -> bool: | ||
throttling_ = self._get_throttling(request) | ||
|
||
if throttling_ is None: | ||
return True | ||
|
||
ident = self.get_ident(request) | ||
key = self._get_cache_key(throttling_.scope, ident) | ||
|
||
history = self.cache.get(key, []) | ||
now = time.time() | ||
|
||
while history and history[-1] <= now - throttling_.duration: | ||
history.pop() | ||
|
||
if len(history) >= throttling_.num_requests: | ||
return False | ||
|
||
history.insert(0, now) | ||
self.cache.set(key, history, throttling_.duration) | ||
return True | ||
|
||
def _get_cache_key(self, scope: str, ident: str) -> str: | ||
return self.cache_format.format(scope=scope, ident=ident) | ||
|
||
def get_ident(self, request) -> str: | ||
user_pk = request.user.pk if request.user.is_authenticated else None | ||
|
||
if user_pk is not None: | ||
return str(user_pk) | ||
|
||
return super().get_ident(request) | ||
|
||
|
||
def _get_throttling(self, request) -> typing.Optional[Throttling]: | ||
return self.throttling_by_http_method.get(request.method.lower()) | ||
|
||
|
||
def _parse_rate(rate: str) -> typing.Tuple[int, int]: | ||
""" | ||
Given the request rate string, return a two tuple of: | ||
<allowed number of requests>, <period of time in seconds> | ||
""" | ||
num, period = rate.split('/') | ||
num_requests = int(num) | ||
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] | ||
return num_requests, duration | ||
|
||
|
||
def create_throttle_classes( | ||
component: Component, | ||
routes: typing.List['Route'], | ||
) -> typing.Tuple[typing.Type[BaseRateThrottle], ...]: | ||
base_throttling_annotation = component.annotations.get_one_or_none(ThrottlingAnnotation) | ||
throttling_by_http_method_: typing.Dict[str, typing.Optional[Throttling]] = {} | ||
|
||
for route in routes: | ||
|
||
throttling_annotation = route.method.annotations.get_one_or_none(ThrottlingAnnotation) | ||
|
||
if throttling_annotation is None: | ||
throttling_annotation = base_throttling_annotation | ||
|
||
if throttling_annotation is not None and throttling_annotation.rate is not None: | ||
num_requests, duration = _parse_rate(throttling_annotation.rate) | ||
throttling_ = Throttling(num_requests, duration, throttling_annotation.scope) | ||
throttling_by_http_method_[route.http_method.lower()] = throttling_ | ||
|
||
if not throttling_by_http_method_: | ||
return () | ||
|
||
class RateThrottle(BaseRateThrottle): | ||
throttling_by_http_method = throttling_by_http_method_ | ||
|
||
return (RateThrottle,) |