diff --git a/fastapi_jsonapi/data_layers/base.py b/fastapi_jsonapi/data_layers/base.py index 8d90df05..b6fe10c1 100644 --- a/fastapi_jsonapi/data_layers/base.py +++ b/fastapi_jsonapi/data_layers/base.py @@ -229,6 +229,49 @@ async def delete_relationship( """ raise NotImplementedError + def get_related_model_query_base( + self, + related_model: Type[TypeModel], + ): + """ + Prepare query for the related model + + :param related_model: Related ORM model class (not instance) + :return: + """ + raise NotImplementedError + + def get_related_object_query( + self, + related_model: Type[TypeModel], + related_id_field: str, + id_value: str, + ): + """ + Prepare query to get related object + :param related_model: + :param related_id_field: + :param id_value: + :return: + """ + raise NotImplementedError + + def get_related_objects_list_query( + self, + related_model: Type[TypeModel], + related_id_field: str, + ids: list[str], + ): + """ + Prepare query to get related objects list + :param related_model: + :param related_id_field: + :param ids: + :return: + """ + raise NotImplementedError + + # async def get_related_object_query(self): async def get_related_object( self, related_model: Type[TypeModel], diff --git a/fastapi_jsonapi/data_layers/sqla_orm.py b/fastapi_jsonapi/data_layers/sqla_orm.py index 61d17491..67a2cb4b 100644 --- a/fastapi_jsonapi/data_layers/sqla_orm.py +++ b/fastapi_jsonapi/data_layers/sqla_orm.py @@ -518,6 +518,38 @@ async def delete_relationship( :param view_kwargs: kwargs from the resource view. """ + def get_related_model_query_base( + self, + related_model: Type[TypeModel], + ) -> "Select": + """ + :param related_model: + :return: + """ + return select(related_model) + + def get_related_object_query( + self, + related_model: Type[TypeModel], + related_id_field: str, + id_value: str, + ): + id_field = getattr(related_model, related_id_field) + id_value = self.prepare_id_value(id_field, id_value) + stmt: "Select" = self.get_related_model_query_base(related_model) + return stmt.where(id_field == id_value) + + def get_related_objects_list_query( + self, + related_model: Type[TypeModel], + related_id_field: str, + ids: list[str], + ) -> Tuple["Select", list[str]]: + id_field = getattr(related_model, related_id_field) + prepared_ids = [self.prepare_id_value(id_field, _id) for _id in ids] + stmt: "Select" = self.get_related_model_query_base(related_model) + return stmt.where(id_field.in_(prepared_ids)), prepared_ids + async def get_related_object( self, related_model: Type[TypeModel], @@ -532,9 +564,12 @@ async def get_related_object( :param id_value: related object id value :return: a related SQLA ORM object """ - id_field = getattr(related_model, related_id_field) - id_value = self.prepare_id_value(id_field, id_value) - stmt = select(related_model).where(id_field == id_value) + stmt = self.get_related_object_query( + related_model=related_model, + related_id_field=related_id_field, + id_value=id_value, + ) + try: related_object = (await self.session.execute(stmt)).scalar_one() except NoResultFound: @@ -556,9 +591,11 @@ async def get_related_objects_list( :param ids: :return: """ - id_field = getattr(related_model, related_id_field) - ids = [self.prepare_id_value(id_field, _id) for _id in ids] - stmt = select(related_model).where(id_field.in_(ids)) + stmt, ids = self.get_related_objects_list_query( + related_model=related_model, + related_id_field=related_id_field, + ids=ids, + ) related_objects = (await self.session.execute(stmt)).scalars().all() object_ids = [getattr(obj, related_id_field) for obj in related_objects]