Skip to content

Commit

Permalink
Add page size and pagination depth privileges
Browse files Browse the repository at this point in the history
  • Loading branch information
sarayourfriend committed May 23, 2024
1 parent 010b494 commit 53b2423
Show file tree
Hide file tree
Showing 20 changed files with 499 additions and 236 deletions.
53 changes: 53 additions & 0 deletions api/api/admin/oauth.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,65 @@
from django import forms
from django.contrib import admin

from oauth2_provider.models import AccessToken

from api.constants.privilege import Privilege
from api.models.oauth import ThrottledApplication


def register(site):
site.register(ThrottledApplication, ThrottledApplicationAdmin)
site.register(AccessToken, AccessTokenAdmin)


class ThrottledApplicationAdminForm(forms.ModelForm):
class Meta:
model = ThrottledApplication
exclude = (
"client_type",
"redirect_uris",
"post_logout_redirect_uris",
"skip_authorization",
"algorithm",
"user",
)

# ArrayField doesn't have a good default field, so use a multiple choice field, but
# override default widget of multi-<select>, which is much more annoying to use
# and easy to accidentally un-select an option. The multi-checkbox is much easier to use
privileges = forms.MultipleChoiceField(
choices=Privilege.choices,
required=False,
widget=forms.CheckboxSelectMultiple,
)


class ThrottledApplicationAdmin(admin.ModelAdmin):
form = ThrottledApplicationAdminForm
view_on_site = False

search_fields = ("client_id", "name", "rate_limit_model", "privileges")
list_display = ("client_id", "name", "created", "rate_limit_model", "privileges")
ordering = ("-created",)

readonly_fields = (
"name",
"created",
"client_id",
"verified",
"authorization_grant_type",
"client_secret",
)

def has_delete_permission(self, *args, **kwargs):
"""
Disallow deleting throttled applications. Use ``revoke`` instead.
This also hides the delete button on the change view.
"""
return False


class AccessTokenAdmin(admin.ModelAdmin):
search_fields = ("token", "id")
list_display = ("token", "id", "created", "scope", "expires")
Expand Down
72 changes: 72 additions & 0 deletions api/api/constants/privilege.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import typing
from dataclasses import dataclass

from rest_framework.request import Request


ANONYMOUS: typing.Literal["anonymous"] = "anonymous"
AUTHENTICATED: typing.Literal["authenticated"] = "authenticated"
PRIVILEGED: typing.Literal["privileged"] = "privileged"

Level = typing.Literal[ANONYMOUS, AUTHENTICATED, PRIVILEGED]
LEVELS = typing.get_args(Level)


_PRIVILEGES = {}


@dataclass
class Privilege:
"""
Privileges granted to applications upon approved request to Openverse maintainers.
Maintainers review requests for increased privileges on a per-case basis.
Distinct from ``rate_limit_model`` which only affects access rates rather than privileges.
"""

slug: str
anonymous: typing.Any
authenticated: typing.Any
privileged: typing.Any

def __post_init__(self):
_PRIVILEGES[self.slug] = self

def request_level(self, request: None | Request) -> tuple[Level, typing.Any]:
"""Retrieve the level of any request in relation to the privilege."""
if request is None or request.auth is None:
return ANONYMOUS, self.anonymous

if self.slug in request.auth.application.privileges:
return PRIVILEGED, self.privileged

return AUTHENTICATED, self.authenticated

@classmethod
@property
def choices(cls):
return ((slug,) * 2 for slug in _PRIVILEGES)


PAGE_SIZE = Privilege(
"page_size",
anonymous=20,
authenticated=50,
# Max out privileged page size at the maximum authenticated
# pagination depth, otherwise privileged page size limit can
# contradict pagination depth limit.
privileged=240,
)

