Skip to content

Commit

Permalink
Support distinct to Aggregate (#312)
Browse files Browse the repository at this point in the history
  • Loading branch information
pjongy authored Mar 12, 2020
1 parent 4a7b0ca commit e1ab02f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ New features:
- Allow usage of ``F`` expressions to in annotations. (#301)
- Now negative number with ``limit(...)`` and ``offset(...)`` raise ``ParamsError``. (#306)
- Allow usage of Function to ``queryset.update()``. (#308)
- Add ability to supply ``distinct`` flag to Aggregate (#312)


Bugfixes:
^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Contributors
* Weiming Dong ``@dongweiming``
* Jinlong Peng ``@long2ice``
* Sang-Heon Jeon ``@lntuition``
* Jong-Yeop Park ``@pjongy``

Special Thanks
==============
Expand Down
27 changes: 26 additions & 1 deletion tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tests.testmodels import Event, Team, Tournament
from tests.testmodels import Event, MinRelation, Team, Tournament
from tortoise.contrib import test
from tortoise.exceptions import ConfigurationError
from tortoise.functions import Count, Min, Sum
Expand Down Expand Up @@ -49,3 +49,28 @@ async def test_aggregation(self):

with self.assertRaisesRegex(ConfigurationError, "name__id not resolvable"):
await Event.all().annotate(tournament_test_id=Sum("name__id")).first()

async def test_aggregation_with_distinct(self):
tournament = await Tournament.create(name="New Tournament")
await Event.create(name="Event 1", tournament=tournament)
await Event.create(name="Event 2", tournament=tournament)
await MinRelation.create(tournament=tournament)

tournament_2 = await Tournament.create(name="New Tournament")
await Event.create(name="Event 1", tournament=tournament_2)
await Event.create(name="Event 2", tournament=tournament_2)
await Event.create(name="Event 3", tournament=tournament_2)
await MinRelation.create(tournament=tournament_2)
await MinRelation.create(tournament=tournament_2)

school_with_distinct_count = (
await Tournament.filter(id=tournament_2.id)
.annotate(
events_count=Count("events", distinct=True),
minrelations_count=Count("minrelations", distinct=True),
)
.first()
)

self.assertEqual(school_with_distinct_count.events_count, 3)
self.assertEqual(school_with_distinct_count.minrelations_count, 2)
27 changes: 24 additions & 3 deletions tortoise/functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import TYPE_CHECKING, Any, Optional, Type, Union, cast

from pypika import Table, functions
from pypika.terms import AggregateFunction, ArithmeticExpression
from pypika.functions import DistinctOptionFunction
from pypika.terms import ArithmeticExpression
from pypika.terms import Function as BaseFunction

from tortoise.exceptions import ConfigurationError
Expand All @@ -12,6 +13,7 @@
from tortoise.models import Model
from tortoise.fields.base import Field


##############################################################################
# Base
##############################################################################
Expand Down Expand Up @@ -46,6 +48,9 @@ def __init__(self, field: Union[str, F, ArithmeticExpression], *default_values:
self.field_object: "Optional[Field]" = None
self.default_values = default_values

def _get_function_field(self, field: str, *default_values):
return self.database_func(field, *default_values)

def _resolve_field_for_model(
self, model: "Type[Model]", table: Table, field: str, *default_values: Any
) -> dict:
Expand Down Expand Up @@ -73,7 +78,7 @@ def _resolve_field_for_model(
if func:
field = func(self.field_object, field)

function_field = self.database_func(field, *default_values)
function_field = self._get_function_field(field, *default_values)
return {"joins": function_joins, "field": function_field}

if field_split[0] not in model._meta.fetch_fields:
Expand Down Expand Up @@ -114,9 +119,25 @@ def resolve(self, model: "Type[Model]", table: Table) -> dict:
class Aggregate(Function):
"""
Base for SQL Aggregates.
:param field: Field name
:param default_values: Extra parameters to the function.
:param is_distinct: Flag for aggregate with distinction
"""

database_func = AggregateFunction
database_func = DistinctOptionFunction

def __init__(
self, field: Union[str, F, ArithmeticExpression], *default_values: Any, distinct=False
) -> None:
super().__init__(field, *default_values)
self.distinct = distinct

def _get_function_field(self, field: str, *default_values):
if self.distinct:
return self.database_func(field, *default_values).distinct()
else:
return self.database_func(field, *default_values)


##############################################################################
Expand Down

0 comments on commit e1ab02f

Please sign in to comment.