Skip to content

Commit

Permalink
Merge pull request #51 from mohamadkhalaj/main
Browse files Browse the repository at this point in the history
Remove $ from newRoot and get db_field in filter
  • Loading branch information
mohamadkhalaj authored Nov 11, 2023
2 parents f9bbc5f + b9f5d33 commit 2a8310c
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 18 deletions.
10 changes: 5 additions & 5 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from aggify.types import QueryParams, CollectionType
from aggify.utilty import (
to_mongo_positive_index,
check_fields_exist,
validate_field_existence,
replace_values_recursive,
convert_match_query,
check_field_exists,
check_field_already_exists,
get_db_field,
)

Expand Down Expand Up @@ -538,7 +538,7 @@ def get_field_name_recursively(self, field: str) -> str:
# Split the field based on double underscores and process each item
for index, item in enumerate(field.split("__")):
# Ensure the field exists at the current level of hierarchy
check_fields_exist(prev_base, [item]) # noqa
validate_field_existence(prev_base, [item]) # noqa

# Append the database field name to the field_name list
field_name.append(get_db_field(prev_base, item))
Expand Down Expand Up @@ -583,7 +583,7 @@ def lookup(
"""

lookup_stages = []
check_field_exists(self.base_model, as_name) # noqa
check_field_already_exists(self.base_model, as_name) # noqa
from_collection_name = from_collection._meta.get("collection") # noqa

if not (let or raw_let) and not (local_field and foreign_field):
Expand Down Expand Up @@ -723,7 +723,7 @@ def replace_root(
{key: mongoengine_fields.IntField() for key, value in merge.items()}
)
else:
new_root = {"$replaceRoot": {"$newRoot": name}}
new_root = {"$replaceRoot": {"newRoot": name}}
self.pipelines.append(new_root)

return self
Expand Down
16 changes: 10 additions & 6 deletions aggify/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mongoengine.base import TopLevelDocumentMetaclass

from aggify.exceptions import InvalidOperator
from aggify.utilty import get_db_field
from aggify.utilty import get_db_field, get_nested_field_model


class Operators:
Expand Down Expand Up @@ -284,16 +284,20 @@ def compile(self, pipelines: list) -> Dict[str, Dict[str, list]]:
match_query[key] = value
continue

field, operator, *_ = key.split("__")
field, operator, *others = key.split("__")
if (
self.is_base_model_field(field)
and operator not in Operators.ALL_OPERATORS
):
pipelines.append(
Match({key.replace("__", ".", 1): value}, self.base_model).compile(
[]
)
field_db_name = get_db_field(self.base_model, field)

nested_field_name = get_db_field(
get_nested_field_model(self.base_model, field), operator
)
key = (
f"{field_db_name}.{nested_field_name}__" + "__".join(others)
).rstrip("__")
pipelines.append(Match({key: value}, self.base_model).compile([]))
continue

if operator not in Operators.ALL_OPERATORS:
Expand Down
40 changes: 35 additions & 5 deletions aggify/utilty.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Union, List, Dict

from mongoengine import Document
from aggify.types import CollectionType

from aggify.exceptions import MongoIndexError, InvalidField, AlreadyExistsField
from aggify.types import CollectionType


def to_mongo_positive_index(index: Union[int, slice]) -> slice:
Expand All @@ -22,9 +23,9 @@ def to_mongo_positive_index(index: Union[int, slice]) -> slice:
return index


def check_fields_exist(model: CollectionType, fields_to_check: List[str]) -> None:
def validate_field_existence(model: CollectionType, fields_to_check: List[str]) -> None:
"""
Check if the specified fields exist in a model's fields.
The function checks a list of fields and raises an InvalidField exception if any are missing.
Args:
model: The model containing the fields to check.
Expand Down Expand Up @@ -105,7 +106,7 @@ def convert_match_query(
return d


def check_field_exists(model: CollectionType, field: str) -> None:
def check_field_already_exists(model: CollectionType, field: str) -> None:
"""
Check if a field exists in the given model.
Expand All @@ -116,7 +117,10 @@ def check_field_exists(model: CollectionType, field: str) -> None:
Raises:
AlreadyExistsField: If the field already exists in the model.
"""
if model._fields.get(field): # noqa
if field in [
f.db_field if hasattr(f, "db_field") else k
for k, f in model._fields.items() # noqa
]:
raise AlreadyExistsField(field=field)


Expand All @@ -138,3 +142,29 @@ def get_db_field(model: CollectionType, field: str, add_dollar_sign=False) -> st
return f"${db_field}" if add_dollar_sign else db_field
except AttributeError:
return field


def get_nested_field_model(model: CollectionType, field: str) -> CollectionType:
"""
Retrieves the nested field model for a specified field within a given model.
This function examines the provided model to determine if the specified field is
a nested field. If it is, the function returns the nested field's model.
Otherwise, it returns the original model.
Args:
model (CollectionType): The model to be inspected. This should be a class that
represents a collection or document in a database, typically
in an ORM or ODM framework.
field (str): The name of the field within the model to inspect for nestedness.
Returns:
CollectionType: The model of the nested field if the specified field is nested;
otherwise, returns the original model.
Raises:
KeyError: If the specified field is not found in the model.
"""
if model._fields[field].__dict__.get("__module__"): # noqa
return model
return model._fields[field].__dict__["document_type_obj"] # noqa
8 changes: 6 additions & 2 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class ParameterTestCase:
),
ParameterTestCase(
compiled_query=(Aggify(PostDocument).replace_root(embedded_field="stat")),
expected_query=[{"$replaceRoot": {"$newRoot": "$stat"}}],
expected_query=[{"$replaceRoot": {"newRoot": "$stat"}}],
),
ParameterTestCase(
compiled_query=(Aggify(PostDocument).replace_with(embedded_field="stat")),
Expand Down Expand Up @@ -525,9 +525,13 @@ class ParameterTestCase:
"localField": "end",
}
},
{"$replaceRoot": {"$newRoot": "$saved_post"}},
{"$replaceRoot": {"newRoot": "$saved_post"}},
],
),
ParameterTestCase(
compiled_query=(Aggify(PostDocument).filter(stat__like_count=2)),
expected_query=[{"$match": {"stat.like_count": 2}}],
),
]


Expand Down

0 comments on commit 2a8310c

Please sign in to comment.