Skip to content

Commit

Permalink
Replaced get_token_info calls with the request.auth.application (#…
Browse files Browse the repository at this point in the history
…3528)

* Replace get_token_info with request.auth.application

Co-authored-by: Kanishka Bansode <96020697+kb-0311@users.noreply.github.com>

* Remove usage of request.user for anon request checks

* Invert logic for checking authorization

* Update test to reflect that the request is anonymous

Previously the request in this test was considered not anonymous,
because `request` was None. With the update it is identified as
anonymous, and therefore the authority boost is not applied.

---------

Co-authored-by: sarayourfriend <git@sarayourfriend.pictures>
Co-authored-by: Staci Cooper <staci.cooper@automattic.com>
  • Loading branch information
3 people authored Feb 27, 2024
1 parent 157bb6c commit 2cb8b5e
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 97 deletions.
17 changes: 10 additions & 7 deletions api/api/middleware/response_headers_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from api.utils.oauth2_helper import get_token_info
from rest_framework.request import Request

from api.models.oauth import ThrottledApplication


def response_headers_middleware(get_response):
Expand All @@ -11,14 +13,15 @@ def response_headers_middleware(get_response):
to identify malicious requesters or request patterns.
"""

def middleware(request):
def middleware(request: Request):
response = get_response(request)

if hasattr(request, "auth") and request.auth:
token_info = get_token_info(str(request.auth))
if token_info:
response["x-ov-client-application-name"] = token_info.application_name
response["x-ov-client-application-verified"] = token_info.verified
if not (hasattr(request, "auth") and hasattr(request.auth, "application")):
return response

application: ThrottledApplication = request.auth.application
response["x-ov-client-application-name"] = application.name
response["x-ov-client-application-verified"] = application.verified

return response

Expand Down
4 changes: 2 additions & 2 deletions api/api/serializers/media_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class PaginatedRequestSerializer(serializers.Serializer):

def validate_page_size(self, value):
request = self.context.get("request")
is_anonymous = bool(request and request.user and request.user.is_anonymous)
is_anonymous = getattr(request, "auth", None) is None
max_value = (
settings.MAX_ANONYMOUS_PAGE_SIZE
if is_anonymous
Expand Down Expand Up @@ -247,7 +247,7 @@ class MediaSearchRequestSerializer(PaginatedRequestSerializer):

def is_request_anonymous(self):
request = self.context.get("request")
return bool(request and request.user and request.user.is_anonymous)
return getattr(request, "auth", None) is None

@staticmethod
def _truncate(value):
Expand Down
65 changes: 0 additions & 65 deletions api/api/utils/oauth2_helper.py

This file was deleted.

20 changes: 11 additions & 9 deletions api/api/utils/throttle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from redis.exceptions import ConnectionError

from api.utils.oauth2_helper import get_token_info


parent_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,8 +45,11 @@ def has_valid_token(self, request):
if not request.auth:
return False

token_info = get_token_info(str(request.auth))
return token_info and token_info.valid
application = getattr(request.auth, "application", None)
if application is None:
return False

return application.client_id and application.verified

def get_cache_key(self, request, view):
return self.cache_format % {
Expand Down Expand Up @@ -146,15 +147,16 @@ class AbstractOAuth2IdRateThrottle(SimpleRateThrottle, metaclass=abc.ABCMeta):

def get_cache_key(self, request, view):
# Find the client ID associated with the access token.
auth = str(request.auth)
token_info = get_token_info(auth)
if not (token_info and token_info.valid):
if not self.has_valid_token(request):
return None

if token_info.rate_limit_model not in self.applies_to_rate_limit_model:
# `self.has_valid_token` call earlier ensures accessing `application` will not fail
application = request.auth.application

if application.rate_limit_model not in self.applies_to_rate_limit_model:
return None

return self.cache_format % {"scope": self.scope, "ident": token_info.client_id}
return self.cache_format % {"scope": self.scope, "ident": application.client_id}


class OAuth2IdThumbnailRateThrottle(AbstractOAuth2IdRateThrottle):
Expand Down
20 changes: 7 additions & 13 deletions api/api/views/oauth2_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from django.conf import settings
from django.core.cache import cache
from django.core.mail import send_mail
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.views import APIView
Expand All @@ -22,7 +23,6 @@
OAuth2KeyInfoSerializer,
OAuth2RegistrationSerializer,
)
from api.utils.oauth2_helper import get_token_info
from api.utils.throttle import OnePerSecond, TenPerDay


Expand Down Expand Up @@ -169,7 +169,7 @@ class CheckRates(APIView):
throttle_classes = (OnePerSecond,)

@key_info
def get(self, request, format=None):
def get(self, request: Request, format=None):
"""
Get information about your API key.
Expand All @@ -181,23 +181,17 @@ def get(self, request, format=None):
"""

# TODO: Replace 403 responses with DRF `authentication_classes`.
if not request.auth:
if not request.auth or not hasattr(request.auth, "application"):
return Response(status=403, data="Forbidden")

access_token = str(request.auth)
token_info = get_token_info(access_token)
application: ThrottledApplication = request.auth.application

if not token_info:
# This shouldn't happen if `request.auth` was true above,
# but better safe than sorry
return Response(status=403, data="Forbidden")

client_id = token_info.client_id
client_id = application.client_id

if not client_id:
return Response(status=403, data="Forbidden")

throttle_type = token_info.rate_limit_model
throttle_type = application.rate_limit_model
throttle_key = "throttle_{scope}_{client_id}"
if throttle_type == "standard":
sustained_throttle_key = throttle_key.format(
Expand Down Expand Up @@ -242,7 +236,7 @@ def get(self, request, format=None):
"requests_this_minute": burst_requests,
"requests_today": sustained_requests,
"rate_limit_model": throttle_type,
"verified": token_info.verified,
"verified": application.verified,
}
)
return Response(status=status, data=response_data.data)
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def test_create_search_query_q_search_with_filters(image_media_type_config):
}
},
{"rank_feature": {"boost": 10000, "field": "standardized_popularity"}},
{"rank_feature": {"boost": 25000, "field": "authority_boost"}},
],
}

Expand Down
1 change: 1 addition & 0 deletions api/test/unit/utils/test_throttle.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def enable_throttles(settings):
def access_token():
token = AccessTokenFactory.create()
token.application.verified = True
token.application.client_id = 123
token.application.save()
return token

Expand Down

0 comments on commit 2cb8b5e

Please sign in to comment.