Skip to content

Commit

Permalink
Added default to order by (#110)
Browse files Browse the repository at this point in the history
* Added default to order by

* Chenged version

* Fixed tests

* Refactor order by

* Refactor order by

* Deleted not need code
  • Loading branch information
andrey-berenda authored Jun 26, 2019
1 parent 956420b commit a5912f3
Show file tree
Hide file tree
Showing 14 changed files with 142 additions and 67 deletions.
11 changes: 11 additions & 0 deletions tests/pagination/test_order_by.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest
from rest_framework import exceptions

import winter


def test_with_default_not_in_allowed_fields():
with pytest.raises(exceptions.ParseError) as exception:
winter.pagination.order_by(['id'], default_sort=('uid',))

assert exception.value.args == ('Fields do not allowed as order by fields: "uid"',)
11 changes: 7 additions & 4 deletions tests/pagination/test_page_position_argument_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import winter
from winter.pagination import PagePosition
from winter.pagination import PagePositionArgumentsInspector
from winter.pagination import PagePositionArgumentResolver
from winter.pagination import PagePositionArgumentsInspector
from winter.routing import get_route


Expand Down Expand Up @@ -38,10 +38,12 @@ def method(self, arg1: argument_type):
assert parameters == expected_parameters


def test_page_position_argument_inspector_with_allowed_order_by_fields():
@pytest.mark.parametrize(('default_sort', 'default_in_parameter'), ((None, None), (('id',), 'id')))
def test_page_position_argument_inspector_with_allowed_order_by_fields(default_sort, default_in_parameter):

class SimpleController:
@winter.route_get('')
@winter.pagination.order_by(['id'])
@winter.pagination.order_by(['id'], default_sort=default_sort)
def method(self, arg1: PagePosition):
return arg1

Expand All @@ -52,11 +54,12 @@ def method(self, arg1: PagePosition):

order_by_parameter = openapi.Parameter(
name=resolver.order_by_name,
description='Comma separated order by fields. Allowed fields: id',
description='Comma separated order by fields. Allowed fields: id.',
required=False,
in_=openapi.IN_QUERY,
type=openapi.TYPE_ARRAY,
items={'type': openapi.TYPE_STRING},
default=default_in_parameter,
)

expected_parameters = [
Expand Down
42 changes: 32 additions & 10 deletions tests/pagination/test_page_position_argument_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,14 @@ def func(arg1: argument_type):
('', PagePosition(None, None)),
('offset=0', PagePosition(None, 0)),
('limit=10&offset=20&order_by=-id,name', PagePosition(10, 20, Sort.by('id').desc().and_(Sort.by('name')))),
('order_by= x', PagePosition(None, None, Sort.by(' x'))),
('order_by=- x', PagePosition(None, None, Sort.by(' x').desc())),
('order_by= -x', PagePosition(None, None, Sort.by(' -x'))),
('order_by=', PagePosition(None, None)),
))
def test_resolve_argument_ok_in_page_position_argument_resolver(query_string, expected_page_position):
def func(arg1: int):
return arg1
@winter.pagination.order_by(['name', 'id', 'email', 'x',])
def method(page_position: PagePosition):
return page_position

method = ComponentMethod(func)
argument = method.get_argument('arg1')
argument = method.get_argument('page_position')

resolver = PagePositionArgumentResolver(allow_any_order_by_field=True)

Expand All @@ -64,15 +61,40 @@ def func(arg1: int):
assert page_position == expected_page_position


@pytest.mark.parametrize(('query_string', 'default_sort', 'expected_page_position'), (
('limit=1&offset=3', ('-name',), PagePosition(1, 3, Sort.by('name').desc())),
))
def test_resolve_argument_ok_in_page_position_argument_resolver_with_default(
query_string,
default_sort,
expected_page_position,
):
@winter.pagination.order_by(['name', 'id', 'email',], default_sort=default_sort)
def method(page_position: PagePosition):
return page_position

argument = method.get_argument('page_position')

resolver = PagePositionArgumentResolver(allow_any_order_by_field=True)

request = Mock(spec=DRFRequest)
request.query_params = QueryDict(query_string)

# Act
page_position = resolver.resolve_argument(argument, request)

# Assert
assert page_position == expected_page_position

