Skip to content

Commit

Permalink
feat: add typevar expansion (#3242)
Browse files Browse the repository at this point in the history
* feat: add typevar expansion #3240

* chore: resolve all PR suggestion #3242

* chore: resolve import formatting

* chore: resolve import formatting
  • Loading branch information
harryle95 authored Mar 23, 2024
1 parent 4b79615 commit 96c59fe
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 3 deletions.
5 changes: 3 additions & 2 deletions litestar/utils/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from litestar.exceptions import ImproperlyConfiguredException
from litestar.types import Empty
from litestar.typing import FieldDefinition
from litestar.utils.typing import unwrap_annotation
from litestar.utils.typing import expand_type_var_in_type_hint, unwrap_annotation

if TYPE_CHECKING:
from typing import Sequence
Expand Down Expand Up @@ -212,8 +212,9 @@ def from_fn(cls, fn: AnyCallable, signature_namespace: dict[str, Any]) -> Self:
"""
signature = Signature.from_callable(fn)
fn_type_hints = get_fn_type_hints(fn, namespace=signature_namespace)
expanded_type_hints = expand_type_var_in_type_hint(fn_type_hints, signature_namespace)

return cls.from_signature(signature, fn_type_hints)
return cls.from_signature(signature, expanded_type_hints)

@classmethod
def from_signature(cls, signature: Signature, fn_type_hints: dict[str, type]) -> Self:
Expand Down
15 changes: 15 additions & 0 deletions litestar/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,21 @@ def get_type_hints_with_generics_resolved(
return {n: _substitute_typevars(type_, typevar_map) for n, type_ in type_hints.items()}


def expand_type_var_in_type_hint(type_hint: dict[str, Any], namespace: dict[str, Any] | None) -> dict[str, Any]:
"""Expand TypeVar for any parameters in type_hint
Args:
type_hint: mapping of parameter to type obtained from calling `get_type_hints` or `get_fn_type_hints`
namespace: mapping of TypeVar to concrete type
Returns:
type_hint with any TypeVar parameter expanded
"""
if namespace:
return {name: _substitute_typevars(hint, namespace) for name, hint in type_hint.items()}
return type_hint


def _substitute_typevars(obj: Any, typevar_map: Mapping[Any, Any]) -> Any:
if params := getattr(obj, "__parameters__", None):
args = tuple(_substitute_typevars(typevar_map.get(p, p), typevar_map) for p in params)
Expand Down
62 changes: 61 additions & 1 deletion tests/unit/test_utils/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import inspect
from inspect import Parameter
from types import ModuleType
from typing import Any, Callable, List, Optional, TypeVar, Union
from typing import Any, Callable, Generic, List, Optional, TypeVar, Union

import pytest
from typing_extensions import Annotated, NotRequired, Required, TypedDict, get_args, get_type_hints

from litestar import Controller, Router, post
from litestar.exceptions import ImproperlyConfiguredException
from litestar.file_system import BaseLocalFileSystem
from litestar.static_files import StaticFiles
Expand All @@ -20,6 +21,10 @@
from litestar.utils.signature import ParsedSignature, add_types_to_signature_namespace, get_fn_type_hints

T = TypeVar("T")
U = TypeVar("U")


class ConcreteT: ...


def test_get_fn_type_hints_asgi_app() -> None:
Expand Down Expand Up @@ -161,3 +166,58 @@ def test_add_types_to_signature_namespace_with_existing_types_raises() -> None:
"""Test add_types_to_signature_namespace with existing types raises."""
with pytest.raises(ImproperlyConfiguredException):
add_types_to_signature_namespace([int], {"int": int})


@pytest.mark.parametrize(
("namespace", "expected"),
(
({T: int}, {"data": int, "return": int}),
({}, {"data": T, "return": T}),
({T: ConcreteT}, {"data": ConcreteT, "return": ConcreteT}),
),
)
def test_using_generics_in_fn_annotations(namespace: dict[str, Any], expected: dict[str, Any]) -> None:
@post(signature_namespace=namespace)
def create_item(data: T) -> T:
return data

signature = create_item.parsed_fn_signature
actual = {"data": signature.parameters["data"].annotation, "return": signature.return_type.annotation}
assert actual == expected


class GenericController(Controller, Generic[T]):
model_class: T

def __class_getitem__(cls, model_class: type) -> type:
cls_dict = {"model_class": model_class}
return type(f"GenericController[{model_class.__name__}", (cls,), cls_dict)

def __init__(self, owner: Router) -> None:
super().__init__(owner)
self.signature_namespace[T] = self.model_class # type: ignore[misc]


class BaseController(GenericController[T]):
@post()
async def create(self, data: T) -> T:
return data


@pytest.mark.parametrize(
("annotation_type", "expected"),
(
(int, {"data": int, "return": int}),
(float, {"data": float, "return": float}),
(ConcreteT, {"data": ConcreteT, "return": ConcreteT}),
),
)
def test_using_generics_in_controller_annotations(annotation_type: type, expected: dict[str, Any]) -> None:
class ConcreteController(BaseController[annotation_type]): # type: ignore[valid-type]
path = "/"

controller_object = ConcreteController(owner=None) # type: ignore[arg-type]

signature = controller_object.get_route_handlers()[0].parsed_fn_signature
actual = {"data": signature.parameters["data"].annotation, "return": signature.return_type.annotation}
assert actual == expected
25 changes: 25 additions & 0 deletions tests/unit/test_utils/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import Annotated

from litestar.utils.typing import (
expand_type_var_in_type_hint,
get_origin_or_inner_type,
get_type_hints_with_generics_resolved,
make_non_optional_union,
Expand Down Expand Up @@ -134,3 +135,27 @@ class NestedFoo(Generic[T]):
)
def test_get_type_hints_with_generics(annotation: Any, expected_type_hints: dict[str, Any]) -> None:
assert get_type_hints_with_generics_resolved(annotation, include_extras=True) == expected_type_hints


class ConcreteT: ...


@pytest.mark.parametrize(
("type_hint", "namespace", "expected"),
(
({"arg1": T, "return": int}, {}, {"arg1": T, "return": int}),
({"arg1": T, "return": int}, None, {"arg1": T, "return": int}),
({"arg1": T, "return": int}, {U: ConcreteT}, {"arg1": T, "return": int}),
({"arg1": T, "return": int}, {T: ConcreteT}, {"arg1": ConcreteT, "return": int}),
({"arg1": T, "return": int}, {T: int}, {"arg1": int, "return": int}),
({"arg1": int, "return": int}, {}, {"arg1": int, "return": int}),
({"arg1": int, "return": int}, None, {"arg1": int, "return": int}),
({"arg1": int, "return": int}, {T: int}, {"arg1": int, "return": int}),
({"arg1": T, "return": T}, {T: ConcreteT}, {"arg1": ConcreteT, "return": ConcreteT}),
({"arg1": T, "return": T}, {T: int}, {"arg1": int, "return": int}),
),
)
def test_expand_type_var_in_type_hints(
type_hint: dict[str, Any], namespace: dict[str, Any] | None, expected: dict[str, Any]
) -> None:
assert expand_type_var_in_type_hint(type_hint, namespace) == expected

0 comments on commit 96c59fe

Please sign in to comment.