From d4ffecb4918f9c90c1e61c8ae2a270ac4b546124 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Thu, 5 Sep 2024 23:51:54 +0300 Subject: [PATCH] chore: #120 pr --- fast_depends/core/builder.py | 19 +++++++++++---- tests/pydantic_specific/test_custom.py | 32 ++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 tests/pydantic_specific/test_custom.py diff --git a/fast_depends/core/builder.py b/fast_depends/core/builder.py index 6e64eab..22a07e9 100644 --- a/fast_depends/core/builder.py +++ b/fast_depends/core/builder.py @@ -90,9 +90,14 @@ def build_call_model( elif get_origin(param.annotation) is Annotated: annotated_args = get_args(param.annotation) type_annotation = annotated_args[0] - custom_annotations = [ - arg for arg in annotated_args[1:] if isinstance(arg, CUSTOM_ANNOTATIONS) - ] + + custom_annotations = [] + regular_annotations = [] + for arg in annotated_args[1:]: + if isinstance(arg, CUSTOM_ANNOTATIONS): + custom_annotations.append(arg) + else: + regular_annotations.append(arg) assert ( len(custom_annotations) <= 1 @@ -107,7 +112,10 @@ def build_call_model( else: # pragma: no cover raise AssertionError("unreachable") - annotation = type_annotation + if regular_annotations: + annotation = param.annotation + else: + annotation = type_annotation else: annotation = param.annotation else: @@ -118,6 +126,8 @@ def build_call_model( default = () elif param_name == "kwargs": default = {} + elif param.default is inspect.Parameter.empty: + default = Ellipsis else: default = param.default @@ -186,6 +196,7 @@ def build_call_model( class_fields.append(OptionItem( field_name=param_name, field_type=annotation, + default_value=default, source=custom, )) diff --git a/tests/pydantic_specific/test_custom.py b/tests/pydantic_specific/test_custom.py new file mode 100644 index 0000000..b92c157 --- /dev/null +++ b/tests/pydantic_specific/test_custom.py @@ -0,0 +1,32 @@ +from typing import Any, Dict + +import pytest +from annotated_types import Ge +from typing_extensions import Annotated + +from fast_depends import inject +from fast_depends.exceptions import ValidationError +from fast_depends.library import CustomField +from tests.marks import pydanticV2 + + +class Header(CustomField): + def use(self, /, **kwargs: Any) -> Dict[str, Any]: + kwargs = super().use(**kwargs) + if v := kwargs.get("headers", {}).get(self.param_name): + kwargs[self.param_name] = v + return kwargs + + +@pydanticV2 +def test_annotated_header_with_meta(): + @inject + def sync_catch(key: Annotated[int, Header(), Ge(3)] = 3): # noqa: B008 + return key + + assert sync_catch(headers={"key": "4"}) == 4 + + assert sync_catch(headers={}) == 3 + + with pytest.raises(ValidationError): + sync_catch(headers={"key": "2"})