PAGINATION_DEPTH = Privilege(
"pagination_depth",
# 12 pages of 20 results
# Both anon and authed are limited to the same depth
# authed users can request bigger pages, but still only the same
# overall number of results available
anonymous=12 * PAGE_SIZE.anonymous,
authenticated=12 * PAGE_SIZE.anonymous,
# 20 pages of maxed out page sizes for privileged apps
privileged=20 * PAGE_SIZE.privileged,
)
19 changes: 19 additions & 0 deletions api/api/migrations/0064_throttledapplication_privileges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Generated by Django 4.2.11 on 2024-05-23 01:41

import django.contrib.postgres.fields
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('api', '0063_merge_20240521_0843'),
]

operations = [
migrations.AddField(
model_name='throttledapplication',
name='privileges',
field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(choices=[('page_size', 'page_size'), ('pagination_depth', 'pagination_depth')]), default=list, help_text='Privileges granted to applications upon approved request to Openverse maintainers. Maintainers review requests for increased privileges on a per-case basis. Distinct from ``rate_limit_model`` which only affects access rates rather than privileges.', size=None),
),
]
13 changes: 13 additions & 0 deletions api/api/models/oauth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from textwrap import dedent

from django.contrib.postgres.fields import ArrayField
from django.db import models

from oauth2_provider.models import AbstractApplication

from api.constants import privilege


class OAuth2Registration(models.Model):
"""Information about API key applicants."""
Expand Down Expand Up @@ -40,6 +45,14 @@ class ThrottledApplication(AbstractApplication):
verified = models.BooleanField(default=False)
revoked = models.BooleanField(default=False)

privileges = ArrayField(
models.CharField(
choices=privilege.Privilege.choices,
),
default=list,
help_text=dedent(privilege.Privilege.__doc__).strip().replace("\n", " "),
)