@pytest.mark.parametrize(('query_string', 'exception_type', 'message'), (
('limit=none', ParseError, 'Invalid "limit" query parameter value: "none"'),
('offset=-20', ValidationError, 'Invalid "offset" query parameter value: "-20"'),
('order_by=id,', ParseError, 'An empty sorting part found'),
('order_by=-', ParseError, 'An empty sorting part found'),
('order_by=id,', ParseError, 'Invalid field for order: ""'),
('order_by=-', ParseError, 'Invalid field for order: "-"'),
(
'order_by=not_allowed_order_by_field',
ParseError,
'Field "not_allowed_order_by_field" does not allowed as order by field',
'Fields do not allowed as order by fields: "not_allowed_order_by_field"',
),
))
def test_resolve_argument_fails_in_page_position_argument_resolver(query_string, exception_type, message):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_page_response():
response = client.get('/winter-simple/page-response/', data=data)

# Assert
assert response.status_code == HTTPStatus.OK
assert response.status_code == HTTPStatus.OK, response.content
assert response.json() == expected_body


Expand Down
2 changes: 1 addition & 1 deletion winter/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.11.1'
__version__ = '1.11.2'
19 changes: 14 additions & 5 deletions winter/pagination/page_position_argument_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from drf_yasg import openapi

from winter.pagination.sort import get_allowed_order_by_fields
from .page_position import PagePosition
from .page_position_argument_resolver import PagePositionArgumentResolver
from .sort import OrderByAnnotation
from ..schema import MethodArgumentsInspector

if TYPE_CHECKING:
Expand Down Expand Up @@ -39,16 +39,25 @@ def inspect_parameters(self, route: 'Route') -> List[openapi.Parameter]:
parameters.append(self.limit_parameter)
parameters.append(self.offset_parameter)

allowed_order_by_fields = get_allowed_order_by_fields(route.method)
if allowed_order_by_fields:
allowed_order_by_fields = ','.join(map(str, allowed_order_by_fields))
order_by_annotation = route.method.annotations.get_one_or_none(OrderByAnnotation)
if order_by_annotation:
allowed_order_by_fields = ','.join(map(str, order_by_annotation.allowed_fields))
default_sort = (
str(order_by_annotation.default_sort)
if order_by_annotation.default_sort is not None else
None
)
order_by_parameter = openapi.Parameter(
name=self._page_position_argument_resolver.order_by_name,
description=f'Comma separated order by fields. Allowed fields: {allowed_order_by_fields}',
description=(
f'Comma separated order by fields. '
f'Allowed fields: {allowed_order_by_fields}.'
),
required=False,
in_=openapi.IN_QUERY,
type=openapi.TYPE_ARRAY,
items={'type': openapi.TYPE_STRING},
default=default_sort,
)
parameters.append(order_by_parameter)

