Skip to content

Commit

Permalink
added return_fields function, attempting to optionally limit fields r…
Browse files Browse the repository at this point in the history
…eturned by find
  • Loading branch information
savynorem committed Jul 5, 2024
1 parent 44cbeaf commit 2df1beb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
26 changes: 25 additions & 1 deletion aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def __init__(
limit: Optional[int] = None,
page_size: int = DEFAULT_PAGE_SIZE,
sort_fields: Optional[List[str]] = None,
return_fields: Optional[List[str]] = None,
nocontent: bool = False,
):
if not has_redisearch(model.db()):
Expand All @@ -445,6 +446,11 @@ def __init__(
else:
self.sort_fields = []

if return_fields:
self.return_fields = self.validate_return_fields(return_fields)
else:
self.return_fields = []

self._expression = None
self._query: Optional[str] = None
self._pagination: List[str] = []
Expand Down Expand Up @@ -502,8 +508,19 @@ def query(self):
if self._query.startswith("(") or self._query == "*"
else f"({self._query})"
) + f"=>[{self.knn}]"
if self.return_fields:
self._query += f" RETURN {','.join(self.return_fields)}"
return self._query

def validate_return_fields(self, return_fields: List[str]):
for field in return_fields:
if field not in self.model.__fields__: # type: ignore
raise QueryNotSupportedError(
f"You tried to return the field {field}, but that field "
f"does not exist on the model {self.model}"
)
return return_fields

@property
def query_params(self):
params: List[Union[str, bytes]] = []
Expand Down Expand Up @@ -956,6 +973,11 @@ def sort_by(self, *fields: str):
if not fields:
return self
return self.copy(sort_fields=list(fields))

def return_fields(self, *fields: str):
if not fields:
return self
return self.copy(return_fields=list(fields))

async def update(self, use_transaction=True, **field_values):
"""
Expand Down Expand Up @@ -1531,7 +1553,9 @@ def find(
*expressions: Union[Any, Expression],
knn: Optional[KNNExpression] = None,
) -> FindQuery:
return FindQuery(expressions=expressions, knn=knn, model=cls)
return FindQuery(
expressions=expressions, knn=knn, model=cls
)

@classmethod
def from_redis(cls, res: Any):
Expand Down
14 changes: 13 additions & 1 deletion tests/test_json_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,9 +935,21 @@ class TypeWithUuid(JsonModel):

await item.save()

@py_test_mark_asyncio
async def test_return_specified_fields(members, m):
member1, member2, member3 = members
actual = await m.Member.find(
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
| (m.Member.last_name == "Smith")
).all()
assert actual == [
{"first_name": "Andrew", "last_name": "Brookins"},
{"first_name": "Andrew", "last_name": "Smith"},
]


@py_test_mark_asyncio
async def test_xfix_queries(m):
async def test_xfix_queries(m):4
await m.Member(
first_name="Steve",
last_name="Lorello",
Expand Down

0 comments on commit 2df1beb

Please sign in to comment.