diff --git a/aggify/aggify.py b/aggify/aggify.py index e208889..7c313a3 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Dict, Type, Union, List, TypeVar, Callable +from typing import Any, Dict, Type, Union, List, TypeVar, Callable, Tuple from mongoengine import Document, EmbeddedDocument, fields as mongoengine_fields from mongoengine.base import TopLevelDocumentMetaclass @@ -13,6 +13,7 @@ OutStageError, InvalidArgument, InvalidProjection, + InvalidAnnotateExpression, ) from aggify.types import QueryParams, CollectionType from aggify.utilty import ( @@ -362,7 +363,11 @@ def unwind( return self def annotate( - self, annotate_name: str, accumulator: str, f: Union[Union[str, Dict], F, int] + self, + annotate_name: Union[str, None] = None, + accumulator: Union[str, None] = None, + f: Union[Union[str, Dict], F, int, None] = None, + **kwargs, ) -> "Aggify": """ Annotate a MongoDB aggregation pipeline with a new field. @@ -372,6 +377,7 @@ def annotate( annotate_name (str): The name of the new annotated field. accumulator (str): The aggregation accumulator operator (e.g., "$sum", "$avg"). f (Union[str, Dict] | F | int): The value for the annotated field. + kwargs: Use F expressions. Returns: self. @@ -381,11 +387,60 @@ def annotate( Example: annotate("totalSales", "sum", "sales") + or + annotate(first_field = F('field').first()) + """ + + try: + stage = list(self.pipelines[-1].keys())[0] + if stage != "$group": + raise AnnotationError( + f"Annotations apply only to $group, not to {stage}" + ) + except IndexError: + raise AnnotationError( + "Annotations apply only to $group, your pipeline is empty" + ) + + # Check either use F expression or not. + base_model_fields = self.base_model._fields # noqa + if not kwargs: + field_type, acc = self._get_field_type_and_accumulator(accumulator) + + # Get the annotation value: If the value is a string object, then it will be validated in the case of + # embedded fields; otherwise, if it is an F expression object, simply return it. + value = self._get_annotate_value(f) + annotate = {annotate_name: {acc: value}} + # Determine the data type based on the aggregation operator + if not base_model_fields.get(annotate_name, None): + base_model_fields[annotate_name] = field_type + else: + annotate, fields = self._do_annotate_with_expression( + kwargs, base_model_fields + ) + + self.pipelines[-1]["$group"].update(annotate) + return self + + @staticmethod + def _get_field_type_and_accumulator( + accumulator: str, + ) -> Tuple[Type, str]: + """ + Retrieves the accumulator name and returns corresponding MongoDB accumulator field type and name. + + Args: + accumulator (str): The name of the accumulator. + + Returns: (Tuple): containing the field type and MongoDB accumulator string. + + Raises: + AnnotationError: If the accumulator name is invalid. """ # Some of the accumulator fields might be false and should be checked. # noinspection SpellCheckingInspection - aggregation_mapping: Dict[str, Type] = { + aggregation_mapping: Dict[str, Tuple] = { "sum": (mongoengine_fields.FloatField(), "$sum"), "avg": (mongoengine_fields.FloatField(), "$avg"), "stdDevPop": (mongoengine_fields.FloatField(), "$stdDevPop"), @@ -428,23 +483,26 @@ def annotate( "lastN": (mongoengine_fields.ListField(), "$lastN"), "maxN": (mongoengine_fields.ListField(), "$maxN"), } - try: - stage = list(self.pipelines[-1].keys())[0] - if stage != "$group": - raise AnnotationError( - f"Annotations apply only to $group, not to {stage}" - ) - except IndexError: - raise AnnotationError( - "Annotations apply only to $group, your pipeline is empty" - ) - - try: - field_type, acc = aggregation_mapping[accumulator] + return aggregation_mapping[accumulator] except KeyError as error: raise AnnotationError(f"Invalid accumulator: {accumulator}") from error + def _get_annotate_value(self, f: Union[F, str]) -> Union[Dict, str]: + """ + Determines the annotation value based on the type of the input 'f'. + + If 'f' is an instance of F, it converts it to a dictionary. + If 'f' is a string, it attempts to retrieve the corresponding field name recursively. + If it encounters an InvalidField exception, it retains 'f' as the value. + Otherwise, 'f' is returned as is. + + Args: + f: The input value, which can be an instance of F, a string, or any other type. + + Returns: + The determined annotation value, which could be a dictionary, a formatted string, or the original input. + """ if isinstance(f, F): value = f.to_dict() else: @@ -455,13 +513,41 @@ def annotate( value = f else: value = f + return value - # Determine the data type based on the aggregation operator - self.pipelines[-1]["$group"].update({annotate_name: {acc: value}}) - base_model_fields = self.base_model._fields # noqa - if not base_model_fields.get(annotate_name, None): - base_model_fields[annotate_name] = field_type - return self + @staticmethod + def _do_annotate_with_expression( + annotate: Dict[str, Dict[str, Any]], base_model_fields: Dict[str, Any] + ) -> Tuple[Dict[str, Dict[str, Any]], List[str]]: + """ + Processes the annotation with an expression, updating the fields and annotation dictionary. + + Args: + annotate (Dict[str, Dict[str, Any]]): A dictionary containing field names and their corresponding F expressions. + base_model_fields (Dict[str, Any]): A dictionary representing the base model fields. + + Returns: Tuple[Dict[str, Dict[str, Any]], List[str]]: A tuple containing the updated annotations and a list + of field names. + + Raises: + InvalidAnnotateExpression: If the F expression is not a dictionary. + """ + # Check if all elements in kwargs were valid + for item in annotate.values(): + if not isinstance(item, dict): + raise InvalidAnnotateExpression() + + # Extract field names + fields = list(annotate.keys()) + + # Process base_model_fields + for field_name in fields: + if field_name not in base_model_fields: + accumulator = next(iter(annotate[field_name])).replace("$", "") + field_type, _ = Aggify._get_field_type_and_accumulator(accumulator) + base_model_fields[field_name] = field_type + + return annotate, fields def __match(self, matches: Dict[str, Any]): """ diff --git a/aggify/compiler.py b/aggify/compiler.py index 0f6752a..71eecf8 100644 --- a/aggify/compiler.py +++ b/aggify/compiler.py @@ -184,6 +184,24 @@ def is_suitable_for_match(key: str) -> bool: return False return True + def first(self): + return {"$first": self.field} + + def last(self): + return {"$last": self.field} + + def min(self): + return {"$min": self.field} + + def max(self): + return {"$max": self.field} + + def sum(self): + return {"$sum": self.field} + + def avg(self): + return {"$avg": self.field} + class Cond: """ diff --git a/aggify/exceptions.py b/aggify/exceptions.py index b5ab25d..de98c64 100644 --- a/aggify/exceptions.py +++ b/aggify/exceptions.py @@ -80,3 +80,9 @@ class InvalidProjection(AggifyBaseException): def __init__(self): self.message = "You can't use inclusion and exclusion together." super().__init__(self.message) + + +class InvalidAnnotateExpression(AggifyBaseException): + def __init__(self): + self.message = "Invalid expression passed to annotate." + super().__init__(self.message) diff --git a/tests/test_aggify.py b/tests/test_aggify.py index b73053e..4156ae6 100644 --- a/tests/test_aggify.py +++ b/tests/test_aggify.py @@ -13,6 +13,7 @@ InvalidEmbeddedField, MongoIndexError, InvalidProjection, + InvalidAnnotateExpression, ) @@ -666,3 +667,17 @@ def test_lookup_raw_let(self): def test_group_multi_expressions(self): thing = list(Aggify(BaseModel).group(["name", "age"])) assert thing[0]["$group"] == {"_id": {"name": "$name", "age": "$age"}} + + def test_annotate_with_f_expressions(self): + thing = list( + Aggify(BaseModel).group(["name", "age"]).annotate(first=F("name").first()) + ) + assert thing[0]["$group"] == { + "_id": {"name": "$name", "age": "$age"}, + "first": {"$first": "$name"}, + } + + def test_annotate_with_invalid_f_expression(self): + with pytest.raises(InvalidAnnotateExpression): + # noinspection PyUnusedLocal + thing = list(Aggify(BaseModel).group(["name", "age"]).annotate(first="")) diff --git a/tests/test_f.py b/tests/test_f.py index f69d230..c56cc2b 100644 --- a/tests/test_f.py +++ b/tests/test_f.py @@ -1,3 +1,5 @@ +import pytest + from aggify import F # Import from your actual module @@ -72,3 +74,18 @@ def test_multiplication_with_multiple_fields(self): f3 = F("nano") f_combined = f1 * f2 * f3 assert f_combined.to_dict() == {"$multiply": ["$quantity", "$price", "$nano"]} + + @pytest.mark.parametrize( + "method", + ( + "first", + "last", + "min", + "max", + "sum", + "avg", + ), + ) + def test_f_operator_methods(self, method): + f1 = F("quantity") + assert getattr(f1, method)() == {f"${method}": "$quantity"}