Expand Down
34 changes: 10 additions & 24 deletions winter/pagination/page_position_argument_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from .limits import LimitsAnnotation
from .limits import MaximumLimitValueExceeded
from .page_position import PagePosition
from .sort import Order
from .sort import OrderByAnnotation
from .sort import Sort
from .sort import SortDirection
from .sort import get_allowed_order_by_fields
from .sort.check_sort import check_sort
from .sort.parse_sort import parse_sort
from ..argument_resolver import ArgumentResolver
from ..core import ComponentMethod
from ..core import ComponentMethodArgument
Expand Down Expand Up @@ -66,7 +66,7 @@ def _get_limits(self, method: ComponentMethod) -> Limits:
def _parse_page_position(self, argument: ComponentMethodArgument, http_request: Request) -> PagePosition:
raw_limit = http_request.query_params.get(self.limit_name)
raw_offset = http_request.query_params.get(self.offset_name)
raw_order_by = http_request.query_params.get(self.order_by_name)
raw_order_by = http_request.query_params.get(self.order_by_name, '')
limit = self._parse_int_param(raw_limit, self.limit_name)
offset = self._parse_int_param(raw_offset, self.offset_name)
sort = self._parse_sort_properties(raw_order_by, argument)
Expand All @@ -87,25 +87,11 @@ def _parse_int_param(raw_param_value: str, param_name: str) -> typing.Optional[i
return param_value

def _parse_sort_properties(self, raw_param_value: str, argument: ComponentMethodArgument) -> typing.Optional[Sort]:
if not raw_param_value:
return None

sort_parts = raw_param_value.split(',')
allowed_order_by_fields = get_allowed_order_by_fields(argument.method)
orders = (self._parse_order(sort_part, allowed_order_by_fields) for sort_part in sort_parts)
return Sort(*orders)

def _parse_order(self, field: str, allowed_order_by_fields: typing.FrozenSet[str]) -> Order:
is_desc = False
if field.startswith('-'):
is_desc = True
field = field[1:]

if not field:
raise exceptions.ParseError('An empty sorting part found')
sort = parse_sort(raw_param_value)
order_by_annotations = argument.method.annotations.get_one_or_none(OrderByAnnotation)

if field not in allowed_order_by_fields and (not self.allow_any_order_by_field or allowed_order_by_fields):
raise exceptions.ParseError(f'Field "{field}" does not allowed as order by field')
if sort is None:
return order_by_annotations and order_by_annotations.default_sort
check_sort(sort, order_by_annotations.allowed_fields)

direction = SortDirection.DESC if is_desc else SortDirection.ASC
return Order(direction=direction, field=field)
return sort
1 change: 0 additions & 1 deletion winter/pagination/sort/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .order_by import order_by
from .order_by_annotation import OrderByAnnotation
from .order_by_annotation import get_allowed_order_by_fields
from .sort import Order
from .sort import Sort
from .sort import SortDirection
17 changes: 17 additions & 0 deletions winter/pagination/sort/check_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import typing

from rest_framework import exceptions

if typing.TYPE_CHECKING:
from .sort import Sort


def check_sort(sort: 'Sort', allowed_fields: typing.FrozenSet[str]):
not_allowed_fields = [
order.field
for order in sort.orders
if order.field not in allowed_fields
]
if not_allowed_fields:
not_allowed_fields = ','.join(not_allowed_fields)
raise exceptions.ParseError(f'Fields do not allowed as order by fields: "{not_allowed_fields}"')
9 changes: 7 additions & 2 deletions winter/pagination/sort/order_by.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import typing

from .check_sort import check_sort
from .order_by_annotation import OrderByAnnotation
from .parse_sort import parse_sort
from ...core.annotation_decorator import annotate_method


def order_by(allowed_fields: typing.Iterable[str]):
def order_by(allowed_fields: typing.Iterable[str], default_sort: typing.Tuple[str] = None):
allowed_fields = frozenset(allowed_fields)
annotation = OrderByAnnotation(allowed_fields)
if default_sort is not None:
default_sort = parse_sort(','.join(default_sort))
check_sort(default_sort, allowed_fields)
annotation = OrderByAnnotation(allowed_fields, default_sort)
return annotate_method(annotation, single=True)
11 changes: 3 additions & 8 deletions winter/pagination/sort/order_by_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,11 @@

import dataclasses

from winter.core import ComponentMethod
if typing.TYPE_CHECKING:
from .sort import Sort


@dataclasses.dataclass
class OrderByAnnotation:
allowed_fields: typing.FrozenSet[str]


def get_allowed_order_by_fields(method: ComponentMethod) -> typing.FrozenSet[str]:
order_by_annotation = method.annotations.get_one_or_none(OrderByAnnotation)
if order_by_annotation is None:
return frozenset()
return order_by_annotation.allowed_fields
default_sort: typing.Optional['Sort'] = None
19 changes: 19 additions & 0 deletions winter/pagination/sort/parse_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import re

from rest_framework import exceptions

from .sort import Order
from .sort import SortDirection

_field_pattern = re.compile(r'(-?)(\w+)')


def parse_order(field: str):
match = _field_pattern.match(field)

if match is None:
raise exceptions.ParseError(f'Invalid field for order: "{field}"')

direction, field = match.groups()
direction = SortDirection.DESC if direction == '-' else SortDirection.ASC
return Order(field, direction)
12 changes: 12 additions & 0 deletions winter/pagination/sort/parse_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import typing

from .parse_order import parse_order
from .sort import Sort


def parse_sort(str_sort: typing.Optional[str]) -> typing.Optional[Sort]:
if not str_sort:
return None
sort_parts = str_sort.split(',')
orders = (parse_order(sort_part) for sort_part in sort_parts)
return Sort(*orders)
19 changes: 8 additions & 11 deletions winter/pagination/sort/sort.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import typing
from enum import Enum

from dataclasses import dataclass
Expand All @@ -18,10 +19,12 @@ def __str__(self):
return ('-' if self.direction == SortDirection.DESC else '') + self.field


@dataclass(frozen=True, init=False, repr=False)
class Sort:
orders: typing.Tuple[Order]

def __init__(self, *orders: Order):
self.orders = orders
object.__setattr__(self, 'orders', orders)

@staticmethod
def by(*fields: str) -> 'Sort':
Expand All @@ -43,14 +46,8 @@ def desc(self) -> 'Sort':
orders = (Order(field=order.field, direction=SortDirection.DESC) for order in self.orders)
return Sort(*orders)

def __repr__(self):
sort_fields = ','.join(map(str, self.orders))
return f"Sort('{sort_fields}')"

def __eq__(self, other):
if not isinstance(other, Sort):
return False
return self.orders == other.orders
def __str__(self):
return ','.join(map(str, self.orders))

def __hash__(self):
return hash(self.orders)
def __repr__(self):
return f"Sort('{self}')"

0 comments on commit a5912f3

Please sign in to comment.