Skip to content

Commit

Permalink
Merge pull request #70 from mts-ai/run-validator-pass-model-field
Browse files Browse the repository at this point in the history
fix run validator: sometimes it requires model field
  • Loading branch information
mahenzon authored Dec 21, 2023
2 parents 7095b40 + 13a8228 commit 6c1f9fd
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 43 deletions.
55 changes: 46 additions & 9 deletions fastapi_jsonapi/data_layers/filtering/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Helper to create sqlalchemy filters according to filter querystring parameter"""
import inspect
import logging
from typing import (
Any,
Expand Down Expand Up @@ -133,10 +134,10 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)
pydantic_types, userspace_types = self._separate_types(types)

if pydantic_types:
func = self._cast_value_with_pydantic
if isinstance(value, list):
clear_value, errors = self._cast_iterable_with_pydantic(pydantic_types, value)
else:
clear_value, errors = self._cast_value_with_pydantic(pydantic_types, value)
func = self._cast_iterable_with_pydantic
clear_value, errors = func(pydantic_types, value, schema_field)

if clear_value is None and userspace_types:
log.warning("Filtering by user type values is not properly tested yet. Use this on your own risk.")
Expand All @@ -151,7 +152,10 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)

# Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку)
if clear_value is None and not can_be_none:
raise InvalidType(detail=", ".join(errors))
raise InvalidType(
detail=", ".join(errors),
pointer=schema_field.name,
)

return getattr(model_column, self.operator)(clear_value)

Expand Down Expand Up @@ -179,32 +183,65 @@ def _separate_types(self, types: List[Type]) -> Tuple[List[Type], List[Type]]:
]
return pydantic_types, userspace_types

def _validator_requires_model_field(self, validator: Callable) -> bool:
"""
Check if validator accepts the `field` param
:param validator:
:return:
"""
signature = inspect.signature(validator)
parameters = signature.parameters

if "field" not in parameters:
return False

field_param = parameters["field"]
field_type = field_param.annotation

return field_type == "ModelField" or field_type is ModelField

def _cast_value_with_pydantic(
self,
types: List[Type],
value: Any,
schema_field: ModelField,
) -> Tuple[Optional[Any], List[str]]:
result_value, errors = None, []

for type_to_cast in types:
for validator in find_validators(type_to_cast, BaseConfig):
args = [value]
# TODO: some other way to get all the validator's dependencies?
if self._validator_requires_model_field(validator):
args.append(schema_field)
try:
result_value = validator(value)
return result_value, errors
result_value = validator(*args)
except Exception as ex:
errors.append(str(ex))
else:
return result_value, errors

return None, errors

def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple[List, List[str]]:
def _cast_iterable_with_pydantic(
self,
types: List[Type],
values: List,
schema_field: ModelField,
) -> Tuple[List, List[str]]:
type_cast_failed = False
failed_values = []

result_values: List[Any] = []
errors: List[str] = []

for value in values:
casted_value, cast_errors = self._cast_value_with_pydantic(types, value)
casted_value, cast_errors = self._cast_value_with_pydantic(
types,
value,
schema_field,
)
errors.extend(cast_errors)

if casted_value is None:
Expand All @@ -217,7 +254,7 @@ def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple

if type_cast_failed:
msg = f"Can't parse items {failed_values} of value {values}"
raise InvalidFilters(msg)
raise InvalidFilters(msg, pointer=schema_field.name)

return result_values, errors

Expand Down
6 changes: 5 additions & 1 deletion fastapi_jsonapi/exceptions/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def __init__(
parameter = parameter or self.parameter
if not errors:
if pointer:
pointer = pointer if pointer.startswith("/") else "/data/" + pointer
pointer = (
pointer
if pointer.startswith("/")
else "/data/" + (pointer if pointer == "id" else "attributes/" + pointer)
)
self.source = {"pointer": pointer}
elif parameter:
self.source = {"parameter": parameter}
Expand Down
13 changes: 13 additions & 0 deletions tests/fixtures/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tests.models import (
Child,
Computer,
CustomUUIDItem,
Parent,
ParentToChildAssociation,
Post,
Expand All @@ -30,6 +31,7 @@
ComputerInSchema,
ComputerPatchSchema,
ComputerSchema,
CustomUUIDItemSchema,
ParentPatchSchema,
ParentSchema,
ParentToChildAssociationSchema,
Expand Down Expand Up @@ -178,6 +180,17 @@ def add_routers(app_plain: FastAPI):
schema_in_post=TaskInSchema,
)

RoutersJSONAPI(
router=router,
path="/custom-uuid-item",
tags=["Custom UUID Item"],
class_detail=DetailViewBaseGeneric,
class_list=ListViewBaseGeneric,
model=CustomUUIDItem,
schema=CustomUUIDItemSchema,
resource_type="custom_uuid_item",
)

atomic = AtomicOperations()

app_plain.include_router(router, prefix="")
Expand Down
3 changes: 2 additions & 1 deletion tests/fixtures/db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,6 @@ async def async_session_plain(async_engine):

@async_fixture(scope="class")
async def async_session(async_session_plain):
async with async_session_plain() as session:
async with async_session_plain() as session: # type: AsyncSession
yield session
await session.rollback()
22 changes: 16 additions & 6 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy.orm import declared_attr, relationship
from sqlalchemy.types import CHAR, TypeDecorator

from tests.common import sqla_uri
from tests.common import is_postgres_tests, sqla_uri


class Base:
Expand Down Expand Up @@ -253,33 +253,43 @@ def load_dialect_impl(self, dialect):
return CHAR(32)

def process_bind_param(self, value, dialect):
if value is None:
return value

if not isinstance(value, UUID):
msg = f"Incorrect type got {type(value).__name__}, expected {UUID.__name__}"
raise Exception(msg)

return str(value)

def process_result_value(self, value, dialect):
return UUID(value)
return value and UUID(value)

@property
def python_type(self):
return UUID if self.as_uuid else str


db_uri = sqla_uri()
if "postgres" in db_uri:
if is_postgres_tests():
# noinspection PyPep8Naming
from sqlalchemy.dialects.postgresql import UUID as UUIDType
from sqlalchemy.dialects.postgresql.asyncpg import AsyncpgUUID as UUIDType
elif "sqlite" in db_uri:
UUIDType = CustomUUIDType
else:
msg = "unsupported dialect (custom uuid?)"
raise ValueError(msg)


class IdCast(Base):
id = Column(UUIDType, primary_key=True)
class CustomUUIDItem(Base):
__tablename__ = "custom_uuid_item"
id = Column(UUIDType(as_uuid=True), primary_key=True)

extra_id = Column(
UUIDType(as_uuid=True),
nullable=True,
unique=True,
)


class SelfRelationship(Base):
Expand Down
9 changes: 8 additions & 1 deletion tests/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,14 @@ class TaskSchema(TaskBaseSchema):
# uuid below


class IdCastSchema(BaseModel):
class CustomUUIDItemAttributesSchema(BaseModel):
extra_id: Optional[UUID] = None

class Config:
orm_mode = True


class CustomUUIDItemSchema(CustomUUIDItemAttributesSchema):
id: UUID = Field(client_can_set_id=True)


Expand Down
Loading

0 comments on commit 6c1f9fd

Please sign in to comment.