Skip to content

Commit

Permalink
fix(connection): Simplify connection header parsing (#2398)
Browse files Browse the repository at this point in the history
* fix(connection): use Headers for connection headers
* Use MultiDict in route handler
* Make changes backwards compatible


---------

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>
  • Loading branch information
provinzkraut authored Oct 7, 2023
1 parent 2db93dd commit 3dfd6fe
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 75 deletions.
19 changes: 9 additions & 10 deletions litestar/_kwargs/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from collections import defaultdict
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Callable, Coroutine, NamedTuple, cast
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Mapping, NamedTuple, cast

from litestar._multipart import parse_multipart_form
from litestar._parsers import (
parse_headers,
parse_query_string,
parse_url_encoded_form_data,
)
from litestar.datastructures import Headers
from litestar.datastructures.upload_file import UploadFile
from litestar.enums import ParamType, RequestEncodingType
from litestar.exceptions import ValidationException
Expand Down Expand Up @@ -78,7 +78,7 @@ def create_connection_value_extractor(
kwargs_model: KwargsModel,
connection_key: str,
expected_params: set[ParameterDefinition],
parser: Callable[[ASGIConnection, KwargsModel], dict[str, Any]] | None = None,
parser: Callable[[ASGIConnection, KwargsModel], Mapping[str, Any]] | None = None,
) -> Callable[[dict[str, Any], ASGIConnection], None]:
"""Create a kwargs extractor function.
Expand Down Expand Up @@ -155,20 +155,17 @@ def parse_connection_query_params(connection: ASGIConnection, kwargs_model: Kwar
)


def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> dict[str, Any]:
def parse_connection_headers(connection: ASGIConnection, _: KwargsModel) -> Headers:
"""Parse header parameters and cache the result in scope.
Args:
connection: The ASGI connection instance.
_: The KwargsModel instance.
Returns:
A dictionary of parsed values
A Headers instance
"""
parsed_headers = connection.scope["_headers"] = ( # type: ignore
connection._headers if connection._headers is not Empty else parse_headers(tuple(connection.scope["headers"]))
)
return cast("dict[str, Any]", parsed_headers)
return Headers.from_scope(connection.scope)


def state_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
Expand All @@ -194,7 +191,9 @@ def headers_extractor(values: dict[str, Any], connection: ASGIConnection) -> Non
Returns:
None
"""
values["headers"] = connection.headers
# TODO: This should be removed in 3.0 and instead Headers should be injected
# directly. We are only keeping this one around to not break things
values["headers"] = dict(connection.headers.items())


def cookies_extractor(values: dict[str, Any], connection: ASGIConnection) -> None:
Expand Down
25 changes: 1 addition & 24 deletions litestar/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections import defaultdict
from functools import lru_cache
from http.cookies import _unquote as unquote_cookie
from typing import Iterable
from urllib.parse import unquote

try:
Expand All @@ -15,7 +14,7 @@ def parse_qsl(qs: bytes, separator: str) -> list[tuple[str, str]]:
return _parse_qsl(qs.decode("latin-1"), keep_blank_values=True, separator=separator)


__all__ = ("parse_cookie_string", "parse_headers", "parse_query_string", "parse_url_encoded_form_data")
__all__ = ("parse_cookie_string", "parse_query_string", "parse_url_encoded_form_data")


@lru_cache(1024)
Expand Down Expand Up @@ -66,25 +65,3 @@ def parse_cookie_string(cookie_string: str) -> dict[str, str]:
)
}
return output


@lru_cache(1024)
def _parse_headers(headers: tuple[tuple[bytes, bytes], ...]) -> dict[str, str]:
"""Parse ASGI headers into a dict of string keys and values.
Args:
headers: A tuple of bytes two tuples.
Returns:
A string / string dict.
"""
return {k.decode(): v.decode() for k, v in headers}


def parse_headers(headers: Iterable[tuple[bytes, bytes] | list[bytes]]) -> dict[str, str]:
"""Parse ASGI headers into a dict of string keys and values.
Since the ASGI protocol only allows for lists (not tuples) which cannot be hashed,
this function will convert the headers to a tuple of tuples before invoking the cached function.
"""
return _parse_headers(tuple(tuple(h) for h in headers))
11 changes: 3 additions & 8 deletions litestar/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast

from litestar._parsers import parse_cookie_string, parse_headers, parse_query_string
from litestar._parsers import parse_cookie_string, parse_query_string
from litestar.datastructures.headers import Headers
from litestar.datastructures.multi_dicts import MultiDict
from litestar.datastructures.state import State
Expand Down Expand Up @@ -55,7 +55,7 @@ async def empty_send(_: Message) -> NoReturn: # pragma: no cover
class ASGIConnection(Generic[HandlerT, UserT, AuthT, StateT]):
"""The base ASGI connection container."""

