From 2f2df02d79758f4e51994b2af41ab158c1c08eb4 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Wed, 29 Nov 2023 00:41:56 +1000 Subject: [PATCH] fix: DTO factory narrowed with a generic alias. This PR is an attempt at handling DTOs that are narrowed with a `_GenericAlias` of a type supported by the DTO factory type. Closes #2500 --- litestar/dto/_backend.py | 3 +- litestar/dto/base_dto.py | 7 +++- litestar/dto/dataclass_dto.py | 7 +++- litestar/typing.py | 29 ++++++++++++++- litestar/utils/typing.py | 4 +- .../test_dto/test_factory/test_base_dto.py | 37 ++++++++++++++++--- tests/unit/test_dto/test_integration.py | 27 +++++++++++++- 7 files changed, 100 insertions(+), 14 deletions(-) diff --git a/litestar/dto/_backend.py b/litestar/dto/_backend.py index aae962f56e..db2a10981a 100644 --- a/litestar/dto/_backend.py +++ b/litestar/dto/_backend.py @@ -18,6 +18,7 @@ ) from msgspec import UNSET, Struct, UnsetType, convert, defstruct, field +from typing_extensions import get_origin from litestar.dto._types import ( CollectionType, @@ -110,7 +111,7 @@ def __init__( rename_fields=self.dto_factory.config.rename_fields, ) self.transfer_model_type = self.create_transfer_model_type( - model_name=model_type.__name__, field_definitions=self.parsed_field_definitions + model_name=(get_origin(model_type) or model_type).__name__, field_definitions=self.parsed_field_definitions ) self.dto_data_type: type[DTOData] | None = None diff --git a/litestar/dto/base_dto.py b/litestar/dto/base_dto.py index 6e60b6c252..f43ef8bc00 100644 --- a/litestar/dto/base_dto.py +++ b/litestar/dto/base_dto.py @@ -5,7 +5,7 @@ from inspect import getmodule from typing import TYPE_CHECKING, Collection, Generic, TypeVar -from typing_extensions import NotRequired, TypedDict, get_type_hints +from typing_extensions import NotRequired, TypedDict from litestar.dto._backend import DTOBackend from litestar.dto._codegen_backend import DTOCodegenBackend @@ -17,6 +17,7 @@ from litestar.types.builtin_types import NoneType from litestar.types.composite_types import TypeEncodersMap from litestar.typing import FieldDefinition +from litestar.utils.typing import get_type_hints_with_generics_resolved if TYPE_CHECKING: from typing import Any, ClassVar, Generator @@ -267,7 +268,9 @@ def get_model_type_hints( return { k: FieldDefinition.from_kwarg(annotation=v, name=k) - for k, v in get_type_hints(model_type, localns=namespace, include_extras=True).items() + for k, v in get_type_hints_with_generics_resolved( + model_type, localns=namespace, include_extras=True + ).items() } @staticmethod diff --git a/litestar/dto/dataclass_dto.py b/litestar/dto/dataclass_dto.py index 554b0f3343..03863ea8af 100644 --- a/litestar/dto/dataclass_dto.py +++ b/litestar/dto/dataclass_dto.py @@ -3,6 +3,8 @@ from dataclasses import MISSING, fields, replace from typing import TYPE_CHECKING, Generic, TypeVar +from typing_extensions import get_origin + from litestar.dto.base_dto import AbstractDTO from litestar.dto.data_structures import DTOFieldDefinition from litestar.dto.field import DTO_FIELD_META_KEY, DTOField @@ -29,7 +31,8 @@ class DataclassDTO(AbstractDTO[T], Generic[T]): def generate_field_definitions( cls, model_type: type[DataclassProtocol] ) -> Generator[DTOFieldDefinition, None, None]: - dc_fields = {f.name: f for f in fields(model_type)} + model_origin = get_origin(model_type) or model_type + dc_fields = {f.name: f for f in fields(model_origin)} for key, field_definition in cls.get_model_type_hints(model_type).items(): if not (dc_field := dc_fields.get(key)): continue @@ -41,7 +44,7 @@ def generate_field_definitions( field_definition=field_definition, default_factory=default_factory, dto_field=dc_field.metadata.get(DTO_FIELD_META_KEY, DTOField()), - model_name=model_type.__name__, + model_name=model_origin.__name__, ), name=key, default=default, diff --git a/litestar/typing.py b/litestar/typing.py index 1048296592..af4db31091 100644 --- a/litestar/typing.py +++ b/litestar/typing.py @@ -4,7 +4,21 @@ from copy import deepcopy from dataclasses import dataclass, is_dataclass, replace from inspect import Parameter, Signature -from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, Protocol, Sequence, TypeVar, cast +from typing import ( # type: ignore[attr-defined] + Any, + AnyStr, + Callable, + ClassVar, + Collection, + ForwardRef, + Literal, + Mapping, + Protocol, + Sequence, + TypeVar, + _GenericAlias, # pyright: ignore + cast, +) from msgspec import UnsetType from typing_extensions import NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict @@ -442,6 +456,19 @@ def is_subclass_of(self, cl: type[Any] | tuple[type[Any], ...]) -> bool: if self.origin in UnionTypes: return all(t.is_subclass_of(cl) for t in self.inner_types) + if isinstance(self.annotation, _GenericAlias) and self.origin not in (ClassVar, Literal): + cl_args = get_args(cl) + cl_origin = get_origin(cl) or cl + return ( + issubclass(self.origin, cl_origin) + and (len(cl_args) == len(self.args) if cl_args else True) + and ( + all(t.is_subclass_of(cl_arg) for t, cl_arg in zip(self.inner_types, cl_args)) + if cl_args + else True + ) + ) + return self.origin not in UnionTypes and is_class_and_subclass(self.origin, cl) if self.annotation is AnyStr: diff --git a/litestar/utils/typing.py b/litestar/utils/typing.py index a99a623b2c..fcad379e48 100644 --- a/litestar/utils/typing.py +++ b/litestar/utils/typing.py @@ -249,7 +249,9 @@ def get_type_hints_with_generics_resolved( if origin is None: # Implies the generic types have not been specified in the annotation type_hints = get_type_hints(annotation, globalns=globalns, localns=localns, include_extras=include_extras) - typevar_map = {p: p for p in annotation.__parameters__} + if not (parameters := getattr(annotation, "__parameters__", None)): + return type_hints + typevar_map = {p: p for p in parameters} else: type_hints = get_type_hints(origin, globalns=globalns, localns=localns, include_extras=include_extras) # the __parameters__ is only available on the origin itself and not the annotation diff --git a/tests/unit/test_dto/test_factory/test_base_dto.py b/tests/unit/test_dto/test_factory/test_base_dto.py index c8ada7d2c3..65012087df 100644 --- a/tests/unit/test_dto/test_factory/test_base_dto.py +++ b/tests/unit/test_dto/test_factory/test_base_dto.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Tuple, TypeVar, Union +from typing import TYPE_CHECKING, Generic, Tuple, TypeVar, Union import pytest from typing_extensions import Annotated @@ -10,6 +10,7 @@ from litestar import Request from litestar.dto import DataclassDTO, DTOConfig from litestar.exceptions.dto_exceptions import InvalidAnnotationException +from litestar.types.empty import Empty from litestar.typing import FieldDefinition from . import Model @@ -19,7 +20,8 @@ from litestar.dto._backend import DTOBackend -T = TypeVar("T", bound=Model) +T = TypeVar("T") +ModelT = TypeVar("ModelT", bound=Model) def get_backend(dto_type: type[DataclassDTO[Any]]) -> DTOBackend: @@ -77,7 +79,7 @@ def test_extra_annotated_metadata_ignored() -> None: def test_overwrite_config() -> None: first = DTOConfig(exclude={"a"}) - generic_dto = DataclassDTO[Annotated[T, first]] # pyright: ignore + generic_dto = DataclassDTO[Annotated[ModelT, first]] # pyright: ignore second = DTOConfig(exclude={"b"}) dto = generic_dto[Annotated[Model, second]] # pyright: ignore assert dto.config is second @@ -86,13 +88,13 @@ def test_overwrite_config() -> None: def test_existing_config_not_overwritten() -> None: assert getattr(DataclassDTO, "_config", None) is None first = DTOConfig(exclude={"a"}) - generic_dto = DataclassDTO[Annotated[T, first]] # pyright: ignore + generic_dto = DataclassDTO[Annotated[ModelT, first]] # pyright: ignore dto = generic_dto[Model] # pyright: ignore assert dto.config is first def test_config_assigned_via_subclassing() -> None: - class CustomGenericDTO(DataclassDTO[T]): + class CustomGenericDTO(DataclassDTO[ModelT]): config = DTOConfig(exclude={"a"}) concrete_dto = CustomGenericDTO[Model] @@ -161,3 +163,28 @@ class SubType(Model): assert ( dto_type._dto_backends["handler_id"]["data_backend"].parsed_field_definitions[-1].name == "c" # pyright: ignore ) + + +def test_type_narrowing_with_generic_type() -> None: + @dataclass + class Foo(Generic[T]): + foo: T + + hints = DataclassDTO.get_model_type_hints(Foo[int]) + assert hints == { + "foo": FieldDefinition( + raw=int, + annotation=int, + type_wrappers=(), + origin=None, + args=(), + metadata=(), + instantiable_origin=None, + safe_generic_origin=None, + inner_types=(), + default=Empty, + extra={}, + kwarg_definition=None, + name="foo", + ) + } diff --git a/tests/unit/test_dto/test_integration.py b/tests/unit/test_dto/test_integration.py index 6e90055e78..f8ba6d07f1 100644 --- a/tests/unit/test_dto/test_integration.py +++ b/tests/unit/test_dto/test_integration.py @@ -1,19 +1,22 @@ from __future__ import annotations -from typing import Dict +from dataclasses import dataclass +from typing import Dict, Generic, TypeVar from unittest.mock import MagicMock import pytest from litestar import Controller, Litestar, Router, post from litestar.config.app import ExperimentalFeatures -from litestar.dto import AbstractDTO, DTOConfig +from litestar.dto import AbstractDTO, DataclassDTO, DTOConfig, DTOData from litestar.dto._backend import DTOBackend from litestar.dto._codegen_backend import DTOCodegenBackend from litestar.testing import create_test_client from . import Model +T = TypeVar("T") + @pytest.fixture() def experimental_features(use_experimental_dto_backend: bool) -> list[ExperimentalFeatures] | None: @@ -153,3 +156,23 @@ def handler(data: Model) -> Model: backend = handler.resolve_data_dto()._dto_backends[handler.handler_id]["data_backend"] # type: ignore[union-attr] assert isinstance(backend, DTOBackend) + + +def test_dto_for_generic_model() -> None: + @dataclass + class Foo(Generic[T]): + foo: T + + FooDTO = DataclassDTO[Foo[int]] + + @post("/foo", dto=FooDTO, signature_types=[Foo]) + async def foo_handler(data: DTOData[Foo[int]]) -> Foo[int]: + return data.create_instance() + + with create_test_client(route_handlers=foo_handler) as client: + response = client.post("/foo", json={"foo": 1}) + assert response.status_code == 201 + assert response.json() == {"foo": 1} + response = client.post("/foo", json={"foo": "1"}) + assert response.status_code == 400 + assert response.json() == {"status_code": 400, "detail": "Expected `int`, got `str` - at `$.foo`"}