From 6f5abd121d3de11cb0634fcd50aebdc17dc54863 Mon Sep 17 00:00:00 2001 From: Pastukhov Nikita Date: Tue, 14 Nov 2023 18:19:42 +0300 Subject: [PATCH] fix (#26): respect original call locals (#27) * fix: respect original call locals * fix: collect all locals --- fast_depends/__about__.py | 2 +- fast_depends/core/build.py | 35 +++++++++++++++-------------- fast_depends/utils.py | 46 +++++++++++++++++++++++++++++--------- tests/test_locals.py | 20 +++++++++++++++++ 4 files changed, 75 insertions(+), 28 deletions(-) create mode 100644 tests/test_locals.py diff --git a/fast_depends/__about__.py b/fast_depends/__about__.py index bd612ce..f7df68c 100644 --- a/fast_depends/__about__.py +++ b/fast_depends/__about__.py @@ -1,3 +1,3 @@ """FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System""" -__version__ = "2.2.2" +__version__ = "2.2.3" diff --git a/fast_depends/core/build.py b/fast_depends/core/build.py index 54083f6..811641b 100644 --- a/fast_depends/core/build.py +++ b/fast_depends/core/build.py @@ -62,7 +62,8 @@ def build_call_model( custom_fields: Dict[str, CustomField] = {} positional_args: List[str] = [] keyword_args: List[str] = [] - for param in typed_params: + + for param_name, param in typed_params.parameters.items(): dep: Optional[Depends] = None custom: Optional[CustomField] = None @@ -78,7 +79,7 @@ def build_call_model( assert ( len(custom_annotations) <= 1 - ), f"Cannot specify multiple `Annotated` Custom arguments for `{param.name}`!" + ), f"Cannot specify multiple `Annotated` Custom arguments for `{param_name}`!" next_custom = next(iter(custom_annotations), None) if next_custom is not None: @@ -96,9 +97,9 @@ def build_call_model( annotation = param.annotation default: Any - if param.name == "args": + if param_name == "args": default = () - elif param.name == "kwargs": + elif param_name == "kwargs": default = {} else: default = param.default @@ -116,16 +117,16 @@ def build_call_model( custom = default elif default is inspect._empty: - class_fields[param.name] = (annotation, ...) + class_fields[param_name] = (annotation, ...) else: - class_fields[param.name] = (annotation, default) + class_fields[param_name] = (annotation, default) if dep: if not cast: dep.cast = False - dependencies[param.name] = build_call_model( + dependencies[param_name] = build_call_model( dep.dependency, cast=dep.cast, use_cache=dep.use_cache, @@ -133,31 +134,31 @@ def build_call_model( ) if dep.cast is True: - class_fields[param.name] = (annotation, ...) - keyword_args.append(param.name) + class_fields[param_name] = (annotation, ...) + keyword_args.append(param_name) elif custom: assert not ( is_sync and is_coroutine_callable(custom.use) ), f"You cannot use async custom field `{type(custom).__name__}` at sync `{name}`" - custom.set_param_name(param.name) - custom_fields[param.name] = custom + custom.set_param_name(param_name) + custom_fields[param_name] = custom if custom.cast is False: annotation = Any if custom.required: - class_fields[param.name] = (annotation, ...) + class_fields[param_name] = (annotation, ...) else: - class_fields[param.name] = (Optional[annotation], None) - keyword_args.append(param.name) + class_fields[param_name] = (Optional[annotation], None) + keyword_args.append(param_name) else: if param.kind is param.KEYWORD_ONLY: - keyword_args.append(param.name) - elif param.name not in ("args", "kwargs"): - positional_args.append(param.name) + keyword_args.append(param_name) + elif param_name not in ("args", "kwargs"): + positional_args.append(param_name) func_model = create_model( # type: ignore[call-overload] name, diff --git a/fast_depends/utils.py b/fast_depends/utils.py index 1f539d7..042f0e2 100644 --- a/fast_depends/utils.py +++ b/fast_depends/utils.py @@ -11,7 +11,6 @@ ContextManager, Dict, ForwardRef, - List, Tuple, Union, cast, @@ -65,27 +64,54 @@ def solve_generator_sync( return stack.enter_context(cm) -def get_typed_signature( - call: Callable[..., Any] -) -> Tuple[List[inspect.Parameter], Any]: +def get_typed_signature(call: Callable[..., Any]) -> Tuple[inspect.Signature, Any]: signature = inspect.signature(call) - globalns = getattr(call, "__globals__", {}) - return [ + locals = collect_outer_stack_locals() + + globalns = getattr(call, "__globals__", {}) + typed_params = [ inspect.Parameter( name=param.name, kind=param.kind, default=param.default, - annotation=get_typed_annotation(param.annotation, globalns), + annotation=get_typed_annotation( + param.annotation, + globalns, + locals, + ), ) for param in signature.parameters.values() - ], signature.return_annotation + ] + + return inspect.Signature(typed_params), get_typed_annotation( + signature.return_annotation, + globalns, + locals, + ) + +def collect_outer_stack_locals() -> Dict[str, Any]: + frame = inspect.currentframe() -def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: + locals = {} + while frame is not None: + if "fast_depends" not in frame.f_code.co_filename: + locals.update(frame.f_locals) + + frame = frame.f_back + + return locals + + +def get_typed_annotation( + annotation: Any, + globalns: Dict[str, Any], + locals: Dict[str, Any], +) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) - annotation = evaluate_forwardref(annotation, globalns, globalns) + annotation = evaluate_forwardref(annotation, globalns, locals) return annotation diff --git a/tests/test_locals.py b/tests/test_locals.py new file mode 100644 index 0000000..f1ef3d3 --- /dev/null +++ b/tests/test_locals.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from pydantic import BaseModel + +from fast_depends import inject + + +def wrap(func): + return inject(func) + + +def test_localns(): + class M(BaseModel): + a: str + + @wrap + def m(a: M) -> M: + return a + + m(a={"a": "Hi!"})