class OAuth2Verification(models.Model):
"""
Expand Down
14 changes: 14 additions & 0 deletions api/api/serializers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import typing

from rest_framework import serializers
from rest_framework.request import Request


class BaseModelSerializer(serializers.ModelSerializer):
Expand All @@ -19,3 +22,14 @@ def build_property_field(self, field_name, model_class):
if doc := getattr(model_class, field_name).__doc__:
kwargs.setdefault("help_text", doc)
return klass, kwargs


class BaseRequestSerializer(serializers.Serializer):
class Context(typing.TypedDict):
request: typing.Required[Request]

context: Context

@property
def request(self) -> Request | None:
return self.context.get("request", None)
104 changes: 73 additions & 31 deletions api/api/serializers/media_serializers.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
from collections import namedtuple
from typing import TypedDict
from math import floor

from django.conf import settings
from django.core.exceptions import ValidationError as DjangoValidationError
from django.core.validators import MaxValueValidator
from django.urls import reverse
from rest_framework import serializers
from rest_framework.exceptions import NotAuthenticated, ValidationError
from rest_framework.request import Request

from drf_spectacular.utils import extend_schema_serializer
from elasticsearch_dsl.response import Hit
from openverse_attribution.license import License

from api.constants import sensitivity
from api.constants import privilege, sensitivity
from api.constants.licenses import LICENSE_GROUPS
from api.constants.media_types import MediaType
from api.constants.parameters import COLLECTION, TAG
from api.constants.search import COLLECTIONS
from api.constants.sorting import DESCENDING, RELEVANCE, SORT_DIRECTIONS, SORT_FIELDS
from api.controllers import search_controller
from api.models.media import AbstractMedia
from api.serializers.base import BaseModelSerializer
from api.serializers.base import BaseModelSerializer, BaseRequestSerializer
from api.serializers.docs import (
COLLECTION_HELP_TEXT,
CREATOR_HELP_TEXT,
Expand All @@ -40,40 +39,36 @@
#######################


class PaginatedRequestSerializer(serializers.Serializer):
class PaginatedRequestSerializer(BaseRequestSerializer):
"""This serializer passes pagination parameters from the query string."""

_SUBJECT_TO_PAGINATION_LIMITS = (
"This parameter is subject to limitations based on authentication "
"and special privileges. For details, refer to [the authentication "
"documentation](#tag/auth)."
)

field_names = [
"page_size",
"page",
]
page_size = serializers.IntegerField(
label="page_size",
help_text=f"Number of results to return per page. "
f"Maximum is {settings.MAX_AUTHED_PAGE_SIZE} for authenticated "
f"requests, and {settings.MAX_ANONYMOUS_PAGE_SIZE} for "
f"unauthenticated requests.",
help_text=f"Number of results to return per page. {_SUBJECT_TO_PAGINATION_LIMITS}",
required=False,
default=settings.MAX_ANONYMOUS_PAGE_SIZE,
default=privilege.PAGE_SIZE.anonymous,
min_value=1,
)
page = serializers.IntegerField(
label="page",
help_text="The page of results to retrieve.",
help_text=f"The page of results to retrieve. {_SUBJECT_TO_PAGINATION_LIMITS}",
required=False,
default=1,
max_value=settings.MAX_PAGINATION_DEPTH,
min_value=1,
)

def validate_page_size(self, value):
request = self.context.get("request")
is_anonymous = getattr(request, "auth", None) is None
max_value = (
settings.MAX_ANONYMOUS_PAGE_SIZE
if is_anonymous
else settings.MAX_AUTHED_PAGE_SIZE
)
level, max_value = privilege.PAGE_SIZE.request_level(self.request)

validator = MaxValueValidator(
max_value,
Expand All @@ -82,19 +77,67 @@ def validate_page_size(self, value):
),
)

if is_anonymous:
try:
validator(value)
except (ValidationError, DjangoValidationError) as e:
raise NotAuthenticated(
detail=e.message,
code=e.code,
)
else:
try:
validator(value)
except (ValidationError, DjangoValidationError) as e:
if level == privilege.PRIVILEGED:
raise

raise NotAuthenticated(
detail=f"page_size may not exceed {max_value} for {level} requests",
code=e.code,
)

return value

def clamp_result_count(self, real_result_count):
_, max_depth = privilege.PAGINATION_DEPTH.request_level(self.request)

if real_result_count > max_depth:
return max_depth

return real_result_count

def clamp_page_count(self, real_page_count):
_, max_depth = privilege.PAGINATION_DEPTH.request_level(self.request)

page_size = self.data["page_size"]
max_possible_page_count = max_depth / page_size

if real_page_count > max_possible_page_count:
return floor(max_possible_page_count)

return real_page_count

def validate(self, data):
data = super().validate(data)

# pagination depth is validated as a combination of page and page size,
# and so cannot be validated in the individual field validation methods
level, max_depth = privilege.PAGINATION_DEPTH.request_level(self.request)

requested_pagination_depth = data["page"] * data["page_size"]

pagination_depth_validator = MaxValueValidator(
max_depth,
message=serializers.IntegerField.default_error_messages["max_value"].format(
max_value=max_depth
),
)

try:
pagination_depth_validator(requested_pagination_depth)
except (ValidationError, DjangoValidationError) as e:
if level == privilege.PRIVILEGED:
raise

raise NotAuthenticated(
detail=f"pagination depth may not exceed {max_depth} for {level} requests",
code=e.code,
)

return data


@extend_schema_serializer(
# Hide internal fields from documentation.
Expand Down Expand Up @@ -275,10 +318,9 @@ class MediaSearchRequestSerializer(PaginatedRequestSerializer):
required=False,
)

class Context(TypedDict, total=True):
class Context(BaseRequestSerializer.Context, total=True):
warnings: list[dict]
media_type: MediaType
request: Request

context: Context

Expand Down Expand Up @@ -401,7 +443,7 @@ def validate_source(self, value):
f"Refer to the source list for valid options: {sources_list}."
)
elif invalid_sources := (sources - valid_sources):
available_sources_uri = self.context["request"].build_absolute_uri(
available_sources_uri = self.request.build_absolute_uri(
reverse(f"{self.media_type}-stats")
)
self.context["warnings"].append(
Expand Down
Loading

0 comments on commit 53b2423

Please sign in to comment.