Skip to content

Commit

Permalink
chore: #120 pr
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik committed Sep 5, 2024
1 parent f56d11b commit d4ffecb
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
19 changes: 15 additions & 4 deletions fast_depends/core/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,14 @@ def build_call_model(
elif get_origin(param.annotation) is Annotated:
annotated_args = get_args(param.annotation)
type_annotation = annotated_args[0]
custom_annotations = [
arg for arg in annotated_args[1:] if isinstance(arg, CUSTOM_ANNOTATIONS)
]

custom_annotations = []
regular_annotations = []
for arg in annotated_args[1:]:
if isinstance(arg, CUSTOM_ANNOTATIONS):
custom_annotations.append(arg)
else:
regular_annotations.append(arg)

assert (
len(custom_annotations) <= 1
Expand All @@ -107,7 +112,10 @@ def build_call_model(
else: # pragma: no cover
raise AssertionError("unreachable")

annotation = type_annotation
if regular_annotations:
annotation = param.annotation
else:
annotation = type_annotation
else:
annotation = param.annotation
else:
Expand All @@ -118,6 +126,8 @@ def build_call_model(
default = ()
elif param_name == "kwargs":
default = {}
elif param.default is inspect.Parameter.empty:
default = Ellipsis
else:
default = param.default

Expand Down Expand Up @@ -186,6 +196,7 @@ def build_call_model(
class_fields.append(OptionItem(
field_name=param_name,
field_type=annotation,
default_value=default,
source=custom,
))

Expand Down
32 changes: 32 additions & 0 deletions tests/pydantic_specific/test_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Any, Dict

import pytest
from annotated_types import Ge
from typing_extensions import Annotated

from fast_depends import inject
from fast_depends.exceptions import ValidationError
from fast_depends.library import CustomField
from tests.marks import pydanticV2


class Header(CustomField):
def use(self, /, **kwargs: Any) -> Dict[str, Any]:
kwargs = super().use(**kwargs)
if v := kwargs.get("headers", {}).get(self.param_name):
kwargs[self.param_name] = v
return kwargs


@pydanticV2
def test_annotated_header_with_meta():
@inject
def sync_catch(key: Annotated[int, Header(), Ge(3)] = 3): # noqa: B008
return key

assert sync_catch(headers={"key": "4"}) == 4

assert sync_catch(headers={}) == 3

with pytest.raises(ValidationError):
sync_catch(headers={"key": "2"})

0 comments on commit d4ffecb

Please sign in to comment.