Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mongo operators to F function #58

Merged
merged 6 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 108 additions & 22 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Any, Dict, Type, Union, List, TypeVar, Callable
from typing import Any, Dict, Type, Union, List, TypeVar, Callable, Tuple

from mongoengine import Document, EmbeddedDocument, fields as mongoengine_fields
from mongoengine.base import TopLevelDocumentMetaclass
Expand All @@ -13,6 +13,7 @@
OutStageError,
InvalidArgument,
InvalidProjection,
InvalidAnnotateExpression,
)
from aggify.types import QueryParams, CollectionType
from aggify.utilty import (
Expand Down Expand Up @@ -362,7 +363,11 @@ def unwind(
return self

def annotate(
self, annotate_name: str, accumulator: str, f: Union[Union[str, Dict], F, int]
self,
annotate_name: Union[str, None] = None,
accumulator: Union[str, None] = None,
f: Union[Union[str, Dict], F, int, None] = None,
**kwargs,
) -> "Aggify":
"""
Annotate a MongoDB aggregation pipeline with a new field.
Expand All @@ -372,6 +377,7 @@ def annotate(
annotate_name (str): The name of the new annotated field.
accumulator (str): The aggregation accumulator operator (e.g., "$sum", "$avg").
f (Union[str, Dict] | F | int): The value for the annotated field.
kwargs: Use F expressions.

Returns:
self.
Expand All @@ -381,11 +387,60 @@ def annotate(

Example:
annotate("totalSales", "sum", "sales")
or
annotate(first_field = F('field').first())
"""

try:
stage = list(self.pipelines[-1].keys())[0]
if stage != "$group":
raise AnnotationError(
f"Annotations apply only to $group, not to {stage}"
)
except IndexError:
raise AnnotationError(
"Annotations apply only to $group, your pipeline is empty"
)

# Check either use F expression or not.
base_model_fields = self.base_model._fields # noqa
if not kwargs:
field_type, acc = self._get_field_type_and_accumulator(accumulator)

# Get the annotation value: If the value is a string object, then it will be validated in the case of
# embedded fields; otherwise, if it is an F expression object, simply return it.
value = self._get_annotate_value(f)
annotate = {annotate_name: {acc: value}}
# Determine the data type based on the aggregation operator
if not base_model_fields.get(annotate_name, None):
base_model_fields[annotate_name] = field_type
else:
annotate, fields = self._do_annotate_with_expression(
kwargs, base_model_fields
)

self.pipelines[-1]["$group"].update(annotate)
return self

@staticmethod
def _get_field_type_and_accumulator(
accumulator: str,
) -> Tuple[Type, str]:
"""
Retrieves the accumulator name and returns corresponding MongoDB accumulator field type and name.

Args:
accumulator (str): The name of the accumulator.

Returns: (Tuple): containing the field type and MongoDB accumulator string.

Raises:
AnnotationError: If the accumulator name is invalid.
"""

# Some of the accumulator fields might be false and should be checked.
# noinspection SpellCheckingInspection
aggregation_mapping: Dict[str, Type] = {
aggregation_mapping: Dict[str, Tuple] = {
"sum": (mongoengine_fields.FloatField(), "$sum"),
"avg": (mongoengine_fields.FloatField(), "$avg"),
"stdDevPop": (mongoengine_fields.FloatField(), "$stdDevPop"),
Expand Down Expand Up @@ -428,23 +483,26 @@ def annotate(
"lastN": (mongoengine_fields.ListField(), "$lastN"),
"maxN": (mongoengine_fields.ListField(), "$maxN"),
}

try:
stage = list(self.pipelines[-1].keys())[0]
if stage != "$group":
raise AnnotationError(
f"Annotations apply only to $group, not to {stage}"
)
except IndexError:
raise AnnotationError(
"Annotations apply only to $group, your pipeline is empty"
)

try:
field_type, acc = aggregation_mapping[accumulator]
return aggregation_mapping[accumulator]
except KeyError as error:
raise AnnotationError(f"Invalid accumulator: {accumulator}") from error

def _get_annotate_value(self, f: Union[F, str]) -> Union[Dict, str]:
"""
Determines the annotation value based on the type of the input 'f'.

If 'f' is an instance of F, it converts it to a dictionary.
If 'f' is a string, it attempts to retrieve the corresponding field name recursively.
If it encounters an InvalidField exception, it retains 'f' as the value.
Otherwise, 'f' is returned as is.

Args:
f: The input value, which can be an instance of F, a string, or any other type.

Returns:
The determined annotation value, which could be a dictionary, a formatted string, or the original input.
"""
if isinstance(f, F):
value = f.to_dict()
else:
Expand All @@ -455,13 +513,41 @@ def annotate(
value = f
else:
value = f
return value

# Determine the data type based on the aggregation operator
self.pipelines[-1]["$group"].update({annotate_name: {acc: value}})
base_model_fields = self.base_model._fields # noqa
if not base_model_fields.get(annotate_name, None):
base_model_fields[annotate_name] = field_type
return self
@staticmethod
def _do_annotate_with_expression(
annotate: Dict[str, Dict[str, Any]], base_model_fields: Dict[str, Any]
) -> Tuple[Dict[str, Dict[str, Any]], List[str]]:
"""
Processes the annotation with an expression, updating the fields and annotation dictionary.

Args:
annotate (Dict[str, Dict[str, Any]]): A dictionary containing field names and their corresponding F expressions.
base_model_fields (Dict[str, Any]): A dictionary representing the base model fields.

Returns: Tuple[Dict[str, Dict[str, Any]], List[str]]: A tuple containing the updated annotations and a list
of field names.

Raises:
InvalidAnnotateExpression: If the F expression is not a dictionary.
"""
# Check if all elements in kwargs were valid
for item in annotate.values():
if not isinstance(item, dict):
raise InvalidAnnotateExpression()

# Extract field names
fields = list(annotate.keys())

# Process base_model_fields
for field_name in fields:
if field_name not in base_model_fields:
accumulator = next(iter(annotate[field_name])).replace("$", "")
field_type, _ = Aggify._get_field_type_and_accumulator(accumulator)
base_model_fields[field_name] = field_type

return annotate, fields

def __match(self, matches: Dict[str, Any]):
"""
Expand Down
18 changes: 18 additions & 0 deletions aggify/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,24 @@ def is_suitable_for_match(key: str) -> bool:
return False
return True

def first(self):
return {"$first": self.field}

def last(self):
return {"$last": self.field}

def min(self):
return {"$min": self.field}

def max(self):
return {"$max": self.field}

def sum(self):
return {"$sum": self.field}

def avg(self):
return {"$avg": self.field}


class Cond:
"""
Expand Down
6 changes: 6 additions & 0 deletions aggify/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,9 @@ class InvalidProjection(AggifyBaseException):
def __init__(self):
self.message = "You can't use inclusion and exclusion together."
super().__init__(self.message)


class InvalidAnnotateExpression(AggifyBaseException):
def __init__(self):
self.message = "Invalid expression passed to annotate."
super().__init__(self.message)
15 changes: 15 additions & 0 deletions tests/test_aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
InvalidEmbeddedField,
MongoIndexError,
InvalidProjection,
InvalidAnnotateExpression,
)


Expand Down Expand Up @@ -666,3 +667,17 @@ def test_lookup_raw_let(self):
def test_group_multi_expressions(self):
thing = list(Aggify(BaseModel).group(["name", "age"]))
assert thing[0]["$group"] == {"_id": {"name": "$name", "age": "$age"}}

def test_annotate_with_f_expressions(self):
thing = list(
Aggify(BaseModel).group(["name", "age"]).annotate(first=F("name").first())
)
assert thing[0]["$group"] == {
"_id": {"name": "$name", "age": "$age"},
"first": {"$first": "$name"},
}

def test_annotate_with_invalid_f_expression(self):
with pytest.raises(InvalidAnnotateExpression):
# noinspection PyUnusedLocal
thing = list(Aggify(BaseModel).group(["name", "age"]).annotate(first=""))
17 changes: 17 additions & 0 deletions tests/test_f.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from aggify import F # Import from your actual module


Expand Down Expand Up @@ -72,3 +74,18 @@ def test_multiplication_with_multiple_fields(self):
f3 = F("nano")
f_combined = f1 * f2 * f3
assert f_combined.to_dict() == {"$multiply": ["$quantity", "$price", "$nano"]}

@pytest.mark.parametrize(
"method",
(
"first",
"last",
"min",
"max",
"sum",
"avg",
),
)
def test_f_operator_methods(self, method):
f1 = F("quantity")
assert getattr(f1, method)() == {f"${method}": "$quantity"}