From 3dfd6fe3de6a246acb358304ae44872f64c2c9f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Sat, 7 Oct 2023 17:34:41 +0200 Subject: [PATCH] fix(connection): Simplify connection header parsing (#2398) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(connection): use Headers for connection headers * Use MultiDict in route handler * Make changes backwards compatible --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --- litestar/_kwargs/extractors.py | 19 +++++++++---------- litestar/_parsers.py | 25 +------------------------ litestar/connection/base.py | 11 +++-------- litestar/datastructures/headers.py | 10 +++++----- litestar/datastructures/multi_dicts.py | 4 +--- tests/unit/test_parsers.py | 25 ------------------------- 6 files changed, 19 insertions(+), 75 deletions(-) diff --git a/litestar/_kwargs/extractors.py b/litestar/_kwargs/extractors.py index 9e9daadceb..df0845ce9f 100644 --- a/litestar/_kwargs/extractors.py +++ b/litestar/_kwargs/extractors.py @@ -2,14 +2,14 @@ from collections import defaultdict from functools import lru_cache -from typing import TYPE_CHECKING, Any, Callable, Coroutine, NamedTuple, cast +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Mapping, NamedTuple, cast from litestar._multipart import parse_multipart_form from litestar._parsers import ( - parse_headers, parse_query_string, parse_url_encoded_form_data, ) +from litestar.datastructures import Headers from litestar.datastructures.upload_file import UploadFile from litestar.enums import ParamType, RequestEncodingType from litestar.exceptions import ValidationException @@ -78,7 +78,7 @@ def create_connection_value_extractor( kwargs_model: KwargsModel, connection_key: str, expected_params: set[ParameterDefinition], - parser: Callable[[ASGIConnection, KwargsModel], dict[str, Any]] | None = None, + parser: Callable[[ASGIConnection, KwargsModel], Mapping[str, Any]] | None = None, ) -> Callable[[dict[str, Any], ASGIConnection], None]: """Create a kwargs extractor function. @@ -155,7 +155,7 @@ def parse_connection_query_params(connection: ASGIConnection, kwargs_model: Kwar ) -def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> dict[str, Any]: +def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> Headers: """Parse header parameters and cache the result in scope. Args: @@ -163,12 +163,9 @@ def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> dict _: The KwargsModel instance. Returns: - A dictionary of parsed values + A Headers instance """ - parsed_headers = connection.scope["_headers"] = ( # type: ignore - connection._headers if connection._headers is not Empty else parse_headers(tuple(connection.scope["headers"])) - ) - return cast("dict[str, Any]", parsed_headers) + return Headers.from_scope(connection.scope) def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: @@ -194,7 +191,9 @@ def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> Non Returns: None """ - values["headers"] = connection.headers + # TODO: This should be removed in 3.0 and instead Headers should be injected + # directly. We are only keeping this one around to not break things + values["headers"] = dict(connection.headers.items()) def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> None: diff --git a/litestar/_parsers.py b/litestar/_parsers.py index 49b795514d..9b9f459346 100644 --- a/litestar/_parsers.py +++ b/litestar/_parsers.py @@ -3,7 +3,6 @@ from collections import defaultdict from functools import lru_cache from http.cookies import _unquote as unquote_cookie -from typing import Iterable from urllib.parse import unquote try: @@ -15,7 +14,7 @@ def parse_qsl(qs: bytes, separator: str) -> list[tuple[str, str]]: return _parse_qsl(qs.decode("latin-1"), keep_blank_values=True, separator=separator) -__all__ = ("parse_cookie_string", "parse_headers", "parse_query_string", "parse_url_encoded_form_data") +__all__ = ("parse_cookie_string", "parse_query_string", "parse_url_encoded_form_data") @lru_cache(1024) @@ -66,25 +65,3 @@ def parse_cookie_string(cookie_string: str) -> dict[str, str]: ) } return output - - -@lru_cache(1024) -def _parse_headers(headers: tuple[tuple[bytes, bytes], ...]) -> dict[str, str]: - """Parse ASGI headers into a dict of string keys and values. - - Args: - headers: A tuple of bytes two tuples. - - Returns: - A string / string dict. - """ - return {k.decode(): v.decode() for k, v in headers} - - -def parse_headers(headers: Iterable[tuple[bytes, bytes] | list[bytes]]) -> dict[str, str]: - """Parse ASGI headers into a dict of string keys and values. - - Since the ASGI protocol only allows for lists (not tuples) which cannot be hashed, - this function will convert the headers to a tuple of tuples before invoking the cached function. - """ - return _parse_headers(tuple(tuple(h) for h in headers)) diff --git a/litestar/connection/base.py b/litestar/connection/base.py index b4310f0445..0d9abdd21a 100644 --- a/litestar/connection/base.py +++ b/litestar/connection/base.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast -from litestar._parsers import parse_cookie_string, parse_headers, parse_query_string +from litestar._parsers import parse_cookie_string, parse_query_string from litestar.datastructures.headers import Headers from litestar.datastructures.multi_dicts import MultiDict from litestar.datastructures.state import State @@ -55,7 +55,7 @@ async def empty_send(_: Message) -> NoReturn: # pragma: no cover class ASGIConnection(Generic[HandlerT, UserT, AuthT, StateT]): """The base ASGI connection container.""" - __slots__ = ("scope", "receive", "send", "_base_url", "_url", "_parsed_query", "_headers", "_cookies") + __slots__ = ("scope", "receive", "send", "_base_url", "_url", "_parsed_query", "_cookies") scope: Scope """The ASGI scope attached to the connection.""" @@ -79,7 +79,6 @@ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = self._url: Any = scope.get("_url", Empty) self._parsed_query: Any = scope.get("_parsed_query", Empty) self._cookies: Any = scope.get("_cookies", Empty) - self._headers: Any = scope.get("_headers", Empty) @property def app(self) -> Litestar: @@ -146,11 +145,7 @@ def headers(self) -> Headers: Returns: A Headers instance with the request's scope["headers"] value. """ - if self._headers is Empty: - self.scope.setdefault("headers", []) - self._headers = self.scope["_headers"] = parse_headers(tuple(self.scope["headers"])) # type: ignore[typeddict-unknown-key] - - return Headers(self._headers) + return Headers.from_scope(self.scope) @property def query_params(self) -> MultiDict[Any]: diff --git a/litestar/datastructures/headers.py b/litestar/datastructures/headers.py index 59a75274c0..c87223da3e 100644 --- a/litestar/datastructures/headers.py +++ b/litestar/datastructures/headers.py @@ -24,7 +24,6 @@ from typing_extensions import get_type_hints from litestar._multipart import parse_content_header -from litestar._parsers import parse_headers from litestar.datastructures.multi_dicts import MultiMixin from litestar.dto.base_dto import AbstractDTO from litestar.exceptions import ImproperlyConfiguredException, ValidationException @@ -51,7 +50,7 @@ def _encode_headers(headers: Iterable[Tuple[str, str]]) -> "RawHeadersList": class Headers(CIMultiDictProxy[str], MultiMixin[str]): - """An immutable, case-insensitive for HTTP headers.""" + """An immutable, case-insensitive multi dict for HTTP headers.""" def __init__(self, headers: Optional[Union[Mapping[str, str], "RawHeaders", MultiMapping]] = None) -> None: """Initialize ``Headers``. @@ -85,9 +84,10 @@ def from_scope(cls, scope: "HeaderScope") -> "Headers": Raises: ValueError: If the message does not have a ``headers`` key """ - if "_headers" not in scope: - scope["_headers"] = parse_headers(tuple(scope["headers"])) # type: ignore - return cls(scope["_headers"]) # type: ignore + if (headers := scope.get("_headers")) is None: + headers = scope["_headers"] = cls(scope["headers"]) # type: ignore[typeddict-unknown-key] + + return cast("Headers", headers) def to_header_list(self) -> "RawHeadersList": """Raw header value. diff --git a/litestar/datastructures/multi_dicts.py b/litestar/datastructures/multi_dicts.py index 650a58f293..c570e8bf90 100644 --- a/litestar/datastructures/multi_dicts.py +++ b/litestar/datastructures/multi_dicts.py @@ -40,9 +40,7 @@ class MultiDict(BaseMultiDict[T], MultiMixin[T], Generic[T]): """MultiDict, using :class:`MultiDict `.""" def __init__(self, args: MultiMapping | Mapping[str, T] | Iterable[tuple[str, T]] | None = None) -> None: - """Initialize ``MultiDict`` from a. - - ``MultiMapping``, :class:`Mapping ` or an iterable of tuples. + """Initialize ``MultiDict`` from a`MultiMapping``, :class:`Mapping ` or an iterable of tuples. Args: args: Mapping-like structure to create the ``MultiDict`` from diff --git a/tests/unit/test_parsers.py b/tests/unit/test_parsers.py index 14c8e90a66..8908f03911 100644 --- a/tests/unit/test_parsers.py +++ b/tests/unit/test_parsers.py @@ -5,9 +5,7 @@ from litestar import HttpMethod from litestar._parsers import ( - _parse_headers, parse_cookie_string, - parse_headers, parse_query_string, parse_url_encoded_form_data, ) @@ -107,26 +105,3 @@ def test_query_parsing_of_escaped_values(values: Tuple[Tuple[str, str], Tuple[st request = client.build_request(method=HttpMethod.GET, url="http://www.example.com", params=dict(values)) parsed_query = parse_query_string(request.url.query) assert parsed_query == values - - -def test_parse_headers() -> None: - """Test that headers are parsed correctly.""" - headers = [ - [b"Host", b"localhost:8000"], - [b"User-Agent", b"curl/7.64.1"], - [b"Accept", b"*/*"], - [b"Cookie", b"foo=bar; bar=baz"], - [b"Content-Type", b"application/x-www-form-urlencoded"], - [b"Content-Length", b"12"], - ] - parsed = parse_headers(headers) - assert parsed["Host"] == "localhost:8000" - assert parsed["User-Agent"] == "curl/7.64.1" - assert parsed["Accept"] == "*/*" - assert parsed["Cookie"] == "foo=bar; bar=baz" - assert parsed["Content-Type"] == "application/x-www-form-urlencoded" - assert parsed["Content-Length"] == "12" - # demonstrate that calling the private function with lists (as ASGI specifies) - # does raise an error - with pytest.raises(TypeError): - _parse_headers(headers) # type: ignore[arg-type]