diff --git a/litestar/utils/signature.py b/litestar/utils/signature.py index eb585990e0..c387b7f93b 100644 --- a/litestar/utils/signature.py +++ b/litestar/utils/signature.py @@ -14,7 +14,7 @@ from litestar.exceptions import ImproperlyConfiguredException from litestar.types import Empty from litestar.typing import FieldDefinition -from litestar.utils.typing import unwrap_annotation +from litestar.utils.typing import expand_type_var_in_type_hint, unwrap_annotation if TYPE_CHECKING: from typing import Sequence @@ -212,8 +212,9 @@ def from_fn(cls, fn: AnyCallable, signature_namespace: dict[str, Any]) -> Self: """ signature = Signature.from_callable(fn) fn_type_hints = get_fn_type_hints(fn, namespace=signature_namespace) + expanded_type_hints = expand_type_var_in_type_hint(fn_type_hints, signature_namespace) - return cls.from_signature(signature, fn_type_hints) + return cls.from_signature(signature, expanded_type_hints) @classmethod def from_signature(cls, signature: Signature, fn_type_hints: dict[str, type]) -> Self: diff --git a/litestar/utils/typing.py b/litestar/utils/typing.py index 9da6c2a6f6..cae445a90e 100644 --- a/litestar/utils/typing.py +++ b/litestar/utils/typing.py @@ -262,6 +262,21 @@ def get_type_hints_with_generics_resolved( return {n: _substitute_typevars(type_, typevar_map) for n, type_ in type_hints.items()} +def expand_type_var_in_type_hint(type_hint: dict[str, Any], namespace: dict[str, Any] | None) -> dict[str, Any]: + """Expand TypeVar for any parameters in type_hint + + Args: + type_hint: mapping of parameter to type obtained from calling `get_type_hints` or `get_fn_type_hints` + namespace: mapping of TypeVar to concrete type + + Returns: + type_hint with any TypeVar parameter expanded + """ + if namespace: + return {name: _substitute_typevars(hint, namespace) for name, hint in type_hint.items()} + return type_hint + + def _substitute_typevars(obj: Any, typevar_map: Mapping[Any, Any]) -> Any: if params := getattr(obj, "__parameters__", None): args = tuple(_substitute_typevars(typevar_map.get(p, p), typevar_map) for p in params) diff --git a/tests/unit/test_utils/test_signature.py b/tests/unit/test_utils/test_signature.py index 1f9767b8c2..8f0de7c38d 100644 --- a/tests/unit/test_utils/test_signature.py +++ b/tests/unit/test_utils/test_signature.py @@ -5,11 +5,12 @@ import inspect from inspect import Parameter from types import ModuleType -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, Generic, List, Optional, TypeVar, Union import pytest from typing_extensions import Annotated, NotRequired, Required, TypedDict, get_args, get_type_hints +from litestar import Controller, Router, post from litestar.exceptions import ImproperlyConfiguredException from litestar.file_system import BaseLocalFileSystem from litestar.static_files import StaticFiles @@ -20,6 +21,10 @@ from litestar.utils.signature import ParsedSignature, add_types_to_signature_namespace, get_fn_type_hints T = TypeVar("T") +U = TypeVar("U") + + +class ConcreteT: ... def test_get_fn_type_hints_asgi_app() -> None: @@ -161,3 +166,58 @@ def test_add_types_to_signature_namespace_with_existing_types_raises() -> None: """Test add_types_to_signature_namespace with existing types raises.""" with pytest.raises(ImproperlyConfiguredException): add_types_to_signature_namespace([int], {"int": int}) + + +@pytest.mark.parametrize( + ("namespace", "expected"), + ( + ({T: int}, {"data": int, "return": int}), + ({}, {"data": T, "return": T}), + ({T: ConcreteT}, {"data": ConcreteT, "return": ConcreteT}), + ), +) +def test_using_generics_in_fn_annotations(namespace: dict[str, Any], expected: dict[str, Any]) -> None: + @post(signature_namespace=namespace) + def create_item(data: T) -> T: + return data + + signature = create_item.parsed_fn_signature + actual = {"data": signature.parameters["data"].annotation, "return": signature.return_type.annotation} + assert actual == expected + + +class GenericController(Controller, Generic[T]): + model_class: T + + def __class_getitem__(cls, model_class: type) -> type: + cls_dict = {"model_class": model_class} + return type(f"GenericController[{model_class.__name__}", (cls,), cls_dict) + + def __init__(self, owner: Router) -> None: + super().__init__(owner) + self.signature_namespace[T] = self.model_class # type: ignore[misc] + + +class BaseController(GenericController[T]): + @post() + async def create(self, data: T) -> T: + return data + + +@pytest.mark.parametrize( + ("annotation_type", "expected"), + ( + (int, {"data": int, "return": int}), + (float, {"data": float, "return": float}), + (ConcreteT, {"data": ConcreteT, "return": ConcreteT}), + ), +) +def test_using_generics_in_controller_annotations(annotation_type: type, expected: dict[str, Any]) -> None: + class ConcreteController(BaseController[annotation_type]): # type: ignore[valid-type] + path = "/" + + controller_object = ConcreteController(owner=None) # type: ignore[arg-type] + + signature = controller_object.get_route_handlers()[0].parsed_fn_signature + actual = {"data": signature.parameters["data"].annotation, "return": signature.return_type.annotation} + assert actual == expected diff --git a/tests/unit/test_utils/test_typing.py b/tests/unit/test_utils/test_typing.py index 38f4a44174..5b9fd95b3f 100644 --- a/tests/unit/test_utils/test_typing.py +++ b/tests/unit/test_utils/test_typing.py @@ -9,6 +9,7 @@ from typing_extensions import Annotated from litestar.utils.typing import ( + expand_type_var_in_type_hint, get_origin_or_inner_type, get_type_hints_with_generics_resolved, make_non_optional_union, @@ -134,3 +135,27 @@ class NestedFoo(Generic[T]): ) def test_get_type_hints_with_generics(annotation: Any, expected_type_hints: dict[str, Any]) -> None: assert get_type_hints_with_generics_resolved(annotation, include_extras=True) == expected_type_hints + + +class ConcreteT: ... + + +@pytest.mark.parametrize( + ("type_hint", "namespace", "expected"), + ( + ({"arg1": T, "return": int}, {}, {"arg1": T, "return": int}), + ({"arg1": T, "return": int}, None, {"arg1": T, "return": int}), + ({"arg1": T, "return": int}, {U: ConcreteT}, {"arg1": T, "return": int}), + ({"arg1": T, "return": int}, {T: ConcreteT}, {"arg1": ConcreteT, "return": int}), + ({"arg1": T, "return": int}, {T: int}, {"arg1": int, "return": int}), + ({"arg1": int, "return": int}, {}, {"arg1": int, "return": int}), + ({"arg1": int, "return": int}, None, {"arg1": int, "return": int}), + ({"arg1": int, "return": int}, {T: int}, {"arg1": int, "return": int}), + ({"arg1": T, "return": T}, {T: ConcreteT}, {"arg1": ConcreteT, "return": ConcreteT}), + ({"arg1": T, "return": T}, {T: int}, {"arg1": int, "return": int}), + ), +) +def test_expand_type_var_in_type_hints( + type_hint: dict[str, Any], namespace: dict[str, Any] | None, expected: dict[str, Any] +) -> None: + assert expand_type_var_in_type_hint(type_hint, namespace) == expected