Skip to content

Commit

Permalink
fix (#26): respect original call locals (#27)
Browse files Browse the repository at this point in the history
* fix: respect original call locals

* fix: collect all locals
  • Loading branch information
Lancetnik authored Nov 14, 2023
1 parent b60b812 commit 6f5abd1
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 28 deletions.
2 changes: 1 addition & 1 deletion fast_depends/__about__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System"""

__version__ = "2.2.2"
__version__ = "2.2.3"
35 changes: 18 additions & 17 deletions fast_depends/core/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def build_call_model(
custom_fields: Dict[str, CustomField] = {}
positional_args: List[str] = []
keyword_args: List[str] = []
for param in typed_params:

for param_name, param in typed_params.parameters.items():
dep: Optional[Depends] = None
custom: Optional[CustomField] = None

Expand All @@ -78,7 +79,7 @@ def build_call_model(

assert (
len(custom_annotations) <= 1
), f"Cannot specify multiple `Annotated` Custom arguments for `{param.name}`!"
), f"Cannot specify multiple `Annotated` Custom arguments for `{param_name}`!"

next_custom = next(iter(custom_annotations), None)
if next_custom is not None:
Expand All @@ -96,9 +97,9 @@ def build_call_model(
annotation = param.annotation

default: Any
if param.name == "args":
if param_name == "args":
default = ()
elif param.name == "kwargs":
elif param_name == "kwargs":
default = {}
else:
default = param.default
Expand All @@ -116,48 +117,48 @@ def build_call_model(
custom = default

elif default is inspect._empty:
class_fields[param.name] = (annotation, ...)
class_fields[param_name] = (annotation, ...)

else:
class_fields[param.name] = (annotation, default)
class_fields[param_name] = (annotation, default)

if dep:
if not cast:
dep.cast = False

dependencies[param.name] = build_call_model(
dependencies[param_name] = build_call_model(
dep.dependency,
cast=dep.cast,
use_cache=dep.use_cache,
is_sync=is_sync,
)

if dep.cast is True:
class_fields[param.name] = (annotation, ...)
keyword_args.append(param.name)
class_fields[param_name] = (annotation, ...)
keyword_args.append(param_name)

elif custom:
assert not (
is_sync and is_coroutine_callable(custom.use)
), f"You cannot use async custom field `{type(custom).__name__}` at sync `{name}`"

custom.set_param_name(param.name)
custom_fields[param.name] = custom
custom.set_param_name(param_name)
custom_fields[param_name] = custom

if custom.cast is False:
annotation = Any

if custom.required:
class_fields[param.name] = (annotation, ...)
class_fields[param_name] = (annotation, ...)
else:
class_fields[param.name] = (Optional[annotation], None)
keyword_args.append(param.name)
class_fields[param_name] = (Optional[annotation], None)
keyword_args.append(param_name)

else:
if param.kind is param.KEYWORD_ONLY:
keyword_args.append(param.name)
elif param.name not in ("args", "kwargs"):
positional_args.append(param.name)
keyword_args.append(param_name)
elif param_name not in ("args", "kwargs"):
positional_args.append(param_name)

func_model = create_model( # type: ignore[call-overload]
name,
Expand Down
46 changes: 36 additions & 10 deletions fast_depends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
ContextManager,
Dict,
ForwardRef,
List,
Tuple,
Union,
cast,
Expand Down Expand Up @@ -65,27 +64,54 @@ def solve_generator_sync(
return stack.enter_context(cm)


def get_typed_signature(
call: Callable[..., Any]
) -> Tuple[List[inspect.Parameter], Any]:
def get_typed_signature(call: Callable[..., Any]) -> Tuple[inspect.Signature, Any]:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})

return [
locals = collect_outer_stack_locals()

globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param.annotation, globalns),
annotation=get_typed_annotation(
param.annotation,
globalns,
locals,
),
)
for param in signature.parameters.values()
], signature.return_annotation
]

return inspect.Signature(typed_params), get_typed_annotation(
signature.return_annotation,
globalns,
locals,
)


def collect_outer_stack_locals() -> Dict[str, Any]:
frame = inspect.currentframe()

def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
locals = {}
while frame is not None:
if "fast_depends" not in frame.f_code.co_filename:
locals.update(frame.f_locals)

frame = frame.f_back

return locals


def get_typed_annotation(
annotation: Any,
globalns: Dict[str, Any],
locals: Dict[str, Any],
) -> Any:
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
annotation = evaluate_forwardref(annotation, globalns, locals)
return annotation


Expand Down
20 changes: 20 additions & 0 deletions tests/test_locals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from pydantic import BaseModel

from fast_depends import inject


def wrap(func):
return inject(func)


def test_localns():
class M(BaseModel):
a: str

@wrap
def m(a: M) -> M:
return a

m(a={"a": "Hi!"})

0 comments on commit 6f5abd1

Please sign in to comment.