diff --git a/tests/pagination/test_order_by.py b/tests/pagination/test_order_by.py new file mode 100644 index 00000000..483837c6 --- /dev/null +++ b/tests/pagination/test_order_by.py @@ -0,0 +1,11 @@ +import pytest +from rest_framework import exceptions + +import winter + + +def test_with_default_not_in_allowed_fields(): + with pytest.raises(exceptions.ParseError) as exception: + winter.pagination.order_by(['id'], default_sort=('uid',)) + + assert exception.value.args == ('Fields do not allowed as order by fields: "uid"',) diff --git a/tests/pagination/test_page_position_argument_inspector.py b/tests/pagination/test_page_position_argument_inspector.py index 0a456511..333788eb 100644 --- a/tests/pagination/test_page_position_argument_inspector.py +++ b/tests/pagination/test_page_position_argument_inspector.py @@ -3,8 +3,8 @@ import winter from winter.pagination import PagePosition -from winter.pagination import PagePositionArgumentsInspector from winter.pagination import PagePositionArgumentResolver +from winter.pagination import PagePositionArgumentsInspector from winter.routing import get_route @@ -38,10 +38,12 @@ def method(self, arg1: argument_type): assert parameters == expected_parameters -def test_page_position_argument_inspector_with_allowed_order_by_fields(): +@pytest.mark.parametrize(('default_sort', 'default_in_parameter'), ((None, None), (('id',), 'id'))) +def test_page_position_argument_inspector_with_allowed_order_by_fields(default_sort, default_in_parameter): + class SimpleController: @winter.route_get('') - @winter.pagination.order_by(['id']) + @winter.pagination.order_by(['id'], default_sort=default_sort) def method(self, arg1: PagePosition): return arg1 @@ -52,11 +54,12 @@ def method(self, arg1: PagePosition): order_by_parameter = openapi.Parameter( name=resolver.order_by_name, - description='Comma separated order by fields. Allowed fields: id', + description='Comma separated order by fields. Allowed fields: id.', required=False, in_=openapi.IN_QUERY, type=openapi.TYPE_ARRAY, items={'type': openapi.TYPE_STRING}, + default=default_in_parameter, ) expected_parameters = [ diff --git a/tests/pagination/test_page_position_argument_resolver.py b/tests/pagination/test_page_position_argument_resolver.py index 888cca5f..8131396f 100644 --- a/tests/pagination/test_page_position_argument_resolver.py +++ b/tests/pagination/test_page_position_argument_resolver.py @@ -40,17 +40,14 @@ def func(arg1: argument_type): ('', PagePosition(None, None)), ('offset=0', PagePosition(None, 0)), ('limit=10&offset=20&order_by=-id,name', PagePosition(10, 20, Sort.by('id').desc().and_(Sort.by('name')))), - ('order_by= x', PagePosition(None, None, Sort.by(' x'))), - ('order_by=- x', PagePosition(None, None, Sort.by(' x').desc())), - ('order_by= -x', PagePosition(None, None, Sort.by(' -x'))), ('order_by=', PagePosition(None, None)), )) def test_resolve_argument_ok_in_page_position_argument_resolver(query_string, expected_page_position): - def func(arg1: int): - return arg1 + @winter.pagination.order_by(['name', 'id', 'email', 'x',]) + def method(page_position: PagePosition): + return page_position - method = ComponentMethod(func) - argument = method.get_argument('arg1') + argument = method.get_argument('page_position') resolver = PagePositionArgumentResolver(allow_any_order_by_field=True) @@ -64,15 +61,40 @@ def func(arg1: int): assert page_position == expected_page_position +@pytest.mark.parametrize(('query_string', 'default_sort', 'expected_page_position'), ( + ('limit=1&offset=3', ('-name',), PagePosition(1, 3, Sort.by('name').desc())), +)) +def test_resolve_argument_ok_in_page_position_argument_resolver_with_default( + query_string, + default_sort, + expected_page_position, +): + @winter.pagination.order_by(['name', 'id', 'email',], default_sort=default_sort) + def method(page_position: PagePosition): + return page_position + + argument = method.get_argument('page_position') + + resolver = PagePositionArgumentResolver(allow_any_order_by_field=True) + + request = Mock(spec=DRFRequest) + request.query_params = QueryDict(query_string) + + # Act + page_position = resolver.resolve_argument(argument, request) + + # Assert + assert page_position == expected_page_position + @pytest.mark.parametrize(('query_string', 'exception_type', 'message'), ( ('limit=none', ParseError, 'Invalid "limit" query parameter value: "none"'), ('offset=-20', ValidationError, 'Invalid "offset" query parameter value: "-20"'), - ('order_by=id,', ParseError, 'An empty sorting part found'), - ('order_by=-', ParseError, 'An empty sorting part found'), + ('order_by=id,', ParseError, 'Invalid field for order: ""'), + ('order_by=-', ParseError, 'Invalid field for order: "-"'), ( 'order_by=not_allowed_order_by_field', ParseError, - 'Field "not_allowed_order_by_field" does not allowed as order by field', + 'Fields do not allowed as order by fields: "not_allowed_order_by_field"', ), )) def test_resolve_argument_fails_in_page_position_argument_resolver(query_string, exception_type, message): diff --git a/tests/test_controller.py b/tests/test_controller.py index 229d35d5..9cc537a4 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -69,7 +69,7 @@ def test_page_response(): response = client.get('/winter-simple/page-response/', data=data) # Assert - assert response.status_code == HTTPStatus.OK + assert response.status_code == HTTPStatus.OK, response.content assert response.json() == expected_body diff --git a/winter/__version__.py b/winter/__version__.py index 522ba08c..6c371de9 100644 --- a/winter/__version__.py +++ b/winter/__version__.py @@ -1 +1 @@ -__version__ = '1.11.1' +__version__ = '1.11.2' diff --git a/winter/pagination/page_position_argument_inspector.py b/winter/pagination/page_position_argument_inspector.py index c92e829e..a2cc297d 100644 --- a/winter/pagination/page_position_argument_inspector.py +++ b/winter/pagination/page_position_argument_inspector.py @@ -3,9 +3,9 @@ from drf_yasg import openapi -from winter.pagination.sort import get_allowed_order_by_fields from .page_position import PagePosition from .page_position_argument_resolver import PagePositionArgumentResolver +from .sort import OrderByAnnotation from ..schema import MethodArgumentsInspector if TYPE_CHECKING: @@ -39,16 +39,25 @@ def inspect_parameters(self, route: 'Route') -> List[openapi.Parameter]: parameters.append(self.limit_parameter) parameters.append(self.offset_parameter) - allowed_order_by_fields = get_allowed_order_by_fields(route.method) - if allowed_order_by_fields: - allowed_order_by_fields = ','.join(map(str, allowed_order_by_fields)) + order_by_annotation = route.method.annotations.get_one_or_none(OrderByAnnotation) + if order_by_annotation: + allowed_order_by_fields = ','.join(map(str, order_by_annotation.allowed_fields)) + default_sort = ( + str(order_by_annotation.default_sort) + if order_by_annotation.default_sort is not None else + None + ) order_by_parameter = openapi.Parameter( name=self._page_position_argument_resolver.order_by_name, - description=f'Comma separated order by fields. Allowed fields: {allowed_order_by_fields}', + description=( + f'Comma separated order by fields. ' + f'Allowed fields: {allowed_order_by_fields}.' + ), required=False, in_=openapi.IN_QUERY, type=openapi.TYPE_ARRAY, items={'type': openapi.TYPE_STRING}, + default=default_sort, ) parameters.append(order_by_parameter) diff --git a/winter/pagination/page_position_argument_resolver.py b/winter/pagination/page_position_argument_resolver.py index 59d4e7c7..319c11ad 100644 --- a/winter/pagination/page_position_argument_resolver.py +++ b/winter/pagination/page_position_argument_resolver.py @@ -8,10 +8,10 @@ from .limits import LimitsAnnotation from .limits import MaximumLimitValueExceeded from .page_position import PagePosition -from .sort import Order +from .sort import OrderByAnnotation from .sort import Sort -from .sort import SortDirection -from .sort import get_allowed_order_by_fields +from .sort.check_sort import check_sort +from .sort.parse_sort import parse_sort from ..argument_resolver import ArgumentResolver from ..core import ComponentMethod from ..core import ComponentMethodArgument @@ -66,7 +66,7 @@ def _get_limits(self, method: ComponentMethod) -> Limits: def _parse_page_position(self, argument: ComponentMethodArgument, http_request: Request) -> PagePosition: raw_limit = http_request.query_params.get(self.limit_name) raw_offset = http_request.query_params.get(self.offset_name) - raw_order_by = http_request.query_params.get(self.order_by_name) + raw_order_by = http_request.query_params.get(self.order_by_name, '') limit = self._parse_int_param(raw_limit, self.limit_name) offset = self._parse_int_param(raw_offset, self.offset_name) sort = self._parse_sort_properties(raw_order_by, argument) @@ -87,25 +87,11 @@ def _parse_int_param(raw_param_value: str, param_name: str) -> typing.Optional[i return param_value def _parse_sort_properties(self, raw_param_value: str, argument: ComponentMethodArgument) -> typing.Optional[Sort]: - if not raw_param_value: - return None - - sort_parts = raw_param_value.split(',') - allowed_order_by_fields = get_allowed_order_by_fields(argument.method) - orders = (self._parse_order(sort_part, allowed_order_by_fields) for sort_part in sort_parts) - return Sort(*orders) - - def _parse_order(self, field: str, allowed_order_by_fields: typing.FrozenSet[str]) -> Order: - is_desc = False - if field.startswith('-'): - is_desc = True - field = field[1:] - - if not field: - raise exceptions.ParseError('An empty sorting part found') + sort = parse_sort(raw_param_value) + order_by_annotations = argument.method.annotations.get_one_or_none(OrderByAnnotation) - if field not in allowed_order_by_fields and (not self.allow_any_order_by_field or allowed_order_by_fields): - raise exceptions.ParseError(f'Field "{field}" does not allowed as order by field') + if sort is None: + return order_by_annotations and order_by_annotations.default_sort + check_sort(sort, order_by_annotations.allowed_fields) - direction = SortDirection.DESC if is_desc else SortDirection.ASC - return Order(direction=direction, field=field) + return sort diff --git a/winter/pagination/sort/__init__.py b/winter/pagination/sort/__init__.py index 3de34494..0f0834a3 100644 --- a/winter/pagination/sort/__init__.py +++ b/winter/pagination/sort/__init__.py @@ -1,6 +1,5 @@ from .order_by import order_by from .order_by_annotation import OrderByAnnotation -from .order_by_annotation import get_allowed_order_by_fields from .sort import Order from .sort import Sort from .sort import SortDirection diff --git a/winter/pagination/sort/check_sort.py b/winter/pagination/sort/check_sort.py new file mode 100644 index 00000000..8b15a239 --- /dev/null +++ b/winter/pagination/sort/check_sort.py @@ -0,0 +1,17 @@ +import typing + +from rest_framework import exceptions + +if typing.TYPE_CHECKING: + from .sort import Sort + + +def check_sort(sort: 'Sort', allowed_fields: typing.FrozenSet[str]): + not_allowed_fields = [ + order.field + for order in sort.orders + if order.field not in allowed_fields + ] + if not_allowed_fields: + not_allowed_fields = ','.join(not_allowed_fields) + raise exceptions.ParseError(f'Fields do not allowed as order by fields: "{not_allowed_fields}"') diff --git a/winter/pagination/sort/order_by.py b/winter/pagination/sort/order_by.py index 225001cc..3e6953d7 100644 --- a/winter/pagination/sort/order_by.py +++ b/winter/pagination/sort/order_by.py @@ -1,10 +1,15 @@ import typing +from .check_sort import check_sort from .order_by_annotation import OrderByAnnotation +from .parse_sort import parse_sort from ...core.annotation_decorator import annotate_method -def order_by(allowed_fields: typing.Iterable[str]): +def order_by(allowed_fields: typing.Iterable[str], default_sort: typing.Tuple[str] = None): allowed_fields = frozenset(allowed_fields) - annotation = OrderByAnnotation(allowed_fields) + if default_sort is not None: + default_sort = parse_sort(','.join(default_sort)) + check_sort(default_sort, allowed_fields) + annotation = OrderByAnnotation(allowed_fields, default_sort) return annotate_method(annotation, single=True) diff --git a/winter/pagination/sort/order_by_annotation.py b/winter/pagination/sort/order_by_annotation.py index e1427074..321becc6 100644 --- a/winter/pagination/sort/order_by_annotation.py +++ b/winter/pagination/sort/order_by_annotation.py @@ -2,16 +2,11 @@ import dataclasses -from winter.core import ComponentMethod +if typing.TYPE_CHECKING: + from .sort import Sort @dataclasses.dataclass class OrderByAnnotation: allowed_fields: typing.FrozenSet[str] - - -def get_allowed_order_by_fields(method: ComponentMethod) -> typing.FrozenSet[str]: - order_by_annotation = method.annotations.get_one_or_none(OrderByAnnotation) - if order_by_annotation is None: - return frozenset() - return order_by_annotation.allowed_fields + default_sort: typing.Optional['Sort'] = None diff --git a/winter/pagination/sort/parse_order.py b/winter/pagination/sort/parse_order.py new file mode 100644 index 00000000..7332fc79 --- /dev/null +++ b/winter/pagination/sort/parse_order.py @@ -0,0 +1,19 @@ +import re + +from rest_framework import exceptions + +from .sort import Order +from .sort import SortDirection + +_field_pattern = re.compile(r'(-?)(\w+)') + + +def parse_order(field: str): + match = _field_pattern.match(field) + + if match is None: + raise exceptions.ParseError(f'Invalid field for order: "{field}"') + + direction, field = match.groups() + direction = SortDirection.DESC if direction == '-' else SortDirection.ASC + return Order(field, direction) diff --git a/winter/pagination/sort/parse_sort.py b/winter/pagination/sort/parse_sort.py new file mode 100644 index 00000000..bd88cd17 --- /dev/null +++ b/winter/pagination/sort/parse_sort.py @@ -0,0 +1,12 @@ +import typing + +from .parse_order import parse_order +from .sort import Sort + + +def parse_sort(str_sort: typing.Optional[str]) -> typing.Optional[Sort]: + if not str_sort: + return None + sort_parts = str_sort.split(',') + orders = (parse_order(sort_part) for sort_part in sort_parts) + return Sort(*orders) diff --git a/winter/pagination/sort/sort.py b/winter/pagination/sort/sort.py index 6d50ceb5..3c749932 100644 --- a/winter/pagination/sort/sort.py +++ b/winter/pagination/sort/sort.py @@ -1,4 +1,5 @@ import itertools +import typing from enum import Enum from dataclasses import dataclass @@ -18,10 +19,12 @@ def __str__(self): return ('-' if self.direction == SortDirection.DESC else '') + self.field +@dataclass(frozen=True, init=False, repr=False) class Sort: + orders: typing.Tuple[Order] def __init__(self, *orders: Order): - self.orders = orders + object.__setattr__(self, 'orders', orders) @staticmethod def by(*fields: str) -> 'Sort': @@ -43,14 +46,8 @@ def desc(self) -> 'Sort': orders = (Order(field=order.field, direction=SortDirection.DESC) for order in self.orders) return Sort(*orders) - def __repr__(self): - sort_fields = ','.join(map(str, self.orders)) - return f"Sort('{sort_fields}')" - - def __eq__(self, other): - if not isinstance(other, Sort): - return False - return self.orders == other.orders + def __str__(self): + return ','.join(map(str, self.orders)) - def __hash__(self): - return hash(self.orders) + def __repr__(self): + return f"Sort('{self}')"