__slots__ = ("scope", "receive", "send", "_base_url", "_url", "_parsed_query", "_headers", "_cookies")
__slots__ = ("scope", "receive", "send", "_base_url", "_url", "_parsed_query", "_cookies")

scope: Scope
"""The ASGI scope attached to the connection."""
Expand All @@ -79,7 +79,6 @@ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send =
self._url: Any = scope.get("_url", Empty)
self._parsed_query: Any = scope.get("_parsed_query", Empty)
self._cookies: Any = scope.get("_cookies", Empty)
self._headers: Any = scope.get("_headers", Empty)

@property
def app(self) -> Litestar:
Expand Down Expand Up @@ -146,11 +145,7 @@ def headers(self) -> Headers:
Returns:
A Headers instance with the request's scope["headers"] value.
"""
if self._headers is Empty:
self.scope.setdefault("headers", [])
self._headers = self.scope["_headers"] = parse_headers(tuple(self.scope["headers"])) # type: ignore[typeddict-unknown-key]

return Headers(self._headers)
return Headers.from_scope(self.scope)

@property
def query_params(self) -> MultiDict[Any]:
Expand Down
10 changes: 5 additions & 5 deletions litestar/datastructures/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing_extensions import get_type_hints

from litestar._multipart import parse_content_header
from litestar._parsers import parse_headers
from litestar.datastructures.multi_dicts import MultiMixin
from litestar.dto.base_dto import AbstractDTO
from litestar.exceptions import ImproperlyConfiguredException, ValidationException
Expand All @@ -51,7 +50,7 @@ def _encode_headers(headers: Iterable[Tuple[str, str]]) -> "RawHeadersList":


class Headers(CIMultiDictProxy[str], MultiMixin[str]):
"""An immutable, case-insensitive for HTTP headers."""
"""An immutable, case-insensitive multi dict for HTTP headers."""

def __init__(self, headers: Optional[Union[Mapping[str, str], "RawHeaders", MultiMapping]] = None) -> None:
"""Initialize ``Headers``.
Expand Down Expand Up @@ -85,9 +84,10 @@ def from_scope(cls, scope: "HeaderScope") -> "Headers":
Raises:
ValueError: If the message does not have a ``headers`` key
"""
if "_headers" not in scope:
scope["_headers"] = parse_headers(tuple(scope["headers"])) # type: ignore
return cls(scope["_headers"]) # type: ignore
if (headers := scope.get("_headers")) is None:
headers = scope["_headers"] = cls(scope["headers"]) # type: ignore[typeddict-unknown-key]

return cast("Headers", headers)

def to_header_list(self) -> "RawHeadersList":
"""Raw header value.
Expand Down
4 changes: 1 addition & 3 deletions litestar/datastructures/multi_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ class MultiDict(BaseMultiDict[T], MultiMixin[T], Generic[T]):
"""MultiDict, using :class:`MultiDict <multidict.MultiDictProxy>`."""

def __init__(self, args: MultiMapping | Mapping[str, T] | Iterable[tuple[str, T]] | None = None) -> None:
"""Initialize ``MultiDict`` from a.
``MultiMapping``, :class:`Mapping <typing.Mapping>` or an iterable of tuples.
"""Initialize ``MultiDict`` from a`MultiMapping``, :class:`Mapping <typing.Mapping>` or an iterable of tuples.
Args:
args: Mapping-like structure to create the ``MultiDict`` from
Expand Down
25 changes: 0 additions & 25 deletions tests/unit/test_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

from litestar import HttpMethod
from litestar._parsers import (
_parse_headers,
parse_cookie_string,
parse_headers,
parse_query_string,
parse_url_encoded_form_data,
)
Expand Down Expand Up @@ -107,26 +105,3 @@ def test_query_parsing_of_escaped_values(values: Tuple[Tuple[str, str], Tuple[st
request = client.build_request(method=HttpMethod.GET, url="http://www.example.com", params=dict(values))
parsed_query = parse_query_string(request.url.query)
assert parsed_query == values


def test_parse_headers() -> None:
"""Test that headers are parsed correctly."""
headers = [
[b"Host", b"localhost:8000"],
[b"User-Agent", b"curl/7.64.1"],
[b"Accept", b"*/*"],
[b"Cookie", b"foo=bar; bar=baz"],
[b"Content-Type", b"application/x-www-form-urlencoded"],
[b"Content-Length", b"12"],
]
parsed = parse_headers(headers)
assert parsed["Host"] == "localhost:8000"
assert parsed["User-Agent"] == "curl/7.64.1"
assert parsed["Accept"] == "*/*"
assert parsed["Cookie"] == "foo=bar; bar=baz"
assert parsed["Content-Type"] == "application/x-www-form-urlencoded"
assert parsed["Content-Length"] == "12"
# demonstrate that calling the private function with lists (as ASGI specifies)
# does raise an error
with pytest.raises(TypeError):
_parse_headers(headers) # type: ignore[arg-type]

0 comments on commit 3dfd6fe

Please sign in to comment.