diff --git a/asab/api/discovery.py b/asab/api/discovery.py index c9b86046..0a37b38d 100644 --- a/asab/api/discovery.py +++ b/asab/api/discovery.py @@ -20,7 +20,7 @@ jwcrypto = None from .. import Service -from ..contextvars import Tenant +from ..contextvars import Tenant, Request L = logging.getLogger(__name__) @@ -401,10 +401,17 @@ def session( ... """ _headers = {} - if isinstance(auth, aiohttp.web.Request): - # TODO: This should be the default option. Use contextvar to access the request. + + if auth is None: + # By default, use the authorization from the incoming request + request = Request.get(None) + if request is not None: + _headers["Authorization"] = request.headers.get("Authorization") + + elif isinstance(auth, aiohttp.web.Request): assert "Authorization" in auth.headers _headers["Authorization"] = auth.headers.get("Authorization") + elif auth == "internal": if jwcrypto is None: raise ModuleNotFoundError( @@ -412,8 +419,7 @@ def session( "Please run 'pip install jwcrypto' or install asab with 'authz' optional dependency." ) _headers["Authorization"] = "Bearer {}".format(self.InternalAuthToken.serialize()) - elif auth is None: - pass + else: raise ValueError( "Invalid 'auth' value. " diff --git a/asab/contextvars.py b/asab/contextvars.py index 18babb99..0ab01daa 100644 --- a/asab/contextvars.py +++ b/asab/contextvars.py @@ -1,3 +1,6 @@ import contextvars Tenant = contextvars.ContextVar("Tenant") + +# Contains aiohttp.web.Request +Request = contextvars.ContextVar("Request") diff --git a/asab/web/container.py b/asab/web/container.py index 10acd480..fd94cc5f 100644 --- a/asab/web/container.py +++ b/asab/web/container.py @@ -11,6 +11,7 @@ from ..tls import SSLContextBuilder from .service import WebService from ..application import Application +from ..contextvars import Request # @@ -136,6 +137,19 @@ def __init__(self, websvc: WebService, config_section_name: str, config: typing. preflight_paths = re.split(r"[,\s]+", preflight_str, re.MULTILINE) self.add_preflight_handlers(preflight_paths) + @aiohttp.web.middleware + async def set_request_context(request: aiohttp.web.Request, handler): + """ + Make sure that the incoming aiohttp.web.Request is available via Request context variable + """ + request_ctx = Request.set(request) + try: + return await handler(request) + finally: + Request.reset(request_ctx) + + self.WebApp.middlewares.append(set_request_context) + async def _start(self, app: Application): await self.WebAppRunner.setup()