diff --git a/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py b/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py index 510732f3..866345c8 100644 --- a/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py +++ b/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py @@ -1,4 +1,5 @@ """Helper to create sqlalchemy filters according to filter querystring parameter""" +import inspect import logging from typing import ( Any, @@ -133,10 +134,10 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value) pydantic_types, userspace_types = self._separate_types(types) if pydantic_types: + func = self._cast_value_with_pydantic if isinstance(value, list): - clear_value, errors = self._cast_iterable_with_pydantic(pydantic_types, value) - else: - clear_value, errors = self._cast_value_with_pydantic(pydantic_types, value) + func = self._cast_iterable_with_pydantic + clear_value, errors = func(pydantic_types, value, schema_field) if clear_value is None and userspace_types: log.warning("Filtering by user type values is not properly tested yet. Use this on your own risk.") @@ -151,7 +152,10 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value) # Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку) if clear_value is None and not can_be_none: - raise InvalidType(detail=", ".join(errors)) + raise InvalidType( + detail=", ".join(errors), + pointer=schema_field.name, + ) return getattr(model_column, self.operator)(clear_value) @@ -179,24 +183,53 @@ def _separate_types(self, types: List[Type]) -> Tuple[List[Type], List[Type]]: ] return pydantic_types, userspace_types + def _validator_requires_model_field(self, validator: Callable) -> bool: + """ + Check if validator accepts the `field` param + + :param validator: + :return: + """ + signature = inspect.signature(validator) + parameters = signature.parameters + + if "field" not in parameters: + return False + + field_param = parameters["field"] + field_type = field_param.annotation + + return field_type == "ModelField" or field_type is ModelField + def _cast_value_with_pydantic( self, types: List[Type], value: Any, + schema_field: ModelField, ) -> Tuple[Optional[Any], List[str]]: result_value, errors = None, [] for type_to_cast in types: for validator in find_validators(type_to_cast, BaseConfig): + args = [value] + # TODO: some other way to get all the validator's dependencies? + if self._validator_requires_model_field(validator): + args.append(schema_field) try: - result_value = validator(value) - return result_value, errors + result_value = validator(*args) except Exception as ex: errors.append(str(ex)) + else: + return result_value, errors return None, errors - def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple[List, List[str]]: + def _cast_iterable_with_pydantic( + self, + types: List[Type], + values: List, + schema_field: ModelField, + ) -> Tuple[List, List[str]]: type_cast_failed = False failed_values = [] @@ -204,7 +237,11 @@ def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple errors: List[str] = [] for value in values: - casted_value, cast_errors = self._cast_value_with_pydantic(types, value) + casted_value, cast_errors = self._cast_value_with_pydantic( + types, + value, + schema_field, + ) errors.extend(cast_errors) if casted_value is None: @@ -217,7 +254,7 @@ def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple if type_cast_failed: msg = f"Can't parse items {failed_values} of value {values}" - raise InvalidFilters(msg) + raise InvalidFilters(msg, pointer=schema_field.name) return result_values, errors diff --git a/fastapi_jsonapi/exceptions/json_api.py b/fastapi_jsonapi/exceptions/json_api.py index 4e58770e..0fb548ba 100644 --- a/fastapi_jsonapi/exceptions/json_api.py +++ b/fastapi_jsonapi/exceptions/json_api.py @@ -53,7 +53,11 @@ def __init__( parameter = parameter or self.parameter if not errors: if pointer: - pointer = pointer if pointer.startswith("/") else "/data/" + pointer + pointer = ( + pointer + if pointer.startswith("/") + else "/data/" + (pointer if pointer == "id" else "attributes/" + pointer) + ) self.source = {"pointer": pointer} elif parameter: self.source = {"parameter": parameter} diff --git a/tests/fixtures/app.py b/tests/fixtures/app.py index 82c6f9f0..9968e6d8 100644 --- a/tests/fixtures/app.py +++ b/tests/fixtures/app.py @@ -15,6 +15,7 @@ from tests.models import ( Child, Computer, + CustomUUIDItem, Parent, ParentToChildAssociation, Post, @@ -30,6 +31,7 @@ ComputerInSchema, ComputerPatchSchema, ComputerSchema, + CustomUUIDItemSchema, ParentPatchSchema, ParentSchema, ParentToChildAssociationSchema, @@ -178,6 +180,17 @@ def add_routers(app_plain: FastAPI): schema_in_post=TaskInSchema, ) + RoutersJSONAPI( + router=router, + path="/custom-uuid-item", + tags=["Custom UUID Item"], + class_detail=DetailViewBaseGeneric, + class_list=ListViewBaseGeneric, + model=CustomUUIDItem, + schema=CustomUUIDItemSchema, + resource_type="custom_uuid_item", + ) + atomic = AtomicOperations() app_plain.include_router(router, prefix="") diff --git a/tests/models.py b/tests/models.py index f6596bd1..51b73b05 100644 --- a/tests/models.py +++ b/tests/models.py @@ -253,6 +253,9 @@ def load_dialect_impl(self, dialect): return CHAR(32) def process_bind_param(self, value, dialect): + if value is None: + return value + if not isinstance(value, UUID): msg = f"Incorrect type got {type(value).__name__}, expected {UUID.__name__}" raise Exception(msg) @@ -260,7 +263,7 @@ def process_bind_param(self, value, dialect): return str(value) def process_result_value(self, value, dialect): - return UUID(value) + return value and UUID(value) @property def python_type(self): @@ -278,9 +281,16 @@ def python_type(self): raise ValueError(msg) -class IdCast(Base): +class CustomUUIDItem(Base): + __tablename__ = "custom_uuid_item" id = Column(UUIDType, primary_key=True) + extra_id = Column( + UUIDType, + nullable=True, + unique=True, + ) + class SelfRelationship(Base): id = Column(Integer, primary_key=True) diff --git a/tests/schemas.py b/tests/schemas.py index c2553e3f..ba5824f3 100644 --- a/tests/schemas.py +++ b/tests/schemas.py @@ -389,7 +389,14 @@ class TaskSchema(TaskBaseSchema): # uuid below -class IdCastSchema(BaseModel): +class CustomUUIDItemAttributesSchema(BaseModel): + extra_id: Optional[UUID] = None + + class Config: + orm_mode = True + + +class CustomUUIDItemSchema(CustomUUIDItemAttributesSchema): id: UUID = Field(client_can_set_id=True) diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index 721e3425..af609bd9 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -3,10 +3,11 @@ from collections import defaultdict from datetime import datetime, timezone from itertools import chain, zip_longest -from json import dumps -from typing import Dict, List +from json import dumps, loads +from typing import Dict, List, Literal from uuid import UUID, uuid4 +import pytest from fastapi import FastAPI, status from httpx import AsyncClient from pydantic import BaseModel, Field @@ -22,7 +23,7 @@ from tests.models import ( Computer, ContainsTimestamp, - IdCast, + CustomUUIDItem, Post, PostComment, SelfRelationship, @@ -32,7 +33,7 @@ ) from tests.schemas import ( CustomUserAttributesSchema, - IdCastSchema, + CustomUUIDItemAttributesSchema, PostAttributesBaseSchema, PostCommentAttributesBaseSchema, SelfRelationshipSchema, @@ -1112,35 +1113,45 @@ async def test_create_id_by_client(self): "meta": None, } - async def test_create_id_by_client_uuid_type(self): - resource_type = fake.word() - app = build_app_custom( - model=IdCast, - schema=IdCastSchema, - resource_type=resource_type, - ) + async def test_create_id_by_client_uuid_type( + self, + app: FastAPI, + client: AsyncClient, + ): + """ + Create id (custom) + + also creates UUID field (just for testing) + + :param app: + :param client: + :return: + """ + resource_type = "custom_uuid_item" new_id = str(uuid4()) + create_attributes = CustomUUIDItemAttributesSchema( + extra_id=uuid4(), + ) create_body = { "data": { - "attributes": {}, + "attributes": loads(create_attributes.json()), "id": new_id, }, } - async with AsyncClient(app=app, base_url="http://test") as client: - url = app.url_path_for(f"get_{resource_type}_list") - res = await client.post(url, json=create_body) - assert res.status_code == status.HTTP_201_CREATED, res.text - assert res.json() == { - "data": { - "attributes": {}, - "id": new_id, - "type": resource_type, - }, - "jsonapi": {"version": "1.0"}, - "meta": None, - } + url = app.url_path_for(f"get_{resource_type}_list") + res = await client.post(url, json=create_body) + assert res.status_code == status.HTTP_201_CREATED, res.text + assert res.json() == { + "data": { + "attributes": loads(create_attributes.json()), + "id": new_id, + "type": resource_type, + }, + "jsonapi": {"version": "1.0"}, + "meta": None, + } async def test_create_with_relationship_to_the_same_table(self): resource_type = "self_relationship" @@ -1264,6 +1275,7 @@ class ContainsTimestampAttrsSchema(BaseModel): }, } + # noinspection PyTypeChecker stms = select(ContainsTimestamp).where(ContainsTimestamp.id == int(entity_id)) (await async_session.execute(stms)).scalar_one() @@ -2368,6 +2380,116 @@ async def test_join_by_relationships_does_not_duplicating_response_entities( "meta": {"count": 1, "totalPages": 1}, } + @pytest.mark.parametrize("filter_kind", ["small", "full"]) + async def test_filter_by_field_of_uuid_type( + self, + app: FastAPI, + client: AsyncClient, + async_session: AsyncSession, + filter_kind: Literal["small", "full"], + ): + resource_type = "custom_uuid_item" + + new_id = uuid4() + extra_id = uuid4() + item = CustomUUIDItem( + id=new_id, + extra_id=extra_id, + ) + another_item = CustomUUIDItem( + id=uuid4(), + extra_id=uuid4(), + ) + async_session.add(item) + async_session.add(another_item) + await async_session.commit() + + # + params = {} + if filter_kind == "small": + params.update( + { + "filter[extra_id]": str(extra_id), + }, + ) + else: + params.update( + { + "filter": dumps( + [ + { + "name": "extra_id", + "op": "eq", + "val": str(extra_id), + }, + ], + ), + }, + ) + + url = app.url_path_for(f"get_{resource_type}_list") + res = await client.get(url, params=params) + assert res.status_code == status.HTTP_200_OK, res.text + assert res.json() == { + "data": [ + { + "attributes": loads(CustomUUIDItemAttributesSchema.from_orm(item).json()), + "id": str(new_id), + "type": resource_type, + }, + ], + "jsonapi": {"version": "1.0"}, + "meta": {"count": 1, "totalPages": 1}, + } + + async def test_filter_invalid_uuid( + self, + app: FastAPI, + client: AsyncClient, + ): + resource_type = "custom_uuid_item" + + extra_id = str(uuid4()) + params = { + "filter[extra_id]": str(extra_id) + "z", + } + + url = app.url_path_for(f"get_{resource_type}_list") + res = await client.get(url, params=params) + assert res.status_code >= status.HTTP_400_BAD_REQUEST, res.text + + async def test_filter_none_instead_of_uuid( + self, + app: FastAPI, + client: AsyncClient, + ): + resource_type = "custom_uuid_item" + + params = { + "filter": dumps( + [ + { + "name": "id", + "op": "eq", + "val": None, + }, + ], + ), + } + url = app.url_path_for(f"get_{resource_type}_list") + res = await client.get(url, params=params) + assert res.status_code == status.HTTP_400_BAD_REQUEST, res.text + assert res.json() == { + "errors": [ + { + "detail": "The field `id` can't be null", + "source": {"parameter": "filters"}, + "status_code": status.HTTP_400_BAD_REQUEST, + "title": "Invalid filters querystring parameter.", + }, + ], + } + ASCENDING = "" DESCENDING = "-"