From e6797cef7b1b1ac9ac1a781254637b3c9286ceb3 Mon Sep 17 00:00:00 2001
From: Marcelo Trylesinski
Date: Sun, 1 Sep 2024 14:54:05 +0200
Subject: [PATCH 1/4] Set `line-length` to 120 on Ruff
---
pyproject.toml | 5 +-
starlette/_compat.py | 4 +-
starlette/_exception_handler.py | 4 +-
starlette/_utils.py | 12 +-
starlette/applications.py | 56 ++++-----
starlette/authentication.py | 16 +--
starlette/background.py | 8 +-
starlette/concurrency.py | 11 +-
starlette/config.py | 22 +---
starlette/datastructures.py | 44 ++-----
starlette/endpoints.py | 20 +---
starlette/formparsers.py | 36 ++----
starlette/middleware/__init__.py | 8 +-
starlette/middleware/authentication.py | 9 +-
starlette/middleware/base.py | 12 +-
starlette/middleware/cors.py | 12 +-
starlette/middleware/errors.py | 10 +-
starlette/middleware/exceptions.py | 9 +-
starlette/middleware/gzip.py | 12 +-
starlette/middleware/sessions.py | 4 +-
starlette/middleware/trustedhost.py | 4 +-
starlette/middleware/wsgi.py | 8 +-
starlette/requests.py | 36 ++----
starlette/responses.py | 20 +---
starlette/routing.py | 99 +++++-----------
starlette/schemas.py | 12 +-
starlette/staticfiles.py | 46 ++------
starlette/templating.py | 40 ++-----
starlette/testclient.py | 107 +++++------------
starlette/types.py | 12 +-
starlette/websockets.py | 48 ++------
tests/middleware/test_base.py | 44 ++-----
tests/middleware/test_cors.py | 33 ++----
tests/middleware/test_gzip.py | 4 +-
tests/middleware/test_session.py | 12 +-
tests/middleware/test_trusted_host.py | 10 +-
tests/test_applications.py | 37 ++----
tests/test_authentication.py | 38 ++-----
tests/test_background.py | 8 +-
tests/test_config.py | 8 +-
tests/test_convertors.py | 11 +-
tests/test_datastructures.py | 26 +----
tests/test_endpoints.py | 4 +-
tests/test_exceptions.py | 8 +-
tests/test_formparsers.py | 152 ++++++-------------------
tests/test_responses.py | 47 ++------
tests/test_routing.py | 132 ++++++---------------
tests/test_schemas.py | 28 +----
tests/test_staticfiles.py | 92 ++++-----------
tests/test_status.py | 6 +-
tests/test_templates.py | 40 ++-----
tests/test_testclient.py | 15 +--
tests/test_websockets.py | 5 +-
53 files changed, 388 insertions(+), 1118 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index f2721c870..04533cb49 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -50,8 +50,11 @@ Source = "https://github.com/encode/starlette"
[tool.hatch.version]
path = "starlette/__init__.py"
+[tool.ruff]
+line-length = 120
+
[tool.ruff.lint]
-select = ["E", "F", "I", "FA", "UP"]
+select = ["E", "F", "I", "FA", "UP", "RUF100"]
ignore = ["UP031"]
[tool.ruff.lint.isort]
diff --git a/starlette/_compat.py b/starlette/_compat.py
index 9087a7645..718bc9020 100644
--- a/starlette/_compat.py
+++ b/starlette/_compat.py
@@ -15,9 +15,7 @@
# that reject usedforsecurity=True
hashlib.md5(b"data", usedforsecurity=False) # type: ignore[call-arg]
- def md5_hexdigest(
- data: bytes, *, usedforsecurity: bool = True
- ) -> str: # pragma: no cover
+ def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str: # pragma: no cover
return hashlib.md5( # type: ignore[call-arg]
data, usedforsecurity=usedforsecurity
).hexdigest()
diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py
index 99cb6b64c..4fbc86394 100644
--- a/starlette/_exception_handler.py
+++ b/starlette/_exception_handler.py
@@ -22,9 +22,7 @@
StatusHandlers = typing.Dict[int, ExceptionHandler]
-def _lookup_exception_handler(
- exc_handlers: ExceptionHandlers, exc: Exception
-) -> ExceptionHandler | None:
+def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
diff --git a/starlette/_utils.py b/starlette/_utils.py
index b6970542b..90bd346fd 100644
--- a/starlette/_utils.py
+++ b/starlette/_utils.py
@@ -37,26 +37,20 @@ def is_async_callable(obj: typing.Any) -> typing.Any:
while isinstance(obj, functools.partial):
obj = obj.func
- return asyncio.iscoroutinefunction(obj) or (
- callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
- )
+ return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))
T_co = typing.TypeVar("T_co", covariant=True)
-class AwaitableOrContextManager(
- typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]
-): ...
+class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ...
class SupportsAsyncClose(typing.Protocol):
async def close(self) -> None: ... # pragma: no cover
-SupportsAsyncCloseType = typing.TypeVar(
- "SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False
-)
+SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
diff --git a/starlette/applications.py b/starlette/applications.py
index 913fd4c9d..f34e80ead 100644
--- a/starlette/applications.py
+++ b/starlette/applications.py
@@ -72,21 +72,15 @@ def __init__(
self.debug = debug
self.state = State()
- self.router = Router(
- routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan
- )
- self.exception_handlers = (
- {} if exception_handlers is None else dict(exception_handlers)
- )
+ self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
+ self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack: ASGIApp | None = None
def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
error_handler = None
- exception_handlers: dict[
- typing.Any, typing.Callable[[Request, Exception], Response]
- ] = {}
+ exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
@@ -97,11 +91,7 @@ def build_middleware_stack(self) -> ASGIApp:
middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ self.user_middleware
- + [
- Middleware(
- ExceptionMiddleware, handlers=exception_handlers, debug=debug
- )
- ]
+ + [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)]
)
app = self.router
@@ -163,9 +153,7 @@ def add_route(
name: str | None = None,
include_in_schema: bool = True,
) -> None: # pragma: no cover
- self.router.add_route(
- path, route, methods=methods, name=name, include_in_schema=include_in_schema
- )
+ self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema)
def add_websocket_route(
self,
@@ -175,16 +163,14 @@ def add_websocket_route(
) -> None: # pragma: no cover
self.router.add_websocket_route(path, route, name=name)
- def exception_handler(
- self, exc_class_or_status_code: int | type[Exception]
- ) -> typing.Callable: # type: ignore[type-arg]
+ def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable: # type: ignore[type-arg]
warnings.warn(
- "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
- "Refer to https://www.starlette.io/exceptions/ for the recommended approach.", # noqa: E501
+ "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "
+ "Refer to https://www.starlette.io/exceptions/ for the recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_exception_handler(exc_class_or_status_code, func)
return func
@@ -205,12 +191,12 @@ def route(
>>> app = Starlette(routes=routes)
"""
warnings.warn(
- "The `route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
- "Refer to https://www.starlette.io/routing/ for the recommended approach.", # noqa: E501
+ "The `route` decorator is deprecated, and will be removed in version 1.0.0. "
+ "Refer to https://www.starlette.io/routing/ for the recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.router.add_route(
path,
func,
@@ -231,18 +217,18 @@ def websocket_route(self, path: str, name: str | None = None) -> typing.Callable
>>> app = Starlette(routes=routes)
"""
warnings.warn(
- "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
- "Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
+ "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. "
+ "Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.router.add_websocket_route(path, func, name=name)
return func
return decorator
- def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
@@ -251,15 +237,13 @@ def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[t
>>> app = Starlette(middleware=middleware)
"""
warnings.warn(
- "The `middleware` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
- "Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.", # noqa: E501
+ "The `middleware` decorator is deprecated, and will be removed in version 1.0.0. "
+ "Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.",
DeprecationWarning,
)
- assert (
- middleware_type == "http"
- ), 'Currently only middleware("http") is supported.'
+ assert middleware_type == "http", 'Currently only middleware("http") is supported.'
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
return func
diff --git a/starlette/authentication.py b/starlette/authentication.py
index f2586a042..4fd866412 100644
--- a/starlette/authentication.py
+++ b/starlette/authentication.py
@@ -31,9 +31,7 @@ def requires(
scopes: str | typing.Sequence[str],
status_code: int = 403,
redirect: str | None = None,
-) -> typing.Callable[
- [typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]
-]:
+) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(
@@ -45,17 +43,13 @@ def decorator(
type_ = parameter.name
break
else:
- raise Exception(
- f'No "request" or "websocket" argument on function "{func}"'
- )
+ raise Exception(f'No "request" or "websocket" argument on function "{func}"')
if type_ == "websocket":
# Handle websocket functions. (Always async)
@functools.wraps(func)
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
- websocket = kwargs.get(
- "websocket", args[idx] if idx < len(args) else None
- )
+ websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
assert isinstance(websocket, WebSocket)
if not has_required_scope(websocket, scopes_list):
@@ -107,9 +101,7 @@ class AuthenticationError(Exception):
class AuthenticationBackend:
- async def authenticate(
- self, conn: HTTPConnection
- ) -> tuple[AuthCredentials, BaseUser] | None:
+ async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
raise NotImplementedError() # pragma: no cover
diff --git a/starlette/background.py b/starlette/background.py
index 1cbed3b22..0430fc08b 100644
--- a/starlette/background.py
+++ b/starlette/background.py
@@ -15,9 +15,7 @@
class BackgroundTask:
- def __init__(
- self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
- ) -> None:
+ def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
self.func = func
self.args = args
self.kwargs = kwargs
@@ -34,9 +32,7 @@ class BackgroundTasks(BackgroundTask):
def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None):
self.tasks = list(tasks) if tasks else []
- def add_task(
- self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
- ) -> None:
+ def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
task = BackgroundTask(func, *args, **kwargs)
self.tasks.append(task)
diff --git a/starlette/concurrency.py b/starlette/concurrency.py
index 215e3a63b..22979404a 100644
--- a/starlette/concurrency.py
+++ b/starlette/concurrency.py
@@ -16,16 +16,15 @@
T = typing.TypeVar("T")
-async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] # noqa: E501
+async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg]
warnings.warn(
- "run_until_first_complete is deprecated "
- "and will be removed in a future version.",
+ "run_until_first_complete is deprecated " "and will be removed in a future version.",
DeprecationWarning,
)
async with anyio.create_task_group() as task_group:
- async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg] # noqa: E501
+ async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg]
await func()
task_group.cancel_scope.cancel()
@@ -33,9 +32,7 @@ async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ign
task_group.start_soon(run, functools.partial(func, **kwargs))
-async def run_in_threadpool(
- func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
-) -> T:
+async def run_in_threadpool(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
if kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
diff --git a/starlette/config.py b/starlette/config.py
index 4c3dfe5b0..7b46e16fb 100644
--- a/starlette/config.py
+++ b/starlette/config.py
@@ -25,18 +25,12 @@ def __getitem__(self, key: str) -> str:
def __setitem__(self, key: str, value: str) -> None:
if key in self._has_been_read:
- raise EnvironError(
- f"Attempting to set environ['{key}'], but the value has already been "
- "read."
- )
+ raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been " "read.")
self._environ.__setitem__(key, value)
def __delitem__(self, key: str) -> None:
if key in self._has_been_read:
- raise EnvironError(
- f"Attempting to delete environ['{key}'], but the value has already "
- "been read."
- )
+ raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already " "been read.")
self._environ.__delitem__(key)
def __iter__(self) -> typing.Iterator[str]:
@@ -85,9 +79,7 @@ def __call__(
) -> T: ...
@typing.overload
- def __call__(
- self, key: str, cast: type[str] = ..., default: T = ...
- ) -> T | str: ...
+ def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ...
def __call__(
self,
@@ -138,13 +130,9 @@ def _perform_cast(
mapping = {"true": True, "1": True, "false": False, "0": False}
value = value.lower()
if value not in mapping:
- raise ValueError(
- f"Config '{key}' has value '{value}'. Not a valid bool."
- )
+ raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.")
return mapping[value]
try:
return cast(value)
except (TypeError, ValueError):
- raise ValueError(
- f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}."
- )
+ raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.")
diff --git a/starlette/datastructures.py b/starlette/datastructures.py
index 54b5e54f3..90a7296a0 100644
--- a/starlette/datastructures.py
+++ b/starlette/datastructures.py
@@ -108,12 +108,7 @@ def is_secure(self) -> bool:
return self.scheme in ("https", "wss")
def replace(self, **kwargs: typing.Any) -> URL:
- if (
- "username" in kwargs
- or "password" in kwargs
- or "hostname" in kwargs
- or "port" in kwargs
- ):
+ if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
hostname = kwargs.pop("hostname", None)
port = kwargs.pop("port", self.port)
username = kwargs.pop("username", self.username)
@@ -264,17 +259,12 @@ def __init__(
value: typing.Any = args[0] if args else []
if kwargs:
- value = (
- ImmutableMultiDict(value).multi_items()
- + ImmutableMultiDict(kwargs).multi_items()
- )
+ value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
if not value:
_items: list[tuple[typing.Any, typing.Any]] = []
elif hasattr(value, "multi_items"):
- value = typing.cast(
- ImmutableMultiDict[_KeyType, _CovariantValueType], value
- )
+ value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
_items = list(value.multi_items())
elif hasattr(value, "items"):
value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
@@ -371,9 +361,7 @@ def append(self, key: typing.Any, value: typing.Any) -> None:
def update(
self,
- *args: MultiDict
- | typing.Mapping[typing.Any, typing.Any]
- | list[tuple[typing.Any, typing.Any]],
+ *args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]],
**kwargs: typing.Any,
) -> None:
value = MultiDict(*args, **kwargs)
@@ -403,9 +391,7 @@ def __init__(
if isinstance(value, str):
super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
elif isinstance(value, bytes):
- super().__init__(
- parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
- )
+ super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
else:
super().__init__(*args, **kwargs) # type: ignore[arg-type]
self._list = [(str(k), str(v)) for k, v in self._list]
@@ -490,9 +476,7 @@ class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
def __init__(
self,
- *args: FormData
- | typing.Mapping[str, str | UploadFile]
- | list[tuple[str, str | UploadFile]],
+ *args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
**kwargs: str | UploadFile,
) -> None:
super().__init__(*args, **kwargs)
@@ -518,10 +502,7 @@ def __init__(
if headers is not None:
assert raw is None, 'Cannot set both "headers" and "raw".'
assert scope is None, 'Cannot set both "headers" and "scope".'
- self._list = [
- (key.lower().encode("latin-1"), value.encode("latin-1"))
- for key, value in headers.items()
- ]
+ self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
elif raw is not None:
assert scope is None, 'Cannot set both "raw" and "scope".'
self._list = raw
@@ -541,18 +522,11 @@ def values(self) -> list[str]: # type: ignore[override]
return [value.decode("latin-1") for key, value in self._list]
def items(self) -> list[tuple[str, str]]: # type: ignore[override]
- return [
- (key.decode("latin-1"), value.decode("latin-1"))
- for key, value in self._list
- ]
+ return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
def getlist(self, key: str) -> list[str]:
get_header_key = key.lower().encode("latin-1")
- return [
- item_value.decode("latin-1")
- for item_key, item_value in self._list
- if item_key == get_header_key
- ]
+ return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
def mutablecopy(self) -> MutableHeaders:
return MutableHeaders(raw=self._list[:])
diff --git a/starlette/endpoints.py b/starlette/endpoints.py
index 57f718824..eb1dace42 100644
--- a/starlette/endpoints.py
+++ b/starlette/endpoints.py
@@ -30,15 +30,9 @@ def __await__(self) -> typing.Generator[typing.Any, None, None]:
async def dispatch(self) -> None:
request = Request(self.scope, receive=self.receive)
- handler_name = (
- "get"
- if request.method == "HEAD" and not hasattr(self, "head")
- else request.method.lower()
- )
-
- handler: typing.Callable[[Request], typing.Any] = getattr(
- self, handler_name, self.method_not_allowed
- )
+ handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
+
+ handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed)
is_async = is_async_callable(handler)
if is_async:
response = await handler(request)
@@ -81,9 +75,7 @@ async def dispatch(self) -> None:
data = await self.decode(websocket, message)
await self.on_receive(websocket, data)
elif message["type"] == "websocket.disconnect":
- close_code = int(
- message.get("code") or status.WS_1000_NORMAL_CLOSURE
- )
+ close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
break
except Exception as exc:
close_code = status.WS_1011_INTERNAL_ERROR
@@ -116,9 +108,7 @@ async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Malformed JSON data received.")
- assert (
- self.encoding is None
- ), f"Unsupported 'encoding' attribute {self.encoding}"
+ assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
return message["text"] if message.get("text") else message["bytes"]
async def on_connect(self, websocket: WebSocket) -> None:
diff --git a/starlette/formparsers.py b/starlette/formparsers.py
index 2e12c7faa..9be98626e 100644
--- a/starlette/formparsers.py
+++ b/starlette/formparsers.py
@@ -46,12 +46,8 @@ def __init__(self, message: str) -> None:
class FormParser:
- def __init__(
- self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
- ) -> None:
- assert (
- multipart is not None
- ), "The `python-multipart` library must be installed to use form parsing."
+ def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None:
+ assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.messages: list[tuple[FormMessage, bytes]] = []
@@ -128,9 +124,7 @@ def __init__(
max_files: int | float = 1000,
max_fields: int | float = 1000,
) -> None:
- assert (
- multipart is not None
- ), "The `python-multipart` library must be installed to use form parsing."
+ assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.max_files = max_files
@@ -181,30 +175,20 @@ def on_header_end(self) -> None:
field = self._current_partial_header_name.lower()
if field == b"content-disposition":
self._current_part.content_disposition = self._current_partial_header_value
- self._current_part.item_headers.append(
- (field, self._current_partial_header_value)
- )
+ self._current_part.item_headers.append((field, self._current_partial_header_value))
self._current_partial_header_name = b""
self._current_partial_header_value = b""
def on_headers_finished(self) -> None:
- disposition, options = parse_options_header(
- self._current_part.content_disposition
- )
+ disposition, options = parse_options_header(self._current_part.content_disposition)
try:
- self._current_part.field_name = _user_safe_decode(
- options[b"name"], self._charset
- )
+ self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
except KeyError:
- raise MultiPartException(
- 'The Content-Disposition header field "name" must be ' "provided."
- )
+ raise MultiPartException('The Content-Disposition header field "name" must be ' "provided.")
if b"filename" in options:
self._current_files += 1
if self._current_files > self.max_files:
- raise MultiPartException(
- f"Too many files. Maximum number of files is {self.max_files}."
- )
+ raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
filename = _user_safe_decode(options[b"filename"], self._charset)
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
self._files_to_close_on_error.append(tempfile)
@@ -217,9 +201,7 @@ def on_headers_finished(self) -> None:
else:
self._current_fields += 1
if self._current_fields > self.max_fields:
- raise MultiPartException(
- f"Too many fields. Maximum number of fields is {self.max_fields}."
- )
+ raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
self._current_part.file = None
def on_end(self) -> None:
diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py
index d9e64f574..8566aac08 100644
--- a/starlette/middleware/__init__.py
+++ b/starlette/middleware/__init__.py
@@ -14,13 +14,9 @@
class _MiddlewareClass(Protocol[P]):
- def __init__(
- self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs
- ) -> None: ... # pragma: no cover
+ def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None: ... # pragma: no cover
- async def __call__(
- self, scope: Scope, receive: Receive, send: Send
- ) -> None: ... # pragma: no cover
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ... # pragma: no cover
class Middleware:
diff --git a/starlette/middleware/authentication.py b/starlette/middleware/authentication.py
index 966c639bb..8555ee078 100644
--- a/starlette/middleware/authentication.py
+++ b/starlette/middleware/authentication.py
@@ -18,14 +18,13 @@ def __init__(
self,
app: ASGIApp,
backend: AuthenticationBackend,
- on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response]
- | None = None,
+ on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
) -> None:
self.app = app
self.backend = backend
- self.on_error: typing.Callable[
- [HTTPConnection, AuthenticationError], Response
- ] = on_error if on_error is not None else self.default_on_error
+ self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = (
+ on_error if on_error is not None else self.default_on_error
+ )
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ["http", "websocket"]:
diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py
index 87c0f51f8..2ac6f7f7f 100644
--- a/starlette/middleware/base.py
+++ b/starlette/middleware/base.py
@@ -11,9 +11,7 @@
from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
-DispatchFunction = typing.Callable[
- [Request, RequestResponseEndpoint], typing.Awaitable[Response]
-]
+DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
T = typing.TypeVar("T")
@@ -180,9 +178,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
if app_exc is not None:
raise app_exc
- response = _StreamingResponse(
- status_code=message["status"], content=body_stream(), info=info
- )
+ response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
response.raw_headers = message["headers"]
return response
@@ -192,9 +188,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
await response(scope, wrapped_receive, send)
response_sent.set()
- async def dispatch(
- self, request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
raise NotImplementedError() # pragma: no cover
diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py
index 4b8e97bc9..61502691a 100644
--- a/starlette/middleware/cors.py
+++ b/starlette/middleware/cors.py
@@ -96,9 +96,7 @@ def is_allowed_origin(self, origin: str) -> bool:
if self.allow_all_origins:
return True
- if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
- origin
- ):
+ if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
return True
return origin in self.allow_origins
@@ -141,15 +139,11 @@ def preflight_response(self, request_headers: Headers) -> Response:
return PlainTextResponse("OK", status_code=200, headers=headers)
- async def simple_response(
- self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
- ) -> None:
+ async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
send = functools.partial(self.send, send=send, request_headers=request_headers)
await self.app(scope, receive, send)
- async def send(
- self, message: Message, send: Send, request_headers: Headers
- ) -> None:
+ async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
if message["type"] != "http.response.start":
await send(message)
return
diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py
index 3fc4a4402..d8cb052ed 100644
--- a/starlette/middleware/errors.py
+++ b/starlette/middleware/errors.py
@@ -112,7 +112,7 @@
{code_context}
-""" # noqa: E501
+"""
LINE = """
@@ -186,9 +186,7 @@ async def _send(message: Message) -> None:
# to optionally raise the error within the test case.
raise exc
- def format_line(
- self, index: int, line: str, frame_lineno: int, frame_index: int
- ) -> str:
+ def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
values = {
# HTML escape - line could contain < or >
"line": html.escape(line).replace(" ", " "),
@@ -225,9 +223,7 @@ def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> s
return FRAME_TEMPLATE.format(**values)
def generate_html(self, exc: Exception, limit: int = 7) -> str:
- traceback_obj = traceback.TracebackException.from_exception(
- exc, capture_locals=True
- )
+ traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
exc_html = ""
is_collapsed = False
diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py
index b2bf88dbf..d708929e3 100644
--- a/starlette/middleware/exceptions.py
+++ b/starlette/middleware/exceptions.py
@@ -18,10 +18,7 @@ class ExceptionMiddleware:
def __init__(
self,
app: ASGIApp,
- handlers: typing.Mapping[
- typing.Any, typing.Callable[[Request, Exception], Response]
- ]
- | None = None,
+ handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None,
debug: bool = False,
) -> None:
self.app = app
@@ -68,9 +65,7 @@ def http_exception(self, request: Request, exc: Exception) -> Response:
assert isinstance(exc, HTTPException)
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
- return PlainTextResponse(
- exc.detail, status_code=exc.status_code, headers=exc.headers
- )
+ return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
assert isinstance(exc, WebSocketException)
diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py
index 0579e0410..127b91e7a 100644
--- a/starlette/middleware/gzip.py
+++ b/starlette/middleware/gzip.py
@@ -7,9 +7,7 @@
class GZipMiddleware:
- def __init__(
- self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
- ) -> None:
+ def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
self.app = app
self.minimum_size = minimum_size
self.compresslevel = compresslevel
@@ -18,9 +16,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
headers = Headers(scope=scope)
if "gzip" in headers.get("Accept-Encoding", ""):
- responder = GZipResponder(
- self.app, self.minimum_size, compresslevel=self.compresslevel
- )
+ responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
await responder(scope, receive, send)
return
await self.app(scope, receive, send)
@@ -35,9 +31,7 @@ def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> N
self.started = False
self.content_encoding_set = False
self.gzip_buffer = io.BytesIO()
- self.gzip_file = gzip.GzipFile(
- mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
- )
+ self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
diff --git a/starlette/middleware/sessions.py b/starlette/middleware/sessions.py
index 5855912ca..5f9fcd883 100644
--- a/starlette/middleware/sessions.py
+++ b/starlette/middleware/sessions.py
@@ -61,7 +61,7 @@ async def send_wrapper(message: Message) -> None:
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = self.signer.sign(data)
headers = MutableHeaders(scope=message)
- header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
+ header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
session_cookie=self.session_cookie,
data=data.decode("utf-8"),
path=self.path,
@@ -72,7 +72,7 @@ async def send_wrapper(message: Message) -> None:
elif not initial_session_was_empty:
# The session has been cleared.
headers = MutableHeaders(scope=message)
- header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501
+ header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
session_cookie=self.session_cookie,
data="null",
path=self.path,
diff --git a/starlette/middleware/trustedhost.py b/starlette/middleware/trustedhost.py
index 59e527363..2d1c999e2 100644
--- a/starlette/middleware/trustedhost.py
+++ b/starlette/middleware/trustedhost.py
@@ -41,9 +41,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
is_valid_host = False
found_www_redirect = False
for pattern in self.allowed_hosts:
- if host == pattern or (
- pattern.startswith("*") and host.endswith(pattern[1:])
- ):
+ if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
is_valid_host = True
break
elif "www." + host == pattern:
diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py
index c9a7e1328..71f4ab5de 100644
--- a/starlette/middleware/wsgi.py
+++ b/starlette/middleware/wsgi.py
@@ -89,9 +89,7 @@ def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None:
self.scope = scope
self.status = None
self.response_headers = None
- self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
- math.inf
- )
+ self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
self.response_started = False
self.exc_info: typing.Any = None
@@ -151,6 +149,4 @@ def wsgi(
{"type": "http.response.body", "body": chunk, "more_body": True},
)
- anyio.from_thread.run(
- self.stream_send.send, {"type": "http.response.body", "body": b""}
- )
+ anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})
diff --git a/starlette/requests.py b/starlette/requests.py
index a2fdfd81e..23f8ac70a 100644
--- a/starlette/requests.py
+++ b/starlette/requests.py
@@ -104,9 +104,7 @@ def base_url(self) -> URL:
# This is used by request.url_for, it might be used inside a Mount which
# would have its own child scope with its own root_path, but the base URL
# for url_for should still be the top level app root path.
- app_root_path = base_url_scope.get(
- "app_root_path", base_url_scope.get("root_path", "")
- )
+ app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
path = app_root_path
if not path.endswith("/"):
path += "/"
@@ -153,23 +151,17 @@ def client(self) -> Address | None:
@property
def session(self) -> dict[str, typing.Any]:
- assert (
- "session" in self.scope
- ), "SessionMiddleware must be installed to access request.session"
+ assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
return self.scope["session"] # type: ignore[no-any-return]
@property
def auth(self) -> typing.Any:
- assert (
- "auth" in self.scope
- ), "AuthenticationMiddleware must be installed to access request.auth"
+ assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
return self.scope["auth"]
@property
def user(self) -> typing.Any:
- assert (
- "user" in self.scope
- ), "AuthenticationMiddleware must be installed to access request.user"
+ assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
return self.scope["user"]
@property
@@ -199,9 +191,7 @@ async def empty_send(message: Message) -> typing.NoReturn:
class Request(HTTPConnection):
_form: FormData | None
- def __init__(
- self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
- ):
+ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
super().__init__(scope)
assert scope["type"] == "http"
self._receive = receive
@@ -252,9 +242,7 @@ async def json(self) -> typing.Any:
self._json = json.loads(body)
return self._json
- async def _get_form(
- self, *, max_files: int | float = 1000, max_fields: int | float = 1000
- ) -> FormData:
+ async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | float = 1000) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
@@ -285,9 +273,7 @@ async def _get_form(
def form(
self, *, max_files: int | float = 1000, max_fields: int | float = 1000
) -> AwaitableOrContextManager[FormData]:
- return AwaitableOrContextManagerWrapper(
- self._get_form(max_files=max_files, max_fields=max_fields)
- )
+ return AwaitableOrContextManagerWrapper(self._get_form(max_files=max_files, max_fields=max_fields))
async def close(self) -> None:
if self._form is not None:
@@ -312,9 +298,5 @@ async def send_push_promise(self, path: str) -> None:
raw_headers: list[tuple[bytes, bytes]] = []
for name in SERVER_PUSH_HEADERS_TO_COPY:
for value in self.headers.getlist(name):
- raw_headers.append(
- (name.encode("latin-1"), value.encode("latin-1"))
- )
- await self._send(
- {"type": "http.response.push", "path": path, "headers": raw_headers}
- )
+ raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
+ await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})
diff --git a/starlette/responses.py b/starlette/responses.py
index 4f15404ca..06d6ce5ca 100644
--- a/starlette/responses.py
+++ b/starlette/responses.py
@@ -54,10 +54,7 @@ def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None:
populate_content_length = True
populate_content_type = True
else:
- raw_headers = [
- (k.lower().encode("latin-1"), v.encode("latin-1"))
- for k, v in headers.items()
- ]
+ raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()]
keys = [h[0] for h in raw_headers]
populate_content_length = b"content-length" not in keys
populate_content_type = b"content-type" not in keys
@@ -73,10 +70,7 @@ def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None:
content_type = self.media_type
if content_type is not None and populate_content_type:
- if (
- content_type.startswith("text/")
- and "charset=" not in content_type.lower()
- ):
+ if content_type.startswith("text/") and "charset=" not in content_type.lower():
content_type += "; charset=" + self.charset
raw_headers.append((b"content-type", content_type.encode("latin-1")))
@@ -201,9 +195,7 @@ def __init__(
headers: typing.Mapping[str, str] | None = None,
background: BackgroundTask | None = None,
) -> None:
- super().__init__(
- content=b"", status_code=status_code, headers=headers, background=background
- )
+ super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
@@ -299,11 +291,9 @@ def __init__(
if self.filename is not None:
content_disposition_filename = quote(self.filename)
if content_disposition_filename != self.filename:
- content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}" # noqa: E501
+ content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"
else:
- content_disposition = (
- f'{content_disposition_type}; filename="{self.filename}"'
- )
+ content_disposition = f'{content_disposition_type}; filename="{self.filename}"'
self.headers.setdefault("content-disposition", content_disposition)
self.stat_result = stat_result
if stat_result is not None:
diff --git a/starlette/routing.py b/starlette/routing.py
index 481b13f5d..300711626 100644
--- a/starlette/routing.py
+++ b/starlette/routing.py
@@ -47,8 +47,7 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
including those wrapped in functools.partial objects.
"""
warnings.warn(
- "iscoroutinefunction_or_partial is deprecated, "
- "and will be removed in a future release.",
+ "iscoroutinefunction_or_partial is deprecated, " "and will be removed in a future release.",
DeprecationWarning,
)
while isinstance(obj, functools.partial):
@@ -143,9 +142,7 @@ def compile_path(
for match in PARAM_REGEX.finditer(path):
param_name, convertor_type = match.groups("str")
convertor_type = convertor_type.lstrip(":")
- assert (
- convertor_type in CONVERTOR_TYPES
- ), f"Unknown path convertor '{convertor_type}'"
+ assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
convertor = CONVERTOR_TYPES[convertor_type]
path_regex += re.escape(path[idx : match.start()])
@@ -275,9 +272,7 @@ def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
if name != self.name or seen_params != expected_params:
raise NoMatchFound(name, path_params)
- path, remaining_params = replace_params(
- self.path_format, self.param_convertors, path_params
- )
+ path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
assert not remaining_params
return URLPath(path=path, protocol="http")
@@ -287,9 +282,7 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
if "app" in scope:
raise HTTPException(status_code=405, headers=headers)
else:
- response = PlainTextResponse(
- "Method Not Allowed", status_code=405, headers=headers
- )
+ response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
await response(scope, receive, send)
else:
await self.app(scope, receive, send)
@@ -361,9 +354,7 @@ def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
if name != self.name or seen_params != expected_params:
raise NoMatchFound(name, path_params)
- path, remaining_params = replace_params(
- self.path_format, self.param_convertors, path_params
- )
+ path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
assert not remaining_params
return URLPath(path=path, protocol="websocket")
@@ -371,11 +362,7 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
- return (
- isinstance(other, WebSocketRoute)
- and self.path == other.path
- and self.endpoint == other.endpoint
- )
+ return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
@@ -392,9 +379,7 @@ def __init__(
middleware: typing.Sequence[Middleware] | None = None,
) -> None:
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
- assert (
- app is not None or routes is not None
- ), "Either 'app=...', or 'routes=' must be specified"
+ assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
self.path = path.rstrip("/")
if app is not None:
self._base_app: ASGIApp = app
@@ -405,9 +390,7 @@ def __init__(
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.name = name
- self.path_regex, self.path_format, self.param_convertors = compile_path(
- self.path + "/{path:path}"
- )
+ self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
@property
def routes(self) -> list[BaseRoute]:
@@ -450,9 +433,7 @@ def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "".
path_params["path"] = path_params["path"].lstrip("/")
- path, remaining_params = replace_params(
- self.path_format, self.param_convertors, path_params
- )
+ path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
if not remaining_params:
return URLPath(path=path)
elif self.name is None or name.startswith(self.name + ":"):
@@ -464,17 +445,13 @@ def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
remaining_name = name[len(self.name) + 1 :]
path_kwarg = path_params.get("path")
path_params["path"] = ""
- path_prefix, remaining_params = replace_params(
- self.path_format, self.param_convertors, path_params
- )
+ path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
if path_kwarg is not None:
remaining_params["path"] = path_kwarg
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
- return URLPath(
- path=path_prefix.rstrip("/") + str(url), protocol=url.protocol
- )
+ return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
except NoMatchFound:
pass
raise NoMatchFound(name, path_params)
@@ -483,11 +460,7 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
- return (
- isinstance(other, Mount)
- and self.path == other.path
- and self.app == other.app
- )
+ return isinstance(other, Mount) and self.path == other.path and self.app == other.app
def __repr__(self) -> str:
class_name = self.__class__.__name__
@@ -526,9 +499,7 @@ def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "".
path = path_params.pop("path")
- host, remaining_params = replace_params(
- self.host_format, self.param_convertors, path_params
- )
+ host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
if not remaining_params:
return URLPath(path=path, host=host)
elif self.name is None or name.startswith(self.name + ":"):
@@ -538,9 +509,7 @@ def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
else:
# 'name' matches ":".
remaining_name = name[len(self.name) + 1 :]
- host, remaining_params = replace_params(
- self.host_format, self.param_convertors, path_params
- )
+ host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
@@ -553,11 +522,7 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
- return (
- isinstance(other, Host)
- and self.host == other.host
- and self.app == other.app
- )
+ return isinstance(other, Host) and self.host == other.host and self.app == other.app
def __repr__(self) -> str:
class_name = self.__class__.__name__
@@ -585,9 +550,7 @@ async def __aexit__(
def _wrap_gen_lifespan_context(
- lifespan_context: typing.Callable[
- [typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]
- ],
+ lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]],
) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]:
cmgr = contextlib.contextmanager(lifespan_context)
@@ -730,9 +693,7 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
async with self.lifespan_context(app) as maybe_state:
if maybe_state is not None:
if "state" not in scope:
- raise RuntimeError(
- 'The server does not support "state" in the lifespan scope.'
- )
+ raise RuntimeError('The server does not support "state" in the lifespan scope.')
scope["state"].update(maybe_state)
await send({"type": "lifespan.startup.complete"})
started = True
@@ -806,15 +767,11 @@ async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, Router) and self.routes == other.routes
- def mount(
- self, path: str, app: ASGIApp, name: str | None = None
- ) -> None: # pragma: nocover
+ def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: nocover
route = Mount(path, app=app, name=name)
self.routes.append(route)
- def host(
- self, host: str, app: ASGIApp, name: str | None = None
- ) -> None: # pragma: no cover
+ def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
route = Host(host, app=app, name=name)
self.routes.append(route)
@@ -860,11 +817,11 @@ def route(
"""
warnings.warn(
"The `route` decorator is deprecated, and will be removed in version 1.0.0."
- "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", # noqa: E501
+ "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_route(
path,
func,
@@ -885,20 +842,18 @@ def websocket_route(self, path: str, name: str | None = None) -> typing.Callable
>>> app = Starlette(routes=routes)
"""
warnings.warn(
- "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " # noqa: E501
- "https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
+ "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "
+ "https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_websocket_route(path, func, name=name)
return func
return decorator
- def add_event_handler(
- self, event_type: str, func: typing.Callable[[], typing.Any]
- ) -> None: # pragma: no cover
+ def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None: # pragma: no cover
assert event_type in ("startup", "shutdown")
if event_type == "startup":
@@ -908,12 +863,12 @@ def add_event_handler(
def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
warnings.warn(
- "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
+ "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/lifespan/ for recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_event_handler(event_type, func)
return func
diff --git a/starlette/schemas.py b/starlette/schemas.py
index 89fa20b89..688fd85be 100644
--- a/starlette/schemas.py
+++ b/starlette/schemas.py
@@ -19,9 +19,7 @@ class OpenAPIResponse(Response):
def render(self, content: typing.Any) -> bytes:
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
- assert isinstance(
- content, dict
- ), "The schema passed to OpenAPIResponse should be a dictionary."
+ assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
return yaml.dump(content, default_flow_style=False).encode("utf-8")
@@ -73,9 +71,7 @@ def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
for method in route.methods or ["GET"]:
if method == "HEAD":
continue
- endpoints_info.append(
- EndpointInfo(path, method.lower(), route.endpoint)
- )
+ endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
else:
path = self._remove_converter(route.path)
for method in ["get", "post", "put", "patch", "delete", "options"]:
@@ -95,9 +91,7 @@ def _remove_converter(self, path: str) -> str:
"""
return re.sub(r":\w+}", "}", path)
- def parse_docstring(
- self, func_or_method: typing.Callable[..., typing.Any]
- ) -> dict[str, typing.Any]:
+ def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
"""
Given a function, parse the docstring as YAML and return a dictionary of info.
"""
diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py
index afb09b56b..7498c3011 100644
--- a/starlette/staticfiles.py
+++ b/starlette/staticfiles.py
@@ -32,11 +32,7 @@ class NotModifiedResponse(Response):
def __init__(self, headers: Headers):
super().__init__(
status_code=304,
- headers={
- name: value
- for name, value in headers.items()
- if name in self.NOT_MODIFIED_HEADERS
- },
+ headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
)
@@ -80,9 +76,7 @@ def get_directories(
spec = importlib.util.find_spec(package)
assert spec is not None, f"Package {package!r} could not be found."
assert spec.origin is not None, f"Package {package!r} could not be found."
- package_directory = os.path.normpath(
- os.path.join(spec.origin, "..", statics_dir)
- )
+ package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
assert os.path.isdir(
package_directory
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
@@ -110,7 +104,7 @@ def get_path(self, scope: Scope) -> str:
with OS specific path separators, and any '..', '.' components removed.
"""
route_path = get_route_path(scope)
- return os.path.normpath(os.path.join(*route_path.split("/"))) # noqa: E501
+ return os.path.normpath(os.path.join(*route_path.split("/")))
async def get_response(self, path: str, scope: Scope) -> Response:
"""
@@ -120,9 +114,7 @@ async def get_response(self, path: str, scope: Scope) -> Response:
raise HTTPException(status_code=405)
try:
- full_path, stat_result = await anyio.to_thread.run_sync(
- self.lookup_path, path
- )
+ full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
except PermissionError:
raise HTTPException(status_code=401)
except OSError as exc:
@@ -140,9 +132,7 @@ async def get_response(self, path: str, scope: Scope) -> Response:
# We're in HTML mode, and have got a directory URL.
# Check if we have 'index.html' file to serve.
index_path = os.path.join(path, "index.html")
- full_path, stat_result = await anyio.to_thread.run_sync(
- self.lookup_path, index_path
- )
+ full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
if not scope["path"].endswith("/"):
# Directory URLs should redirect to always end in "/".
@@ -153,9 +143,7 @@ async def get_response(self, path: str, scope: Scope) -> Response:
if self.html:
# Check for '404.html' if we're in HTML mode.
- full_path, stat_result = await anyio.to_thread.run_sync(
- self.lookup_path, "404.html"
- )
+ full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
if stat_result and stat.S_ISREG(stat_result.st_mode):
return FileResponse(full_path, stat_result=stat_result, status_code=404)
raise HTTPException(status_code=404)
@@ -187,9 +175,7 @@ def file_response(
) -> Response:
request_headers = Headers(scope=scope)
- response = FileResponse(
- full_path, status_code=status_code, stat_result=stat_result
- )
+ response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
if self.is_not_modified(response.headers, request_headers):
return NotModifiedResponse(response.headers)
return response
@@ -206,17 +192,11 @@ async def check_config(self) -> None:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
except FileNotFoundError:
- raise RuntimeError(
- f"StaticFiles directory '{self.directory}' does not exist."
- )
+ raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
- raise RuntimeError(
- f"StaticFiles path '{self.directory}' is not a directory."
- )
+ raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
- def is_not_modified(
- self, response_headers: Headers, request_headers: Headers
- ) -> bool:
+ def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
"""
Given the request and response headers, return `True` if an HTTP
"Not Modified" response could be returned instead.
@@ -232,11 +212,7 @@ def is_not_modified(
try:
if_modified_since = parsedate(request_headers["if-modified-since"])
last_modified = parsedate(response_headers["last-modified"])
- if (
- if_modified_since is not None
- and last_modified is not None
- and if_modified_since >= last_modified
- ):
+ if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
return True
except KeyError:
pass
diff --git a/starlette/templating.py b/starlette/templating.py
index aae2cbe24..48e54c0cc 100644
--- a/starlette/templating.py
+++ b/starlette/templating.py
@@ -68,8 +68,7 @@ def __init__(
self,
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
*,
- context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
- | None = None,
+ context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
**env_options: typing.Any,
) -> None: ...
@@ -78,31 +77,24 @@ def __init__(
self,
*,
env: jinja2.Environment,
- context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
- | None = None,
+ context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
) -> None: ...
def __init__(
self,
- directory: str
- | PathLike[str]
- | typing.Sequence[str | PathLike[str]]
- | None = None,
+ directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]] | None = None,
*,
- context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
- | None = None,
+ context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
env: jinja2.Environment | None = None,
**env_options: typing.Any,
) -> None:
if env_options:
warnings.warn(
- "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.", # noqa: E501
+ "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.",
DeprecationWarning,
)
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
- assert bool(directory) ^ bool(
- env
- ), "either 'directory' or 'env' arguments must be passed"
+ assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed"
self.context_processors = context_processors or []
if directory is not None:
self.env = self._create_env(directory, **env_options)
@@ -163,25 +155,19 @@ def TemplateResponse(
# Deprecated usage
...
- def TemplateResponse(
- self, *args: typing.Any, **kwargs: typing.Any
- ) -> _TemplateResponse:
+ def TemplateResponse(self, *args: typing.Any, **kwargs: typing.Any) -> _TemplateResponse:
if args:
- if isinstance(
- args[0], str
- ): # the first argument is template name (old style)
+ if isinstance(args[0], str): # the first argument is template name (old style)
warnings.warn(
"The `name` is not the first parameter anymore. "
"The first parameter should be the `Request` instance.\n"
- 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', # noqa: E501
+ 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.',
DeprecationWarning,
)
name = args[0]
context = args[1] if len(args) > 1 else kwargs.get("context", {})
- status_code = (
- args[2] if len(args) > 2 else kwargs.get("status_code", 200)
- )
+ status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200)
headers = args[2] if len(args) > 2 else kwargs.get("headers")
media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
background = args[4] if len(args) > 4 else kwargs.get("background")
@@ -193,9 +179,7 @@ def TemplateResponse(
request = args[0]
name = args[1] if len(args) > 1 else kwargs["name"]
context = args[2] if len(args) > 2 else kwargs.get("context", {})
- status_code = (
- args[3] if len(args) > 3 else kwargs.get("status_code", 200)
- )
+ status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200)
headers = args[4] if len(args) > 4 else kwargs.get("headers")
media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
background = args[6] if len(args) > 6 else kwargs.get("background")
@@ -203,7 +187,7 @@ def TemplateResponse(
if "request" not in kwargs:
warnings.warn(
"The `TemplateResponse` now requires the `request` argument.\n"
- 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', # noqa: E501
+ 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.',
DeprecationWarning,
)
if "request" not in kwargs.get("context", {}):
diff --git a/starlette/testclient.py b/starlette/testclient.py
index bf928d23f..cc6c6e92c 100644
--- a/starlette/testclient.py
+++ b/starlette/testclient.py
@@ -37,9 +37,7 @@
"You can install this with:\n"
" $ pip install httpx\n"
)
-_PortalFactoryType = typing.Callable[
- [], typing.ContextManager[anyio.abc.BlockingPortal]
-]
+_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
ASGI2App = typing.Callable[[Scope], ASGIInstance]
@@ -169,9 +167,7 @@ async def _asgi_send(self, message: Message) -> None:
def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
- raise WebSocketDisconnect(
- code=message.get("code", 1000), reason=message.get("reason", "")
- )
+ raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
elif message["type"] == "websocket.http.response.start":
status_code: int = message["status"]
headers: list[tuple[bytes, bytes]] = message["headers"]
@@ -199,9 +195,7 @@ def send_text(self, data: str) -> None:
def send_bytes(self, data: bytes) -> None:
self.send({"type": "websocket.receive", "bytes": data})
- def send_json(
- self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text"
- ) -> None:
+ def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None:
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
if mode == "text":
self.send({"type": "websocket.receive", "text": text})
@@ -227,9 +221,7 @@ def receive_bytes(self) -> bytes:
self._raise_on_close(message)
return typing.cast(bytes, message["bytes"])
- def receive_json(
- self, mode: typing.Literal["text", "binary"] = "text"
- ) -> typing.Any:
+ def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any:
message = self.receive()
self._raise_on_close(message)
if mode == "text":
@@ -280,10 +272,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
headers = [(b"host", (f"{host}:{port}").encode())]
# Include other request headers.
- headers += [
- (key.lower().encode(), value.encode())
- for key, value in request.headers.multi_items()
- ]
+ headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
scope: dict[str, typing.Any]
@@ -365,22 +354,13 @@ async def send(message: Message) -> None:
nonlocal raw_kwargs, response_started, template, context
if message["type"] == "http.response.start":
- assert (
- not response_started
- ), 'Received multiple "http.response.start" messages.'
+ assert not response_started, 'Received multiple "http.response.start" messages.'
raw_kwargs["status_code"] = message["status"]
- raw_kwargs["headers"] = [
- (key.decode(), value.decode())
- for key, value in message.get("headers", [])
- ]
+ raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
response_started = True
elif message["type"] == "http.response.body":
- assert (
- response_started
- ), 'Received "http.response.body" without "http.response.start".'
- assert (
- not response_complete.is_set()
- ), 'Received "http.response.body" after response completed.'
+ assert response_started, 'Received "http.response.body" without "http.response.start".'
+ assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
body = message.get("body", b"")
more_body = message.get("more_body", False)
if request.method != "HEAD":
@@ -435,9 +415,7 @@ def __init__(
headers: dict[str, str] | None = None,
follow_redirects: bool = True,
) -> None:
- self.async_backend = _AsyncBackend(
- backend=backend, backend_options=backend_options or {}
- )
+ self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
if _is_asgi3(app):
asgi_app = app
else:
@@ -468,22 +446,15 @@ def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, No
if self.portal is not None:
yield self.portal
else:
- with anyio.from_thread.start_blocking_portal(
- **self.async_backend
- ) as portal:
+ with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
yield portal
def _choose_redirect_arg(
self, follow_redirects: bool | None, allow_redirects: bool | None
) -> bool | httpx._client.UseClientDefault:
- redirect: bool | httpx._client.UseClientDefault = (
- httpx._client.USE_CLIENT_DEFAULT
- )
+ redirect: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT
if allow_redirects is not None:
- message = (
- "The `allow_redirects` argument is deprecated. "
- "Use `follow_redirects` instead."
- )
+ message = "The `allow_redirects` argument is deprecated. " "Use `follow_redirects` instead."
warnings.warn(message, DeprecationWarning)
redirect = allow_redirects
if follow_redirects is not None:
@@ -506,12 +477,10 @@ def request( # type: ignore[override]
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
url = self._merge_url(url)
@@ -539,12 +508,10 @@ def get( # type: ignore[override]
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -566,12 +533,10 @@ def options( # type: ignore[override]
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -593,12 +558,10 @@ def head( # type: ignore[override]
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -624,12 +587,10 @@ def post( # type: ignore[override]
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -659,12 +620,10 @@ def put( # type: ignore[override]
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -694,12 +653,10 @@ def patch( # type: ignore[override]
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -725,12 +682,10 @@ def delete( # type: ignore[override]
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -770,9 +725,7 @@ def websocket_connect(
def __enter__(self) -> TestClient:
with contextlib.ExitStack() as stack:
- self.portal = portal = stack.enter_context(
- anyio.from_thread.start_blocking_portal(**self.async_backend)
- )
+ self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
@stack.callback
def reset_portal() -> None:
diff --git a/starlette/types.py b/starlette/types.py
index f78dd63ae..893f87296 100644
--- a/starlette/types.py
+++ b/starlette/types.py
@@ -16,15 +16,9 @@
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]]
-StatefulLifespan = typing.Callable[
- [AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
-]
+StatefulLifespan = typing.Callable[[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]]
Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
-HTTPExceptionHandler = typing.Callable[
- ["Request", Exception], "Response | typing.Awaitable[Response]"
-]
-WebSocketExceptionHandler = typing.Callable[
- ["WebSocket", Exception], typing.Awaitable[None]
-]
+HTTPExceptionHandler = typing.Callable[["Request", Exception], "Response | typing.Awaitable[Response]"]
+WebSocketExceptionHandler = typing.Callable[["WebSocket", Exception], typing.Awaitable[None]]
ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler]
diff --git a/starlette/websockets.py b/starlette/websockets.py
index 53ab5a70c..dc0457858 100644
--- a/starlette/websockets.py
+++ b/starlette/websockets.py
@@ -39,10 +39,7 @@ async def receive(self) -> Message:
message = await self._receive()
message_type = message["type"]
if message_type != "websocket.connect":
- raise RuntimeError(
- 'Expected ASGI message "websocket.connect", '
- f"but got {message_type!r}"
- )
+ raise RuntimeError('Expected ASGI message "websocket.connect", ' f"but got {message_type!r}")
self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
@@ -50,16 +47,13 @@ async def receive(self) -> Message:
message_type = message["type"]
if message_type not in {"websocket.receive", "websocket.disconnect"}:
raise RuntimeError(
- 'Expected ASGI message "websocket.receive" or '
- f'"websocket.disconnect", but got {message_type!r}'
+ 'Expected ASGI message "websocket.receive" or ' f'"websocket.disconnect", but got {message_type!r}'
)
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
return message
else:
- raise RuntimeError(
- 'Cannot call "receive" once a disconnect message has been received.'
- )
+ raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
async def send(self, message: Message) -> None:
"""
@@ -88,8 +82,7 @@ async def send(self, message: Message) -> None:
message_type = message["type"]
if message_type not in {"websocket.send", "websocket.close"}:
raise RuntimeError(
- 'Expected ASGI message "websocket.send" or "websocket.close", '
- f"but got {message_type!r}"
+ 'Expected ASGI message "websocket.send" or "websocket.close", ' f"but got {message_type!r}"
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
@@ -101,10 +94,7 @@ async def send(self, message: Message) -> None:
elif self.application_state == WebSocketState.RESPONSE:
message_type = message["type"]
if message_type != "websocket.http.response.body":
- raise RuntimeError(
- 'Expected ASGI message "websocket.http.response.body", '
- f"but got {message_type!r}"
- )
+ raise RuntimeError('Expected ASGI message "websocket.http.response.body", ' f"but got {message_type!r}")
if not message.get("more_body", False):
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
@@ -121,9 +111,7 @@ async def accept(
if self.client_state == WebSocketState.CONNECTING:
# If we haven't yet seen the 'connect' message, then wait for it first.
await self.receive()
- await self.send(
- {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
- )
+ await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
def _raise_on_disconnect(self, message: Message) -> None:
if message["type"] == "websocket.disconnect":
@@ -131,18 +119,14 @@ def _raise_on_disconnect(self, message: Message) -> None:
async def receive_text(self) -> str:
if self.application_state != WebSocketState.CONNECTED:
- raise RuntimeError(
- 'WebSocket is not connected. Need to call "accept" first.'
- )
+ raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return typing.cast(str, message["text"])
async def receive_bytes(self) -> bytes:
if self.application_state != WebSocketState.CONNECTED:
- raise RuntimeError(
- 'WebSocket is not connected. Need to call "accept" first.'
- )
+ raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return typing.cast(bytes, message["bytes"])
@@ -151,9 +135,7 @@ async def receive_json(self, mode: str = "text") -> typing.Any:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
if self.application_state != WebSocketState.CONNECTED:
- raise RuntimeError(
- 'WebSocket is not connected. Need to call "accept" first.'
- )
+ raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
@@ -200,17 +182,13 @@ async def send_json(self, data: typing.Any, mode: str = "text") -> None:
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
async def close(self, code: int = 1000, reason: str | None = None) -> None:
- await self.send(
- {"type": "websocket.close", "code": code, "reason": reason or ""}
- )
+ await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
async def send_denial_response(self, response: Response) -> None:
if "websocket.http.response" in self.scope.get("extensions", {}):
await response(self.scope, self.receive, self.send)
else:
- raise RuntimeError(
- "The server doesn't support the Websocket Denial Response extension."
- )
+ raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
class WebSocketClose:
@@ -219,6 +197,4 @@ def __init__(self, code: int = 1000, reason: str | None = None) -> None:
self.reason = reason or ""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- await send(
- {"type": "websocket.close", "code": self.code, "reason": self.reason}
- )
+ await send({"type": "websocket.close", "code": self.code, "reason": self.reason})
diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py
index 8e410cb15..225038650 100644
--- a/tests/middleware/test_base.py
+++ b/tests/middleware/test_base.py
@@ -169,9 +169,7 @@ def test_app_middleware_argument(test_client_factory: TestClientFactory) -> None
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage")
- app = Starlette(
- routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
- )
+ app = Starlette(routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)])
client = test_client_factory(app)
response = client.get("/")
@@ -249,9 +247,7 @@ async def homepage(request: Request) -> PlainTextResponse:
ctxvar.set("set by endpoint")
return PlainTextResponse("Homepage")
- app = Starlette(
- middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
- )
+ app = Starlette(middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)])
client = test_client_factory(app)
response = client.get("/")
@@ -316,13 +312,9 @@ async def sleep_and_set() -> None:
events.append("Background task finished")
async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
- return PlainTextResponse(
- content="Hello", background=BackgroundTask(sleep_and_set)
- )
+ return PlainTextResponse(content="Hello", background=BackgroundTask(sleep_and_set))
- async def passthrough(
- request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
+ async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)
app = Starlette(
@@ -490,9 +482,7 @@ async def cancel_on_disconnect(
}
)
- pytest.fail(
- "http.disconnect should have been received and canceled the scope"
- ) # pragma: no cover
+ pytest.fail("http.disconnect should have been received and canceled the scope") # pragma: no cover
app = DiscardingMiddleware(downstream_app)
@@ -787,7 +777,7 @@ async def send(msg: Message) -> None:
await rcv_stream.aclose()
-def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( # noqa: E501
+def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(
test_client_factory: TestClientFactory,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
@@ -800,9 +790,7 @@ async def dispatch(
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
- assert (
- await request.body() == b"a"
- ) # this buffers the request body in memory
+ assert await request.body() == b"a" # this buffers the request body in memory
resp = await call_next(request)
async for chunk in request.stream():
if chunk:
@@ -819,7 +807,7 @@ async def dispatch(
assert response.status_code == 200
-def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( # noqa: E501
+def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(
test_client_factory: TestClientFactory,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
@@ -832,9 +820,7 @@ async def dispatch(
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
- assert (
- await request.body() == b"a"
- ) # this buffers the request body in memory
+ assert await request.body() == b"a" # this buffers the request body in memory
resp = await call_next(request)
assert await request.body() == b"a" # no problem here
return resp
@@ -1026,9 +1012,7 @@ def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None:
self.events = events
super().__init__(app)
- async def dispatch(
- self, request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
self.events.append(f"{self.version}:STARTED")
res = await call_next(request)
self.events.append(f"{self.version}:COMPLETED")
@@ -1047,9 +1031,7 @@ async def sleepy(request: Request) -> Response:
app = Starlette(
routes=[Route("/", sleepy)],
- middleware=[
- Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)
- ],
+ middleware=[Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)],
)
scope = {
@@ -1114,9 +1096,7 @@ async def app_poll_disconnect(scope: Scope, receive: Receive, send: Send) -> Non
await Response(b"good!")(scope, receive, send)
class MyMiddleware(BaseHTTPMiddleware):
- async def dispatch(
- self, request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)
app = MyMiddleware(app_poll_disconnect)
diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py
index 630361243..0d987263e 100644
--- a/tests/middleware/test_cors.py
+++ b/tests/middleware/test_cors.py
@@ -252,9 +252,7 @@ def homepage(request: Request) -> None:
app = Starlette(
routes=[Route("/", endpoint=homepage)],
- middleware=[
- Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
- ],
+ middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])],
)
client = test_client_factory(app)
@@ -284,9 +282,7 @@ def homepage(request: Request) -> PlainTextResponse:
methods=["delete", "get", "head", "options", "patch", "post", "put"],
)
],
- middleware=[
- Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
- ],
+ middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])],
)
client = test_client_factory(app)
@@ -397,10 +393,7 @@ def homepage(request: Request) -> PlainTextResponse:
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
- assert (
- response.headers["access-control-allow-origin"]
- == "https://subdomain.example.org"
- )
+ assert response.headers["access-control-allow-origin"] == "https://subdomain.example.org"
assert "access-control-allow-credentials" not in response.headers
# Test diallowed standard response
@@ -456,9 +449,7 @@ def test_cors_vary_header_is_not_set_for_non_credentialed_request(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
- return PlainTextResponse(
- "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
- )
+ return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"})
app = Starlette(
routes=[Route("/", endpoint=homepage)],
@@ -475,9 +466,7 @@ def test_cors_vary_header_is_properly_set_for_credentialed_request(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
- return PlainTextResponse(
- "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
- )
+ return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"})
app = Starlette(
routes=[Route("/", endpoint=homepage)],
@@ -485,9 +474,7 @@ def homepage(request: Request) -> PlainTextResponse:
)
client = test_client_factory(app)
- response = client.get(
- "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
- )
+ response = client.get("/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"})
assert response.status_code == 200
assert response.headers["vary"] == "Accept-Encoding, Origin"
@@ -496,9 +483,7 @@ def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
- return PlainTextResponse(
- "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
- )
+ return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"})
app = Starlette(
routes=[
@@ -538,9 +523,7 @@ def homepage(request: Request) -> PlainTextResponse:
assert response.headers["access-control-allow-origin"] == "*"
assert "access-control-allow-credentials" not in response.headers
- response = client.get(
- "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
- )
+ response = client.get("/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"})
assert response.headers["access-control-allow-origin"] == "https://someplace.org"
assert "access-control-allow-credentials" not in response.headers
diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py
index b6f68296d..b20a7cb84 100644
--- a/tests/middleware/test_gzip.py
+++ b/tests/middleware/test_gzip.py
@@ -91,9 +91,7 @@ async def generator(bytes: bytes, count: int) -> ContentStream:
yield bytes
streaming = generator(bytes=b"x" * 400, count=10)
- return StreamingResponse(
- streaming, status_code=200, headers={"Content-Encoding": "text"}
- )
+ return StreamingResponse(streaming, status_code=200, headers={"Content-Encoding": "text"})
app = Starlette(
routes=[Route("/", endpoint=homepage)],
diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py
index 9a0d70a0d..b4f3c64fa 100644
--- a/tests/middleware/test_session.py
+++ b/tests/middleware/test_session.py
@@ -89,9 +89,7 @@ def test_secure_session(test_client_factory: TestClientFactory) -> None:
Route("/update_session", endpoint=update_session, methods=["POST"]),
Route("/clear_session", endpoint=clear_session, methods=["POST"]),
],
- middleware=[
- Middleware(SessionMiddleware, secret_key="example", https_only=True)
- ],
+ middleware=[Middleware(SessionMiddleware, secret_key="example", https_only=True)],
)
secure_client = test_client_factory(app, base_url="https://testserver")
unsecure_client = test_client_factory(app, base_url="http://testserver")
@@ -126,9 +124,7 @@ def test_session_cookie_subpath(test_client_factory: TestClientFactory) -> None:
routes=[
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
- middleware=[
- Middleware(SessionMiddleware, secret_key="example", path="/second_app")
- ],
+ middleware=[Middleware(SessionMiddleware, secret_key="example", path="/second_app")],
)
app = Starlette(routes=[Mount("/second_app", app=second_app)])
client = test_client_factory(app, base_url="http://testserver/second_app")
@@ -188,9 +184,7 @@ def test_domain_cookie(test_client_factory: TestClientFactory) -> None:
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
- middleware=[
- Middleware(SessionMiddleware, secret_key="example", domain=".example.com")
- ],
+ middleware=[Middleware(SessionMiddleware, secret_key="example", domain=".example.com")],
)
client: TestClient = test_client_factory(app)
diff --git a/tests/middleware/test_trusted_host.py b/tests/middleware/test_trusted_host.py
index ddff46c48..5b8b217c3 100644
--- a/tests/middleware/test_trusted_host.py
+++ b/tests/middleware/test_trusted_host.py
@@ -13,11 +13,7 @@ def homepage(request: Request) -> PlainTextResponse:
app = Starlette(
routes=[Route("/", endpoint=homepage)],
- middleware=[
- Middleware(
- TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"]
- )
- ],
+ middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"])],
)
client = test_client_factory(app)
@@ -45,9 +41,7 @@ def homepage(request: Request) -> PlainTextResponse:
app = Starlette(
routes=[Route("/", endpoint=homepage)],
- middleware=[
- Middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])
- ],
+ middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])],
)
client = test_client_factory(app, base_url="https://example.com")
diff --git a/tests/test_applications.py b/tests/test_applications.py
index 20da7ea81..e86eba322 100644
--- a/tests/test_applications.py
+++ b/tests/test_applications.py
@@ -109,9 +109,7 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) ->
CustomWSException: custom_ws_exception_handler,
}
-middleware = [
- Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"])
-]
+middleware = [Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"])]
app = Starlette(
routes=[
@@ -349,9 +347,7 @@ def run_cleanup() -> None:
nonlocal cleanup_complete
cleanup_complete = True
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
app = Starlette(
on_startup=[run_startup],
on_shutdown=[run_cleanup],
@@ -445,51 +441,34 @@ def test_decorator_deprecations() -> None:
app = Starlette()
with pytest.deprecated_call(
- match=(
- "The `exception_handler` decorator is deprecated, "
- "and will be removed in version 1.0.0."
- )
+ match=("The `exception_handler` decorator is deprecated, " "and will be removed in version 1.0.0.")
) as record:
app.exception_handler(500)(http_exception)
assert len(record) == 1
with pytest.deprecated_call(
- match=(
- "The `middleware` decorator is deprecated, "
- "and will be removed in version 1.0.0."
- )
+ match=("The `middleware` decorator is deprecated, " "and will be removed in version 1.0.0.")
) as record:
- async def middleware(
- request: Request, call_next: RequestResponseEndpoint
- ) -> None: ... # pragma: no cover
+ async def middleware(request: Request, call_next: RequestResponseEndpoint) -> None: ... # pragma: no cover
app.middleware("http")(middleware)
assert len(record) == 1
with pytest.deprecated_call(
- match=(
- "The `route` decorator is deprecated, "
- "and will be removed in version 1.0.0."
- )
+ match=("The `route` decorator is deprecated, " "and will be removed in version 1.0.0.")
) as record:
app.route("/")(async_homepage)
assert len(record) == 1
with pytest.deprecated_call(
- match=(
- "The `websocket_route` decorator is deprecated, "
- "and will be removed in version 1.0.0."
- )
+ match=("The `websocket_route` decorator is deprecated, " "and will be removed in version 1.0.0.")
) as record:
app.websocket_route("/ws")(websocket_endpoint)
assert len(record) == 1
with pytest.deprecated_call(
- match=(
- "The `on_event` decorator is deprecated, "
- "and will be removed in version 1.0.0."
- )
+ match=("The `on_event` decorator is deprecated, " "and will be removed in version 1.0.0.")
) as record:
async def startup() -> None: ... # pragma: no cover
diff --git a/tests/test_authentication.py b/tests/test_authentication.py
index 35c1110d1..a1bde67b9 100644
--- a/tests/test_authentication.py
+++ b/tests/test_authentication.py
@@ -259,9 +259,7 @@ def test_authentication_required(test_client_factory: TestClientFactory) -> None
response = client.get("/dashboard/decorated")
assert response.status_code == 403
- response = client.get(
- "/dashboard/decorated/sync", auth=("tomchristie", "example")
- )
+ response = client.get("/dashboard/decorated/sync", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {
"authenticated": True,
@@ -286,14 +284,10 @@ def test_websocket_authentication_required(
pass # pragma: nocover
with pytest.raises(WebSocketDisconnect):
- with client.websocket_connect(
- "/ws", headers={"Authorization": "basic foobar"}
- ):
+ with client.websocket_connect("/ws", headers={"Authorization": "basic foobar"}):
pass # pragma: nocover
- with client.websocket_connect(
- "/ws", auth=("tomchristie", "example")
- ) as websocket:
+ with client.websocket_connect("/ws", auth=("tomchristie", "example")) as websocket:
data = websocket.receive_json()
assert data == {"authenticated": True, "user": "tomchristie"}
@@ -302,14 +296,10 @@ def test_websocket_authentication_required(
pass # pragma: nocover
with pytest.raises(WebSocketDisconnect):
- with client.websocket_connect(
- "/ws/decorated", headers={"Authorization": "basic foobar"}
- ):
+ with client.websocket_connect("/ws/decorated", headers={"Authorization": "basic foobar"}):
pass # pragma: nocover
- with client.websocket_connect(
- "/ws/decorated", auth=("tomchristie", "example")
- ) as websocket:
+ with client.websocket_connect("/ws/decorated", auth=("tomchristie", "example")) as websocket:
data = websocket.receive_json()
assert data == {
"authenticated": True,
@@ -322,9 +312,7 @@ def test_authentication_redirect(test_client_factory: TestClientFactory) -> None
with test_client_factory(app) as client:
response = client.get("/admin")
assert response.status_code == 200
- url = "{}?{}".format(
- "http://testserver/", urlencode({"next": "http://testserver/admin"})
- )
+ url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin"}))
assert response.url == url
response = client.get("/admin", auth=("tomchristie", "example"))
@@ -333,9 +321,7 @@ def test_authentication_redirect(test_client_factory: TestClientFactory) -> None
response = client.get("/admin/sync")
assert response.status_code == 200
- url = "{}?{}".format(
- "http://testserver/", urlencode({"next": "http://testserver/admin/sync"})
- )
+ url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin/sync"}))
assert response.url == url
response = client.get("/admin/sync", auth=("tomchristie", "example"))
@@ -359,11 +345,7 @@ def control_panel(request: Request) -> JSONResponse:
other_app = Starlette(
routes=[Route("/control-panel", control_panel)],
- middleware=[
- Middleware(
- AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error
- )
- ],
+ middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error)],
)
@@ -373,8 +355,6 @@ def test_custom_on_error(test_client_factory: TestClientFactory) -> None:
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
- response = client.get(
- "/control-panel", headers={"Authorization": "basic foobar"}
- )
+ response = client.get("/control-panel", headers={"Authorization": "basic foobar"})
assert response.status_code == 401
assert response.json() == {"error": "Invalid basic auth credentials"}
diff --git a/tests/test_background.py b/tests/test_background.py
index cbffcc06a..990e270ea 100644
--- a/tests/test_background.py
+++ b/tests/test_background.py
@@ -56,9 +56,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
tasks.add_task(increment, amount=1)
tasks.add_task(increment, amount=2)
tasks.add_task(increment, amount=3)
- response = Response(
- "tasks initiated", media_type="text/plain", background=tasks
- )
+ response = Response("tasks initiated", media_type="text/plain", background=tasks)
await response(scope, receive, send)
client = test_client_factory(app)
@@ -82,9 +80,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
tasks = BackgroundTasks()
tasks.add_task(increment)
tasks.add_task(increment)
- response = Response(
- "tasks initiated", media_type="text/plain", background=tasks
- )
+ response = Response("tasks initiated", media_type="text/plain", background=tasks)
await response(scope, receive, send)
client = test_client_factory(app)
diff --git a/tests/test_config.py b/tests/test_config.py
index f37591007..7d2cd1f9d 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -14,9 +14,7 @@ def test_config_types() -> None:
"""
We use `assert_type` to test the types returned by Config via mypy.
"""
- config = Config(
- environ={"STR": "some_str_value", "STR_CAST": "some_str_value", "BOOL": "true"}
- )
+ config = Config(environ={"STR": "some_str_value", "STR_CAST": "some_str_value", "BOOL": "true"})
assert_type(config("STR"), str)
assert_type(config("STR_DEFAULT", default=""), str)
@@ -138,9 +136,7 @@ def test_environ() -> None:
def test_config_with_env_prefix(tmpdir: Path, monkeypatch: pytest.MonkeyPatch) -> None:
- config = Config(
- environ={"APP_DEBUG": "value", "ENVIRONMENT": "dev"}, env_prefix="APP_"
- )
+ config = Config(environ={"APP_DEBUG": "value", "ENVIRONMENT": "dev"}, env_prefix="APP_")
assert config.get("DEBUG") == "value"
with pytest.raises(KeyError):
diff --git a/tests/test_convertors.py b/tests/test_convertors.py
index 520c98767..ced1b86cc 100644
--- a/tests/test_convertors.py
+++ b/tests/test_convertors.py
@@ -48,23 +48,18 @@ def datetime_convertor(request: Request) -> JSONResponse:
)
-def test_datetime_convertor(
- test_client_factory: TestClientFactory, app: Router
-) -> None:
+def test_datetime_convertor(test_client_factory: TestClientFactory, app: Router) -> None:
client = test_client_factory(app)
response = client.get("/datetime/2020-01-01T00:00:00")
assert response.json() == {"datetime": "2020-01-01T00:00:00"}
assert (
- app.url_path_for("datetime-convertor", param=datetime(1996, 1, 22, 23, 0, 0))
- == "/datetime/1996-01-22T23:00:00"
+ app.url_path_for("datetime-convertor", param=datetime(1996, 1, 22, 23, 0, 0)) == "/datetime/1996-01-22T23:00:00"
)
@pytest.mark.parametrize("param, status_code", [("1.0", 200), ("1-0", 404)])
-def test_default_float_convertor(
- test_client_factory: TestClientFactory, param: str, status_code: int
-) -> None:
+def test_default_float_convertor(test_client_factory: TestClientFactory, param: str, status_code: int) -> None:
def float_convertor(request: Request) -> JSONResponse:
param = request.path_params["param"]
assert isinstance(param, float)
diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py
index a6bca6ef6..0e7d35c3c 100644
--- a/tests/test_datastructures.py
+++ b/tests/test_datastructures.py
@@ -115,9 +115,7 @@ def test_csv() -> None:
def test_url_from_scope() -> None:
- u = URL(
- scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []}
- )
+ u = URL(scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []})
assert u == "/path/to/somewhere?abc=123"
assert repr(u) == "URL('/path/to/somewhere?abc=123')"
@@ -296,13 +294,9 @@ def test_queryparams() -> None:
assert dict(q) == {"a": "456", "b": "789"}
assert str(q) == "a=123&a=456&b=789"
assert repr(q) == "QueryParams('a=123&a=456&b=789')"
- assert QueryParams({"a": "123", "b": "456"}) == QueryParams(
- [("a", "123"), ("b", "456")]
- )
+ assert QueryParams({"a": "123", "b": "456"}) == QueryParams([("a", "123"), ("b", "456")])
assert QueryParams({"a": "123", "b": "456"}) == QueryParams("a=123&b=456")
- assert QueryParams({"a": "123", "b": "456"}) == QueryParams(
- {"b": "456", "a": "123"}
- )
+ assert QueryParams({"a": "123", "b": "456"}) == QueryParams({"b": "456", "a": "123"})
assert QueryParams() == QueryParams({})
assert QueryParams([("a", "123"), ("a", "456")]) == QueryParams("a=123&a=456")
assert QueryParams({"a": "123", "b": "456"}) != "invalid"
@@ -382,10 +376,7 @@ def test_formdata() -> None:
assert len(form) == 2
assert list(form) == ["a", "b"]
assert dict(form) == {"a": "456", "b": upload}
- assert (
- repr(form)
- == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])"
- )
+ assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])"
assert FormData(form) == form
assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")])
assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"}
@@ -402,10 +393,7 @@ async def test_upload_file_repr() -> None:
async def test_upload_file_repr_headers() -> None:
stream = io.BytesIO(b"data")
file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"}))
- assert (
- repr(file)
- == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))"
- )
+ assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))"
def test_multidict() -> None:
@@ -425,9 +413,7 @@ def test_multidict() -> None:
assert dict(q) == {"a": "456", "b": "789"}
assert str(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])"
assert repr(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])"
- assert MultiDict({"a": "123", "b": "456"}) == MultiDict(
- [("a", "123"), ("b", "456")]
- )
+ assert MultiDict({"a": "123", "b": "456"}) == MultiDict([("a", "123"), ("b", "456")])
assert MultiDict({"a": "123", "b": "456"}) == MultiDict({"b": "456", "a": "123"})
assert MultiDict() == MultiDict({})
assert MultiDict({"a": "123", "b": "456"}) != "invalid"
diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py
index 8f201e25b..42776a5b3 100644
--- a/tests/test_endpoints.py
+++ b/tests/test_endpoints.py
@@ -19,9 +19,7 @@ async def get(self, request: Request) -> PlainTextResponse:
return PlainTextResponse(f"Hello, {username}!")
-app = Router(
- routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)]
-)
+app = Router(routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)])
@pytest.fixture
diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py
index f4e91ad87..b3dc7843f 100644
--- a/tests/test_exceptions.py
+++ b/tests/test_exceptions.py
@@ -42,9 +42,7 @@ async def read_body_and_raise_exc(request: Request) -> None:
raise BadBodyException(422)
-async def handler_that_reads_body(
- request: Request, exc: BadBodyException
-) -> JSONResponse:
+async def handler_that_reads_body(request: Request, exc: BadBodyException) -> JSONResponse:
body = await request.body()
return JSONResponse(status_code=422, content={"body": body.decode()})
@@ -158,9 +156,7 @@ def test_http_str() -> None:
def test_http_repr() -> None:
- assert repr(HTTPException(404)) == (
- "HTTPException(status_code=404, detail='Not Found')"
- )
+ assert repr(HTTPException(404)) == ("HTTPException(status_code=404, detail='Not Found')")
assert repr(HTTPException(404, detail="Not Found: foo")) == (
"HTTPException(status_code=404, detail='Not Found: foo')"
)
diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py
index 8d97a0ba7..a5ebdd043 100644
--- a/tests/test_formparsers.py
+++ b/tests/test_formparsers.py
@@ -127,17 +127,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
return app
-def test_multipart_request_data(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART)
assert response.json() == {"some": "data"}
-def test_multipart_request_files(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(b"")
@@ -155,9 +151,7 @@ def test_multipart_request_files(
}
-def test_multipart_request_files_with_content_type(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_files_with_content_type(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(b"")
@@ -175,9 +169,7 @@ def test_multipart_request_files_with_content_type(
}
-def test_multipart_request_multiple_files(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_multiple_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path1 = os.path.join(tmpdir, "test1.txt")
with open(path1, "wb") as file:
file.write(b"")
@@ -188,9 +180,7 @@ def test_multipart_request_multiple_files(
client = test_client_factory(app)
with open(path1, "rb") as f1, open(path2, "rb") as f2:
- response = client.post(
- "/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")}
- )
+ response = client.post("/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")})
assert response.json() == {
"test1": {
"filename": "test1.txt",
@@ -207,9 +197,7 @@ def test_multipart_request_multiple_files(
}
-def test_multipart_request_multiple_files_with_headers(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_multiple_files_with_headers(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path1 = os.path.join(tmpdir, "test1.txt")
with open(path1, "wb") as file:
file.write(b"")
@@ -281,9 +269,7 @@ def test_multi_items(tmpdir: Path, test_client_factory: TestClientFactory) -> No
}
-def test_multipart_request_mixed_files_and_data(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
"/",
@@ -303,11 +289,7 @@ def test_multipart_request_mixed_files_and_data(
b"value1\r\n"
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
),
- headers={
- "Content-Type": (
- "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c"
- )
- },
+ headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert response.json() == {
"file": {
@@ -321,26 +303,19 @@ def test_multipart_request_mixed_files_and_data(
}
-def test_multipart_request_with_charset_for_filename(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
"/",
data=(
# file
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore
- b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # noqa: E501
+ b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n'
b"Content-Type: text/plain\r\n\r\n"
b"\r\n"
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
),
- headers={
- "Content-Type": (
- "multipart/form-data; charset=utf-8; "
- "boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c"
- )
- },
+ headers={"Content-Type": ("multipart/form-data; charset=utf-8; " "boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert response.json() == {
"file": {
@@ -352,25 +327,19 @@ def test_multipart_request_with_charset_for_filename(
}
-def test_multipart_request_without_charset_for_filename(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_without_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
"/",
data=(
# file
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore
- b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n' # noqa: E501
+ b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n'
b"Content-Type: image/jpeg\r\n\r\n"
b"\r\n"
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
),
- headers={
- "Content-Type": (
- "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c"
- )
- },
+ headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert response.json() == {
"file": {
@@ -382,9 +351,7 @@ def test_multipart_request_without_charset_for_filename(
}
-def test_multipart_request_with_encoded_value(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_with_encoded_value(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
"/",
@@ -395,19 +362,12 @@ def test_multipart_request_with_encoded_value(
b"Transf\xc3\xa9rer\r\n"
b"--20b303e711c4ab8c443184ac833ab00f--\r\n"
),
- headers={
- "Content-Type": (
- "multipart/form-data; charset=utf-8; "
- "boundary=20b303e711c4ab8c443184ac833ab00f"
- )
- },
+ headers={"Content-Type": ("multipart/form-data; charset=utf-8; " "boundary=20b303e711c4ab8c443184ac833ab00f")},
)
assert response.json() == {"value": "Transférer"}
-def test_urlencoded_request_data(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_urlencoded_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"some": "data"})
assert response.json() == {"some": "data"}
@@ -419,37 +379,27 @@ def test_no_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -
assert response.json() == {}
-def test_urlencoded_percent_encoding(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_urlencoded_percent_encoding(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"some": "da ta"})
assert response.json() == {"some": "da ta"}
-def test_urlencoded_percent_encoding_keys(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_urlencoded_percent_encoding_keys(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"so me": "data"})
assert response.json() == {"so me": "data"}
-def test_urlencoded_multi_field_app_reads_body(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_urlencoded_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app_read_body)
response = client.post("/", data={"some": "data", "second": "key pair"})
assert response.json() == {"some": "data", "second": "key pair"}
-def test_multipart_multi_field_app_reads_body(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app_read_body)
- response = client.post(
- "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART
- )
+ response = client.post("/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART)
assert response.json() == {"some": "data", "second": "key pair"}
@@ -481,7 +431,7 @@ def test_missing_boundary_parameter(
"/",
data=(
# file
- b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # type: ignore # noqa: E501
+ b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # type: ignore
b"Content-Type: text/plain\r\n\r\n"
b"\r\n"
),
@@ -513,16 +463,10 @@ def test_missing_name_parameter_on_content_disposition(
b'Content-Disposition: form-data; ="field0"\r\n\r\n'
b"value0\r\n"
),
- headers={
- "Content-Type": (
- "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c"
- )
- },
+ headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert res.status_code == 400
- assert (
- res.text == 'The Content-Disposition header field "name" must be provided.'
- )
+ assert res.text == 'The Content-Disposition header field "name" must be provided.'
@pytest.mark.parametrize(
@@ -540,9 +484,7 @@ def test_too_many_fields_raise(
client = test_client_factory(app)
fields = []
for i in range(1001):
- fields.append(
- "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
@@ -569,11 +511,7 @@ def test_too_many_files_raise(
client = test_client_factory(app)
fields = []
for i in range(1001):
- fields.append(
- "--B\r\n"
- f'Content-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n'
- "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
@@ -602,11 +540,7 @@ def test_too_many_files_single_field_raise(
for i in range(1001):
# This uses the same field name "N" for all files, equivalent to a
# multifile upload form field
- fields.append(
- "--B\r\n"
- f'Content-Disposition: form-data; name="N"; filename="F{i}";\r\n\r\n'
- "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="N"; filename="F{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
@@ -633,14 +567,8 @@ def test_too_many_files_and_fields_raise(
client = test_client_factory(app)
fields = []
for i in range(1001):
- fields.append(
- "--B\r\n"
- f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n'
- "\r\n"
- )
- fields.append(
- "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n")
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
@@ -670,9 +598,7 @@ def test_max_fields_is_customizable_low_raises(
client = test_client_factory(app)
fields = []
for i in range(2):
- fields.append(
- "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
@@ -702,11 +628,7 @@ def test_max_files_is_customizable_low_raises(
client = test_client_factory(app)
fields = []
for i in range(2):
- fields.append(
- "--B\r\n"
- f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n'
- "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
@@ -724,14 +646,8 @@ def test_max_fields_is_customizable_high(
client = test_client_factory(make_app_max_parts(max_fields=2000, max_files=2000))
fields = []
for i in range(2000):
- fields.append(
- "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n"
- )
- fields.append(
- "--B\r\n"
- f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n'
- "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
data += b"--B--\r\n"
res = client.post(
diff --git a/tests/test_responses.py b/tests/test_responses.py
index c63c92de5..ad1901ca5 100644
--- a/tests/test_responses.py
+++ b/tests/test_responses.py
@@ -118,9 +118,7 @@ async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None:
cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9)
generator = numbers(1, 5)
- response = StreamingResponse(
- generator, media_type="text/plain", background=cleanup_task
- )
+ response = StreamingResponse(generator, media_type="text/plain", background=cleanup_task)
await response(scope, receive, send)
assert filled_by_bg_task == ""
@@ -236,9 +234,7 @@ async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None:
cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
- response = FileResponse(
- path=path, filename="example.png", background=cleanup_task
- )
+ response = FileResponse(path=path, filename="example.png", background=cleanup_task)
await response(scope, receive, send)
assert filled_by_bg_task == ""
@@ -284,9 +280,7 @@ async def send(message: Message) -> None:
await app({"type": "http", "method": "head"}, receive, send)
-def test_file_response_set_media_type(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_set_media_type(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "xyz"
path.write_bytes(b"")
@@ -298,9 +292,7 @@ def test_file_response_set_media_type(
assert response.headers["content-type"] == "image/jpeg"
-def test_file_response_with_directory_raises_error(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_with_directory_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
app = FileResponse(path=tmp_path, filename="example.png")
client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
@@ -308,9 +300,7 @@ def test_file_response_with_directory_raises_error(
assert "is not a file" in str(exc_info.value)
-def test_file_response_with_missing_file_raises_error(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_with_missing_file_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "404.txt"
app = FileResponse(path=path, filename="404.txt")
client = test_client_factory(app)
@@ -319,9 +309,7 @@ def test_file_response_with_missing_file_raises_error(
assert "does not exist" in str(exc_info.value)
-def test_file_response_with_chinese_filename(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_with_chinese_filename(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
content = b"file content"
filename = "你好.txt" # probably "Hello.txt" in Chinese
path = tmp_path / filename
@@ -335,9 +323,7 @@ def test_file_response_with_chinese_filename(
assert response.headers["content-disposition"] == expected_disposition
-def test_file_response_with_inline_disposition(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_with_inline_disposition(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
content = b"file content"
filename = "hello.txt"
path = tmp_path / filename
@@ -356,9 +342,7 @@ def test_file_response_with_method_warns(tmp_path: Path) -> None:
FileResponse(path=tmp_path, filename="example.png", method="GET")
-def test_set_cookie(
- test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch
-) -> None:
+def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None:
# Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)
monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp())
@@ -382,8 +366,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = client.get("/")
assert response.text == "Hello, world!"
assert (
- response.headers["set-cookie"]
- == "mycookie=myvalue; Domain=localhost; expires=Thu, 22 Jan 2037 12:00:10 GMT; "
+ response.headers["set-cookie"] == "mycookie=myvalue; Domain=localhost; expires=Thu, 22 Jan 2037 12:00:10 GMT; "
"HttpOnly; Max-Age=10; Path=/; SameSite=none; Secure"
)
@@ -403,9 +386,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
@pytest.mark.parametrize(
"expires",
[
- pytest.param(
- dt.datetime(2037, 1, 22, 12, 0, 10, tzinfo=dt.timezone.utc), id="datetime"
- ),
+ pytest.param(dt.datetime(2037, 1, 22, 12, 0, 10, tzinfo=dt.timezone.utc), id="datetime"),
pytest.param("Thu, 22 Jan 2037 12:00:10 GMT", id="str"),
pytest.param(10, id="int"),
],
@@ -495,9 +476,7 @@ def test_response_do_not_add_redundant_charset(
assert response.headers["content-type"] == "text/plain; charset=utf-8"
-def test_file_response_known_size(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_known_size(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "xyz"
content = b"" * 1000
path.write_bytes(content)
@@ -518,9 +497,7 @@ def test_streaming_response_unknown_size(
def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None:
- app = StreamingResponse(
- content=iter(["hello", "world"]), headers={"content-length": "10"}
- )
+ app = StreamingResponse(content=iter(["hello", "world"]), headers={"content-length": "10"})
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == "10"
diff --git a/tests/test_routing.py b/tests/test_routing.py
index 1490723b4..132baa602 100644
--- a/tests/test_routing.py
+++ b/tests/test_routing.py
@@ -232,10 +232,7 @@ def test_route_converters(client: TestClient) -> None:
response = client.get("/path-with-parentheses(7)")
assert response.status_code == 200
assert response.json() == {"int": 7}
- assert (
- app.url_path_for("path-with-parentheses", param=7)
- == "/path-with-parentheses(7)"
- )
+ assert app.url_path_for("path-with-parentheses", param=7) == "/path-with-parentheses(7)"
# Test float conversion
response = client.get("/float/25.5")
@@ -247,18 +244,14 @@ def test_route_converters(client: TestClient) -> None:
response = client.get("/path/some/example")
assert response.status_code == 200
assert response.json() == {"path": "some/example"}
- assert (
- app.url_path_for("path-convertor", param="some/example") == "/path/some/example"
- )
+ assert app.url_path_for("path-convertor", param="some/example") == "/path/some/example"
# Test UUID conversion
response = client.get("/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a")
assert response.status_code == 200
assert response.json() == {"uuid": "ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"}
assert (
- app.url_path_for(
- "uuid-convertor", param=uuid.UUID("ec38df32-ceda-4cfa-9b4a-1aeb94ad551a")
- )
+ app.url_path_for("uuid-convertor", param=uuid.UUID("ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"))
== "/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"
)
@@ -267,13 +260,9 @@ def test_url_path_for() -> None:
assert app.url_path_for("homepage") == "/"
assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie"
assert app.url_path_for("websocket_endpoint") == "/ws"
- with pytest.raises(
- NoMatchFound, match='No route exists for name "broken" and params "".'
- ):
+ with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "".'):
assert app.url_path_for("broken")
- with pytest.raises(
- NoMatchFound, match='No route exists for name "broken" and params "key, key2".'
- ):
+ with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "key, key2".'):
assert app.url_path_for("broken", key="value", key2="value2")
with pytest.raises(AssertionError):
app.url_path_for("user", username="tom/christie")
@@ -282,32 +271,21 @@ def test_url_path_for() -> None:
def test_url_for() -> None:
+ assert app.url_path_for("homepage").make_absolute_url(base_url="https://example.org") == "https://example.org/"
assert (
- app.url_path_for("homepage").make_absolute_url(base_url="https://example.org")
- == "https://example.org/"
- )
- assert (
- app.url_path_for("homepage").make_absolute_url(
- base_url="https://example.org/root_path/"
- )
+ app.url_path_for("homepage").make_absolute_url(base_url="https://example.org/root_path/")
== "https://example.org/root_path/"
)
assert (
- app.url_path_for("user", username="tomchristie").make_absolute_url(
- base_url="https://example.org"
- )
+ app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org")
== "https://example.org/users/tomchristie"
)
assert (
- app.url_path_for("user", username="tomchristie").make_absolute_url(
- base_url="https://example.org/root_path/"
- )
+ app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org/root_path/")
== "https://example.org/root_path/users/tomchristie"
)
assert (
- app.url_path_for("websocket_endpoint").make_absolute_url(
- base_url="https://example.org"
- )
+ app.url_path_for("websocket_endpoint").make_absolute_url(base_url="https://example.org")
== "wss://example.org/ws"
)
@@ -409,13 +387,8 @@ def test_reverse_mount_urls() -> None:
users = Router([Route("/{username}", ok, name="user")])
mounted = Router([Mount("/{subpath}/users", users, name="users")])
- assert (
- mounted.url_path_for("users:user", subpath="test", username="tom")
- == "/test/users/tom"
- )
- assert (
- mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom"
- )
+ assert mounted.url_path_for("users:user", subpath="test", username="tom") == "/test/users/tom"
+ assert mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom"
def test_mount_at_root(test_client_factory: TestClientFactory) -> None:
@@ -472,9 +445,7 @@ def test_host_routing(test_client_factory: TestClientFactory) -> None:
response = client.get("/")
assert response.status_code == 200
- client = test_client_factory(
- mixed_hosts_app, base_url="https://port.example.org:3600/"
- )
+ client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:3600/")
response = client.get("/users")
assert response.status_code == 404
@@ -489,31 +460,23 @@ def test_host_routing(test_client_factory: TestClientFactory) -> None:
response = client.get("/")
assert response.status_code == 200
- client = test_client_factory(
- mixed_hosts_app, base_url="https://port.example.org:5600/"
- )
+ client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:5600/")
response = client.get("/")
assert response.status_code == 200
def test_host_reverse_urls() -> None:
+ assert mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever") == "https://www.example.org/"
assert (
- mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever")
- == "https://www.example.org/"
- )
- assert (
- mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever")
- == "https://www.example.org/users"
+ mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever") == "https://www.example.org/users"
)
assert (
mixed_hosts_app.url_path_for("api:users").make_absolute_url("https://whatever")
== "https://api.example.org/users"
)
assert (
- mixed_hosts_app.url_path_for("port:homepage").make_absolute_url(
- "https://whatever"
- )
+ mixed_hosts_app.url_path_for("port:homepage").make_absolute_url("https://whatever")
== "https://port.example.org:3600/"
)
@@ -523,9 +486,7 @@ async def subdomain_app(scope: Scope, receive: Receive, send: Send) -> None:
await response(scope, receive, send)
-subdomain_router = Router(
- routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")]
-)
+subdomain_router = Router(routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")])
def test_subdomain_routing(test_client_factory: TestClientFactory) -> None:
@@ -538,9 +499,9 @@ def test_subdomain_routing(test_client_factory: TestClientFactory) -> None:
def test_subdomain_reverse_urls() -> None:
assert (
- subdomain_router.url_path_for(
- "subdomains", subdomain="foo", path="/homepage"
- ).make_absolute_url("https://whatever")
+ subdomain_router.url_path_for("subdomains", subdomain="foo", path="/homepage").make_absolute_url(
+ "https://whatever"
+ )
== "https://foo.example.org/homepage"
)
@@ -566,9 +527,7 @@ async def echo_urls(request: Request) -> JSONResponse:
def test_url_for_with_root_path(test_client_factory: TestClientFactory) -> None:
app = Starlette(routes=echo_url_routes)
- client = test_client_factory(
- app, base_url="https://www.example.org/", root_path="/sub_path"
- )
+ client = test_client_factory(app, base_url="https://www.example.org/", root_path="/sub_path")
response = client.get("/sub_path/")
assert response.json() == {
"index": "https://www.example.org/sub_path/",
@@ -657,9 +616,7 @@ async def run_shutdown() -> None:
nonlocal shutdown_complete
shutdown_complete = True
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
app = Router(
on_startup=[run_startup],
on_shutdown=[run_shutdown],
@@ -697,18 +654,14 @@ def run_shutdown() -> None: # pragma: no cover
nonlocal shutdown_called
shutdown_called = True
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
with pytest.warns(
UserWarning,
match=(
- "The `lifespan` parameter cannot be used with `on_startup` or `on_shutdown`." # noqa: E501
+ "The `lifespan` parameter cannot be used with `on_startup` or `on_shutdown`."
),
):
- app = Router(
- on_startup=[run_startup], on_shutdown=[run_shutdown], lifespan=lifespan
- )
+ app = Router(on_startup=[run_startup], on_shutdown=[run_shutdown], lifespan=lifespan)
assert not lifespan_called
assert not startup_called
@@ -738,9 +691,7 @@ def run_shutdown() -> None:
nonlocal shutdown_complete
shutdown_complete = True
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
app = Router(
on_startup=[run_startup],
on_shutdown=[run_shutdown],
@@ -775,9 +726,7 @@ async def no_state_wrapper(scope: Scope, receive: Receive, send: Send) -> None:
del scope["state"]
await app(scope, receive, send)
- with pytest.raises(
- RuntimeError, match='The server does not support "state" in the lifespan scope'
- ):
+ with pytest.raises(RuntimeError, match='The server does not support "state" in the lifespan scope'):
with test_client_factory(no_state_wrapper):
raise AssertionError("Should not be called") # pragma: no cover
@@ -834,9 +783,7 @@ def test_raise_on_startup(test_client_factory: TestClientFactory) -> None:
def run_startup() -> None:
raise RuntimeError()
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
router = Router(on_startup=[run_startup])
startup_failed = False
@@ -859,9 +806,7 @@ def test_raise_on_shutdown(test_client_factory: TestClientFactory) -> None:
def run_shutdown() -> None:
raise RuntimeError()
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
app = Router(on_shutdown=[run_shutdown])
with pytest.raises(RuntimeError):
@@ -934,9 +879,7 @@ def __call__(self, request: Request) -> None: ... # pragma: no cover
pytest.param(lambda request: ..., "", id="lambda"),
],
)
-def test_route_name(
- endpoint: typing.Callable[..., Response], expected_name: str
-) -> None:
+def test_route_name(endpoint: typing.Callable[..., Response], expected_name: str) -> None:
assert Route(path="/", endpoint=endpoint).name == expected_name
@@ -1172,10 +1115,7 @@ async def modified_send(msg: Message) -> None:
def test_route_repr() -> None:
route = Route("/welcome", endpoint=homepage)
- assert (
- repr(route)
- == "Route(path='/welcome', name='homepage', methods=['GET', 'HEAD'])"
- )
+ assert repr(route) == "Route(path='/welcome', name='homepage', methods=['GET', 'HEAD'])"
def test_route_repr_without_methods() -> None:
@@ -1264,9 +1204,7 @@ async def echo_paths(request: Request, name: str) -> JSONResponse:
)
-async def pure_asgi_echo_paths(
- scope: Scope, receive: Receive, send: Send, name: str
-) -> None:
+async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str) -> None:
data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]}
content = json.dumps(data).encode("utf-8")
await send(
@@ -1304,9 +1242,7 @@ async def pure_asgi_echo_paths(
def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None:
app = Starlette(routes=echo_paths_routes)
- client = test_client_factory(
- app, base_url="https://www.example.org/", root_path="/root"
- )
+ client = test_client_factory(app, base_url="https://www.example.org/", root_path="/root")
response = client.get("/root/path")
assert response.status_code == 200
assert response.json() == {
diff --git a/tests/test_schemas.py b/tests/test_schemas.py
index f4a5b4ad9..0ed4d5801 100644
--- a/tests/test_schemas.py
+++ b/tests/test_schemas.py
@@ -7,9 +7,7 @@
from starlette.websockets import WebSocket
from tests.types import TestClientFactory
-schemas = SchemaGenerator(
- {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}}
-)
+schemas = SchemaGenerator({"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}})
def ws(session: WebSocket) -> None:
@@ -157,25 +155,13 @@ def test_schema_generation() -> None:
},
},
"/regular-docstring-and-schema": {
- "get": {
- "responses": {
- 200: {"description": "This is included in the schema."}
- }
- }
+ "get": {"responses": {200: {"description": "This is included in the schema."}}}
},
"/subapp/subapp-endpoint": {
- "get": {
- "responses": {
- 200: {"description": "This endpoint is part of a subapp."}
- }
- }
+ "get": {"responses": {200: {"description": "This endpoint is part of a subapp."}}}
},
"/subapp2/subapp-endpoint": {
- "get": {
- "responses": {
- 200: {"description": "This endpoint is part of a subapp."}
- }
- }
+ "get": {"responses": {200: {"description": "This endpoint is part of a subapp."}}}
},
"/users": {
"get": {
@@ -186,11 +172,7 @@ def test_schema_generation() -> None:
}
}
},
- "post": {
- "responses": {
- 200: {"description": "A user.", "examples": {"username": "tom"}}
- }
- },
+ "post": {"responses": {200: {"description": "A user.", "examples": {"username": "tom"}}}},
},
"/users/{id}": {
"get": {
diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py
index 65d71b97b..8beb3cd87 100644
--- a/tests/test_staticfiles.py
+++ b/tests/test_staticfiles.py
@@ -31,9 +31,7 @@ def test_staticfiles(tmpdir: Path, test_client_factory: TestClientFactory) -> No
assert response.text == ""
-def test_staticfiles_with_pathlib(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_with_pathlib(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "example.txt"
with open(path, "w") as file:
file.write("")
@@ -45,9 +43,7 @@ def test_staticfiles_with_pathlib(
assert response.text == ""
-def test_staticfiles_head_with_middleware(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_head_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
"""
see https://github.com/encode/starlette/pull/935
"""
@@ -55,9 +51,7 @@ def test_staticfiles_head_with_middleware(
with open(path, "w") as file:
file.write("x" * 100)
- async def does_nothing_middleware(
- request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
+ async def does_nothing_middleware(request: Request, call_next: RequestResponseEndpoint) -> Response:
response = await call_next(request)
return response
@@ -99,9 +93,7 @@ def test_staticfiles_post(tmpdir: Path, test_client_factory: TestClientFactory)
assert response.text == "Method Not Allowed"
-def test_staticfiles_with_directory_returns_404(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_with_directory_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
@@ -115,9 +107,7 @@ def test_staticfiles_with_directory_returns_404(
assert response.text == "Not Found"
-def test_staticfiles_with_missing_file_returns_404(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_with_missing_file_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
@@ -138,9 +128,7 @@ def test_staticfiles_instantiated_with_missing_directory(tmpdir: Path) -> None:
assert "does not exist" in str(exc_info.value)
-def test_staticfiles_configured_with_missing_directory(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_configured_with_missing_directory(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "no_such_directory")
app = StaticFiles(directory=path, check_dir=False)
client = test_client_factory(app)
@@ -163,9 +151,7 @@ def test_staticfiles_configured_with_file_instead_of_directory(
assert "is not a directory" in str(exc_info.value)
-def test_staticfiles_config_check_occurs_only_once(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_config_check_occurs_only_once(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
app = StaticFiles(directory=tmpdir)
client = test_client_factory(app)
assert not app.config_checked
@@ -199,9 +185,7 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir: Path) -> None:
assert exc_info.value.detail == "Not Found"
-def test_staticfiles_never_read_file_for_head_method(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_never_read_file_for_head_method(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
@@ -214,9 +198,7 @@ def test_staticfiles_never_read_file_for_head_method(
assert response.headers["content-length"] == "14"
-def test_staticfiles_304_with_etag_match(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_304_with_etag_match(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
@@ -229,9 +211,7 @@ def test_staticfiles_304_with_etag_match(
second_resp = client.get("/example.txt", headers={"if-none-match": last_etag})
assert second_resp.status_code == 304
assert second_resp.content == b""
- second_resp = client.get(
- "/example.txt", headers={"if-none-match": f'W/{last_etag}, "123"'}
- )
+ second_resp = client.get("/example.txt", headers={"if-none-match": f'W/{last_etag}, "123"'})
assert second_resp.status_code == 304
assert second_resp.content == b""
@@ -240,9 +220,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req(
tmpdir: Path, test_client_factory: TestClientFactory
) -> None:
path = os.path.join(tmpdir, "example.txt")
- file_last_modified_time = time.mktime(
- time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S")
- )
+ file_last_modified_time = time.mktime(time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S"))
with open(path, "w") as file:
file.write("")
os.utime(path, (file_last_modified_time, file_last_modified_time))
@@ -250,22 +228,16 @@ def test_staticfiles_304_with_last_modified_compare_last_req(
app = StaticFiles(directory=tmpdir)
client = test_client_factory(app)
# last modified less than last request, 304
- response = client.get(
- "/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"}
- )
+ response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"})
assert response.status_code == 304
assert response.content == b""
# last modified greater than last request, 200 with content
- response = client.get(
- "/example.txt", headers={"If-Modified-Since": "Thu, 20 Feb 2012 15:30:19 GMT"}
- )
+ response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 20 Feb 2012 15:30:19 GMT"})
assert response.status_code == 200
assert response.content == b""
-def test_staticfiles_html_normal(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_html_normal(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "404.html")
with open(path, "w") as file:
file.write("Custom not found page
")
@@ -298,9 +270,7 @@ def test_staticfiles_html_normal(
assert response.text == "Custom not found page
"
-def test_staticfiles_html_without_index(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_html_without_index(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "404.html")
with open(path, "w") as file:
file.write("Custom not found page
")
@@ -325,9 +295,7 @@ def test_staticfiles_html_without_index(
assert response.text == "Custom not found page
"
-def test_staticfiles_html_without_404(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_html_without_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "dir")
os.mkdir(path)
path = os.path.join(path, "index.html")
@@ -352,9 +320,7 @@ def test_staticfiles_html_without_404(
assert exc_info.value.status_code == 404
-def test_staticfiles_html_only_files(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_html_only_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "hello.html")
with open(path, "w") as file:
file.write("Hello
")
@@ -381,9 +347,7 @@ def test_staticfiles_cache_invalidation_for_deleted_file_html_mode(
with open(path_some, "w") as file:
file.write("some file
")
- common_modified_time = time.mktime(
- time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S")
- )
+ common_modified_time = time.mktime(time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S"))
os.utime(path_404, (common_modified_time, common_modified_time))
os.utime(path_some, (common_modified_time, common_modified_time))
@@ -435,9 +399,7 @@ def test_staticfiles_with_invalid_dir_permissions_returns_401(
tmp_path.chmod(original_mode)
-def test_staticfiles_with_missing_dir_returns_404(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_with_missing_dir_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
@@ -451,9 +413,7 @@ def test_staticfiles_with_missing_dir_returns_404(
assert response.text == "Not Found"
-def test_staticfiles_access_file_as_dir_returns_404(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_access_file_as_dir_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
@@ -467,9 +427,7 @@ def test_staticfiles_access_file_as_dir_returns_404(
assert response.text == "Not Found"
-def test_staticfiles_filename_too_long(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_filename_too_long(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
app = Starlette(routes=routes)
client = test_client_factory(app)
@@ -503,9 +461,7 @@ def mock_timeout(*args: typing.Any, **kwargs: typing.Any) -> None:
assert response.text == "Internal Server Error"
-def test_staticfiles_follows_symlinks(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_follows_symlinks(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
statics_path = os.path.join(tmpdir, "statics")
os.mkdir(statics_path)
@@ -526,9 +482,7 @@ def test_staticfiles_follows_symlinks(
assert response.text == "Hello
"
-def test_staticfiles_follows_symlink_directories(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_follows_symlink_directories(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
statics_path = os.path.join(tmpdir, "statics")
statics_html_path = os.path.join(statics_path, "html")
os.mkdir(statics_path)
diff --git a/tests/test_status.py b/tests/test_status.py
index 04719e87e..1371bc1a7 100644
--- a/tests/test_status.py
+++ b/tests/test_status.py
@@ -8,13 +8,11 @@
(
(
"WS_1004_NO_STATUS_RCVD",
- "'WS_1004_NO_STATUS_RCVD' is deprecated. "
- "Use 'WS_1005_NO_STATUS_RCVD' instead.",
+ "'WS_1004_NO_STATUS_RCVD' is deprecated. " "Use 'WS_1005_NO_STATUS_RCVD' instead.",
),
(
"WS_1005_ABNORMAL_CLOSURE",
- "'WS_1005_ABNORMAL_CLOSURE' is deprecated. "
- "Use 'WS_1006_ABNORMAL_CLOSURE' instead.",
+ "'WS_1005_ABNORMAL_CLOSURE' is deprecated. " "Use 'WS_1006_ABNORMAL_CLOSURE' instead.",
),
),
)
diff --git a/tests/test_templates.py b/tests/test_templates.py
index 8e344f331..6b2080c17 100644
--- a/tests/test_templates.py
+++ b/tests/test_templates.py
@@ -36,9 +36,7 @@ async def homepage(request: Request) -> Response:
assert set(response.context.keys()) == {"request"} # type: ignore
-def test_calls_context_processors(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_calls_context_processors(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "index.html"
path.write_text("Hello {{ username }}")
@@ -66,9 +64,7 @@ def hello_world_processor(request: Request) -> dict[str, str]:
assert set(response.context.keys()) == {"request", "username"} # type: ignore
-def test_template_with_middleware(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_template_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("Hello, world")
@@ -77,9 +73,7 @@ async def homepage(request: Request) -> Response:
return templates.TemplateResponse(request, "index.html")
class CustomMiddleware(BaseHTTPMiddleware):
- async def dispatch(
- self, request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)
app = Starlette(
@@ -96,9 +90,7 @@ async def dispatch(
assert set(response.context.keys()) == {"request"} # type: ignore
-def test_templates_with_directories(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_templates_with_directories(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
dir_a = tmp_path.resolve() / "a"
dir_a.mkdir()
template_a = dir_a / "template_a.html"
@@ -134,16 +126,12 @@ async def page_b(request: Request) -> Response:
def test_templates_require_directory_or_environment() -> None:
- with pytest.raises(
- AssertionError, match="either 'directory' or 'env' arguments must be passed"
- ):
+ with pytest.raises(AssertionError, match="either 'directory' or 'env' arguments must be passed"):
Jinja2Templates() # type: ignore[call-overload]
def test_templates_require_directory_or_enviroment_not_both() -> None:
- with pytest.raises(
- AssertionError, match="either 'directory' or 'env' arguments must be passed"
- ):
+ with pytest.raises(AssertionError, match="either 'directory' or 'env' arguments must be passed"):
Jinja2Templates(directory="dir", env=jinja2.Environment())
@@ -157,9 +145,7 @@ def test_templates_with_directory(tmpdir: Path) -> None:
assert template.render({}) == "Hello"
-def test_templates_with_environment(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_templates_with_environment(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("Hello, world")
@@ -185,9 +171,7 @@ def test_templates_with_environment_options_emit_warning(tmpdir: Path) -> None:
Jinja2Templates(str(tmpdir), autoescape=True)
-def test_templates_with_kwargs_only(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_templates_with_kwargs_only(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
# MAINTAINERS: remove after 1.0
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
@@ -242,9 +226,7 @@ def test_templates_with_kwargs_only_warns_when_no_request_keyword(
templates = Jinja2Templates(directory=str(tmpdir))
def page(request: Request) -> Response:
- return templates.TemplateResponse(
- name="index.html", context={"request": request}
- )
+ return templates.TemplateResponse(name="index.html", context={"request": request})
app = Starlette(routes=[Route("/", page)])
client = test_client_factory(app)
@@ -297,9 +279,7 @@ def page(request: Request) -> Response:
spy.assert_called()
-def test_templates_when_first_argument_is_request(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_templates_when_first_argument_is_request(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
# MAINTAINERS: remove after 1.0
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
diff --git a/tests/test_testclient.py b/tests/test_testclient.py
index 77de3d976..67d060723 100644
--- a/tests/test_testclient.py
+++ b/tests/test_testclient.py
@@ -88,9 +88,7 @@ def test_testclient_headers_behavior() -> None:
assert client.headers.get("Authentication") == "Bearer 123"
-def test_use_testclient_as_contextmanager(
- test_client_factory: TestClientFactory, anyio_backend_name: str
-) -> None:
+def test_use_testclient_as_contextmanager(test_client_factory: TestClientFactory, anyio_backend_name: str) -> None:
"""
This test asserts a number of properties that are important for an
app level task_group
@@ -169,9 +167,7 @@ async def loop_id(request: Request) -> JSONResponse:
def test_error_on_startup(test_client_factory: TestClientFactory) -> None:
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
startup_error_app = Starlette(on_startup=[startup])
with pytest.raises(RuntimeError):
@@ -306,8 +302,7 @@ def homepage(request: Request) -> Response:
marks=[
pytest.mark.xfail(
sys.version_info < (3, 11),
- reason="Fails due to domain handling in http.cookiejar module (see "
- "#2152)",
+ reason="Fails due to domain handling in http.cookiejar module (see " "#2152)",
),
],
),
@@ -316,9 +311,7 @@ def homepage(request: Request) -> Response:
("example.com", False),
],
)
-def test_domain_restricted_cookies(
- test_client_factory: TestClientFactory, domain: str, ok: bool
-) -> None:
+def test_domain_restricted_cookies(test_client_factory: TestClientFactory, domain: str, ok: bool) -> None:
"""
Test that test client discards domain restricted cookies which do not match the
base_url of the testclient (`http://testserver` by default).
diff --git a/tests/test_websockets.py b/tests/test_websockets.py
index 16d2d0f1f..385e510b9 100644
--- a/tests/test_websockets.py
+++ b/tests/test_websockets.py
@@ -402,10 +402,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
client = test_client_factory(app)
with pytest.raises(
RuntimeError,
- match=(
- 'Expected ASGI message "websocket.http.response.body", but got '
- "'websocket.http.response.start'"
- ),
+ match=('Expected ASGI message "websocket.http.response.body", but got ' "'websocket.http.response.start'"),
):
with client.websocket_connect("/"):
pass # pragma: no cover
From eaecae829f63ebaea780a3a7c6589f9aa9af4ba8 Mon Sep 17 00:00:00 2001
From: Marcelo Trylesinski
Date: Sun, 1 Sep 2024 14:57:12 +0200
Subject: [PATCH 2/4] Add links to selected rules
---
pyproject.toml | 11 +++++++++--
1 file changed, 9 insertions(+), 2 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 04533cb49..52156528e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -54,8 +54,15 @@ path = "starlette/__init__.py"
line-length = 120
[tool.ruff.lint]
-select = ["E", "F", "I", "FA", "UP", "RUF100"]
-ignore = ["UP031"]
+select = [
+ "E", # https://docs.astral.sh/ruff/rules/#error-e
+ "F", # https://docs.astral.sh/ruff/rules/#pyflakes-f
+ "I", # https://docs.astral.sh/ruff/rules/#isort-i
+ "FA", # https://docs.astral.sh/ruff/rules/#flake8-future-annotations-fa
+ "UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up
+ "RUF100", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
+]
+ignore = ["UP031"] # https://docs.astral.sh/ruff/rules/printf-string-formatting/
[tool.ruff.lint.isort]
combine-as-imports = true
From f55aa57b43af124bd93448b61c0040f798e50fdf Mon Sep 17 00:00:00 2001
From: Marcelo Trylesinski
Date: Sun, 1 Sep 2024 15:00:49 +0200
Subject: [PATCH 3/4] Remove empty strings
---
starlette/concurrency.py | 2 +-
starlette/config.py | 4 ++--
starlette/routing.py | 2 +-
starlette/testclient.py | 2 +-
tests/test_applications.py | 10 +++++-----
tests/test_formparsers.py | 4 ++--
tests/test_schemas.py | 2 +-
tests/test_status.py | 4 ++--
tests/test_testclient.py | 2 +-
9 files changed, 16 insertions(+), 16 deletions(-)
diff --git a/starlette/concurrency.py b/starlette/concurrency.py
index 22979404a..ce3f5c82b 100644
--- a/starlette/concurrency.py
+++ b/starlette/concurrency.py
@@ -18,7 +18,7 @@
async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg]
warnings.warn(
- "run_until_first_complete is deprecated " "and will be removed in a future version.",
+ "run_until_first_complete is deprecated and will be removed in a future version.",
DeprecationWarning,
)
diff --git a/starlette/config.py b/starlette/config.py
index 7b46e16fb..ca15c5646 100644
--- a/starlette/config.py
+++ b/starlette/config.py
@@ -25,12 +25,12 @@ def __getitem__(self, key: str) -> str:
def __setitem__(self, key: str, value: str) -> None:
if key in self._has_been_read:
- raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been " "read.")
+ raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.")
self._environ.__setitem__(key, value)
def __delitem__(self, key: str) -> None:
if key in self._has_been_read:
- raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already " "been read.")
+ raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.")
self._environ.__delitem__(key)
def __iter__(self) -> typing.Iterator[str]:
diff --git a/starlette/routing.py b/starlette/routing.py
index 300711626..cde771563 100644
--- a/starlette/routing.py
+++ b/starlette/routing.py
@@ -47,7 +47,7 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
including those wrapped in functools.partial objects.
"""
warnings.warn(
- "iscoroutinefunction_or_partial is deprecated, " "and will be removed in a future release.",
+ "iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.",
DeprecationWarning,
)
while isinstance(obj, functools.partial):
diff --git a/starlette/testclient.py b/starlette/testclient.py
index cc6c6e92c..fcf392e33 100644
--- a/starlette/testclient.py
+++ b/starlette/testclient.py
@@ -454,7 +454,7 @@ def _choose_redirect_arg(
) -> bool | httpx._client.UseClientDefault:
redirect: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT
if allow_redirects is not None:
- message = "The `allow_redirects` argument is deprecated. " "Use `follow_redirects` instead."
+ message = "The `allow_redirects` argument is deprecated. Use `follow_redirects` instead."
warnings.warn(message, DeprecationWarning)
redirect = allow_redirects
if follow_redirects is not None:
diff --git a/tests/test_applications.py b/tests/test_applications.py
index e86eba322..86c713c38 100644
--- a/tests/test_applications.py
+++ b/tests/test_applications.py
@@ -441,13 +441,13 @@ def test_decorator_deprecations() -> None:
app = Starlette()
with pytest.deprecated_call(
- match=("The `exception_handler` decorator is deprecated, " "and will be removed in version 1.0.0.")
+ match=("The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0.")
) as record:
app.exception_handler(500)(http_exception)
assert len(record) == 1
with pytest.deprecated_call(
- match=("The `middleware` decorator is deprecated, " "and will be removed in version 1.0.0.")
+ match=("The `middleware` decorator is deprecated, and will be removed in version 1.0.0.")
) as record:
async def middleware(request: Request, call_next: RequestResponseEndpoint) -> None: ... # pragma: no cover
@@ -456,19 +456,19 @@ async def middleware(request: Request, call_next: RequestResponseEndpoint) -> No
assert len(record) == 1
with pytest.deprecated_call(
- match=("The `route` decorator is deprecated, " "and will be removed in version 1.0.0.")
+ match=("The `route` decorator is deprecated, and will be removed in version 1.0.0.")
) as record:
app.route("/")(async_homepage)
assert len(record) == 1
with pytest.deprecated_call(
- match=("The `websocket_route` decorator is deprecated, " "and will be removed in version 1.0.0.")
+ match=("The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0.")
) as record:
app.websocket_route("/ws")(websocket_endpoint)
assert len(record) == 1
with pytest.deprecated_call(
- match=("The `on_event` decorator is deprecated, " "and will be removed in version 1.0.0.")
+ match=("The `on_event` decorator is deprecated, and will be removed in version 1.0.0.")
) as record:
async def startup() -> None: ... # pragma: no cover
diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py
index a5ebdd043..61c1bede1 100644
--- a/tests/test_formparsers.py
+++ b/tests/test_formparsers.py
@@ -315,7 +315,7 @@ def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_f
b"\r\n"
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
),
- headers={"Content-Type": ("multipart/form-data; charset=utf-8; " "boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
+ headers={"Content-Type": ("multipart/form-data; charset=utf-8; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert response.json() == {
"file": {
@@ -362,7 +362,7 @@ def test_multipart_request_with_encoded_value(tmpdir: Path, test_client_factory:
b"Transf\xc3\xa9rer\r\n"
b"--20b303e711c4ab8c443184ac833ab00f--\r\n"
),
- headers={"Content-Type": ("multipart/form-data; charset=utf-8; " "boundary=20b303e711c4ab8c443184ac833ab00f")},
+ headers={"Content-Type": ("multipart/form-data; charset=utf-8; boundary=20b303e711c4ab8c443184ac833ab00f")},
)
assert response.json() == {"value": "Transférer"}
diff --git a/tests/test_schemas.py b/tests/test_schemas.py
index 0ed4d5801..3b321ca0b 100644
--- a/tests/test_schemas.py
+++ b/tests/test_schemas.py
@@ -140,7 +140,7 @@ def test_schema_generation() -> None:
"get": {
"responses": {
200: {
- "description": "A list of " "organisations.",
+ "description": "A list of organisations.",
"examples": [{"name": "Foo Corp."}, {"name": "Acme Ltd."}],
}
}
diff --git a/tests/test_status.py b/tests/test_status.py
index 1371bc1a7..4852c06ef 100644
--- a/tests/test_status.py
+++ b/tests/test_status.py
@@ -8,11 +8,11 @@
(
(
"WS_1004_NO_STATUS_RCVD",
- "'WS_1004_NO_STATUS_RCVD' is deprecated. " "Use 'WS_1005_NO_STATUS_RCVD' instead.",
+ "'WS_1004_NO_STATUS_RCVD' is deprecated. Use 'WS_1005_NO_STATUS_RCVD' instead.",
),
(
"WS_1005_ABNORMAL_CLOSURE",
- "'WS_1005_ABNORMAL_CLOSURE' is deprecated. " "Use 'WS_1006_ABNORMAL_CLOSURE' instead.",
+ "'WS_1005_ABNORMAL_CLOSURE' is deprecated. Use 'WS_1006_ABNORMAL_CLOSURE' instead.",
),
),
)
diff --git a/tests/test_testclient.py b/tests/test_testclient.py
index 67d060723..92f16d336 100644
--- a/tests/test_testclient.py
+++ b/tests/test_testclient.py
@@ -302,7 +302,7 @@ def homepage(request: Request) -> Response:
marks=[
pytest.mark.xfail(
sys.version_info < (3, 11),
- reason="Fails due to domain handling in http.cookiejar module (see " "#2152)",
+ reason="Fails due to domain handling in http.cookiejar module (see #2152)",
),
],
),
From d59c797a185c1befd42b4148b4ea15faffc46587 Mon Sep 17 00:00:00 2001
From: Marcelo Trylesinski
Date: Sun, 1 Sep 2024 15:08:44 +0200
Subject: [PATCH 4/4] Fix more stuff
---
starlette/formparsers.py | 2 +-
starlette/middleware/errors.py | 2 +-
starlette/websockets.py | 17 ++++++-----------
tests/test_routing.py | 5 +----
tests/test_websockets.py | 26 +++++---------------------
5 files changed, 14 insertions(+), 38 deletions(-)
diff --git a/starlette/formparsers.py b/starlette/formparsers.py
index 9be98626e..56f63a8be 100644
--- a/starlette/formparsers.py
+++ b/starlette/formparsers.py
@@ -184,7 +184,7 @@ def on_headers_finished(self) -> None:
try:
self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
except KeyError:
- raise MultiPartException('The Content-Disposition header field "name" must be ' "provided.")
+ raise MultiPartException('The Content-Disposition header field "name" must be provided.')
if b"filename" in options:
self._current_files += 1
if self._current_files > self.max_files:
diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py
index d8cb052ed..76ad776be 100644
--- a/starlette/middleware/errors.py
+++ b/starlette/middleware/errors.py
@@ -112,7 +112,7 @@
{code_context}
-"""
+""" # noqa: E501
LINE = """
diff --git a/starlette/websockets.py b/starlette/websockets.py
index dc0457858..b7acaa3f0 100644
--- a/starlette/websockets.py
+++ b/starlette/websockets.py
@@ -39,7 +39,7 @@ async def receive(self) -> Message:
message = await self._receive()
message_type = message["type"]
if message_type != "websocket.connect":
- raise RuntimeError('Expected ASGI message "websocket.connect", ' f"but got {message_type!r}")
+ raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
@@ -47,7 +47,7 @@ async def receive(self) -> Message:
message_type = message["type"]
if message_type not in {"websocket.receive", "websocket.disconnect"}:
raise RuntimeError(
- 'Expected ASGI message "websocket.receive" or ' f'"websocket.disconnect", but got {message_type!r}'
+ f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
)
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
@@ -61,14 +61,9 @@ async def send(self, message: Message) -> None:
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
- if message_type not in {
- "websocket.accept",
- "websocket.close",
- "websocket.http.response.start",
- }:
+ if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
raise RuntimeError(
- 'Expected ASGI message "websocket.accept",'
- '"websocket.close" or "websocket.http.response.start",'
+ 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
f"but got {message_type!r}"
)
if message_type == "websocket.close":
@@ -82,7 +77,7 @@ async def send(self, message: Message) -> None:
message_type = message["type"]
if message_type not in {"websocket.send", "websocket.close"}:
raise RuntimeError(
- 'Expected ASGI message "websocket.send" or "websocket.close", ' f"but got {message_type!r}"
+ f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
@@ -94,7 +89,7 @@ async def send(self, message: Message) -> None:
elif self.application_state == WebSocketState.RESPONSE:
message_type = message["type"]
if message_type != "websocket.http.response.body":
- raise RuntimeError('Expected ASGI message "websocket.http.response.body", ' f"but got {message_type!r}")
+ raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
if not message.get("more_body", False):
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
diff --git a/tests/test_routing.py b/tests/test_routing.py
index 132baa602..9fa44def4 100644
--- a/tests/test_routing.py
+++ b/tests/test_routing.py
@@ -656,10 +656,7 @@ def run_shutdown() -> None: # pragma: no cover
with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
with pytest.warns(
- UserWarning,
- match=(
- "The `lifespan` parameter cannot be used with `on_startup` or `on_shutdown`."
- ),
+ UserWarning, match="The `lifespan` parameter cannot be used with `on_startup` or `on_shutdown`."
):
app = Router(on_startup=[run_startup], on_shutdown=[run_shutdown], lifespan=lifespan)
diff --git a/tests/test_websockets.py b/tests/test_websockets.py
index 385e510b9..7a9b9272a 100644
--- a/tests/test_websockets.py
+++ b/tests/test_websockets.py
@@ -270,8 +270,7 @@ async def receive() -> Message:
async def send(message: Message) -> None:
if message["type"] == "websocket.accept":
return
- # Simulate the exception the server would send to the application when the
- # client disconnects.
+ # Simulate the exception the server would send to the application when the client disconnects.
raise OSError
with pytest.raises(WebSocketDisconnect) as ctx:
@@ -334,19 +333,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
"headers": [(b"content-type", b"text/plain"), (b"foo", b"bar")],
}
)
- await websocket.send(
- {
- "type": "websocket.http.response.body",
- "body": b"hard",
- "more_body": True,
- }
- )
- await websocket.send(
- {
- "type": "websocket.http.response.body",
- "body": b"body",
- }
- )
+ await websocket.send({"type": "websocket.http.response.body", "body": b"hard", "more_body": True})
+ await websocket.send({"type": "websocket.http.response.body", "body": b"body"})
client = test_client_factory(app)
with pytest.raises(WebSocketDenialResponse) as exc:
@@ -402,7 +390,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
client = test_client_factory(app)
with pytest.raises(
RuntimeError,
- match=('Expected ASGI message "websocket.http.response.body", but got ' "'websocket.http.response.start'"),
+ match=("Expected ASGI message \"websocket.http.response.body\", but got 'websocket.http.response.start'"),
):
with client.websocket_connect("/"):
pass # pragma: no cover
@@ -490,11 +478,7 @@ async def mock_receive() -> Message: # type: ignore
async def mock_send(message: Message) -> None: ... # pragma: no cover
- websocket = WebSocket(
- {"type": "websocket", "path": "/abc/", "headers": []},
- receive=mock_receive,
- send=mock_send,
- )
+ websocket = WebSocket({"type": "websocket", "path": "/abc/", "headers": []}, receive=mock_receive, send=mock_send)
assert websocket["type"] == "websocket"
assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []}
assert len(websocket) == 3