Skip to content

Commit

Permalink
Add throttling (#59)
Browse files Browse the repository at this point in the history
* Add tests for SwaggerAutoSchema
* Add throttling
  • Loading branch information
andrey-berenda authored and mofr committed Feb 27, 2019
1 parent f7c8b63 commit 2439b15
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/controllers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@
from .controller_with_media_types_routing import ControllerWithMediaTypesRouting
from .controller_with_path_parameters import ControllerWithPathParameters
from .controller_with_serializer import ControllerWithSerializer
from .controller_with_throttling import ControllerWithThrottlingOnController
from .controller_with_throttling import ControllerWithThrottlingOnMethod
from .no_authentication_controller import NoAuthenticationController
from .simple_controller import SimpleController
28 changes: 28 additions & 0 deletions tests/controllers/controller_with_throttling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import winter.http
import winter

@winter.http.throttling('5/s')
@winter.route_get('with-throttling-on-controller/')
class ControllerWithThrottlingOnController:

@winter.route_get()
def simple_method(self):
return 1

@winter.http.throttling(None)
@winter.route_post()
def method_without_throttling(self):
return None


@winter.route_get('with-throttling-on-method/')
class ControllerWithThrottlingOnMethod:

@winter.route_get()
@winter.http.throttling('5/s')
def simple_method(self):
return 1

@winter.route_get('without-throttling/')
def method_without_throttling(self):
return None
7 changes: 7 additions & 0 deletions tests/entities/users.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import random
import uuid


class User:

@property
Expand All @@ -18,6 +22,9 @@ def is_anonymous(self):

class AuthorizedUser(User):

def __init__(self, pk=None):
self.pk = pk if pk is not None else uuid.uuid4()

@property
def is_authenticated(self):
return True
45 changes: 45 additions & 0 deletions tests/test_throttling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
from rest_framework.test import APIClient

from .entities import AuthorizedUser


def test_throttling():
client = APIClient()
user = AuthorizedUser()
client.force_authenticate(user)

for i in range(1, 10):
response = client.get('/with-throttling-on-controller/')
if i > 5:
assert response.status_code == 429, i
else:
assert response.status_code == 200, i


def test_throttling_on_method():
client = APIClient()
user = AuthorizedUser()
client.force_authenticate(user)

for i in range(1, 10):
response = client.get('/with-throttling-on-method/')
if i > 5:
assert response.status_code == 429, i
else:
assert response.status_code == 200, i


@pytest.mark.parametrize(('url', 'method'), (
('/with-throttling-on-method/without-throttling/', 'get'),
('/with-throttling-on-controller/', 'post')
))
def test_throttling_without_throttling(url, method):
client = APIClient()
user = AuthorizedUser()
client.force_authenticate(user)

for i in range(1, 10):
client_method = getattr(client, method)
response = client_method(url)
assert response.status_code == 200, i
2 changes: 2 additions & 0 deletions tests/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@
*winter.django.create_django_urls(controllers.ControllerWithMediaTypesRouting),
*winter.django.create_django_urls(controllers.ControllerWithPathParameters),
*winter.django.create_django_urls(controllers.ControllerWithSerializer),
*winter.django.create_django_urls(controllers.ControllerWithThrottlingOnController),
*winter.django.create_django_urls(controllers.ControllerWithThrottlingOnMethod),
]
2 changes: 2 additions & 0 deletions winter/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .exceptions import exceptions_handler
from .exceptions import get_throws
from .exceptions import handle_winter_exception
from .http.throttling import create_throttle_classes
from .http.urls import rewrite_uritemplate_with_regexps
from .output_processor import get_output_processor
from .response_entity import ResponseEntity
Expand Down Expand Up @@ -58,6 +59,7 @@ def _create_django_view(controller, component, routes: List[Route]):
class WinterView(rest_framework.views.APIView):
authentication_classes = (SessionAuthentication,)
permission_classes = (IsAuthenticated,) if is_authentication_needed(component) else ()
throttle_classes = create_throttle_classes(component, routes)
swagger_schema = SwaggerAutoSchema

for route in routes:
Expand Down
1 change: 1 addition & 0 deletions winter/http/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .media_type import InvalidMediaTypeException
from .media_type import MediaType
from .urls import register_url_regexp
from .throttling import throttling
111 changes: 111 additions & 0 deletions winter/http/throttling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import time
import typing

import dataclasses
from django.core.cache import cache as default_cache
from rest_framework.throttling import BaseThrottle

from ..core import Component
from ..core import annotate

if typing.TYPE_CHECKING: # pragma: no cover
from ..routing import Route


@dataclasses.dataclass
class ThrottlingAnnotation:
rate: typing.Optional[str]
scope: typing.Optional[str]


@dataclasses.dataclass
class Throttling:
num_requests: int
duration: int
scope: str


def throttling(rate: typing.Optional[str], scope: typing.Optional[str] = None):
return annotate(ThrottlingAnnotation(rate, scope), single=True)


class BaseRateThrottle(BaseThrottle):
throttling_by_http_method: typing.Dict[str, Throttling] = {}
cache = default_cache
cache_format = 'throttle_{scope}_{ident}'

def allow_request(self, request, view) -> bool:
throttling_ = self._get_throttling(request)

if throttling_ is None:
return True

ident = self.get_ident(request)
key = self._get_cache_key(throttling_.scope, ident)

history = self.cache.get(key, [])
now = time.time()

while history and history[-1] <= now - throttling_.duration:
history.pop()

if len(history) >= throttling_.num_requests:
return False

history.insert(0, now)
self.cache.set(key, history, throttling_.duration)
return True

def _get_cache_key(self, scope: str, ident: str) -> str:
return self.cache_format.format(scope=scope, ident=ident)

def get_ident(self, request) -> str:
user_pk = request.user.pk if request.user.is_authenticated else None

if user_pk is not None:
return str(user_pk)

return super().get_ident(request)


def _get_throttling(self, request) -> typing.Optional[Throttling]:
return self.throttling_by_http_method.get(request.method.lower())


def _parse_rate(rate: str) -> typing.Tuple[int, int]:
"""
Given the request rate string, return a two tuple of:
<allowed number of requests>, <period of time in seconds>
"""
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return num_requests, duration


def create_throttle_classes(
component: Component,
routes: typing.List['Route'],
) -> typing.Tuple[typing.Type[BaseRateThrottle], ...]:
base_throttling_annotation = component.annotations.get_one_or_none(ThrottlingAnnotation)
throttling_by_http_method_: typing.Dict[str, typing.Optional[Throttling]] = {}

for route in routes:

throttling_annotation = route.method.annotations.get_one_or_none(ThrottlingAnnotation)

if throttling_annotation is None:
throttling_annotation = base_throttling_annotation

if throttling_annotation is not None and throttling_annotation.rate is not None:
num_requests, duration = _parse_rate(throttling_annotation.rate)
throttling_ = Throttling(num_requests, duration, throttling_annotation.scope)
throttling_by_http_method_[route.http_method.lower()] = throttling_

if not throttling_by_http_method_:
return ()

class RateThrottle(BaseRateThrottle):
throttling_by_http_method = throttling_by_http_method_

return (RateThrottle,)

0 comments on commit 2439b15

Please sign in to comment.