diff --git a/fast_depends/__about__.py b/fast_depends/__about__.py index 88e13a55..a2a124b5 100644 --- a/fast_depends/__about__.py +++ b/fast_depends/__about__.py @@ -1,3 +1,3 @@ """FastDepends - extracted and cleared from HTTP domain Fastapi Dependency Injection System""" -__version__ = "1.1.7" +__version__ = "2.0.0b" diff --git a/fast_depends/__init__.py b/fast_depends/__init__.py index f8b8b0df..538cce2e 100644 --- a/fast_depends/__init__.py +++ b/fast_depends/__init__.py @@ -1,8 +1,8 @@ -from fast_depends.provider import dependency_provider -from fast_depends.usage import Depends, inject +from fast_depends.dependencies import dependency_provider +from fast_depends.use import Depends, inject __all__ = ( - "inject", "Depends", "dependency_provider", + "inject", ) diff --git a/fast_depends/_compat.py b/fast_depends/_compat.py new file mode 100644 index 00000000..9febafc4 --- /dev/null +++ b/fast_depends/_compat.py @@ -0,0 +1,20 @@ +from pydantic.version import VERSION as PYDANTIC_VERSION + +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + +from pydantic import BaseModel, create_model + +if PYDANTIC_V2: + from pydantic._internal._typing_extra import ( + eval_type_lenient as evaluate_forwardref, + ) +else: + from pydantic.typing import evaluate_forwardref + + +__all__ = ( + "BaseModel", + "create_model", + "evaluate_forwardref", + "PYDANTIC_V2", +) diff --git a/fast_depends/construct.py b/fast_depends/construct.py deleted file mode 100644 index 49dd634b..00000000 --- a/fast_depends/construct.py +++ /dev/null @@ -1,276 +0,0 @@ -import inspect -from typing import Any, ForwardRef, Optional, Tuple, Type, Union - -from pydantic import BaseConfig -from pydantic.fields import ( - SHAPE_FROZENSET, - SHAPE_LIST, - SHAPE_SEQUENCE, - SHAPE_SET, - SHAPE_TUPLE, - SHAPE_TUPLE_ELLIPSIS, - FieldInfo, - ModelField, - Required, - Undefined, - UndefinedType, -) -from pydantic.schema import get_annotation_from_field_info -from pydantic.typing import evaluate_forwardref, get_args, get_origin -from typing_extensions import Annotated - -from fast_depends import model -from fast_depends.library import CustomField -from fast_depends.types import AnyCallable, AnyDict -from fast_depends.utils import is_coroutine_callable - -sequence_shapes = { - SHAPE_LIST, - SHAPE_SET, - SHAPE_FROZENSET, - SHAPE_TUPLE, - SHAPE_SEQUENCE, - SHAPE_TUPLE_ELLIPSIS, -} -sequence_types = (list, set, tuple) - - -def get_dependant( - *, - path: str, - call: AnyCallable, - name: Optional[str] = None, - use_cache: bool = True, -) -> model.Dependant: - dependant = model.Dependant( - call=call, - path=path, - name=name, - use_cache=use_cache, - return_field=None, - ) - - is_async = is_coroutine_callable(call) - - endpoint_signature = get_typed_signature(call) - signature_params = endpoint_signature.parameters - - for param in signature_params.values(): - custom, depends, param_field = analyze_param( - param_name=param.name, - annotation=param.annotation, - default=param.default, - ) - - if param.name == model.RETURN_FIELD: - dependant.return_field = param_field - continue - - elif custom is not None: - dependant.custom.append(custom) - - elif depends is not None: - assert is_async or not is_coroutine_callable( - depends.dependency - ), f"You cannot use async dependency `{depends}` with sync `{dependant}`" - - sub_dependant = get_param_sub_dependant( - param_name=param.name, - depends=depends, - path=path, - ) - dependant.dependencies.append(sub_dependant) - - dependant.params.append(param_field) - - return dependant - - -def analyze_param( - *, - param_name: str, - annotation: Any, - default: Any, -) -> Tuple[Optional[CustomField], Optional[model.Depends], ModelField]: - depends = None - custom = None - field_info = None - - if ( - annotation is not inspect.Signature.empty - and get_origin(annotation) is Annotated # type: ignore[comparison-overlap] - ): - annotated_args = get_args(annotation) - custom_annotations = [ - arg - for arg in annotated_args[1:] - if isinstance(arg, (FieldInfo, model.Depends, CustomField)) - ] - - custom_annotations = next(iter(custom_annotations), None) - if isinstance(custom_annotations, FieldInfo): - field_info = custom_annotations - assert field_info.default is Undefined or field_info.default is Required, ( - f"`{field_info.__class__.__name__}` default value cannot be set in" - f" `Annotated` for {param_name!r}. Set the default value with `=` instead." - ) - field_info.default = Required - - elif isinstance(custom_annotations, model.Depends): - depends = custom_annotations - - elif isinstance(custom_annotations, CustomField): # pragma: no branch - custom_annotations.set_param_name(param_name) - custom = custom_annotations - if custom.cast is False: - annotation = Any - - if isinstance(default, model.Depends): - assert depends is None, ( - "Cannot specify `Depends` in `Annotated` and default value" - f" together for {param_name!r}" - ) - assert field_info is None, ( - "Cannot specify a annotation in `Annotated` and `Depends` as a" - f" default value together for {param_name!r}" - ) - depends = default - - elif isinstance(default, CustomField): - default.set_param_name(param_name) - custom = default - if custom.cast is False: - annotation = Any - - elif isinstance(default, FieldInfo): - assert field_info is None, ( - "Cannot specify annotations in `Annotated` and default value" - f" together for {param_name!r}" - ) - field_info = default - - if (depends or custom) is not None: - field = None - - if field_info is not None: - annotation = get_annotation_from_field_info( - annotation if annotation is not inspect.Signature.empty else Any, - field_info, - param_name, - ) - else: - field_info = FieldInfo(default=default) - - alias = field_info.alias or param_name - - if custom and custom.required is True: - required = True - else: - required = field_info.default in (Required, Undefined, inspect._empty) - - field = create_response_field( - name=param_name, - type_=Any if depends and depends.cast is False else annotation, - default=None if any((depends, custom)) else field_info.default, - alias=alias, - required=required, - field_info=field_info, - ) - - return custom, depends, field - - -def get_typed_signature(call: AnyCallable) -> inspect.Signature: - signature = inspect.signature(call) - globalns = getattr(call, "__globals__", {}) - typed_params = [ - inspect.Parameter( - name=param.name, - kind=param.kind, - default=param.default, - annotation=get_typed_annotation(param.annotation, globalns), - ) - for param in signature.parameters.values() - ] - - if signature.return_annotation is not signature.empty: - typed_params.append( - inspect.Parameter( - name=model.RETURN_FIELD, - kind=inspect._KEYWORD_ONLY, - annotation=get_typed_annotation(signature.return_annotation, globalns), - ) - ) - typed_signature = inspect.Signature(typed_params) - return typed_signature - - -def get_typed_annotation(annotation: Any, globalns: AnyDict) -> Any: - if isinstance(annotation, str): - try: - annotation = ForwardRef(annotation) - annotation = evaluate_forwardref(annotation, globalns, globalns) - except Exception: - raise ValueError( # noqa: B904 - f"Invalid filed annotation! Hint: check that {annotation} is a valid pydantic field type" - ) - return annotation - - -def get_param_sub_dependant( - *, - param_name: str, - depends: model.Depends, - path: str, -) -> model.Dependant: - assert depends.dependency - return get_sub_dependant( - depends=depends, - dependency=depends.dependency, - path=path, - name=param_name, - ) - - -def get_sub_dependant( - *, - depends: model.Depends, - dependency: AnyCallable, - path: str, - name: Optional[str] = None, -) -> model.Dependant: - sub_dependant = get_dependant( - path=path, - call=dependency, - name=name, - use_cache=depends.use_cache, - ) - return sub_dependant - - -def create_response_field( - name: str, - type_: Type[Any], - default: Optional[Any] = None, - required: Union[bool, UndefinedType] = True, - field_info: Optional[FieldInfo] = None, - alias: Optional[str] = None, -) -> ModelField: - """ - Create a new response field. Raises if type_ is invalid. - """ - try: - return ModelField( - name=name, - type_=type_ if type_ is not inspect._empty else Any, - default=default, - required=required, - class_validators={}, - model_config=BaseConfig, - alias=alias, - field_info=field_info or FieldInfo(), - ) - except RuntimeError: # pragma: no cover - raise ValueError( # noqa: B904 - f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type" - ) diff --git a/fast_depends/core/__init__.py b/fast_depends/core/__init__.py new file mode 100644 index 00000000..99799a77 --- /dev/null +++ b/fast_depends/core/__init__.py @@ -0,0 +1,7 @@ +from fast_depends.core.build import build_call_model +from fast_depends.core.model import CallModel + +__all__ = ( + "CallModel", + "build_call_model", +) diff --git a/fast_depends/core/build.py b/fast_depends/core/build.py new file mode 100644 index 00000000..549df329 --- /dev/null +++ b/fast_depends/core/build.py @@ -0,0 +1,125 @@ +import inspect +from typing import Any, Callable, Optional + +from typing_extensions import Annotated, assert_never, get_args, get_origin + +from fast_depends._compat import create_model +from fast_depends.core.model import CallModel +from fast_depends.dependencies import Depends +from fast_depends.library import CustomField +from fast_depends.utils import get_typed_signature, is_coroutine_callable + +CUSTOM_ANNOTATIONS = (Depends, CustomField) + + +def build_call_model( + call: Callable[..., Any], + *, + cast: bool = True, + use_cache: bool = True, + is_sync: Optional[bool] = None, +) -> CallModel: + name = getattr(call, "__name__", type(call).__name__) + + is_call_async = is_coroutine_callable(call) + if is_sync is None: + is_sync = not is_call_async + else: + assert not ( + is_sync and is_call_async + ), f"You cannot use async dependency `{name}` at sync main" + + typed_params, return_annotation = get_typed_signature(call) + + class_fields = {} + dependencies = {} + custom_fields = {} + for param in typed_params: + dep: Optional[Depends] = None + custom: Optional[CustomField] = None + + if param.annotation is inspect._empty: + annotation = Any + + elif get_origin(param.annotation) is Annotated: + annotated_args = get_args(param.annotation) + type_annotation = annotated_args[0] + custom_annotations = [ + arg for arg in annotated_args[1:] if isinstance(arg, CUSTOM_ANNOTATIONS) + ] + + assert ( + len(custom_annotations) <= 1 + ), f"Cannot specify multiple `Annotated` Custom arguments for `{param.name}`!" + + next_custom = next(iter(custom_annotations), None) + if next_custom is not None: + if isinstance(next_custom, Depends): + dep = next_custom + elif isinstance(next_custom, CustomField): + custom = next_custom + else: # pragma: no cover + assert_never() + + annotation = type_annotation + else: + annotation = param.annotation + else: + annotation = param.annotation + + default = param.default + if dep or isinstance(default, Depends): + dep = dep or default + + dependencies[param.name] = build_call_model( + dep.call, + cast=dep.cast, + use_cache=dep.use_cache, + is_sync=is_sync, + ) + + if dep.cast is True: + class_fields[param.name] = (annotation, ...) + + elif custom or isinstance(default, CustomField): + custom = custom or default + assert not ( + is_sync and is_coroutine_callable(custom.use) + ), f"You cannot use async custom field `{type(custom).__name__}` at sync `{name}`" + + custom.set_param_name(param.name) + custom_fields[param.name] = custom + + if custom.cast is False: + annotation = Any + + if custom.required: + class_fields[param.name] = (annotation, ...) + else: + class_fields[param.name] = (Optional[annotation], None) + + elif default is inspect._empty: + class_fields[param.name] = (annotation, ...) + + else: + class_fields[param.name] = (annotation, default) + + if return_annotation is not inspect._empty: + response_model = create_model( + "ResponseModel", response=(return_annotation, ...) + ) + else: + response_model = None + + func_model = create_model(name, **class_fields) + + return CallModel( + call=call, + model=func_model, + response_model=response_model, + cast=cast, + use_cache=use_cache, + is_async=is_call_async, + dependencies=dependencies, + custom_fields=custom_fields, + ) diff --git a/fast_depends/core/model.py b/fast_depends/core/model.py new file mode 100644 index 00000000..c45b7375 --- /dev/null +++ b/fast_depends/core/model.py @@ -0,0 +1,223 @@ +from contextlib import AsyncExitStack, ExitStack +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generator, + List, + Optional, + Set, + Type, + Union, +) + +from typing_extensions import ParamSpec, TypeVar + +from fast_depends._compat import PYDANTIC_V2, BaseModel +from fast_depends.library import CustomField +from fast_depends.utils import ( + args_to_kwargs, + is_async_gen_callable, + is_coroutine_callable, + is_gen_callable, + run_async, + solve_generator_async, + solve_generator_sync, +) + +P = ParamSpec("P") +T = TypeVar("T") + + +class CallModel: + call: Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ] + is_async: bool + is_generator: bool + model: Type[BaseModel] + response_model: Optional[Type[BaseModel]] + arguments: List[str] + alias_arguments: List[str] + + dependencies: Dict[str, "CallModel"] + custom_fields: Dict[str, CustomField] + + # Dependencies and custom fields + use_cache: bool + cast: bool + + @property + def call_name(self): + return getattr(self.call, "__name__", type(self.call).__name__) + + @property + def real_params(self) -> Set[str]: + return set(self.arguments) - set(self.dependencies.keys()) + + @property + def flat_params(self) -> Set[str]: + params = set(self.real_params) + for d in self.dependencies.values(): + params |= set(d.flat_params) + return params + + def __init__( + self, + call: Union[ + Callable[P, T], + Callable[P, Awaitable[T]], + ], + model: BaseModel, + response_model: Optional[Type[BaseModel]] = None, + use_cache: bool = True, + cast: bool = True, + is_async: bool = False, + dependencies: Optional[Dict[str, "CallModel"]] = None, + custom_fields: Optional[Dict[str, CustomField]] = None, + ): + self.call = call + self.model = model + self.response_model = response_model + + self.arguments = [] + self.alias_arguments = [] + + if PYDANTIC_V2: + fields = self.model.model_fields + else: + fields = self.model.__fields__ + + for name, f in fields.items(): + self.arguments.append(name) + self.alias_arguments.append(f.alias or name) + + self.dependencies = dependencies or {} + self.custom_fields = custom_fields or {} + + self.use_cache = use_cache + self.cast = cast + self.is_async = is_async or is_coroutine_callable(call) + self.is_generator = is_gen_callable(self.call) or is_async_gen_callable( + self.call + ) + + def _cast_args( + self, + *args: P.args, + **kwargs: P.kwargs, + ) -> Generator[Dict[str, Any], Any, T,]: + kw = args_to_kwargs(self.alias_arguments, *args, **kwargs) + + kw_with_solved_dep = yield kw + + casted_model = self.model(**kw_with_solved_dep) + + casted_kw = { + arg: getattr(casted_model, arg, kw_with_solved_dep.get(arg)) + for arg in (*self.arguments, *self.dependencies.keys()) + } + + response = yield casted_kw + + if self.cast is True and self.response_model is not None: + casted_resp = self.response_model(response=response) + response = casted_resp.response + + return response + + def solve( + self, + *args: P.args, + stack: ExitStack, + cache_dependencies: Dict[str, Any], + dependency_overrides: Optional[Dict[Callable[..., Any], Any]] = None, + **kwargs: P.kwargs, + ) -> T: + if dependency_overrides: + self.call = dependency_overrides.get(self.call, self.call) + assert not is_coroutine_callable( + self.call + ), f"You cannot use async dependency `{self.call_name}` at sync main" + + if self.use_cache and cache_dependencies.get(self.call): + return cache_dependencies.get(self.call) + + cast_gen = self._cast_args(*args, **kwargs) + kwargs = next(cast_gen) + + for dep_arg, dep in self.dependencies.items(): + kwargs[dep_arg] = dep.solve( + stack=stack, + cache_dependencies=cache_dependencies, + dependency_overrides=dependency_overrides, + **kwargs, + ) + + for custom in self.custom_fields.values(): + kwargs = custom.use(**kwargs) + + final_kw = cast_gen.send(kwargs) + + if self.is_generator: + response = solve_generator_sync( + call=self.call, + stack=stack, + **final_kw, + ) + else: + response = self.call(**final_kw) + + try: + cast_gen.send(response) + except StopIteration as e: + if self.use_cache: # pragma: no branch + cache_dependencies[self.call] = e.value + return e.value + + async def asolve( + self, + *args: P.args, + stack: AsyncExitStack, + cache_dependencies: Dict[str, Any], + dependency_overrides: Optional[Dict[Callable[..., Any], Any]] = None, + **kwargs: P.kwargs, + ) -> T: + if dependency_overrides: + self.call = dependency_overrides.get(self.call, self.call) + + if self.use_cache and cache_dependencies.get(self.call): + return cache_dependencies.get(self.call) + + cast_gen = self._cast_args(*args, **kwargs) + kwargs = next(cast_gen) + + for dep_arg, dep in self.dependencies.items(): + kwargs[dep_arg] = await dep.asolve( + stack=stack, + cache_dependencies=cache_dependencies, + dependency_overrides=dependency_overrides, + **kwargs, + ) + + for custom in self.custom_fields.values(): + kwargs = await run_async(custom.use, **kwargs) + + final_kw = cast_gen.send(kwargs) + + if self.is_generator: + response = await solve_generator_async( + call=self.call, + stack=stack, + **final_kw, + ) + else: + response = await run_async(self.call, **final_kw) + try: + cast_gen.send(response) + except StopIteration as e: + if self.use_cache: # pragma: no branch + cache_dependencies[self.call] = e.value + return e.value diff --git a/fast_depends/dependencies/__init__.py b/fast_depends/dependencies/__init__.py new file mode 100644 index 00000000..ef93b458 --- /dev/null +++ b/fast_depends/dependencies/__init__.py @@ -0,0 +1,7 @@ +from fast_depends.dependencies.model import Depends +from fast_depends.dependencies.provider import dependency_provider + +__all__ = ( + "Depends", + "dependency_provider", +) diff --git a/fast_depends/dependencies/model.py b/fast_depends/dependencies/model.py new file mode 100644 index 00000000..58412516 --- /dev/null +++ b/fast_depends/dependencies/model.py @@ -0,0 +1,22 @@ +from typing import Any, Callable + + +class Depends: + use_cache: bool + cast: bool + + def __init__( + self, + call: Callable[..., Any], + *, + use_cache: bool = True, + cast: bool = True, + ) -> None: + self.call = call + self.use_cache = use_cache + self.cast = cast + + def __repr__(self) -> str: + attr = getattr(self.call, "__name__", type(self.call).__name__) + cache = "" if self.use_cache else ", use_cache=False" + return f"{self.__class__.__name__}({attr}{cache})" diff --git a/fast_depends/provider.py b/fast_depends/dependencies/provider.py similarity index 100% rename from fast_depends/provider.py rename to fast_depends/dependencies/provider.py diff --git a/fast_depends/injector.py b/fast_depends/injector.py deleted file mode 100644 index 8cbd9bb6..00000000 --- a/fast_depends/injector.py +++ /dev/null @@ -1,223 +0,0 @@ -from contextlib import AsyncExitStack, ExitStack -from copy import deepcopy -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, TypeVar, cast - -from pydantic.error_wrappers import ErrorList, ErrorWrapper -from pydantic.errors import MissingError -from pydantic.fields import ModelField - -from fast_depends.construct import get_dependant -from fast_depends.model import Dependant -from fast_depends.types import AnyCallable, AnyDict -from fast_depends.utils import ( - is_async_gen_callable, - is_coroutine_callable, - is_gen_callable, - run_async, - solve_generator_async, - solve_generator_sync, -) - -T = TypeVar("T") - - -async def solve_dependencies_async( - *, - dependant: Dependant, - stack: AsyncExitStack, - body: Optional[AnyDict] = None, - dependency_overrides_provider: Optional[Any] = None, - dependency_cache: Optional[Dict[Tuple[AnyCallable, Tuple[str]], Any]] = None, -) -> Tuple[Dict[str, Any], List[ErrorList], Dict[Tuple[AnyCallable, Tuple[str]], Any],]: - errors: List[ErrorList] = [] - - dependency_cache = dependency_cache or {} - - sub_dependant: Dependant - for sub_dependant in dependant.dependencies: - sub_dependant.call = cast(AnyCallable, sub_dependant.call) - sub_dependant.cache_key = cast( - Tuple[AnyCallable, Tuple[str]], sub_dependant.cache_key - ) - call = sub_dependant.call - use_sub_dependant = sub_dependant - if ( - dependency_overrides_provider - and dependency_overrides_provider.dependency_overrides - ): - call = getattr( - dependency_overrides_provider, "dependency_overrides", {} - ).get(sub_dependant.call) - if call is not None: # pragma: no branch - use_sub_dependant = get_dependant( - path=sub_dependant.path, - call=call, - name=sub_dependant.name, - ) - assert call - solved_result = await solve_dependencies_async( - dependant=use_sub_dependant, - body=body, - dependency_overrides_provider=dependency_overrides_provider, - dependency_cache=dependency_cache, - stack=stack, - ) - ( - sub_values, - sub_errors, - sub_dependency_cache, - ) = solved_result - - dependency_cache.update(sub_dependency_cache) - - if sub_errors: - errors.extend(sub_errors) - continue - - if ( - use_sub_dependant.use_cache - and use_sub_dependant.cache_key in dependency_cache - ): - solved = dependency_cache[use_sub_dependant.cache_key] - elif is_gen_callable(call) or is_async_gen_callable(call): - solved, sub_errors = use_sub_dependant.cast_response( - await solve_generator_async( - call=call, stack=stack, sub_values=sub_values - ) - ) - else: - solved, sub_errors = use_sub_dependant.cast_response( - await run_async(use_sub_dependant.call, **sub_values) - ) - - if sub_errors: - errors.append(sub_errors) - continue - - if use_sub_dependant.name is not None: # pragma: no branch - body[use_sub_dependant.name] = solved - - if use_sub_dependant.cache_key not in dependency_cache: - dependency_cache[use_sub_dependant.cache_key] = solved - - for custom in dependant.custom: - body = await run_async(custom.use, **(body or {})) - - params, main_errors = params_to_args(dependant.params, body or {}) - errors.extend(main_errors) - return params, errors, dependency_cache - - -def solve_dependencies_sync( - *, - dependant: Dependant, - stack: ExitStack, - body: Optional[AnyDict] = None, - dependency_overrides_provider: Optional[Any] = None, - dependency_cache: Optional[Dict[Tuple[AnyCallable, Tuple[str]], Any]] = None, -) -> Tuple[Dict[str, Any], List[ErrorList], Dict[Tuple[AnyCallable, Tuple[str]], Any],]: - assert not is_coroutine_callable(dependant.call) and not is_async_gen_callable( - dependant.call - ), f"You can't call async `{dependant.call.__name__}` at sync code" - - errors: List[ErrorList] = [] - - dependency_cache = dependency_cache or {} - - sub_dependant: Dependant - for sub_dependant in dependant.dependencies: - sub_dependant.call = cast(AnyCallable, sub_dependant.call) - sub_dependant.cache_key = cast( - Tuple[AnyCallable, Tuple[str]], sub_dependant.cache_key - ) - call = sub_dependant.call - use_sub_dependant = sub_dependant - if ( - dependency_overrides_provider - and dependency_overrides_provider.dependency_overrides - ): - call = getattr( - dependency_overrides_provider, "dependency_overrides", {} - ).get(sub_dependant.call) - if call is not None: # pragma: no branch - use_sub_dependant = get_dependant( - path=sub_dependant.path, - call=call, - name=sub_dependant.name, - ) - assert call - solved_result = solve_dependencies_sync( - dependant=use_sub_dependant, - body=body, - stack=stack, - dependency_overrides_provider=dependency_overrides_provider, - dependency_cache=dependency_cache, - ) - ( - sub_values, - sub_errors, - sub_dependency_cache, - ) = solved_result - - dependency_cache.update(sub_dependency_cache) - - if sub_errors: - errors.extend(sub_errors) - continue - - if ( - use_sub_dependant.use_cache - and use_sub_dependant.cache_key in dependency_cache - ): - solved = dependency_cache[sub_dependant.cache_key] - elif is_gen_callable(call): - solved, sub_errors = use_sub_dependant.cast_response( - solve_generator_sync(call=call, stack=stack, sub_values=sub_values) - ) - else: - solved, sub_errors = use_sub_dependant.cast_response(call(**sub_values)) - - if sub_errors: - errors.append(sub_errors) - continue - - if use_sub_dependant.name is not None: # pragma: no branch - body[sub_dependant.name] = solved - - if use_sub_dependant.cache_key not in dependency_cache: - dependency_cache[use_sub_dependant.cache_key] = solved - - for custom in dependant.custom: - assert not is_coroutine_callable(custom.use) and not is_async_gen_callable( - custom.use - ), f"You can't use async `{type(custom).__name__}` at sync code" - body = custom.use(**(body or {})) - - params, main_errors = params_to_args(dependant.params, body or {}) - errors.extend(main_errors) - return params, errors, dependency_cache - - -def params_to_args( - required_params: Sequence[ModelField], - received_params: Mapping[str, Any], -) -> Tuple[AnyDict, List[ErrorList]]: - values: AnyDict = {} - errors: List[ErrorList] = [] - for field in required_params: - value = received_params.get(field.alias) - if value is None: - if field.required: - errors.append(ErrorWrapper(MissingError(), loc=(field.alias,))) - else: - values[field.name] = deepcopy(field.default) - continue - - v_, errors_ = field.validate(value, values, loc=(field.alias,)) - if isinstance(errors_, ErrorWrapper): - errors.append(errors_) - elif isinstance(errors_, list): # pragma: no cover - errors.extend(errors_) - else: - values[field.name] = v_ - return values, errors diff --git a/fast_depends/library/model.py b/fast_depends/library/model.py index ca213563..f118decf 100644 --- a/fast_depends/library/model.py +++ b/fast_depends/library/model.py @@ -1,7 +1,5 @@ from abc import ABC -from typing import Optional, TypeVar - -from fast_depends.types import AnyDict +from typing import Any, Dict, Optional, TypeVar Cls = TypeVar("Cls", bound="CustomField") @@ -20,6 +18,6 @@ def set_param_name(self: Cls, name: str) -> Cls: self.param_name = name return self - def use(self, **kwargs: AnyDict) -> AnyDict: + def use(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]: assert self.param_name, "You should specify `param_name` before using" return kwargs diff --git a/fast_depends/model.py b/fast_depends/model.py deleted file mode 100644 index d4bb6d3e..00000000 --- a/fast_depends/model.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Any, List, Optional, Tuple, List -from itertools import chain - -from pydantic import create_model -from pydantic.error_wrappers import ErrorList -from pydantic.fields import ModelField - -from fast_depends.library import CustomField -from fast_depends.types import AnyCallable - -RETURN_FIELD = "custom_return" - - -class Dependant: - def __init__( - self, - *, - call: AnyCallable, - params: Optional[List[ModelField]] = None, - return_field: Optional[ModelField] = None, - dependencies: Optional[List["Dependant"]] = None, - custom: Optional[List[CustomField]] = None, - use_cache: bool = True, - path: Optional[str] = None, - name: Optional[str] = None, - ) -> None: - self.params = params or [] - self.return_field = return_field - self.dependencies = dependencies or [] - self.custom = custom or [] - self.call = call - self.use_cache = use_cache - # Parent argument name at subdependency - self.name = name - # Store the path to be able to re-generate a dependable from it in overrides - self.path = path - # Save the cache key at creation to optimize performance - self.cache_key = (self.call,) - self.error_model = create_model(getattr(call, "__name__", str(call))) - - @property - def real_params(self) -> List[ModelField]: - custom = tuple(chain( - (c.param_name for c in self.custom), - (d.name for d in self.dependencies) - )) - return list(filter(lambda x: x.name not in custom, self.params)) - - @property - def flat_params(self) -> List[ModelField]: - params = self.real_params - for d in self.dependencies: - params.extend(d.flat_params) - - params_unique_names = set() - params_unique = [] - for p in params: - if p.name not in params_unique_names: - params_unique.append(p) - params_unique_names.add(p.name) - - return params_unique - - def cast_response(self, response: Any) -> Tuple[Optional[Any], Optional[ErrorList]]: - if self.return_field is None: - return response, [] - return self.return_field.validate(response, {}, loc=RETURN_FIELD) - - -class Depends: - def __init__( - self, - dependency: AnyCallable, - *, - use_cache: bool = True, - cast: bool = True, - ) -> None: - self.dependency = dependency - self.use_cache = use_cache - self.cast = cast - - def __repr__(self) -> str: - attr = getattr(self.dependency, "__name__", type(self.dependency).__name__) - cache = "" if self.use_cache else ", use_cache=False" - return f"{self.__class__.__name__}({attr}{cache})" diff --git a/fast_depends/usage.py b/fast_depends/usage.py deleted file mode 100644 index 6d5ff260..00000000 --- a/fast_depends/usage.py +++ /dev/null @@ -1,107 +0,0 @@ -from contextlib import AsyncExitStack, ExitStack -from functools import partial, wraps -from typing import Any, Callable, Optional, TypeVar - -from pydantic import ValidationError - -from fast_depends import model -from fast_depends.construct import get_dependant -from fast_depends.injector import solve_dependencies_async, solve_dependencies_sync -from fast_depends.provider import dependency_provider -from fast_depends.types import AnyCallable, P -from fast_depends.utils import args_to_kwargs, is_coroutine_callable, run_async - -T = TypeVar("T") - - -def Depends( - dependency: AnyCallable, - *, - use_cache: bool = True, - cast: bool = True, -) -> Any: # noqa: N802 - return model.Depends(dependency=dependency, use_cache=use_cache, cast=cast) - - -def wrap_dependant(dependant: model.Dependant) -> model.Dependant: - return dependant - - -def inject( - func: Callable[P, T], - *, - dependency_overrides_provider: Optional[Any] = dependency_provider, - wrap_dependant: Callable[[model.Dependant], model.Dependant] = wrap_dependant, -) -> Callable[P, T]: - dependant = get_dependant(call=func, path=func.__name__) - - dependant = wrap_dependant(dependant) - - if is_coroutine_callable(func) is True: - f = async_typed_wrapper - else: - f = sync_typed_wrapper - - return wraps(func)( - partial( - f, - dependant=dependant, - dependency_overrides_provider=dependency_overrides_provider, - ) - ) - - -async def async_typed_wrapper( - *args: P.args, - dependant: model.Dependant, - dependency_overrides_provider: Optional[Any], - **kwargs: P.kwargs, -) -> Any: - kwargs = args_to_kwargs((x.name for x in dependant.params), *args, **kwargs) - - async with AsyncExitStack() as stack: - solved_result, errors, _ = await solve_dependencies_async( - body=kwargs, - dependant=dependant, - stack=stack, - dependency_overrides_provider=dependency_overrides_provider, - ) - - if errors: - raise ValidationError(errors, dependant.error_model) - - v, casted_errors = dependant.cast_response( - await run_async(dependant.call, **solved_result) - ) - - if casted_errors: - raise ValidationError(errors, dependant.error_model) - - return v - - -def sync_typed_wrapper( - *args: P.args, - dependant: model.Dependant, - dependency_overrides_provider: Optional[Any], - **kwargs: P.kwargs, -) -> Any: - kwargs = args_to_kwargs((x.name for x in dependant.params), *args, **kwargs) - - with ExitStack() as stack: - solved_result, errors, _ = solve_dependencies_sync( - body=kwargs, - dependant=dependant, - stack=stack, - dependency_overrides_provider=dependency_overrides_provider, - ) - - if errors: - raise ValidationError(errors, dependant.error_model) - - v, casted_errors = dependant.cast_response(dependant.call(**solved_result)) - - if casted_errors: - raise ValidationError(errors, dependant.error_model) - - return v diff --git a/fast_depends/use.py b/fast_depends/use.py new file mode 100644 index 00000000..dbe922f9 --- /dev/null +++ b/fast_depends/use.py @@ -0,0 +1,64 @@ +from contextlib import AsyncExitStack, ExitStack +from functools import wraps +from typing import Any, Awaitable, Callable, Optional, ParamSpec, TypeVar, Union + +from fast_depends.core import CallModel, build_call_model +from fast_depends.dependencies import dependency_provider, model + +P = ParamSpec("P") +T = TypeVar("T") + + +def Depends( + dependency: Union[Callable[P, T], Callable[P, Awaitable[T]]], + *, + use_cache: bool = True, + cast: bool = True, +) -> Any: # noqa: N802 + return model.Depends(call=dependency, use_cache=use_cache, cast=cast) + + +def inject( + func: Union[Callable[P, T], Callable[P, Awaitable[T]]], + *, + dependency_overrides_provider: Optional[Any] = dependency_provider, + wrap_dependant: Callable[[CallModel], CallModel] = lambda x: x, +) -> Union[Callable[P, T], Callable[P, Awaitable[T]]]: + model = wrap_dependant(build_call_model(func)) + + if ( + dependency_overrides_provider + and getattr(dependency_overrides_provider, "dependency_overrides", None) + is not None + ): + overrides = dependency_overrides_provider.dependency_overrides + else: + overrides = None + + if model.is_async: + + @wraps(func) + async def call_func(*args: P.args, **kwargs: P.kwargs) -> T: + async with AsyncExitStack() as stack: + return await model.asolve( + *args, + stack=stack, + dependency_overrides=overrides, + cache_dependencies={}, + **kwargs, + ) + + else: + + @wraps(func) + def call_func(*args: P.args, **kwargs: P.kwargs) -> T: + with ExitStack() as stack: + return model.solve( + *args, + stack=stack, + dependency_overrides=overrides, + cache_dependencies={}, + **kwargs, + ) + + return wraps(func)(call_func) diff --git a/fast_depends/utils.py b/fast_depends/utils.py index 8beb7e0a..0effce20 100644 --- a/fast_depends/utils.py +++ b/fast_depends/utils.py @@ -1,4 +1,3 @@ -import asyncio import functools import inspect from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager @@ -8,33 +7,25 @@ Callable, ContextManager, Dict, + ForwardRef, Iterable, + List, + ParamSpec, + Tuple, TypeVar, - cast, ) import anyio -from fast_depends.types import AnyCallable, AnyDict, P +from fast_depends._compat import evaluate_forwardref +P = ParamSpec("P") T = TypeVar("T") -def args_to_kwargs( - arguments: Iterable[str], *args: P.args, **kwargs: P.kwargs -) -> AnyDict: - if not args: - return kwargs - - unused = filter(lambda x: x not in kwargs, arguments) - - return dict((*zip(unused, args), *kwargs.items())) - - async def run_async(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: - if asyncio.iscoroutinefunction(func): - r = await func(*args, **kwargs) - return cast(T, r) + if is_coroutine_callable(func): + return await func(*args, **kwargs) else: return await run_in_threadpool(func, *args, **kwargs) @@ -47,31 +38,8 @@ async def run_in_threadpool( return await anyio.to_thread.run_sync(func, *args) -def is_async_gen_callable(call: AnyCallable) -> bool: - if inspect.isasyncgenfunction(call): - return True - dunder_call = getattr(call, "__call__", None) # noqa: B004 - return inspect.isasyncgenfunction(dunder_call) - - -def is_gen_callable(call: AnyCallable) -> bool: - if inspect.isgeneratorfunction(call): - return True - dunder_call = getattr(call, "__call__", None) # noqa: B004 - return inspect.isgeneratorfunction(dunder_call) - - -def is_coroutine_callable(call: AnyCallable) -> bool: - if inspect.isroutine(call): - return inspect.iscoroutinefunction(call) - if inspect.isclass(call): - return False - call_ = getattr(call, "__call__", None) # noqa: B004 - return inspect.iscoroutinefunction(call_) - - async def solve_generator_async( - *, call: AnyCallable, stack: AsyncExitStack, sub_values: Dict[str, Any] + *, call: Callable[..., Any], stack: AsyncExitStack, **sub_values: Any ) -> Any: if is_gen_callable(call): cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) @@ -81,12 +49,46 @@ async def solve_generator_async( def solve_generator_sync( - *, call: AnyCallable, stack: ExitStack, sub_values: Dict[str, Any] + *, call: Callable[..., Any], stack: ExitStack, **sub_values: Any ) -> Any: cm = contextmanager(call)(**sub_values) return stack.enter_context(cm) +def args_to_kwargs( + arguments: Iterable[str], *args: P.args, **kwargs: P.kwargs +) -> Dict[str, Any]: + if not args: + return kwargs + + unused = filter(lambda x: x not in kwargs, arguments) + + return dict((*zip(unused, args), *kwargs.items())) + + +def get_typed_signature( + call: Callable[..., Any] +) -> Tuple[List[inspect.Parameter], Any]: + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + return [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation(param.annotation, globalns), + ) + for param in signature.parameters.values() + ], signature.return_annotation + + +def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation = evaluate_forwardref(annotation, globalns, globalns) + return annotation + + @asynccontextmanager async def contextmanager_in_threadpool( cm: ContextManager[T], @@ -106,3 +108,26 @@ async def contextmanager_in_threadpool( await anyio.to_thread.run_sync( cm.__exit__, None, None, None, limiter=exit_limiter ) + + +def is_gen_callable(call: Callable[..., Any]) -> bool: + if inspect.isgeneratorfunction(call): + return True + dunder_call = getattr(call, "__call__", None) # noqa: B004 + return inspect.isgeneratorfunction(dunder_call) + + +def is_async_gen_callable(call: Callable[..., Any]) -> bool: + if inspect.isasyncgenfunction(call): + return True + dunder_call = getattr(call, "__call__", None) # noqa: B004 + return inspect.isasyncgenfunction(dunder_call) + + +def is_coroutine_callable(call: Callable[..., Any]) -> bool: + if inspect.isroutine(call): + return inspect.iscoroutinefunction(call) + if inspect.isclass(call): + return False + call_ = getattr(call, "__call__", None) # noqa: B004 + return inspect.iscoroutinefunction(call_) diff --git a/pyproject.toml b/pyproject.toml index 746883d9..a6dd0c09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ classifiers = [ ] dependencies = [ - "pydantic>=1.8", + "pydantic==2.0b3", "anyio", ] diff --git a/tests/async/test_async.py b/tests/async/test_async.py index 959db404..7983d872 100644 --- a/tests/async/test_async.py +++ b/tests/async/test_async.py @@ -46,15 +46,6 @@ async def some_func(a: "int") -> float: assert isinstance(await some_func("1"), float) -@pytest.mark.asyncio -async def test_annotated_wrong(): - with pytest.raises(ValueError): - - @inject - async def some_func(b: "dsada"): # pragma: no cover - pass - - @pytest.mark.asyncio async def test_pydantic_types_casting(): class SomeModel(BaseModel): @@ -79,7 +70,7 @@ async def another_func(a=Field(..., alias="b")) -> float: assert isinstance(a, str) return a - assert isinstance(await some_func(b="2"), float) + assert isinstance(await some_func(b="2", c=3), float) assert isinstance(await another_func(b="2"), float) diff --git a/tests/async/test_async_depends.py b/tests/async/test_async_depends.py index 8d3a6f4c..602dd448 100644 --- a/tests/async/test_async_depends.py +++ b/tests/async/test_async_depends.py @@ -26,7 +26,7 @@ def dep_func(a: int) -> float: return a @inject - async def some_func(a: int, b: int, c=Depends(dep_func)) -> str: + async def some_func(a: int, b: int, c=Depends(dep_func)) -> float: assert isinstance(c, float) return a + b + c @@ -39,7 +39,7 @@ async def dep_func(a): return a @inject - async def some_func(a: int, b: int, c: int = Depends(dep_func)) -> str: + async def some_func(a: int, b: int, c: int = Depends(dep_func)) -> float: assert isinstance(c, int) return a + b + c @@ -73,7 +73,7 @@ async def dep_func(a): D = Annotated[int, Depends(dep_func)] @inject - async def some_func(a: int, b: int, c: D = None) -> str: + async def some_func(a: int, b: int, c: D = None) -> float: assert isinstance(c, int) return a + b + c @@ -82,7 +82,7 @@ async def another_func(a: int, c: D): return a + c assert await some_func("1", "2") - assert (await another_func(3)) == 6 + assert (await another_func(3)) == 6.0 @pytest.mark.asyncio diff --git a/tests/library/test_custom.py b/tests/library/test_custom.py index 1a9b8dbc..cd941ff2 100644 --- a/tests/library/test_custom.py +++ b/tests/library/test_custom.py @@ -12,7 +12,8 @@ class Header(CustomField): def use(self, **kwargs: AnyDict) -> AnyDict: kwargs = super().use(**kwargs) - kwargs[self.param_name] = kwargs.get("headers", {}).get(self.param_name) + if kwargs.get("headers", {}).get(self.param_name): + kwargs[self.param_name] = kwargs.get("headers", {}).get(self.param_name) return kwargs @@ -23,93 +24,92 @@ async def use(self, **kwargs: AnyDict) -> AnyDict: def test_header(): @inject - def catch(key: str = Header()): + def sync_catch(key: int = Header()): return key - assert catch(headers={"key": 1}) == "1" + assert sync_catch(headers={"key": "1"}) == 1 @pytest.mark.asyncio async def test_header_async(): @inject - async def catch(key: str = Header()): + async def async_catch(key: int = Header()): return key - assert (await catch(headers={"key": 1})) == "1" + assert (await async_catch(headers={"key": "1"})) == 1 def test_multiple_header(): @inject - def catch(key: str = Header(), key2: int = Header()): + def sync_catch(key: str = Header(), key2: int = Header()): assert key == "1" assert key2 == 2 - catch(headers={"key": 1, "key2": 2}) + sync_catch(headers={"key": "1", "key2": "2"}) @pytest.mark.asyncio async def test_async_header_async(): @inject - async def catch(key: str = AsyncHeader()): + async def async_catch(key: float = AsyncHeader()): return key - assert (await catch(headers={"key": 1})) == "1" + assert (await async_catch(headers={"key": "1"})) == 1.0 -def test_adync_header_sync(): - @inject - def catch(key: str = AsyncHeader()): # pragma: no cover - return key - +def test_async_header_sync(): with pytest.raises(AssertionError): - catch(headers={"key": 1}) == "1" + + @inject + def sync_catch(key: str = AsyncHeader()): # pragma: no cover + return key def test_header_annotated(): @inject - def catch(key: Annotated[str, Header()]): + def sync_catch(key: Annotated[int, Header()]): return key - assert catch(headers={"key": 1}) == "1" + assert sync_catch(headers={"key": "1"}) == 1 def test_header_required(): @inject - def catch(key2=Header()): # pragma: no cover + def sync_catch(key2=Header()): # pragma: no cover return key2 - with pytest.raises(pydantic.error_wrappers.ValidationError): - catch() + with pytest.raises(pydantic.ValidationError): + sync_catch() def test_header_not_required(): @inject - def catch(key2=Header(required=False)): + def sync_catch(key2=Header(required=False)): assert key2 is None - catch() + sync_catch() def test_depends(): - def dep(key: Annotated[str, Header()]): + def dep(key: Annotated[int, Header()]): return key @inject - def catch(k=Depends(dep)): + def sync_catch(k=Depends(dep)): return k - assert catch(headers={"key": 1}) == "1" + assert sync_catch(headers={"key": "1"}) == 1 def test_not_cast(): @inject - def catch(key: Annotated[str, Header(cast=False)]): + def sync_catch(key: Annotated[float, Header(cast=False)]): return key - assert catch(headers={"key": 1}) == 1 + assert sync_catch(headers={"key": 1}) == 1 @inject - def catch(key: logging.Logger = Header(cast=False)): + def sync_catch(key: logging.Logger = Header(cast=False)): return key - assert catch(headers={"key": 1}) == 1 + assert sync_catch(headers={"key": 1}) == 1 diff --git a/tests/sync/test_sync.py b/tests/sync/test_sync.py index a83bfa13..066dbe97 100644 --- a/tests/sync/test_sync.py +++ b/tests/sync/test_sync.py @@ -22,14 +22,6 @@ def some_func(a, b: int): assert isinstance(some_func(1, "2"), int) -def test_annotated_wrong(): - with pytest.raises(ValueError): - - @inject - def some_func(b: "dsada"): # pragma: no cover - pass - - def test_validation_error(): @inject def some_func(a, b: str = Field(..., max_length=1)): # pragma: no cover diff --git a/tests/sync/test_sync_depends.py b/tests/sync/test_sync_depends.py index acb48ad7..d7fc54ff 100644 --- a/tests/sync/test_sync_depends.py +++ b/tests/sync/test_sync_depends.py @@ -53,7 +53,7 @@ def dep_func(a): return a @inject - def some_func(a: int, b: int, c: int = Depends(dep_func)) -> str: + def some_func(a: int, b: int, c: int = Depends(dep_func)) -> float: assert isinstance(c, int) return a + b + c @@ -67,7 +67,7 @@ def dep_func(a): D = Annotated[int, Depends(dep_func)] @inject - def some_func(a: int, b: int, c: D = None) -> str: + def some_func(a: int, b: int, c: D = None) -> float: assert isinstance(c, int) return a + b + c @@ -76,7 +76,7 @@ def another_func(a: int, c: D): return a + c assert some_func("1", "2") - assert another_func("3") == 6 + assert another_func("3") == 6.0 def test_cash(): diff --git a/tests/test_params.py b/tests/test_params.py new file mode 100644 index 00000000..2dec066d --- /dev/null +++ b/tests/test_params.py @@ -0,0 +1,21 @@ +from fast_depends import Depends +from fast_depends.core import build_call_model + + +def test_params(): + def func1(a): + ... + + def func2(c, b=Depends(func1)): + ... + + def func3(b): + ... + + def main(a, b, m=Depends(func2), k=Depends(func3), c=Depends(func1)): + ... + + model = build_call_model(main) + + assert model.real_params == {"a", "b"} + assert model.flat_params == {"a", "b", "c"}