Skip to content

Commit

Permalink
Merge pull request #67 from mts-ai/feature/separate_methods_for_fetch…
Browse files Browse the repository at this point in the history
…ing_related_objects

create separate methods for building query for fetching related objects
  • Loading branch information
mahenzon authored Dec 19, 2023
2 parents 0a08e71 + 8e34eb7 commit 71c3294
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 6 deletions.
43 changes: 43 additions & 0 deletions fastapi_jsonapi/data_layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,49 @@ async def delete_relationship(
"""
raise NotImplementedError

def get_related_model_query_base(
self,
related_model: Type[TypeModel],
):
"""
Prepare query for the related model
:param related_model: Related ORM model class (not instance)
:return:
"""
raise NotImplementedError

def get_related_object_query(
self,
related_model: Type[TypeModel],
related_id_field: str,
id_value: str,
):
"""
Prepare query to get related object
:param related_model:
:param related_id_field:
:param id_value:
:return:
"""
raise NotImplementedError

def get_related_objects_list_query(
self,
related_model: Type[TypeModel],
related_id_field: str,
ids: list[str],
):
"""
Prepare query to get related objects list
:param related_model:
:param related_id_field:
:param ids:
:return:
"""
raise NotImplementedError

# async def get_related_object_query(self):
async def get_related_object(
self,
related_model: Type[TypeModel],
Expand Down
49 changes: 43 additions & 6 deletions fastapi_jsonapi/data_layers/sqla_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,38 @@ async def delete_relationship(
:param view_kwargs: kwargs from the resource view.
"""

def get_related_model_query_base(
self,
related_model: Type[TypeModel],
) -> "Select":
"""
:param related_model:
:return:
"""
return select(related_model)

def get_related_object_query(
self,
related_model: Type[TypeModel],
related_id_field: str,
id_value: str,
):
id_field = getattr(related_model, related_id_field)
id_value = self.prepare_id_value(id_field, id_value)
stmt: "Select" = self.get_related_model_query_base(related_model)
return stmt.where(id_field == id_value)

def get_related_objects_list_query(
self,
related_model: Type[TypeModel],
related_id_field: str,
ids: list[str],
) -> Tuple["Select", list[str]]:
id_field = getattr(related_model, related_id_field)
prepared_ids = [self.prepare_id_value(id_field, _id) for _id in ids]
stmt: "Select" = self.get_related_model_query_base(related_model)
return stmt.where(id_field.in_(prepared_ids)), prepared_ids

async def get_related_object(
self,
related_model: Type[TypeModel],
Expand All @@ -532,9 +564,12 @@ async def get_related_object(
:param id_value: related object id value
:return: a related SQLA ORM object
"""
id_field = getattr(related_model, related_id_field)
id_value = self.prepare_id_value(id_field, id_value)
stmt = select(related_model).where(id_field == id_value)
stmt = self.get_related_object_query(
related_model=related_model,
related_id_field=related_id_field,
id_value=id_value,
)

try:
related_object = (await self.session.execute(stmt)).scalar_one()
except NoResultFound:
Expand All @@ -556,9 +591,11 @@ async def get_related_objects_list(
:param ids:
:return:
"""
id_field = getattr(related_model, related_id_field)
ids = [self.prepare_id_value(id_field, _id) for _id in ids]
stmt = select(related_model).where(id_field.in_(ids))
stmt, ids = self.get_related_objects_list_query(
related_model=related_model,
related_id_field=related_id_field,
ids=ids,
)

related_objects = (await self.session.execute(stmt)).scalars().all()
object_ids = [getattr(obj, related_id_field) for obj in related_objects]
Expand Down

0 comments on commit 71c3294

Please sign in to comment.