diff --git a/falcon/hooks.py b/falcon/hooks.py index 45c477d24..b2f6c4ce8 100644 --- a/falcon/hooks.py +++ b/falcon/hooks.py @@ -19,29 +19,64 @@ from inspect import getmembers from inspect import iscoroutinefunction import re -import typing as t +import typing from falcon.constants import COMBINED_METHODS from falcon.util.misc import get_argnames from falcon.util.sync import _wrap_non_coroutine_unsafe -if t.TYPE_CHECKING: # pragma: no cover +if typing.TYPE_CHECKING: # pragma: no cover import falcon as wsgi from falcon import asgi + ResponderParams = typing.ParamSpec('ResponderParams') + + class SyncResponder(typing.Protocol[ResponderParams]): + def __call__( + self, + responder: SyncResponderOrResource, + req: wsgi.Request, + resp: wsgi.Response, + *args: ResponderParams.args, + **kwargs: ResponderParams.kwargs, + ) -> None: + ... + + class AsyncResponder(typing.Protocol): + async def __call__( + self, + responder: AsyncResponderOrResource, + req: asgi.Request, + resp: asgi.Response, + *args: ResponderParams.args, + **kwargs: ResponderParams.kwargs, + ) -> None: + ... + + Responder = typing.Union[SyncResponder, AsyncResponder] + Resource = object + SyncResponderOrResource = typing.Union[SyncResponder, Resource] + AsyncResponderOrResource = typing.Union[AsyncResponder, Resource] + ResponderOrResource = typing.Union[Responder, Resource] + SynchronousAction = typing.Callable[..., typing.Any] + AsynchronousAction = typing.Callable[..., typing.Awaitable[typing.Any]] + Action = typing.Union[SynchronousAction, AsynchronousAction] +else: + Resource = object + SynchronousAction = typing.Callable[..., typing.Any] + AsynchronousAction = typing.Callable[..., typing.Awaitable[typing.Any]] + SyncResponder = typing.Callable + AsyncResponder = typing.Awaitable + Responder = typing.Union[SyncResponder, AsyncResponder] + _DECORABLE_METHOD_NAME = re.compile( r'^on_({})(_\w+)?$'.format('|'.join(method.lower() for method in COMBINED_METHODS)) ) -Resource = object -Responder = t.Callable -ResponderOrResource = t.Union[Responder, Resource] -Action = t.Callable - def before( - action: Action, *args: t.Any, is_async: bool = False, **kwargs: t.Any -) -> t.Callable[[ResponderOrResource], ResponderOrResource]: + action: Action, *args: typing.Any, is_async: bool = False, **kwargs: typing.Any +) -> typing.Callable[[ResponderOrResource], ResponderOrResource]: """Execute the given action function *before* the responder. The `params` argument that is passed to the hook @@ -93,29 +128,28 @@ def do_something(req, resp, resource, params): def _before(responder_or_resource: ResponderOrResource) -> ResponderOrResource: if isinstance(responder_or_resource, type): - resource = responder_or_resource - - for responder_name, responder in getmembers(resource, callable): + for responder_name, responder in getmembers( + responder_or_resource, callable + ): if _DECORABLE_METHOD_NAME.match(responder_name): # This pattern is necessary to capture the current value of # responder in the do_before_all closure; otherwise, they # will capture the same responder variable that is shared # between iterations of the for loop, above. - responder = t.cast(Responder, responder) - def let(responder: Responder = responder) -> None: + def let(responder: typing.Callable = responder) -> None: do_before_all = _wrap_with_before( responder, action, args, kwargs, is_async ) - setattr(resource, responder_name, do_before_all) + setattr(responder_or_resource, responder_name, do_before_all) let() - return resource + return responder_or_resource else: - responder = t.cast(Responder, responder_or_resource) + responder = typing.cast(Responder, responder_or_resource) do_before_one = _wrap_with_before(responder, action, args, kwargs, is_async) return do_before_one @@ -124,8 +158,8 @@ def let(responder: Responder = responder) -> None: def after( - action: Action, *args: t.Any, is_async: bool = False, **kwargs: t.Any -) -> t.Callable[[ResponderOrResource], ResponderOrResource]: + action: Action, *args: typing.Any, is_async: bool = False, **kwargs: typing.Any +) -> typing.Callable[[ResponderOrResource], ResponderOrResource]: """Execute the given action function *after* the responder. Args: @@ -160,25 +194,24 @@ def after( def _after(responder_or_resource: ResponderOrResource) -> ResponderOrResource: if isinstance(responder_or_resource, type): - resource = t.cast(Resource, responder_or_resource) - - for responder_name, responder in getmembers(resource, callable): + for responder_name, responder in getmembers( + responder_or_resource, callable + ): if _DECORABLE_METHOD_NAME.match(responder_name): - responder = t.cast(Responder, responder) - def let(responder: Responder = responder) -> None: + def let(responder: Responder | typing.Callable = responder) -> None: do_after_all = _wrap_with_after( responder, action, args, kwargs, is_async ) - setattr(resource, responder_name, do_after_all) + setattr(responder_or_resource, responder_name, do_after_all) let() - return resource + return responder_or_resource else: - responder = t.cast(Responder, responder_or_resource) + responder = typing.cast(Responder, responder_or_resource) do_after_one = _wrap_with_after(responder, action, args, kwargs, is_async) return do_after_one @@ -194,8 +227,8 @@ def let(responder: Responder = responder) -> None: def _wrap_with_after( responder: Responder, action: Action, - action_args: t.Any, - action_kwargs: t.Any, + action_args: typing.Any, + action_kwargs: typing.Any, is_async: bool, ) -> Responder: """Execute the given action function after a responder method. @@ -214,40 +247,44 @@ def _wrap_with_after( responder_argnames = get_argnames(responder) extra_argnames = responder_argnames[2:] # Skip req, resp + do_after_responder: Responder if is_async or iscoroutinefunction(responder): # NOTE(kgriffs): I manually verified that the implicit "else" branch # is actually covered, but coverage isn't tracking it for # some reason. if not is_async: # pragma: nocover - async_action = _wrap_non_coroutine_unsafe(action) + async_action = typing.cast( + AsynchronousAction, _wrap_non_coroutine_unsafe(action) + ) else: - async_action = action + async_action = typing.cast(AsynchronousAction, action) + async_responder = typing.cast(AsyncResponder, responder) @wraps(responder) async def do_after( - self: ResponderOrResource, + self: AsyncResponderOrResource, req: asgi.Request, resp: asgi.Response, - *args: t.Any, - **kwargs: t.Any, + *args: typing.Any, + **kwargs: typing.Any, ) -> None: if args: _merge_responder_args(args, kwargs, extra_argnames) - - await responder(self, req, resp, **kwargs) - assert async_action + await async_responder(self, req, resp, **kwargs) await async_action(req, resp, self, *action_args, **action_kwargs) + do_after_responder = typing.cast(AsyncResponder, do_after) else: + responder = typing.cast(SyncResponder, responder) @wraps(responder) def do_after( - self: ResponderOrResource, + self: SyncResponderOrResource, req: wsgi.Request, resp: wsgi.Response, - *args: t.Any, - **kwargs: t.Any, + *args: typing.Any, + **kwargs: typing.Any, ) -> None: if args: _merge_responder_args(args, kwargs, extra_argnames) @@ -255,16 +292,17 @@ def do_after( responder(self, req, resp, **kwargs) action(req, resp, self, *action_args, **action_kwargs) - return do_after + do_after_responder = typing.cast(SyncResponder, do_after) + return do_after_responder def _wrap_with_before( responder: Responder, action: Action, - action_args: t.Tuple[t.Any, ...], - action_kwargs: t.Dict[str, t.Any], + action_args: typing.Tuple[typing.Any, ...], + action_kwargs: typing.Dict[str, typing.Any], is_async: bool, -) -> t.Union[t.Callable[..., t.Awaitable[None]], t.Callable[..., None]]: +) -> Responder: """Execute the given action function before a responder method. Args: @@ -281,40 +319,45 @@ def _wrap_with_before( responder_argnames = get_argnames(responder) extra_argnames = responder_argnames[2:] # Skip req, resp + do_before_responder: Responder if is_async or iscoroutinefunction(responder): # NOTE(kgriffs): I manually verified that the implicit "else" branch # is actually covered, but coverage isn't tracking it for # some reason. if not is_async: # pragma: nocover - async_action = _wrap_non_coroutine_unsafe(action) + async_action = typing.cast( + AsynchronousAction, _wrap_non_coroutine_unsafe(action) + ) else: - async_action = action + async_action = typing.cast(AsynchronousAction, action) + async_responder = typing.cast(AsyncResponder, responder) @wraps(responder) async def do_before( - self: ResponderOrResource, + self: AsyncResponderOrResource, req: asgi.Request, resp: asgi.Response, - *args: t.Any, - **kwargs: t.Any, + *args: typing.Any, + **kwargs: typing.Any, ) -> None: if args: _merge_responder_args(args, kwargs, extra_argnames) - assert async_action await async_action(req, resp, self, kwargs, *action_args, **action_kwargs) - await responder(self, req, resp, **kwargs) + await async_responder(self, req, resp, **kwargs) + do_before_responder = typing.cast(AsyncResponder, do_before) else: + responder = typing.cast(SyncResponder, responder) @wraps(responder) def do_before( - self: ResponderOrResource, + self: SyncResponderOrResource, req: wsgi.Request, resp: wsgi.Response, - *args: t.Any, - **kwargs: t.Any, + *args: typing.Any, + **kwargs: typing.Any, ) -> None: if args: _merge_responder_args(args, kwargs, extra_argnames) @@ -322,11 +365,14 @@ def do_before( action(req, resp, self, kwargs, *action_args, **action_kwargs) responder(self, req, resp, **kwargs) - return do_before + do_before_responder = typing.cast(SyncResponder, do_before) + return do_before_responder def _merge_responder_args( - args: t.Tuple[t.Any, ...], kwargs: t.Dict[str, t.Any], argnames: t.List[str] + args: typing.Tuple[typing.Any, ...], + kwargs: typing.Dict[str, typing.Any], + argnames: typing.List[str], ) -> None: """Merge responder args into kwargs. diff --git a/falcon/testing/resource.py b/falcon/testing/resource.py index c20854a3e..bb738cf5f 100644 --- a/falcon/testing/resource.py +++ b/falcon/testing/resource.py @@ -22,13 +22,26 @@ resource = testing.SimpleTestResource() """ +from __future__ import annotations from json import dumps as json_dumps +import typing import falcon +if typing.TYPE_CHECKING: # pragma: no cover + from falcon import app as wsgi + from falcon.asgi import app as asgi + from falcon.hooks import ResponderOrResource + from falcon.typing import RawHeaders -def capture_responder_args(req, resp, resource, params): + +def capture_responder_args( + req: wsgi.Request, + resp: wsgi.Response, + resource: ResponderOrResource, + params: typing.Mapping[str, str], +) -> None: """Before hook for capturing responder arguments. Adds the following attributes to the hooked responder's resource @@ -49,41 +62,53 @@ def capture_responder_args(req, resp, resource, params): * `capture-req-media` """ - resource.captured_req = req - resource.captured_resp = resp - resource.captured_kwargs = params + simple_resource = typing.cast(SimpleTestResource, resource) + simple_resource.captured_req = req + simple_resource.captured_resp = resp + simple_resource.captured_kwargs = params - resource.captured_req_media = None - resource.captured_req_body = None + simple_resource.captured_req_media = None + simple_resource.captured_req_body = None num_bytes = req.get_header('capture-req-body-bytes') if num_bytes: - resource.captured_req_body = req.stream.read(int(num_bytes)) + simple_resource.captured_req_body = req.stream.read(int(num_bytes)) elif req.get_header('capture-req-media'): - resource.captured_req_media = req.get_media() + simple_resource.captured_req_media = req.get_media() -async def capture_responder_args_async(req, resp, resource, params): +async def capture_responder_args_async( + req: asgi.Request, + resp: asgi.Response, + resource: ResponderOrResource, + params: typing.Mapping[str, str], +) -> None: """Before hook for capturing responder arguments. An asynchronous version of :meth:`~falcon.testing.capture_responder_args`. """ - resource.captured_req = req - resource.captured_resp = resp - resource.captured_kwargs = params + simple_resource = typing.cast(SimpleTestResource, resource) + simple_resource.captured_req = req + simple_resource.captured_resp = resp + simple_resource.captured_kwargs = params - resource.captured_req_media = None - resource.captured_req_body = None + simple_resource.captured_req_media = None + simple_resource.captured_req_body = None num_bytes = req.get_header('capture-req-body-bytes') if num_bytes: - resource.captured_req_body = await req.stream.read(int(num_bytes)) + simple_resource.captured_req_body = await req.stream.read(int(num_bytes)) elif req.get_header('capture-req-media'): - resource.captured_req_media = await req.get_media() + simple_resource.captured_req_media = await req.get_media() -def set_resp_defaults(req, resp, resource, params): +def set_resp_defaults( + req: wsgi.Request, + resp: wsgi.Response, + resource: ResponderOrResource, + params: typing.Mapping[str, str], +) -> None: """Before hook for setting default response properties. This hook simply sets the the response body, status, @@ -92,18 +117,23 @@ def set_resp_defaults(req, resp, resource, params): that are assumed to be defined on the resource object. """ + simple_resource = typing.cast(SimpleTestResource, resource) + if simple_resource._default_status is not None: + resp.status = simple_resource._default_status - if resource._default_status is not None: - resp.status = resource._default_status - - if resource._default_body is not None: - resp.text = resource._default_body + if simple_resource._default_body is not None: + resp.text = simple_resource._default_body - if resource._default_headers is not None: - resp.set_headers(resource._default_headers) + if simple_resource._default_headers is not None: + resp.set_headers(simple_resource._default_headers) -async def set_resp_defaults_async(req, resp, resource, params): +async def set_resp_defaults_async( + req: asgi.Request, + resp: asgi.Response, + resource: ResponderOrResource, + params: typing.Mapping[str, str], +) -> None: """Wrap :meth:`~falcon.testing.set_resp_defaults` in a coroutine.""" set_resp_defaults(req, resp, resource, params) @@ -145,7 +175,13 @@ class SimpleTestResource: responder methods. """ - def __init__(self, status=None, body=None, json=None, headers=None): + def __init__( + self, + status: typing.Optional[str] = None, + body: typing.Optional[str] = None, + json: typing.Optional[dict[str, str]] = None, + headers: typing.Optional[RawHeaders] = None, + ): self._default_status = status self._default_headers = headers @@ -154,14 +190,22 @@ def __init__(self, status=None, body=None, json=None, headers=None): msg = 'Either json or body may be specified, but not both' raise ValueError(msg) - self._default_body = json_dumps(json, ensure_ascii=False) + self._default_body: typing.Optional[str] = json_dumps( + json, ensure_ascii=False + ) else: self._default_body = body - self.captured_req = None - self.captured_resp = None - self.captured_kwargs = None + self.captured_req: typing.Optional[ + typing.Union[wsgi.Request, asgi.Request] + ] = None + self.captured_resp: typing.Optional[ + typing.Union[wsgi.Response, asgi.Response] + ] = None + self.captured_kwargs: typing.Optional[typing.Any] = None + self.captured_req_media: typing.Optional[typing.Any] = None + self.captured_req_body: typing.Optional[str] = None @property def called(self): @@ -169,12 +213,16 @@ def called(self): @falcon.before(capture_responder_args) @falcon.before(set_resp_defaults) - def on_get(self, req, resp, **kwargs): + def on_get( + self, req: wsgi.Request, resp: wsgi.Response, **kwargs: typing.Any + ) -> None: pass @falcon.before(capture_responder_args) @falcon.before(set_resp_defaults) - def on_post(self, req, resp, **kwargs): + def on_post( + self, req: wsgi.Request, resp: wsgi.Response, **kwargs: typing.Any + ) -> None: pass @@ -218,10 +266,14 @@ class SimpleTestResourceAsync(SimpleTestResource): @falcon.before(capture_responder_args_async) @falcon.before(set_resp_defaults_async) - async def on_get(self, req, resp, **kwargs): + async def on_get( + self, req: asgi.Request, resp: asgi.Response, **kwargs: typing.Any + ) -> None: pass @falcon.before(capture_responder_args_async) @falcon.before(set_resp_defaults_async) - async def on_post(self, req, resp, **kwargs): + async def on_post( + self, req: asgi.Request, resp: asgi.Response, **kwargs: typing.Any + ) -> None: pass diff --git a/falcon/typing.py b/falcon/typing.py index 540a9070a..aff77e039 100644 --- a/falcon/typing.py +++ b/falcon/typing.py @@ -27,11 +27,11 @@ from typing import Union if TYPE_CHECKING: + from typing import Protocol + from falcon.request import Request from falcon.response import Response - from typing import Protocol - class Serializer(Protocol): def serialize( self, diff --git a/tests/test_after_hooks.py b/tests/test_after_hooks.py index 4f95914b7..06eb2adfd 100644 --- a/tests/test_after_hooks.py +++ b/tests/test_after_hooks.py @@ -1,10 +1,13 @@ import functools import json +import typing import pytest import falcon +from falcon import app as wsgi from falcon import testing +from falcon.hooks import Resource from _util import create_app, create_resp # NOQA @@ -346,8 +349,9 @@ class ResourceAwareGameHook: VALUES = ('rock', 'scissors', 'paper') @classmethod - def __call__(cls, req, resp, resource): + def __call__(cls, req: wsgi.Request, resp: wsgi.Response, resource: Resource): assert resource + resource = typing.cast(HandGame, resource) assert resource.seed in cls.VALUES assert resp.text == 'Responder called.'