diff --git a/README.md b/README.md index 567165e..50bf593 100644 --- a/README.md +++ b/README.md @@ -126,7 +126,7 @@ from fast_depends import inject from fast_depends.library import CustomField class Header(CustomField): - def use(self, **kwargs: AnyDict) -> AnyDict: + def use(self, /, **kwargs: AnyDict) -> AnyDict: kwargs = super().use(**kwargs) kwargs[self.param_name] = kwargs["headers"][self.param_name] return kwargs diff --git a/docs/docs/index.md b/docs/docs/index.md index 49b1aeb..95dcb0c 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -91,7 +91,7 @@ from fast_depends import inject from fast_depends.library import CustomField class Header(CustomField): - def use(self, **kwargs: AnyDict) -> AnyDict: + def use(self, /, **kwargs: AnyDict) -> AnyDict: kwargs = super().use(**kwargs) kwargs[self.param_name] = kwargs["headers"][self.param_name] return kwargs diff --git a/docs/docs_src/advanced/custom/class_declaration.py b/docs/docs_src/advanced/custom/class_declaration.py index b1447e9..155074c 100644 --- a/docs/docs_src/advanced/custom/class_declaration.py +++ b/docs/docs_src/advanced/custom/class_declaration.py @@ -1,7 +1,7 @@ from fast_depends.library import CustomField class Header(CustomField): - def use(self, **kwargs): + def use(self, /, **kwargs): kwargs = super().use(**kwargs) kwargs[self.param_name] = kwargs["headers"][self.param_name] return kwargs diff --git a/docs/docs_src/advanced/custom/starlette.py b/docs/docs_src/advanced/custom/starlette.py index b484666..0935e8e 100644 --- a/docs/docs_src/advanced/custom/starlette.py +++ b/docs/docs_src/advanced/custom/starlette.py @@ -6,7 +6,7 @@ from starlette.routing import Route class Path(CustomField): - def use(self, *, request, **kwargs): + def use(self, /, *, request, **kwargs): return { **super().use(request=request, **kwargs), self.param_name: request.path_params.get(self.param_name) diff --git a/fast_depends/__about__.py b/fast_depends/__about__.py index 3ee6f73..7dc15cd 100644 --- a/fast_depends/__about__.py +++ b/fast_depends/__about__.py @@ -1,3 +1,3 @@ """FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System""" -__version__ = "2.2.4" +__version__ = "2.2.5" diff --git a/fast_depends/core/model.py b/fast_depends/core/model.py index 4d8daf0..d9e2ec1 100644 --- a/fast_depends/core/model.py +++ b/fast_depends/core/model.py @@ -78,25 +78,26 @@ class CallModel(Generic[P, T]): ) @property - def call_name(self__) -> str: - return getattr(self__.call, "__name__", type(self__.call).__name__) + def call_name(self) -> str: + return getattr(self.call, "__name__", type(self.call).__name__) @property - def real_params(self__) -> Dict[str, FieldInfo]: - params = self__.params.copy() - for name in self__.custom_fields.keys(): + def real_params(self) -> Dict[str, FieldInfo]: + params = self.params.copy() + for name in self.custom_fields.keys(): params.pop(name, None) return params @property - def flat_params(self__) -> Dict[str, FieldInfo]: - params = self__.real_params - for d in self__.dependencies.values(): + def flat_params(self) -> Dict[str, FieldInfo]: + params = self.real_params + for d in self.dependencies.values(): params.update(d.flat_params) return params def __init__( - self__, + self, + /, call: Union[ Callable[P, T], Callable[P, Awaitable[T]], @@ -112,39 +113,40 @@ def __init__( positional_args: Optional[List[str]] = None, custom_fields: Optional[Dict[str, CustomField]] = None, ): - self__.call = call - self__.model = model - self__.response_model = response_model + self.call = call + self.model = model + self.response_model = response_model fields: Dict[str, FieldInfo] if PYDANTIC_V2: - fields = self__.model.model_fields + fields = self.model.model_fields else: - fields = self__.model.__fields__ # type: ignore + fields = self.model.__fields__ # type: ignore - self__.dependencies = dependencies or {} - self__.extra_dependencies = extra_dependencies or [] - self__.custom_fields = custom_fields or {} + self.dependencies = dependencies or {} + self.extra_dependencies = extra_dependencies or [] + self.custom_fields = custom_fields or {} - self__.alias_arguments = [f.alias or name for name, f in fields.items()] - self__.keyword_args = tuple(keyword_args or ()) - self__.positional_args = tuple(positional_args or ()) + self.alias_arguments = [f.alias or name for name, f in fields.items()] + self.keyword_args = tuple(keyword_args or ()) + self.positional_args = tuple(positional_args or ()) - self__.params = fields.copy() - for name in self__.dependencies.keys(): - self__.params.pop(name, None) + self.params = fields.copy() + for name in self.dependencies.keys(): + self.params.pop(name, None) - self__.use_cache = use_cache - self__.cast = cast - self__.is_async = ( - is_async or is_coroutine_callable(call) or is_async_gen_callable(self__.call) + self.use_cache = use_cache + self.cast = cast + self.is_async = ( + is_async or is_coroutine_callable(call) or is_async_gen_callable(self.call) ) - self__.is_generator = is_gen_callable(self__.call) or is_async_gen_callable( - self__.call + self.is_generator = is_gen_callable(self.call) or is_async_gen_callable( + self.call ) def _solve( - self__, + self, + /, *args: P.args, cache_dependencies: Dict[ Union[ @@ -168,30 +170,30 @@ def _solve( **kwargs: P.kwargs, ) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], Any, T]: if dependency_overrides: - self__.call = dependency_overrides.get(self__.call, self__.call) - assert self__.is_async or not is_coroutine_callable( - self__.call - ), f"You cannot use async dependency `{self__.call_name}` at sync main" + self.call = dependency_overrides.get(self.call, self.call) + assert self.is_async or not is_coroutine_callable( + self.call + ), f"You cannot use async dependency `{self.call_name}` at sync main" - if self__.use_cache and self__.call in cache_dependencies: - return cache_dependencies[self__.call] + if self.use_cache and self.call in cache_dependencies: + return cache_dependencies[self.call] kw = {} - for arg in self__.keyword_args: + for arg in self.keyword_args: v = kwargs.pop(arg, inspect._empty) if v is not inspect._empty: kw[arg] = v - if "kwargs" in self__.alias_arguments: + if "kwargs" in self.alias_arguments: kw["kwargs"] = kwargs else: kw.update(kwargs) - has_args = "args" in self__.alias_arguments + has_args = "args" in self.alias_arguments - for arg in self__.positional_args: + for arg in self.positional_args: if args: kw[arg], args = args[0], args[1:] @@ -199,7 +201,7 @@ def _solve( kw["args"] = args else: - for arg in self__.keyword_args: + for arg in self.keyword_args: if args: kw[arg], args = args[0], args[1:] @@ -207,17 +209,17 @@ def _solve( solved_kw = yield (), kw casted_model: object - if self__.cast: - casted_model = self__.model(**solved_kw) + if self.cast: + casted_model = self.model(**solved_kw) else: casted_model = object() kwargs_ = { arg: getattr(casted_model, arg, solved_kw.get(arg)) for arg in ( - self__.keyword_args + self__.positional_args + self.keyword_args + self.positional_args if not has_args - else self__.keyword_args + else self.keyword_args ) } kwargs_.update(getattr(casted_model, "kwargs", {})) @@ -226,7 +228,7 @@ def _solve( if has_args: args_ = [ getattr(casted_model, arg, solved_kw.get(arg)) - for arg in self__.positional_args + for arg in self.positional_args ] args_.extend(getattr(casted_model, "args", ())) else: @@ -235,22 +237,23 @@ def _solve( response: T response = yield args_, kwargs_ - if self__.cast and not self__.is_generator: - response = self__._cast_response(response) + if self.cast and not self.is_generator: + response = self._cast_response(response) - if self__.use_cache: # pragma: no branch - cache_dependencies[self__.call] = response + if self.use_cache: # pragma: no branch + cache_dependencies[self.call] = response return response - def _cast_response(self__, value: Any) -> Any: - if self__.response_model is not None and self__.cast: - return self__.response_model(response=value).response + def _cast_response(self, /, value: Any) -> Any: + if self.response_model is not None and self.cast: + return self.response_model(response=value).response else: return value def solve( - self__, + self, + /, *args: P.args, stack: ExitStack, cache_dependencies: Dict[ @@ -275,7 +278,7 @@ def solve( nested: bool = False, **kwargs: P.kwargs, ) -> T: - cast_gen = self__._solve( + cast_gen = self._solve( *args, cache_dependencies=cache_dependencies, dependency_overrides=dependency_overrides, @@ -287,7 +290,7 @@ def solve( cached_value: T = e.value return cached_value - for dep in self__.extra_dependencies: + for dep in self.extra_dependencies: dep.solve( stack=stack, cache_dependencies=cache_dependencies, @@ -296,7 +299,7 @@ def solve( **kwargs, ) - for dep_arg, dep in self__.dependencies.items(): + for dep_arg, dep in self.dependencies.items(): kwargs[dep_arg] = dep.solve( stack=stack, cache_dependencies=cache_dependencies, @@ -305,37 +308,38 @@ def solve( **kwargs, ) - for custom in self__.custom_fields.values(): + for custom in self.custom_fields.values(): kwargs = custom.use(**kwargs) final_args, final_kwargs = cast_gen.send(kwargs) - if self__.is_generator and nested: + if self.is_generator and nested: response = solve_generator_sync( *final_args, - call=self__.call, + call=self.call, stack=stack, **final_kwargs, ) else: - response = self__.call(*final_args, **final_kwargs) + response = self.call(*final_args, **final_kwargs) try: cast_gen.send(response) except StopIteration as e: value: T = e.value - if not self__.cast or nested or not self__.is_generator: + if not self.cast or nested or not self.is_generator: return value else: - return map(self__._cast_response, value) # type: ignore[no-any-return, call-overload] + return map(self._cast_response, value) # type: ignore[no-any-return, call-overload] assert_never(response) # pragma: no cover async def asolve( - self__, + self, + /, *args: P.args, stack: AsyncExitStack, cache_dependencies: Dict[ @@ -360,7 +364,7 @@ async def asolve( nested: bool = False, **kwargs: P.kwargs, ) -> T: - cast_gen = self__._solve( + cast_gen = self._solve( *args, cache_dependencies=cache_dependencies, dependency_overrides=dependency_overrides, @@ -372,7 +376,7 @@ async def asolve( cached_value: T = e.value return cached_value - for dep in self__.extra_dependencies: + for dep in self.extra_dependencies: await dep.asolve( stack=stack, cache_dependencies=cache_dependencies, @@ -381,7 +385,7 @@ async def asolve( **kwargs, ) - for dep_arg, dep in self__.dependencies.items(): + for dep_arg, dep in self.dependencies.items(): kwargs[dep_arg] = await dep.asolve( stack=stack, cache_dependencies=cache_dependencies, @@ -390,30 +394,30 @@ async def asolve( **kwargs, ) - for custom in self__.custom_fields.values(): + for custom in self.custom_fields.values(): kwargs = await run_async(custom.use, **kwargs) final_args, final_kwargs = cast_gen.send(kwargs) - if self__.is_generator and nested: + if self.is_generator and nested: response = await solve_generator_async( *final_args, - call=self__.call, + call=self.call, stack=stack, **final_kwargs, ) else: - response = await run_async(self__.call, *final_args, **final_kwargs) + response = await run_async(self.call, *final_args, **final_kwargs) try: cast_gen.send(response) except StopIteration as e: value: T = e.value - if not self__.cast or nested or not self__.is_generator: + if not self.cast or nested or not self.is_generator: return value else: - return async_map(self__._cast_response, value) # type: ignore[return-value, arg-type] + return async_map(self._cast_response, value) # type: ignore[return-value, arg-type] assert_never(response) # pragma: no cover diff --git a/fast_depends/library/model.py b/fast_depends/library/model.py index f118dec..13a721f 100644 --- a/fast_depends/library/model.py +++ b/fast_depends/library/model.py @@ -9,7 +9,12 @@ class CustomField(ABC): cast: bool required: bool - def __init__(self, *, cast: bool = True, required: bool = True) -> None: + def __init__( + self, + *, + cast: bool = True, + required: bool = True, + ) -> None: self.cast = cast self.param_name = None self.required = required @@ -18,6 +23,6 @@ def set_param_name(self: Cls, name: str) -> Cls: self.param_name = name return self - def use(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]: + def use(self, /, **kwargs: Dict[str, Any]) -> Dict[str, Any]: assert self.param_name, "You should specify `param_name` before using" return kwargs diff --git a/pyproject.toml b/pyproject.toml index a02f57f..7e892a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,12 +127,13 @@ known_third_party = ["pydantic", "anyio"] [tool.ruff] select = [ - "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # pyflakes - "I", # isort - "C", # flake8-comprehensions - "B", # flake8-bugbear + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "C", # flake8-comprehensions + "B", # flake8-bugbear + "T20", # flake8-print ] ignore = [ "E501", # line too long, handled by black diff --git a/tests/async/test_class.py b/tests/async/test_class.py index 7d5ddd9..bf59554 100644 --- a/tests/async/test_class.py +++ b/tests/async/test_class.py @@ -9,11 +9,11 @@ def _get_var(): class Class: @inject - def __init__(self, a = Depends(_get_var)) -> None: + def __init__(self, a=Depends(_get_var)) -> None: self.a = a @inject - async def calc(self, a = Depends(_get_var)) -> int: + async def calc(self, a=Depends(_get_var)) -> int: return a + self.a diff --git a/tests/library/test_custom.py b/tests/library/test_custom.py index 7aecb18..5e68c66 100644 --- a/tests/library/test_custom.py +++ b/tests/library/test_custom.py @@ -10,7 +10,7 @@ class Header(CustomField): - def use(self, **kwargs: Callable[..., Any]) -> Callable[..., Any]: + def use(self, /, **kwargs: Callable[..., Any]) -> Callable[..., Any]: kwargs = super().use(**kwargs) if kwargs.get("headers", {}).get(self.param_name): kwargs[self.param_name] = kwargs.get("headers", {}).get(self.param_name) @@ -30,6 +30,15 @@ def sync_catch(key: int = Header()): # noqa: B008 assert sync_catch(headers={"key": "1"}) == 1 +def test_custom_with_class(): + class T: + @inject + def __init__(self, key: int = Header()): + self.key = key + + assert T(headers={"key": "1"}).key == 1 + + @pytest.mark.anyio async def test_header_async(): @inject diff --git a/tests/sync/test_class.py b/tests/sync/test_class.py index ddc9f3a..93a94f1 100644 --- a/tests/sync/test_class.py +++ b/tests/sync/test_class.py @@ -7,11 +7,11 @@ def _get_var(): class Class: @inject - def __init__(self, a = Depends(_get_var)) -> None: + def __init__(self, a=Depends(_get_var)) -> None: self.a = a @inject - def calc(self, a = Depends(_get_var)) -> int: + def calc(self, a=Depends(_get_var)) -> int: return a + self.a