diff --git a/pyproject.toml b/pyproject.toml index f2721c870..52156528e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,9 +50,19 @@ Source = "https://github.com/encode/starlette" [tool.hatch.version] path = "starlette/__init__.py" +[tool.ruff] +line-length = 120 + [tool.ruff.lint] -select = ["E", "F", "I", "FA", "UP"] -ignore = ["UP031"] +select = [ + "E", # https://docs.astral.sh/ruff/rules/#error-e + "F", # https://docs.astral.sh/ruff/rules/#pyflakes-f + "I", # https://docs.astral.sh/ruff/rules/#isort-i + "FA", # https://docs.astral.sh/ruff/rules/#flake8-future-annotations-fa + "UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up + "RUF100", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf +] +ignore = ["UP031"] # https://docs.astral.sh/ruff/rules/printf-string-formatting/ [tool.ruff.lint.isort] combine-as-imports = true diff --git a/starlette/_compat.py b/starlette/_compat.py index 9087a7645..718bc9020 100644 --- a/starlette/_compat.py +++ b/starlette/_compat.py @@ -15,9 +15,7 @@ # that reject usedforsecurity=True hashlib.md5(b"data", usedforsecurity=False) # type: ignore[call-arg] - def md5_hexdigest( - data: bytes, *, usedforsecurity: bool = True - ) -> str: # pragma: no cover + def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str: # pragma: no cover return hashlib.md5( # type: ignore[call-arg] data, usedforsecurity=usedforsecurity ).hexdigest() diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index 99cb6b64c..4fbc86394 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -22,9 +22,7 @@ StatusHandlers = typing.Dict[int, ExceptionHandler] -def _lookup_exception_handler( - exc_handlers: ExceptionHandlers, exc: Exception -) -> ExceptionHandler | None: +def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None: for cls in type(exc).__mro__: if cls in exc_handlers: return exc_handlers[cls] diff --git a/starlette/_utils.py b/starlette/_utils.py index b6970542b..90bd346fd 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -37,26 +37,20 @@ def is_async_callable(obj: typing.Any) -> typing.Any: while isinstance(obj, functools.partial): obj = obj.func - return asyncio.iscoroutinefunction(obj) or ( - callable(obj) and asyncio.iscoroutinefunction(obj.__call__) - ) + return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__)) T_co = typing.TypeVar("T_co", covariant=True) -class AwaitableOrContextManager( - typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co] -): ... +class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ... class SupportsAsyncClose(typing.Protocol): async def close(self) -> None: ... # pragma: no cover -SupportsAsyncCloseType = typing.TypeVar( - "SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False -) +SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False) class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]): diff --git a/starlette/applications.py b/starlette/applications.py index 913fd4c9d..f34e80ead 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -72,21 +72,15 @@ def __init__( self.debug = debug self.state = State() - self.router = Router( - routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan - ) - self.exception_handlers = ( - {} if exception_handlers is None else dict(exception_handlers) - ) + self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan) + self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers) self.user_middleware = [] if middleware is None else list(middleware) self.middleware_stack: ASGIApp | None = None def build_middleware_stack(self) -> ASGIApp: debug = self.debug error_handler = None - exception_handlers: dict[ - typing.Any, typing.Callable[[Request, Exception], Response] - ] = {} + exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {} for key, value in self.exception_handlers.items(): if key in (500, Exception): @@ -97,11 +91,7 @@ def build_middleware_stack(self) -> ASGIApp: middleware = ( [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] + self.user_middleware - + [ - Middleware( - ExceptionMiddleware, handlers=exception_handlers, debug=debug - ) - ] + + [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)] ) app = self.router @@ -163,9 +153,7 @@ def add_route( name: str | None = None, include_in_schema: bool = True, ) -> None: # pragma: no cover - self.router.add_route( - path, route, methods=methods, name=name, include_in_schema=include_in_schema - ) + self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema) def add_websocket_route( self, @@ -175,16 +163,14 @@ def add_websocket_route( ) -> None: # pragma: no cover self.router.add_websocket_route(path, route, name=name) - def exception_handler( - self, exc_class_or_status_code: int | type[Exception] - ) -> typing.Callable: # type: ignore[type-arg] + def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable: # type: ignore[type-arg] warnings.warn( - "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501 - "Refer to https://www.starlette.io/exceptions/ for the recommended approach.", # noqa: E501 + "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/exceptions/ for the recommended approach.", DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] self.add_exception_handler(exc_class_or_status_code, func) return func @@ -205,12 +191,12 @@ def route( >>> app = Starlette(routes=routes) """ warnings.warn( - "The `route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501 - "Refer to https://www.starlette.io/routing/ for the recommended approach.", # noqa: E501 + "The `route` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/routing/ for the recommended approach.", DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] self.router.add_route( path, func, @@ -231,18 +217,18 @@ def websocket_route(self, path: str, name: str | None = None) -> typing.Callable >>> app = Starlette(routes=routes) """ warnings.warn( - "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501 - "Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501 + "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.", DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] self.router.add_websocket_route(path, func, name=name) return func return decorator - def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 + def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg] """ We no longer document this decorator style API, and its usage is discouraged. Instead you should use the following approach: @@ -251,15 +237,13 @@ def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[t >>> app = Starlette(middleware=middleware) """ warnings.warn( - "The `middleware` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501 - "Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.", # noqa: E501 + "The `middleware` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.", DeprecationWarning, ) - assert ( - middleware_type == "http" - ), 'Currently only middleware("http") is supported.' + assert middleware_type == "http", 'Currently only middleware("http") is supported.' - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] self.add_middleware(BaseHTTPMiddleware, dispatch=func) return func diff --git a/starlette/authentication.py b/starlette/authentication.py index f2586a042..4fd866412 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -31,9 +31,7 @@ def requires( scopes: str | typing.Sequence[str], status_code: int = 403, redirect: str | None = None, -) -> typing.Callable[ - [typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any] -]: +) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]: scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) def decorator( @@ -45,17 +43,13 @@ def decorator( type_ = parameter.name break else: - raise Exception( - f'No "request" or "websocket" argument on function "{func}"' - ) + raise Exception(f'No "request" or "websocket" argument on function "{func}"') if type_ == "websocket": # Handle websocket functions. (Always async) @functools.wraps(func) async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: - websocket = kwargs.get( - "websocket", args[idx] if idx < len(args) else None - ) + websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None) assert isinstance(websocket, WebSocket) if not has_required_scope(websocket, scopes_list): @@ -107,9 +101,7 @@ class AuthenticationError(Exception): class AuthenticationBackend: - async def authenticate( - self, conn: HTTPConnection - ) -> tuple[AuthCredentials, BaseUser] | None: + async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None: raise NotImplementedError() # pragma: no cover diff --git a/starlette/background.py b/starlette/background.py index 1cbed3b22..0430fc08b 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -15,9 +15,7 @@ class BackgroundTask: - def __init__( - self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs - ) -> None: + def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None: self.func = func self.args = args self.kwargs = kwargs @@ -34,9 +32,7 @@ class BackgroundTasks(BackgroundTask): def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None): self.tasks = list(tasks) if tasks else [] - def add_task( - self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs - ) -> None: + def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None: task = BackgroundTask(func, *args, **kwargs) self.tasks.append(task) diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 215e3a63b..ce3f5c82b 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -16,16 +16,15 @@ T = typing.TypeVar("T") -async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] # noqa: E501 +async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] warnings.warn( - "run_until_first_complete is deprecated " - "and will be removed in a future version.", + "run_until_first_complete is deprecated and will be removed in a future version.", DeprecationWarning, ) async with anyio.create_task_group() as task_group: - async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg] # noqa: E501 + async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg] await func() task_group.cancel_scope.cancel() @@ -33,9 +32,7 @@ async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ign task_group.start_soon(run, functools.partial(func, **kwargs)) -async def run_in_threadpool( - func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs -) -> T: +async def run_in_threadpool(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: if kwargs: # pragma: no cover # run_sync doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) diff --git a/starlette/config.py b/starlette/config.py index 4c3dfe5b0..ca15c5646 100644 --- a/starlette/config.py +++ b/starlette/config.py @@ -25,18 +25,12 @@ def __getitem__(self, key: str) -> str: def __setitem__(self, key: str, value: str) -> None: if key in self._has_been_read: - raise EnvironError( - f"Attempting to set environ['{key}'], but the value has already been " - "read." - ) + raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.") self._environ.__setitem__(key, value) def __delitem__(self, key: str) -> None: if key in self._has_been_read: - raise EnvironError( - f"Attempting to delete environ['{key}'], but the value has already " - "been read." - ) + raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.") self._environ.__delitem__(key) def __iter__(self) -> typing.Iterator[str]: @@ -85,9 +79,7 @@ def __call__( ) -> T: ... @typing.overload - def __call__( - self, key: str, cast: type[str] = ..., default: T = ... - ) -> T | str: ... + def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ... def __call__( self, @@ -138,13 +130,9 @@ def _perform_cast( mapping = {"true": True, "1": True, "false": False, "0": False} value = value.lower() if value not in mapping: - raise ValueError( - f"Config '{key}' has value '{value}'. Not a valid bool." - ) + raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.") return mapping[value] try: return cast(value) except (TypeError, ValueError): - raise ValueError( - f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}." - ) + raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.") diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 54b5e54f3..90a7296a0 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -108,12 +108,7 @@ def is_secure(self) -> bool: return self.scheme in ("https", "wss") def replace(self, **kwargs: typing.Any) -> URL: - if ( - "username" in kwargs - or "password" in kwargs - or "hostname" in kwargs - or "port" in kwargs - ): + if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs: hostname = kwargs.pop("hostname", None) port = kwargs.pop("port", self.port) username = kwargs.pop("username", self.username) @@ -264,17 +259,12 @@ def __init__( value: typing.Any = args[0] if args else [] if kwargs: - value = ( - ImmutableMultiDict(value).multi_items() - + ImmutableMultiDict(kwargs).multi_items() - ) + value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items() if not value: _items: list[tuple[typing.Any, typing.Any]] = [] elif hasattr(value, "multi_items"): - value = typing.cast( - ImmutableMultiDict[_KeyType, _CovariantValueType], value - ) + value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value) _items = list(value.multi_items()) elif hasattr(value, "items"): value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value) @@ -371,9 +361,7 @@ def append(self, key: typing.Any, value: typing.Any) -> None: def update( self, - *args: MultiDict - | typing.Mapping[typing.Any, typing.Any] - | list[tuple[typing.Any, typing.Any]], + *args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]], **kwargs: typing.Any, ) -> None: value = MultiDict(*args, **kwargs) @@ -403,9 +391,7 @@ def __init__( if isinstance(value, str): super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) elif isinstance(value, bytes): - super().__init__( - parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs - ) + super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs) else: super().__init__(*args, **kwargs) # type: ignore[arg-type] self._list = [(str(k), str(v)) for k, v in self._list] @@ -490,9 +476,7 @@ class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]): def __init__( self, - *args: FormData - | typing.Mapping[str, str | UploadFile] - | list[tuple[str, str | UploadFile]], + *args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]], **kwargs: str | UploadFile, ) -> None: super().__init__(*args, **kwargs) @@ -518,10 +502,7 @@ def __init__( if headers is not None: assert raw is None, 'Cannot set both "headers" and "raw".' assert scope is None, 'Cannot set both "headers" and "scope".' - self._list = [ - (key.lower().encode("latin-1"), value.encode("latin-1")) - for key, value in headers.items() - ] + self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()] elif raw is not None: assert scope is None, 'Cannot set both "raw" and "scope".' self._list = raw @@ -541,18 +522,11 @@ def values(self) -> list[str]: # type: ignore[override] return [value.decode("latin-1") for key, value in self._list] def items(self) -> list[tuple[str, str]]: # type: ignore[override] - return [ - (key.decode("latin-1"), value.decode("latin-1")) - for key, value in self._list - ] + return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list] def getlist(self, key: str) -> list[str]: get_header_key = key.lower().encode("latin-1") - return [ - item_value.decode("latin-1") - for item_key, item_value in self._list - if item_key == get_header_key - ] + return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key] def mutablecopy(self) -> MutableHeaders: return MutableHeaders(raw=self._list[:]) diff --git a/starlette/endpoints.py b/starlette/endpoints.py index 57f718824..eb1dace42 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -30,15 +30,9 @@ def __await__(self) -> typing.Generator[typing.Any, None, None]: async def dispatch(self) -> None: request = Request(self.scope, receive=self.receive) - handler_name = ( - "get" - if request.method == "HEAD" and not hasattr(self, "head") - else request.method.lower() - ) - - handler: typing.Callable[[Request], typing.Any] = getattr( - self, handler_name, self.method_not_allowed - ) + handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower() + + handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed) is_async = is_async_callable(handler) if is_async: response = await handler(request) @@ -81,9 +75,7 @@ async def dispatch(self) -> None: data = await self.decode(websocket, message) await self.on_receive(websocket, data) elif message["type"] == "websocket.disconnect": - close_code = int( - message.get("code") or status.WS_1000_NORMAL_CLOSURE - ) + close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE) break except Exception as exc: close_code = status.WS_1011_INTERNAL_ERROR @@ -116,9 +108,7 @@ async def decode(self, websocket: WebSocket, message: Message) -> typing.Any: await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) raise RuntimeError("Malformed JSON data received.") - assert ( - self.encoding is None - ), f"Unsupported 'encoding' attribute {self.encoding}" + assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}" return message["text"] if message.get("text") else message["bytes"] async def on_connect(self, websocket: WebSocket) -> None: diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 2e12c7faa..56f63a8be 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -46,12 +46,8 @@ def __init__(self, message: str) -> None: class FormParser: - def __init__( - self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] - ) -> None: - assert ( - multipart is not None - ), "The `python-multipart` library must be installed to use form parsing." + def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None: + assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream self.messages: list[tuple[FormMessage, bytes]] = [] @@ -128,9 +124,7 @@ def __init__( max_files: int | float = 1000, max_fields: int | float = 1000, ) -> None: - assert ( - multipart is not None - ), "The `python-multipart` library must be installed to use form parsing." + assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream self.max_files = max_files @@ -181,30 +175,20 @@ def on_header_end(self) -> None: field = self._current_partial_header_name.lower() if field == b"content-disposition": self._current_part.content_disposition = self._current_partial_header_value - self._current_part.item_headers.append( - (field, self._current_partial_header_value) - ) + self._current_part.item_headers.append((field, self._current_partial_header_value)) self._current_partial_header_name = b"" self._current_partial_header_value = b"" def on_headers_finished(self) -> None: - disposition, options = parse_options_header( - self._current_part.content_disposition - ) + disposition, options = parse_options_header(self._current_part.content_disposition) try: - self._current_part.field_name = _user_safe_decode( - options[b"name"], self._charset - ) + self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset) except KeyError: - raise MultiPartException( - 'The Content-Disposition header field "name" must be ' "provided." - ) + raise MultiPartException('The Content-Disposition header field "name" must be provided.') if b"filename" in options: self._current_files += 1 if self._current_files > self.max_files: - raise MultiPartException( - f"Too many files. Maximum number of files is {self.max_files}." - ) + raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.") filename = _user_safe_decode(options[b"filename"], self._charset) tempfile = SpooledTemporaryFile(max_size=self.max_file_size) self._files_to_close_on_error.append(tempfile) @@ -217,9 +201,7 @@ def on_headers_finished(self) -> None: else: self._current_fields += 1 if self._current_fields > self.max_fields: - raise MultiPartException( - f"Too many fields. Maximum number of fields is {self.max_fields}." - ) + raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.") self._current_part.file = None def on_end(self) -> None: diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index d9e64f574..8566aac08 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -14,13 +14,9 @@ class _MiddlewareClass(Protocol[P]): - def __init__( - self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs - ) -> None: ... # pragma: no cover + def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None: ... # pragma: no cover - async def __call__( - self, scope: Scope, receive: Receive, send: Send - ) -> None: ... # pragma: no cover + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ... # pragma: no cover class Middleware: diff --git a/starlette/middleware/authentication.py b/starlette/middleware/authentication.py index 966c639bb..8555ee078 100644 --- a/starlette/middleware/authentication.py +++ b/starlette/middleware/authentication.py @@ -18,14 +18,13 @@ def __init__( self, app: ASGIApp, backend: AuthenticationBackend, - on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] - | None = None, + on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None, ) -> None: self.app = app self.backend = backend - self.on_error: typing.Callable[ - [HTTPConnection, AuthenticationError], Response - ] = on_error if on_error is not None else self.default_on_error + self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = ( + on_error if on_error is not None else self.default_on_error + ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ["http", "websocket"]: diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 87c0f51f8..2ac6f7f7f 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -11,9 +11,7 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] -DispatchFunction = typing.Callable[ - [Request, RequestResponseEndpoint], typing.Awaitable[Response] -] +DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]] T = typing.TypeVar("T") @@ -180,9 +178,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: if app_exc is not None: raise app_exc - response = _StreamingResponse( - status_code=message["status"], content=body_stream(), info=info - ) + response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info) response.raw_headers = message["headers"] return response @@ -192,9 +188,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: await response(scope, wrapped_receive, send) response_sent.set() - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: raise NotImplementedError() # pragma: no cover diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 4b8e97bc9..61502691a 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -96,9 +96,7 @@ def is_allowed_origin(self, origin: str) -> bool: if self.allow_all_origins: return True - if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch( - origin - ): + if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin): return True return origin in self.allow_origins @@ -141,15 +139,11 @@ def preflight_response(self, request_headers: Headers) -> Response: return PlainTextResponse("OK", status_code=200, headers=headers) - async def simple_response( - self, scope: Scope, receive: Receive, send: Send, request_headers: Headers - ) -> None: + async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None: send = functools.partial(self.send, send=send, request_headers=request_headers) await self.app(scope, receive, send) - async def send( - self, message: Message, send: Send, request_headers: Headers - ) -> None: + async def send(self, message: Message, send: Send, request_headers: Headers) -> None: if message["type"] != "http.response.start": await send(message) return diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index 3fc4a4402..76ad776be 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -186,9 +186,7 @@ async def _send(message: Message) -> None: # to optionally raise the error within the test case. raise exc - def format_line( - self, index: int, line: str, frame_lineno: int, frame_index: int - ) -> str: + def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str: values = { # HTML escape - line could contain < or > "line": html.escape(line).replace(" ", " "), @@ -225,9 +223,7 @@ def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> s return FRAME_TEMPLATE.format(**values) def generate_html(self, exc: Exception, limit: int = 7) -> str: - traceback_obj = traceback.TracebackException.from_exception( - exc, capture_locals=True - ) + traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True) exc_html = "" is_collapsed = False diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index b2bf88dbf..d708929e3 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -18,10 +18,7 @@ class ExceptionMiddleware: def __init__( self, app: ASGIApp, - handlers: typing.Mapping[ - typing.Any, typing.Callable[[Request, Exception], Response] - ] - | None = None, + handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None, debug: bool = False, ) -> None: self.app = app @@ -68,9 +65,7 @@ def http_exception(self, request: Request, exc: Exception) -> Response: assert isinstance(exc, HTTPException) if exc.status_code in {204, 304}: return Response(status_code=exc.status_code, headers=exc.headers) - return PlainTextResponse( - exc.detail, status_code=exc.status_code, headers=exc.headers - ) + return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers) async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None: assert isinstance(exc, WebSocketException) diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index 0579e0410..127b91e7a 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -7,9 +7,7 @@ class GZipMiddleware: - def __init__( - self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9 - ) -> None: + def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None: self.app = app self.minimum_size = minimum_size self.compresslevel = compresslevel @@ -18,9 +16,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "http": headers = Headers(scope=scope) if "gzip" in headers.get("Accept-Encoding", ""): - responder = GZipResponder( - self.app, self.minimum_size, compresslevel=self.compresslevel - ) + responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel) await responder(scope, receive, send) return await self.app(scope, receive, send) @@ -35,9 +31,7 @@ def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> N self.started = False self.content_encoding_set = False self.gzip_buffer = io.BytesIO() - self.gzip_file = gzip.GzipFile( - mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel - ) + self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: self.send = send diff --git a/starlette/middleware/sessions.py b/starlette/middleware/sessions.py index 5855912ca..5f9fcd883 100644 --- a/starlette/middleware/sessions.py +++ b/starlette/middleware/sessions.py @@ -61,7 +61,7 @@ async def send_wrapper(message: Message) -> None: data = b64encode(json.dumps(scope["session"]).encode("utf-8")) data = self.signer.sign(data) headers = MutableHeaders(scope=message) - header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501 + header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( session_cookie=self.session_cookie, data=data.decode("utf-8"), path=self.path, @@ -72,7 +72,7 @@ async def send_wrapper(message: Message) -> None: elif not initial_session_was_empty: # The session has been cleared. headers = MutableHeaders(scope=message) - header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501 + header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( session_cookie=self.session_cookie, data="null", path=self.path, diff --git a/starlette/middleware/trustedhost.py b/starlette/middleware/trustedhost.py index 59e527363..2d1c999e2 100644 --- a/starlette/middleware/trustedhost.py +++ b/starlette/middleware/trustedhost.py @@ -41,9 +41,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: is_valid_host = False found_www_redirect = False for pattern in self.allowed_hosts: - if host == pattern or ( - pattern.startswith("*") and host.endswith(pattern[1:]) - ): + if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])): is_valid_host = True break elif "www." + host == pattern: diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index c9a7e1328..71f4ab5de 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -89,9 +89,7 @@ def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None: self.scope = scope self.status = None self.response_headers = None - self.stream_send, self.stream_receive = anyio.create_memory_object_stream( - math.inf - ) + self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf) self.response_started = False self.exc_info: typing.Any = None @@ -151,6 +149,4 @@ def wsgi( {"type": "http.response.body", "body": chunk, "more_body": True}, ) - anyio.from_thread.run( - self.stream_send.send, {"type": "http.response.body", "body": b""} - ) + anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""}) diff --git a/starlette/requests.py b/starlette/requests.py index a2fdfd81e..23f8ac70a 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -104,9 +104,7 @@ def base_url(self) -> URL: # This is used by request.url_for, it might be used inside a Mount which # would have its own child scope with its own root_path, but the base URL # for url_for should still be the top level app root path. - app_root_path = base_url_scope.get( - "app_root_path", base_url_scope.get("root_path", "") - ) + app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", "")) path = app_root_path if not path.endswith("/"): path += "/" @@ -153,23 +151,17 @@ def client(self) -> Address | None: @property def session(self) -> dict[str, typing.Any]: - assert ( - "session" in self.scope - ), "SessionMiddleware must be installed to access request.session" + assert "session" in self.scope, "SessionMiddleware must be installed to access request.session" return self.scope["session"] # type: ignore[no-any-return] @property def auth(self) -> typing.Any: - assert ( - "auth" in self.scope - ), "AuthenticationMiddleware must be installed to access request.auth" + assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth" return self.scope["auth"] @property def user(self) -> typing.Any: - assert ( - "user" in self.scope - ), "AuthenticationMiddleware must be installed to access request.user" + assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user" return self.scope["user"] @property @@ -199,9 +191,7 @@ async def empty_send(message: Message) -> typing.NoReturn: class Request(HTTPConnection): _form: FormData | None - def __init__( - self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send - ): + def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send): super().__init__(scope) assert scope["type"] == "http" self._receive = receive @@ -252,9 +242,7 @@ async def json(self) -> typing.Any: self._json = json.loads(body) return self._json - async def _get_form( - self, *, max_files: int | float = 1000, max_fields: int | float = 1000 - ) -> FormData: + async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | float = 1000) -> FormData: if self._form is None: assert ( parse_options_header is not None @@ -285,9 +273,7 @@ async def _get_form( def form( self, *, max_files: int | float = 1000, max_fields: int | float = 1000 ) -> AwaitableOrContextManager[FormData]: - return AwaitableOrContextManagerWrapper( - self._get_form(max_files=max_files, max_fields=max_fields) - ) + return AwaitableOrContextManagerWrapper(self._get_form(max_files=max_files, max_fields=max_fields)) async def close(self) -> None: if self._form is not None: @@ -312,9 +298,5 @@ async def send_push_promise(self, path: str) -> None: raw_headers: list[tuple[bytes, bytes]] = [] for name in SERVER_PUSH_HEADERS_TO_COPY: for value in self.headers.getlist(name): - raw_headers.append( - (name.encode("latin-1"), value.encode("latin-1")) - ) - await self._send( - {"type": "http.response.push", "path": path, "headers": raw_headers} - ) + raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) + await self._send({"type": "http.response.push", "path": path, "headers": raw_headers}) diff --git a/starlette/responses.py b/starlette/responses.py index 4f15404ca..06d6ce5ca 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -54,10 +54,7 @@ def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None: populate_content_length = True populate_content_type = True else: - raw_headers = [ - (k.lower().encode("latin-1"), v.encode("latin-1")) - for k, v in headers.items() - ] + raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()] keys = [h[0] for h in raw_headers] populate_content_length = b"content-length" not in keys populate_content_type = b"content-type" not in keys @@ -73,10 +70,7 @@ def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None: content_type = self.media_type if content_type is not None and populate_content_type: - if ( - content_type.startswith("text/") - and "charset=" not in content_type.lower() - ): + if content_type.startswith("text/") and "charset=" not in content_type.lower(): content_type += "; charset=" + self.charset raw_headers.append((b"content-type", content_type.encode("latin-1"))) @@ -201,9 +195,7 @@ def __init__( headers: typing.Mapping[str, str] | None = None, background: BackgroundTask | None = None, ) -> None: - super().__init__( - content=b"", status_code=status_code, headers=headers, background=background - ) + super().__init__(content=b"", status_code=status_code, headers=headers, background=background) self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") @@ -299,11 +291,9 @@ def __init__( if self.filename is not None: content_disposition_filename = quote(self.filename) if content_disposition_filename != self.filename: - content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}" # noqa: E501 + content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}" else: - content_disposition = ( - f'{content_disposition_type}; filename="{self.filename}"' - ) + content_disposition = f'{content_disposition_type}; filename="{self.filename}"' self.headers.setdefault("content-disposition", content_disposition) self.stat_result = stat_result if stat_result is not None: diff --git a/starlette/routing.py b/starlette/routing.py index 481b13f5d..cde771563 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -47,8 +47,7 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover including those wrapped in functools.partial objects. """ warnings.warn( - "iscoroutinefunction_or_partial is deprecated, " - "and will be removed in a future release.", + "iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.", DeprecationWarning, ) while isinstance(obj, functools.partial): @@ -143,9 +142,7 @@ def compile_path( for match in PARAM_REGEX.finditer(path): param_name, convertor_type = match.groups("str") convertor_type = convertor_type.lstrip(":") - assert ( - convertor_type in CONVERTOR_TYPES - ), f"Unknown path convertor '{convertor_type}'" + assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'" convertor = CONVERTOR_TYPES[convertor_type] path_regex += re.escape(path[idx : match.start()]) @@ -275,9 +272,7 @@ def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: if name != self.name or seen_params != expected_params: raise NoMatchFound(name, path_params) - path, remaining_params = replace_params( - self.path_format, self.param_convertors, path_params - ) + path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) assert not remaining_params return URLPath(path=path, protocol="http") @@ -287,9 +282,7 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: if "app" in scope: raise HTTPException(status_code=405, headers=headers) else: - response = PlainTextResponse( - "Method Not Allowed", status_code=405, headers=headers - ) + response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers) await response(scope, receive, send) else: await self.app(scope, receive, send) @@ -361,9 +354,7 @@ def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: if name != self.name or seen_params != expected_params: raise NoMatchFound(name, path_params) - path, remaining_params = replace_params( - self.path_format, self.param_convertors, path_params - ) + path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) assert not remaining_params return URLPath(path=path, protocol="websocket") @@ -371,11 +362,7 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def __eq__(self, other: typing.Any) -> bool: - return ( - isinstance(other, WebSocketRoute) - and self.path == other.path - and self.endpoint == other.endpoint - ) + return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint def __repr__(self) -> str: return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})" @@ -392,9 +379,7 @@ def __init__( middleware: typing.Sequence[Middleware] | None = None, ) -> None: assert path == "" or path.startswith("/"), "Routed paths must start with '/'" - assert ( - app is not None or routes is not None - ), "Either 'app=...', or 'routes=' must be specified" + assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified" self.path = path.rstrip("/") if app is not None: self._base_app: ASGIApp = app @@ -405,9 +390,7 @@ def __init__( for cls, args, kwargs in reversed(middleware): self.app = cls(app=self.app, *args, **kwargs) self.name = name - self.path_regex, self.path_format, self.param_convertors = compile_path( - self.path + "/{path:path}" - ) + self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}") @property def routes(self) -> list[BaseRoute]: @@ -450,9 +433,7 @@ def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: if self.name is not None and name == self.name and "path" in path_params: # 'name' matches "". path_params["path"] = path_params["path"].lstrip("/") - path, remaining_params = replace_params( - self.path_format, self.param_convertors, path_params - ) + path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) if not remaining_params: return URLPath(path=path) elif self.name is None or name.startswith(self.name + ":"): @@ -464,17 +445,13 @@ def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: remaining_name = name[len(self.name) + 1 :] path_kwarg = path_params.get("path") path_params["path"] = "" - path_prefix, remaining_params = replace_params( - self.path_format, self.param_convertors, path_params - ) + path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) if path_kwarg is not None: remaining_params["path"] = path_kwarg for route in self.routes or []: try: url = route.url_path_for(remaining_name, **remaining_params) - return URLPath( - path=path_prefix.rstrip("/") + str(url), protocol=url.protocol - ) + return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol) except NoMatchFound: pass raise NoMatchFound(name, path_params) @@ -483,11 +460,7 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def __eq__(self, other: typing.Any) -> bool: - return ( - isinstance(other, Mount) - and self.path == other.path - and self.app == other.app - ) + return isinstance(other, Mount) and self.path == other.path and self.app == other.app def __repr__(self) -> str: class_name = self.__class__.__name__ @@ -526,9 +499,7 @@ def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: if self.name is not None and name == self.name and "path" in path_params: # 'name' matches "". path = path_params.pop("path") - host, remaining_params = replace_params( - self.host_format, self.param_convertors, path_params - ) + host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params) if not remaining_params: return URLPath(path=path, host=host) elif self.name is None or name.startswith(self.name + ":"): @@ -538,9 +509,7 @@ def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: else: # 'name' matches ":". remaining_name = name[len(self.name) + 1 :] - host, remaining_params = replace_params( - self.host_format, self.param_convertors, path_params - ) + host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params) for route in self.routes or []: try: url = route.url_path_for(remaining_name, **remaining_params) @@ -553,11 +522,7 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def __eq__(self, other: typing.Any) -> bool: - return ( - isinstance(other, Host) - and self.host == other.host - and self.app == other.app - ) + return isinstance(other, Host) and self.host == other.host and self.app == other.app def __repr__(self) -> str: class_name = self.__class__.__name__ @@ -585,9 +550,7 @@ async def __aexit__( def _wrap_gen_lifespan_context( - lifespan_context: typing.Callable[ - [typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any] - ], + lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]], ) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]: cmgr = contextlib.contextmanager(lifespan_context) @@ -730,9 +693,7 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: async with self.lifespan_context(app) as maybe_state: if maybe_state is not None: if "state" not in scope: - raise RuntimeError( - 'The server does not support "state" in the lifespan scope.' - ) + raise RuntimeError('The server does not support "state" in the lifespan scope.') scope["state"].update(maybe_state) await send({"type": "lifespan.startup.complete"}) started = True @@ -806,15 +767,11 @@ async def app(self, scope: Scope, receive: Receive, send: Send) -> None: def __eq__(self, other: typing.Any) -> bool: return isinstance(other, Router) and self.routes == other.routes - def mount( - self, path: str, app: ASGIApp, name: str | None = None - ) -> None: # pragma: nocover + def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: nocover route = Mount(path, app=app, name=name) self.routes.append(route) - def host( - self, host: str, app: ASGIApp, name: str | None = None - ) -> None: # pragma: no cover + def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover route = Host(host, app=app, name=name) self.routes.append(route) @@ -860,11 +817,11 @@ def route( """ warnings.warn( "The `route` decorator is deprecated, and will be removed in version 1.0.0." - "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", # noqa: E501 + "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] self.add_route( path, func, @@ -885,20 +842,18 @@ def websocket_route(self, path: str, name: str | None = None) -> typing.Callable >>> app = Starlette(routes=routes) """ warnings.warn( - "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " # noqa: E501 - "https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501 + "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " + "https://www.starlette.io/routing/#websocket-routing for the recommended approach.", DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] self.add_websocket_route(path, func, name=name) return func return decorator - def add_event_handler( - self, event_type: str, func: typing.Callable[[], typing.Any] - ) -> None: # pragma: no cover + def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None: # pragma: no cover assert event_type in ("startup", "shutdown") if event_type == "startup": @@ -908,12 +863,12 @@ def add_event_handler( def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg] warnings.warn( - "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501 + "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " "Refer to https://www.starlette.io/lifespan/ for recommended approach.", DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] self.add_event_handler(event_type, func) return func diff --git a/starlette/schemas.py b/starlette/schemas.py index 89fa20b89..688fd85be 100644 --- a/starlette/schemas.py +++ b/starlette/schemas.py @@ -19,9 +19,7 @@ class OpenAPIResponse(Response): def render(self, content: typing.Any) -> bytes: assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse." - assert isinstance( - content, dict - ), "The schema passed to OpenAPIResponse should be a dictionary." + assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary." return yaml.dump(content, default_flow_style=False).encode("utf-8") @@ -73,9 +71,7 @@ def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]: for method in route.methods or ["GET"]: if method == "HEAD": continue - endpoints_info.append( - EndpointInfo(path, method.lower(), route.endpoint) - ) + endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint)) else: path = self._remove_converter(route.path) for method in ["get", "post", "put", "patch", "delete", "options"]: @@ -95,9 +91,7 @@ def _remove_converter(self, path: str) -> str: """ return re.sub(r":\w+}", "}", path) - def parse_docstring( - self, func_or_method: typing.Callable[..., typing.Any] - ) -> dict[str, typing.Any]: + def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]: """ Given a function, parse the docstring as YAML and return a dictionary of info. """ diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index afb09b56b..7498c3011 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -32,11 +32,7 @@ class NotModifiedResponse(Response): def __init__(self, headers: Headers): super().__init__( status_code=304, - headers={ - name: value - for name, value in headers.items() - if name in self.NOT_MODIFIED_HEADERS - }, + headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS}, ) @@ -80,9 +76,7 @@ def get_directories( spec = importlib.util.find_spec(package) assert spec is not None, f"Package {package!r} could not be found." assert spec.origin is not None, f"Package {package!r} could not be found." - package_directory = os.path.normpath( - os.path.join(spec.origin, "..", statics_dir) - ) + package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir)) assert os.path.isdir( package_directory ), f"Directory '{statics_dir!r}' in package {package!r} could not be found." @@ -110,7 +104,7 @@ def get_path(self, scope: Scope) -> str: with OS specific path separators, and any '..', '.' components removed. """ route_path = get_route_path(scope) - return os.path.normpath(os.path.join(*route_path.split("/"))) # noqa: E501 + return os.path.normpath(os.path.join(*route_path.split("/"))) async def get_response(self, path: str, scope: Scope) -> Response: """ @@ -120,9 +114,7 @@ async def get_response(self, path: str, scope: Scope) -> Response: raise HTTPException(status_code=405) try: - full_path, stat_result = await anyio.to_thread.run_sync( - self.lookup_path, path - ) + full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path) except PermissionError: raise HTTPException(status_code=401) except OSError as exc: @@ -140,9 +132,7 @@ async def get_response(self, path: str, scope: Scope) -> Response: # We're in HTML mode, and have got a directory URL. # Check if we have 'index.html' file to serve. index_path = os.path.join(path, "index.html") - full_path, stat_result = await anyio.to_thread.run_sync( - self.lookup_path, index_path - ) + full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path) if stat_result is not None and stat.S_ISREG(stat_result.st_mode): if not scope["path"].endswith("/"): # Directory URLs should redirect to always end in "/". @@ -153,9 +143,7 @@ async def get_response(self, path: str, scope: Scope) -> Response: if self.html: # Check for '404.html' if we're in HTML mode. - full_path, stat_result = await anyio.to_thread.run_sync( - self.lookup_path, "404.html" - ) + full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html") if stat_result and stat.S_ISREG(stat_result.st_mode): return FileResponse(full_path, stat_result=stat_result, status_code=404) raise HTTPException(status_code=404) @@ -187,9 +175,7 @@ def file_response( ) -> Response: request_headers = Headers(scope=scope) - response = FileResponse( - full_path, status_code=status_code, stat_result=stat_result - ) + response = FileResponse(full_path, status_code=status_code, stat_result=stat_result) if self.is_not_modified(response.headers, request_headers): return NotModifiedResponse(response.headers) return response @@ -206,17 +192,11 @@ async def check_config(self) -> None: try: stat_result = await anyio.to_thread.run_sync(os.stat, self.directory) except FileNotFoundError: - raise RuntimeError( - f"StaticFiles directory '{self.directory}' does not exist." - ) + raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.") if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)): - raise RuntimeError( - f"StaticFiles path '{self.directory}' is not a directory." - ) + raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.") - def is_not_modified( - self, response_headers: Headers, request_headers: Headers - ) -> bool: + def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool: """ Given the request and response headers, return `True` if an HTTP "Not Modified" response could be returned instead. @@ -232,11 +212,7 @@ def is_not_modified( try: if_modified_since = parsedate(request_headers["if-modified-since"]) last_modified = parsedate(response_headers["last-modified"]) - if ( - if_modified_since is not None - and last_modified is not None - and if_modified_since >= last_modified - ): + if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified: return True except KeyError: pass diff --git a/starlette/templating.py b/starlette/templating.py index aae2cbe24..48e54c0cc 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -68,8 +68,7 @@ def __init__( self, directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]], *, - context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] - | None = None, + context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None, **env_options: typing.Any, ) -> None: ... @@ -78,31 +77,24 @@ def __init__( self, *, env: jinja2.Environment, - context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] - | None = None, + context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None, ) -> None: ... def __init__( self, - directory: str - | PathLike[str] - | typing.Sequence[str | PathLike[str]] - | None = None, + directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]] | None = None, *, - context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] - | None = None, + context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None, env: jinja2.Environment | None = None, **env_options: typing.Any, ) -> None: if env_options: warnings.warn( - "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.", # noqa: E501 + "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.", DeprecationWarning, ) assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" - assert bool(directory) ^ bool( - env - ), "either 'directory' or 'env' arguments must be passed" + assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed" self.context_processors = context_processors or [] if directory is not None: self.env = self._create_env(directory, **env_options) @@ -163,25 +155,19 @@ def TemplateResponse( # Deprecated usage ... - def TemplateResponse( - self, *args: typing.Any, **kwargs: typing.Any - ) -> _TemplateResponse: + def TemplateResponse(self, *args: typing.Any, **kwargs: typing.Any) -> _TemplateResponse: if args: - if isinstance( - args[0], str - ): # the first argument is template name (old style) + if isinstance(args[0], str): # the first argument is template name (old style) warnings.warn( "The `name` is not the first parameter anymore. " "The first parameter should be the `Request` instance.\n" - 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', # noqa: E501 + 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', DeprecationWarning, ) name = args[0] context = args[1] if len(args) > 1 else kwargs.get("context", {}) - status_code = ( - args[2] if len(args) > 2 else kwargs.get("status_code", 200) - ) + status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200) headers = args[2] if len(args) > 2 else kwargs.get("headers") media_type = args[3] if len(args) > 3 else kwargs.get("media_type") background = args[4] if len(args) > 4 else kwargs.get("background") @@ -193,9 +179,7 @@ def TemplateResponse( request = args[0] name = args[1] if len(args) > 1 else kwargs["name"] context = args[2] if len(args) > 2 else kwargs.get("context", {}) - status_code = ( - args[3] if len(args) > 3 else kwargs.get("status_code", 200) - ) + status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200) headers = args[4] if len(args) > 4 else kwargs.get("headers") media_type = args[5] if len(args) > 5 else kwargs.get("media_type") background = args[6] if len(args) > 6 else kwargs.get("background") @@ -203,7 +187,7 @@ def TemplateResponse( if "request" not in kwargs: warnings.warn( "The `TemplateResponse` now requires the `request` argument.\n" - 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', # noqa: E501 + 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', DeprecationWarning, ) if "request" not in kwargs.get("context", {}): diff --git a/starlette/testclient.py b/starlette/testclient.py index bf928d23f..fcf392e33 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -37,9 +37,7 @@ "You can install this with:\n" " $ pip install httpx\n" ) -_PortalFactoryType = typing.Callable[ - [], typing.ContextManager[anyio.abc.BlockingPortal] -] +_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]] ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]] ASGI2App = typing.Callable[[Scope], ASGIInstance] @@ -169,9 +167,7 @@ async def _asgi_send(self, message: Message) -> None: def _raise_on_close(self, message: Message) -> None: if message["type"] == "websocket.close": - raise WebSocketDisconnect( - code=message.get("code", 1000), reason=message.get("reason", "") - ) + raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", "")) elif message["type"] == "websocket.http.response.start": status_code: int = message["status"] headers: list[tuple[bytes, bytes]] = message["headers"] @@ -199,9 +195,7 @@ def send_text(self, data: str) -> None: def send_bytes(self, data: bytes) -> None: self.send({"type": "websocket.receive", "bytes": data}) - def send_json( - self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text" - ) -> None: + def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None: text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) if mode == "text": self.send({"type": "websocket.receive", "text": text}) @@ -227,9 +221,7 @@ def receive_bytes(self) -> bytes: self._raise_on_close(message) return typing.cast(bytes, message["bytes"]) - def receive_json( - self, mode: typing.Literal["text", "binary"] = "text" - ) -> typing.Any: + def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any: message = self.receive() self._raise_on_close(message) if mode == "text": @@ -280,10 +272,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: headers = [(b"host", (f"{host}:{port}").encode())] # Include other request headers. - headers += [ - (key.lower().encode(), value.encode()) - for key, value in request.headers.multi_items() - ] + headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] scope: dict[str, typing.Any] @@ -365,22 +354,13 @@ async def send(message: Message) -> None: nonlocal raw_kwargs, response_started, template, context if message["type"] == "http.response.start": - assert ( - not response_started - ), 'Received multiple "http.response.start" messages.' + assert not response_started, 'Received multiple "http.response.start" messages.' raw_kwargs["status_code"] = message["status"] - raw_kwargs["headers"] = [ - (key.decode(), value.decode()) - for key, value in message.get("headers", []) - ] + raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])] response_started = True elif message["type"] == "http.response.body": - assert ( - response_started - ), 'Received "http.response.body" without "http.response.start".' - assert ( - not response_complete.is_set() - ), 'Received "http.response.body" after response completed.' + assert response_started, 'Received "http.response.body" without "http.response.start".' + assert not response_complete.is_set(), 'Received "http.response.body" after response completed.' body = message.get("body", b"") more_body = message.get("more_body", False) if request.method != "HEAD": @@ -435,9 +415,7 @@ def __init__( headers: dict[str, str] | None = None, follow_redirects: bool = True, ) -> None: - self.async_backend = _AsyncBackend( - backend=backend, backend_options=backend_options or {} - ) + self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {}) if _is_asgi3(app): asgi_app = app else: @@ -468,22 +446,15 @@ def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, No if self.portal is not None: yield self.portal else: - with anyio.from_thread.start_blocking_portal( - **self.async_backend - ) as portal: + with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal: yield portal def _choose_redirect_arg( self, follow_redirects: bool | None, allow_redirects: bool | None ) -> bool | httpx._client.UseClientDefault: - redirect: bool | httpx._client.UseClientDefault = ( - httpx._client.USE_CLIENT_DEFAULT - ) + redirect: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT if allow_redirects is not None: - message = ( - "The `allow_redirects` argument is deprecated. " - "Use `follow_redirects` instead." - ) + message = "The `allow_redirects` argument is deprecated. Use `follow_redirects` instead." warnings.warn(message, DeprecationWarning) redirect = allow_redirects if follow_redirects is not None: @@ -506,12 +477,10 @@ def request( # type: ignore[override] params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, - auth: httpx._types.AuthTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | None = None, allow_redirects: bool | None = None, - timeout: httpx._types.TimeoutTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: url = self._merge_url(url) @@ -539,12 +508,10 @@ def get( # type: ignore[override] params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, - auth: httpx._types.AuthTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | None = None, allow_redirects: bool | None = None, - timeout: httpx._types.TimeoutTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) @@ -566,12 +533,10 @@ def options( # type: ignore[override] params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, - auth: httpx._types.AuthTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | None = None, allow_redirects: bool | None = None, - timeout: httpx._types.TimeoutTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) @@ -593,12 +558,10 @@ def head( # type: ignore[override] params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, - auth: httpx._types.AuthTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | None = None, allow_redirects: bool | None = None, - timeout: httpx._types.TimeoutTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) @@ -624,12 +587,10 @@ def post( # type: ignore[override] params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, - auth: httpx._types.AuthTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | None = None, allow_redirects: bool | None = None, - timeout: httpx._types.TimeoutTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) @@ -659,12 +620,10 @@ def put( # type: ignore[override] params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, - auth: httpx._types.AuthTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | None = None, allow_redirects: bool | None = None, - timeout: httpx._types.TimeoutTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) @@ -694,12 +653,10 @@ def patch( # type: ignore[override] params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, - auth: httpx._types.AuthTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | None = None, allow_redirects: bool | None = None, - timeout: httpx._types.TimeoutTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) @@ -725,12 +682,10 @@ def delete( # type: ignore[override] params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, - auth: httpx._types.AuthTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | None = None, allow_redirects: bool | None = None, - timeout: httpx._types.TimeoutTypes - | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) @@ -770,9 +725,7 @@ def websocket_connect( def __enter__(self) -> TestClient: with contextlib.ExitStack() as stack: - self.portal = portal = stack.enter_context( - anyio.from_thread.start_blocking_portal(**self.async_backend) - ) + self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend)) @stack.callback def reset_portal() -> None: diff --git a/starlette/types.py b/starlette/types.py index f78dd63ae..893f87296 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -16,15 +16,9 @@ ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]] -StatefulLifespan = typing.Callable[ - [AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]] -] +StatefulLifespan = typing.Callable[[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]] Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]] -HTTPExceptionHandler = typing.Callable[ - ["Request", Exception], "Response | typing.Awaitable[Response]" -] -WebSocketExceptionHandler = typing.Callable[ - ["WebSocket", Exception], typing.Awaitable[None] -] +HTTPExceptionHandler = typing.Callable[["Request", Exception], "Response | typing.Awaitable[Response]"] +WebSocketExceptionHandler = typing.Callable[["WebSocket", Exception], typing.Awaitable[None]] ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler] diff --git a/starlette/websockets.py b/starlette/websockets.py index 53ab5a70c..b7acaa3f0 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -39,10 +39,7 @@ async def receive(self) -> Message: message = await self._receive() message_type = message["type"] if message_type != "websocket.connect": - raise RuntimeError( - 'Expected ASGI message "websocket.connect", ' - f"but got {message_type!r}" - ) + raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}') self.client_state = WebSocketState.CONNECTED return message elif self.client_state == WebSocketState.CONNECTED: @@ -50,16 +47,13 @@ async def receive(self) -> Message: message_type = message["type"] if message_type not in {"websocket.receive", "websocket.disconnect"}: raise RuntimeError( - 'Expected ASGI message "websocket.receive" or ' - f'"websocket.disconnect", but got {message_type!r}' + f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}' ) if message_type == "websocket.disconnect": self.client_state = WebSocketState.DISCONNECTED return message else: - raise RuntimeError( - 'Cannot call "receive" once a disconnect message has been received.' - ) + raise RuntimeError('Cannot call "receive" once a disconnect message has been received.') async def send(self, message: Message) -> None: """ @@ -67,14 +61,9 @@ async def send(self, message: Message) -> None: """ if self.application_state == WebSocketState.CONNECTING: message_type = message["type"] - if message_type not in { - "websocket.accept", - "websocket.close", - "websocket.http.response.start", - }: + if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}: raise RuntimeError( - 'Expected ASGI message "websocket.accept",' - '"websocket.close" or "websocket.http.response.start",' + 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", ' f"but got {message_type!r}" ) if message_type == "websocket.close": @@ -88,8 +77,7 @@ async def send(self, message: Message) -> None: message_type = message["type"] if message_type not in {"websocket.send", "websocket.close"}: raise RuntimeError( - 'Expected ASGI message "websocket.send" or "websocket.close", ' - f"but got {message_type!r}" + f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}' ) if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED @@ -101,10 +89,7 @@ async def send(self, message: Message) -> None: elif self.application_state == WebSocketState.RESPONSE: message_type = message["type"] if message_type != "websocket.http.response.body": - raise RuntimeError( - 'Expected ASGI message "websocket.http.response.body", ' - f"but got {message_type!r}" - ) + raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}') if not message.get("more_body", False): self.application_state = WebSocketState.DISCONNECTED await self._send(message) @@ -121,9 +106,7 @@ async def accept( if self.client_state == WebSocketState.CONNECTING: # If we haven't yet seen the 'connect' message, then wait for it first. await self.receive() - await self.send( - {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers} - ) + await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}) def _raise_on_disconnect(self, message: Message) -> None: if message["type"] == "websocket.disconnect": @@ -131,18 +114,14 @@ def _raise_on_disconnect(self, message: Message) -> None: async def receive_text(self) -> str: if self.application_state != WebSocketState.CONNECTED: - raise RuntimeError( - 'WebSocket is not connected. Need to call "accept" first.' - ) + raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') message = await self.receive() self._raise_on_disconnect(message) return typing.cast(str, message["text"]) async def receive_bytes(self) -> bytes: if self.application_state != WebSocketState.CONNECTED: - raise RuntimeError( - 'WebSocket is not connected. Need to call "accept" first.' - ) + raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') message = await self.receive() self._raise_on_disconnect(message) return typing.cast(bytes, message["bytes"]) @@ -151,9 +130,7 @@ async def receive_json(self, mode: str = "text") -> typing.Any: if mode not in {"text", "binary"}: raise RuntimeError('The "mode" argument should be "text" or "binary".') if self.application_state != WebSocketState.CONNECTED: - raise RuntimeError( - 'WebSocket is not connected. Need to call "accept" first.' - ) + raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') message = await self.receive() self._raise_on_disconnect(message) @@ -200,17 +177,13 @@ async def send_json(self, data: typing.Any, mode: str = "text") -> None: await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) async def close(self, code: int = 1000, reason: str | None = None) -> None: - await self.send( - {"type": "websocket.close", "code": code, "reason": reason or ""} - ) + await self.send({"type": "websocket.close", "code": code, "reason": reason or ""}) async def send_denial_response(self, response: Response) -> None: if "websocket.http.response" in self.scope.get("extensions", {}): await response(self.scope, self.receive, self.send) else: - raise RuntimeError( - "The server doesn't support the Websocket Denial Response extension." - ) + raise RuntimeError("The server doesn't support the Websocket Denial Response extension.") class WebSocketClose: @@ -219,6 +192,4 @@ def __init__(self, code: int = 1000, reason: str | None = None) -> None: self.reason = reason or "" async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - await send( - {"type": "websocket.close", "code": self.code, "reason": self.reason} - ) + await send({"type": "websocket.close", "code": self.code, "reason": self.reason}) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 8e410cb15..225038650 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -169,9 +169,7 @@ def test_app_middleware_argument(test_client_factory: TestClientFactory) -> None def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage") - app = Starlette( - routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)] - ) + app = Starlette(routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]) client = test_client_factory(app) response = client.get("/") @@ -249,9 +247,7 @@ async def homepage(request: Request) -> PlainTextResponse: ctxvar.set("set by endpoint") return PlainTextResponse("Homepage") - app = Starlette( - middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)] - ) + app = Starlette(middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]) client = test_client_factory(app) response = client.get("/") @@ -316,13 +312,9 @@ async def sleep_and_set() -> None: events.append("Background task finished") async def endpoint_with_background_task(_: Request) -> PlainTextResponse: - return PlainTextResponse( - content="Hello", background=BackgroundTask(sleep_and_set) - ) + return PlainTextResponse(content="Hello", background=BackgroundTask(sleep_and_set)) - async def passthrough( - request: Request, call_next: RequestResponseEndpoint - ) -> Response: + async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response: return await call_next(request) app = Starlette( @@ -490,9 +482,7 @@ async def cancel_on_disconnect( } ) - pytest.fail( - "http.disconnect should have been received and canceled the scope" - ) # pragma: no cover + pytest.fail("http.disconnect should have been received and canceled the scope") # pragma: no cover app = DiscardingMiddleware(downstream_app) @@ -787,7 +777,7 @@ async def send(msg: Message) -> None: await rcv_stream.aclose() -def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( # noqa: E501 +def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( test_client_factory: TestClientFactory, ) -> None: async def homepage(request: Request) -> PlainTextResponse: @@ -800,9 +790,7 @@ async def dispatch( request: Request, call_next: RequestResponseEndpoint, ) -> Response: - assert ( - await request.body() == b"a" - ) # this buffers the request body in memory + assert await request.body() == b"a" # this buffers the request body in memory resp = await call_next(request) async for chunk in request.stream(): if chunk: @@ -819,7 +807,7 @@ async def dispatch( assert response.status_code == 200 -def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( # noqa: E501 +def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( test_client_factory: TestClientFactory, ) -> None: async def homepage(request: Request) -> PlainTextResponse: @@ -832,9 +820,7 @@ async def dispatch( request: Request, call_next: RequestResponseEndpoint, ) -> Response: - assert ( - await request.body() == b"a" - ) # this buffers the request body in memory + assert await request.body() == b"a" # this buffers the request body in memory resp = await call_next(request) assert await request.body() == b"a" # no problem here return resp @@ -1026,9 +1012,7 @@ def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None: self.events = events super().__init__(app) - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: self.events.append(f"{self.version}:STARTED") res = await call_next(request) self.events.append(f"{self.version}:COMPLETED") @@ -1047,9 +1031,7 @@ async def sleepy(request: Request) -> Response: app = Starlette( routes=[Route("/", sleepy)], - middleware=[ - Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10) - ], + middleware=[Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)], ) scope = { @@ -1114,9 +1096,7 @@ async def app_poll_disconnect(scope: Scope, receive: Receive, send: Send) -> Non await Response(b"good!")(scope, receive, send) class MyMiddleware(BaseHTTPMiddleware): - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: return await call_next(request) app = MyMiddleware(app_poll_disconnect) diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 630361243..0d987263e 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -252,9 +252,7 @@ def homepage(request: Request) -> None: app = Starlette( routes=[Route("/", endpoint=homepage)], - middleware=[ - Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) - ], + middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])], ) client = test_client_factory(app) @@ -284,9 +282,7 @@ def homepage(request: Request) -> PlainTextResponse: methods=["delete", "get", "head", "options", "patch", "post", "put"], ) ], - middleware=[ - Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) - ], + middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])], ) client = test_client_factory(app) @@ -397,10 +393,7 @@ def homepage(request: Request) -> PlainTextResponse: response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" - assert ( - response.headers["access-control-allow-origin"] - == "https://subdomain.example.org" - ) + assert response.headers["access-control-allow-origin"] == "https://subdomain.example.org" assert "access-control-allow-credentials" not in response.headers # Test diallowed standard response @@ -456,9 +449,7 @@ def test_cors_vary_header_is_not_set_for_non_credentialed_request( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> PlainTextResponse: - return PlainTextResponse( - "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} - ) + return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) app = Starlette( routes=[Route("/", endpoint=homepage)], @@ -475,9 +466,7 @@ def test_cors_vary_header_is_properly_set_for_credentialed_request( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> PlainTextResponse: - return PlainTextResponse( - "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} - ) + return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) app = Starlette( routes=[Route("/", endpoint=homepage)], @@ -485,9 +474,7 @@ def homepage(request: Request) -> PlainTextResponse: ) client = test_client_factory(app) - response = client.get( - "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"} - ) + response = client.get("/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding, Origin" @@ -496,9 +483,7 @@ def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> PlainTextResponse: - return PlainTextResponse( - "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} - ) + return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) app = Starlette( routes=[ @@ -538,9 +523,7 @@ def homepage(request: Request) -> PlainTextResponse: assert response.headers["access-control-allow-origin"] == "*" assert "access-control-allow-credentials" not in response.headers - response = client.get( - "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"} - ) + response = client.get("/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}) assert response.headers["access-control-allow-origin"] == "https://someplace.org" assert "access-control-allow-credentials" not in response.headers diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index b6f68296d..b20a7cb84 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -91,9 +91,7 @@ async def generator(bytes: bytes, count: int) -> ContentStream: yield bytes streaming = generator(bytes=b"x" * 400, count=10) - return StreamingResponse( - streaming, status_code=200, headers={"Content-Encoding": "text"} - ) + return StreamingResponse(streaming, status_code=200, headers={"Content-Encoding": "text"}) app = Starlette( routes=[Route("/", endpoint=homepage)], diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 9a0d70a0d..b4f3c64fa 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -89,9 +89,7 @@ def test_secure_session(test_client_factory: TestClientFactory) -> None: Route("/update_session", endpoint=update_session, methods=["POST"]), Route("/clear_session", endpoint=clear_session, methods=["POST"]), ], - middleware=[ - Middleware(SessionMiddleware, secret_key="example", https_only=True) - ], + middleware=[Middleware(SessionMiddleware, secret_key="example", https_only=True)], ) secure_client = test_client_factory(app, base_url="https://testserver") unsecure_client = test_client_factory(app, base_url="http://testserver") @@ -126,9 +124,7 @@ def test_session_cookie_subpath(test_client_factory: TestClientFactory) -> None: routes=[ Route("/update_session", endpoint=update_session, methods=["POST"]), ], - middleware=[ - Middleware(SessionMiddleware, secret_key="example", path="/second_app") - ], + middleware=[Middleware(SessionMiddleware, secret_key="example", path="/second_app")], ) app = Starlette(routes=[Mount("/second_app", app=second_app)]) client = test_client_factory(app, base_url="http://testserver/second_app") @@ -188,9 +184,7 @@ def test_domain_cookie(test_client_factory: TestClientFactory) -> None: Route("/view_session", endpoint=view_session), Route("/update_session", endpoint=update_session, methods=["POST"]), ], - middleware=[ - Middleware(SessionMiddleware, secret_key="example", domain=".example.com") - ], + middleware=[Middleware(SessionMiddleware, secret_key="example", domain=".example.com")], ) client: TestClient = test_client_factory(app) diff --git a/tests/middleware/test_trusted_host.py b/tests/middleware/test_trusted_host.py index ddff46c48..5b8b217c3 100644 --- a/tests/middleware/test_trusted_host.py +++ b/tests/middleware/test_trusted_host.py @@ -13,11 +13,7 @@ def homepage(request: Request) -> PlainTextResponse: app = Starlette( routes=[Route("/", endpoint=homepage)], - middleware=[ - Middleware( - TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"] - ) - ], + middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"])], ) client = test_client_factory(app) @@ -45,9 +41,7 @@ def homepage(request: Request) -> PlainTextResponse: app = Starlette( routes=[Route("/", endpoint=homepage)], - middleware=[ - Middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"]) - ], + middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])], ) client = test_client_factory(app, base_url="https://example.com") diff --git a/tests/test_applications.py b/tests/test_applications.py index 20da7ea81..86c713c38 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -109,9 +109,7 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> CustomWSException: custom_ws_exception_handler, } -middleware = [ - Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"]) -] +middleware = [Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"])] app = Starlette( routes=[ @@ -349,9 +347,7 @@ def run_cleanup() -> None: nonlocal cleanup_complete cleanup_complete = True - with pytest.deprecated_call( - match="The on_startup and on_shutdown parameters are deprecated" - ): + with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"): app = Starlette( on_startup=[run_startup], on_shutdown=[run_cleanup], @@ -445,51 +441,34 @@ def test_decorator_deprecations() -> None: app = Starlette() with pytest.deprecated_call( - match=( - "The `exception_handler` decorator is deprecated, " - "and will be removed in version 1.0.0." - ) + match=("The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0.") ) as record: app.exception_handler(500)(http_exception) assert len(record) == 1 with pytest.deprecated_call( - match=( - "The `middleware` decorator is deprecated, " - "and will be removed in version 1.0.0." - ) + match=("The `middleware` decorator is deprecated, and will be removed in version 1.0.0.") ) as record: - async def middleware( - request: Request, call_next: RequestResponseEndpoint - ) -> None: ... # pragma: no cover + async def middleware(request: Request, call_next: RequestResponseEndpoint) -> None: ... # pragma: no cover app.middleware("http")(middleware) assert len(record) == 1 with pytest.deprecated_call( - match=( - "The `route` decorator is deprecated, " - "and will be removed in version 1.0.0." - ) + match=("The `route` decorator is deprecated, and will be removed in version 1.0.0.") ) as record: app.route("/")(async_homepage) assert len(record) == 1 with pytest.deprecated_call( - match=( - "The `websocket_route` decorator is deprecated, " - "and will be removed in version 1.0.0." - ) + match=("The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0.") ) as record: app.websocket_route("/ws")(websocket_endpoint) assert len(record) == 1 with pytest.deprecated_call( - match=( - "The `on_event` decorator is deprecated, " - "and will be removed in version 1.0.0." - ) + match=("The `on_event` decorator is deprecated, and will be removed in version 1.0.0.") ) as record: async def startup() -> None: ... # pragma: no cover diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 35c1110d1..a1bde67b9 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -259,9 +259,7 @@ def test_authentication_required(test_client_factory: TestClientFactory) -> None response = client.get("/dashboard/decorated") assert response.status_code == 403 - response = client.get( - "/dashboard/decorated/sync", auth=("tomchristie", "example") - ) + response = client.get("/dashboard/decorated/sync", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == { "authenticated": True, @@ -286,14 +284,10 @@ def test_websocket_authentication_required( pass # pragma: nocover with pytest.raises(WebSocketDisconnect): - with client.websocket_connect( - "/ws", headers={"Authorization": "basic foobar"} - ): + with client.websocket_connect("/ws", headers={"Authorization": "basic foobar"}): pass # pragma: nocover - with client.websocket_connect( - "/ws", auth=("tomchristie", "example") - ) as websocket: + with client.websocket_connect("/ws", auth=("tomchristie", "example")) as websocket: data = websocket.receive_json() assert data == {"authenticated": True, "user": "tomchristie"} @@ -302,14 +296,10 @@ def test_websocket_authentication_required( pass # pragma: nocover with pytest.raises(WebSocketDisconnect): - with client.websocket_connect( - "/ws/decorated", headers={"Authorization": "basic foobar"} - ): + with client.websocket_connect("/ws/decorated", headers={"Authorization": "basic foobar"}): pass # pragma: nocover - with client.websocket_connect( - "/ws/decorated", auth=("tomchristie", "example") - ) as websocket: + with client.websocket_connect("/ws/decorated", auth=("tomchristie", "example")) as websocket: data = websocket.receive_json() assert data == { "authenticated": True, @@ -322,9 +312,7 @@ def test_authentication_redirect(test_client_factory: TestClientFactory) -> None with test_client_factory(app) as client: response = client.get("/admin") assert response.status_code == 200 - url = "{}?{}".format( - "http://testserver/", urlencode({"next": "http://testserver/admin"}) - ) + url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin"})) assert response.url == url response = client.get("/admin", auth=("tomchristie", "example")) @@ -333,9 +321,7 @@ def test_authentication_redirect(test_client_factory: TestClientFactory) -> None response = client.get("/admin/sync") assert response.status_code == 200 - url = "{}?{}".format( - "http://testserver/", urlencode({"next": "http://testserver/admin/sync"}) - ) + url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin/sync"})) assert response.url == url response = client.get("/admin/sync", auth=("tomchristie", "example")) @@ -359,11 +345,7 @@ def control_panel(request: Request) -> JSONResponse: other_app = Starlette( routes=[Route("/control-panel", control_panel)], - middleware=[ - Middleware( - AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error - ) - ], + middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error)], ) @@ -373,8 +355,6 @@ def test_custom_on_error(test_client_factory: TestClientFactory) -> None: assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} - response = client.get( - "/control-panel", headers={"Authorization": "basic foobar"} - ) + response = client.get("/control-panel", headers={"Authorization": "basic foobar"}) assert response.status_code == 401 assert response.json() == {"error": "Invalid basic auth credentials"} diff --git a/tests/test_background.py b/tests/test_background.py index cbffcc06a..990e270ea 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -56,9 +56,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: tasks.add_task(increment, amount=1) tasks.add_task(increment, amount=2) tasks.add_task(increment, amount=3) - response = Response( - "tasks initiated", media_type="text/plain", background=tasks - ) + response = Response("tasks initiated", media_type="text/plain", background=tasks) await response(scope, receive, send) client = test_client_factory(app) @@ -82,9 +80,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: tasks = BackgroundTasks() tasks.add_task(increment) tasks.add_task(increment) - response = Response( - "tasks initiated", media_type="text/plain", background=tasks - ) + response = Response("tasks initiated", media_type="text/plain", background=tasks) await response(scope, receive, send) client = test_client_factory(app) diff --git a/tests/test_config.py b/tests/test_config.py index f37591007..7d2cd1f9d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -14,9 +14,7 @@ def test_config_types() -> None: """ We use `assert_type` to test the types returned by Config via mypy. """ - config = Config( - environ={"STR": "some_str_value", "STR_CAST": "some_str_value", "BOOL": "true"} - ) + config = Config(environ={"STR": "some_str_value", "STR_CAST": "some_str_value", "BOOL": "true"}) assert_type(config("STR"), str) assert_type(config("STR_DEFAULT", default=""), str) @@ -138,9 +136,7 @@ def test_environ() -> None: def test_config_with_env_prefix(tmpdir: Path, monkeypatch: pytest.MonkeyPatch) -> None: - config = Config( - environ={"APP_DEBUG": "value", "ENVIRONMENT": "dev"}, env_prefix="APP_" - ) + config = Config(environ={"APP_DEBUG": "value", "ENVIRONMENT": "dev"}, env_prefix="APP_") assert config.get("DEBUG") == "value" with pytest.raises(KeyError): diff --git a/tests/test_convertors.py b/tests/test_convertors.py index 520c98767..ced1b86cc 100644 --- a/tests/test_convertors.py +++ b/tests/test_convertors.py @@ -48,23 +48,18 @@ def datetime_convertor(request: Request) -> JSONResponse: ) -def test_datetime_convertor( - test_client_factory: TestClientFactory, app: Router -) -> None: +def test_datetime_convertor(test_client_factory: TestClientFactory, app: Router) -> None: client = test_client_factory(app) response = client.get("/datetime/2020-01-01T00:00:00") assert response.json() == {"datetime": "2020-01-01T00:00:00"} assert ( - app.url_path_for("datetime-convertor", param=datetime(1996, 1, 22, 23, 0, 0)) - == "/datetime/1996-01-22T23:00:00" + app.url_path_for("datetime-convertor", param=datetime(1996, 1, 22, 23, 0, 0)) == "/datetime/1996-01-22T23:00:00" ) @pytest.mark.parametrize("param, status_code", [("1.0", 200), ("1-0", 404)]) -def test_default_float_convertor( - test_client_factory: TestClientFactory, param: str, status_code: int -) -> None: +def test_default_float_convertor(test_client_factory: TestClientFactory, param: str, status_code: int) -> None: def float_convertor(request: Request) -> JSONResponse: param = request.path_params["param"] assert isinstance(param, float) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index a6bca6ef6..0e7d35c3c 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -115,9 +115,7 @@ def test_csv() -> None: def test_url_from_scope() -> None: - u = URL( - scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []} - ) + u = URL(scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []}) assert u == "/path/to/somewhere?abc=123" assert repr(u) == "URL('/path/to/somewhere?abc=123')" @@ -296,13 +294,9 @@ def test_queryparams() -> None: assert dict(q) == {"a": "456", "b": "789"} assert str(q) == "a=123&a=456&b=789" assert repr(q) == "QueryParams('a=123&a=456&b=789')" - assert QueryParams({"a": "123", "b": "456"}) == QueryParams( - [("a", "123"), ("b", "456")] - ) + assert QueryParams({"a": "123", "b": "456"}) == QueryParams([("a", "123"), ("b", "456")]) assert QueryParams({"a": "123", "b": "456"}) == QueryParams("a=123&b=456") - assert QueryParams({"a": "123", "b": "456"}) == QueryParams( - {"b": "456", "a": "123"} - ) + assert QueryParams({"a": "123", "b": "456"}) == QueryParams({"b": "456", "a": "123"}) assert QueryParams() == QueryParams({}) assert QueryParams([("a", "123"), ("a", "456")]) == QueryParams("a=123&a=456") assert QueryParams({"a": "123", "b": "456"}) != "invalid" @@ -382,10 +376,7 @@ def test_formdata() -> None: assert len(form) == 2 assert list(form) == ["a", "b"] assert dict(form) == {"a": "456", "b": upload} - assert ( - repr(form) - == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])" - ) + assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])" assert FormData(form) == form assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")]) assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"} @@ -402,10 +393,7 @@ async def test_upload_file_repr() -> None: async def test_upload_file_repr_headers() -> None: stream = io.BytesIO(b"data") file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"})) - assert ( - repr(file) - == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))" - ) + assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))" def test_multidict() -> None: @@ -425,9 +413,7 @@ def test_multidict() -> None: assert dict(q) == {"a": "456", "b": "789"} assert str(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])" assert repr(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])" - assert MultiDict({"a": "123", "b": "456"}) == MultiDict( - [("a", "123"), ("b", "456")] - ) + assert MultiDict({"a": "123", "b": "456"}) == MultiDict([("a", "123"), ("b", "456")]) assert MultiDict({"a": "123", "b": "456"}) == MultiDict({"b": "456", "a": "123"}) assert MultiDict() == MultiDict({}) assert MultiDict({"a": "123", "b": "456"}) != "invalid" diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 8f201e25b..42776a5b3 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -19,9 +19,7 @@ async def get(self, request: Request) -> PlainTextResponse: return PlainTextResponse(f"Hello, {username}!") -app = Router( - routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)] -) +app = Router(routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)]) @pytest.fixture diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index f4e91ad87..b3dc7843f 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -42,9 +42,7 @@ async def read_body_and_raise_exc(request: Request) -> None: raise BadBodyException(422) -async def handler_that_reads_body( - request: Request, exc: BadBodyException -) -> JSONResponse: +async def handler_that_reads_body(request: Request, exc: BadBodyException) -> JSONResponse: body = await request.body() return JSONResponse(status_code=422, content={"body": body.decode()}) @@ -158,9 +156,7 @@ def test_http_str() -> None: def test_http_repr() -> None: - assert repr(HTTPException(404)) == ( - "HTTPException(status_code=404, detail='Not Found')" - ) + assert repr(HTTPException(404)) == ("HTTPException(status_code=404, detail='Not Found')") assert repr(HTTPException(404, detail="Not Found: foo")) == ( "HTTPException(status_code=404, detail='Not Found: foo')" ) diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 8d97a0ba7..61c1bede1 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -127,17 +127,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: return app -def test_multipart_request_data( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_multipart_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) assert response.json() == {"some": "data"} -def test_multipart_request_files( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_multipart_request_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") @@ -155,9 +151,7 @@ def test_multipart_request_files( } -def test_multipart_request_files_with_content_type( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_multipart_request_files_with_content_type(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") @@ -175,9 +169,7 @@ def test_multipart_request_files_with_content_type( } -def test_multipart_request_multiple_files( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_multipart_request_multiple_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") @@ -188,9 +180,7 @@ def test_multipart_request_multiple_files( client = test_client_factory(app) with open(path1, "rb") as f1, open(path2, "rb") as f2: - response = client.post( - "/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")} - ) + response = client.post("/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")}) assert response.json() == { "test1": { "filename": "test1.txt", @@ -207,9 +197,7 @@ def test_multipart_request_multiple_files( } -def test_multipart_request_multiple_files_with_headers( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_multipart_request_multiple_files_with_headers(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") @@ -281,9 +269,7 @@ def test_multi_items(tmpdir: Path, test_client_factory: TestClientFactory) -> No } -def test_multipart_request_mixed_files_and_data( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post( "/", @@ -303,11 +289,7 @@ def test_multipart_request_mixed_files_and_data( b"value1\r\n" b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" ), - headers={ - "Content-Type": ( - "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" - ) - }, + headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")}, ) assert response.json() == { "file": { @@ -321,26 +303,19 @@ def test_multipart_request_mixed_files_and_data( } -def test_multipart_request_with_charset_for_filename( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post( "/", data=( # file b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore - b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # noqa: E501 + b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' b"Content-Type: text/plain\r\n\r\n" b"\r\n" b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" ), - headers={ - "Content-Type": ( - "multipart/form-data; charset=utf-8; " - "boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" - ) - }, + headers={"Content-Type": ("multipart/form-data; charset=utf-8; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")}, ) assert response.json() == { "file": { @@ -352,25 +327,19 @@ def test_multipart_request_with_charset_for_filename( } -def test_multipart_request_without_charset_for_filename( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_multipart_request_without_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post( "/", data=( # file b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore - b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n' # noqa: E501 + b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n' b"Content-Type: image/jpeg\r\n\r\n" b"\r\n" b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" ), - headers={ - "Content-Type": ( - "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" - ) - }, + headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")}, ) assert response.json() == { "file": { @@ -382,9 +351,7 @@ def test_multipart_request_without_charset_for_filename( } -def test_multipart_request_with_encoded_value( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_multipart_request_with_encoded_value(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post( "/", @@ -395,19 +362,12 @@ def test_multipart_request_with_encoded_value( b"Transf\xc3\xa9rer\r\n" b"--20b303e711c4ab8c443184ac833ab00f--\r\n" ), - headers={ - "Content-Type": ( - "multipart/form-data; charset=utf-8; " - "boundary=20b303e711c4ab8c443184ac833ab00f" - ) - }, + headers={"Content-Type": ("multipart/form-data; charset=utf-8; boundary=20b303e711c4ab8c443184ac833ab00f")}, ) assert response.json() == {"value": "Transférer"} -def test_urlencoded_request_data( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_urlencoded_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post("/", data={"some": "data"}) assert response.json() == {"some": "data"} @@ -419,37 +379,27 @@ def test_no_request_data(tmpdir: Path, test_client_factory: TestClientFactory) - assert response.json() == {} -def test_urlencoded_percent_encoding( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_urlencoded_percent_encoding(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post("/", data={"some": "da ta"}) assert response.json() == {"some": "da ta"} -def test_urlencoded_percent_encoding_keys( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_urlencoded_percent_encoding_keys(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post("/", data={"so me": "data"}) assert response.json() == {"so me": "data"} -def test_urlencoded_multi_field_app_reads_body( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_urlencoded_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app_read_body) response = client.post("/", data={"some": "data", "second": "key pair"}) assert response.json() == {"some": "data", "second": "key pair"} -def test_multipart_multi_field_app_reads_body( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_multipart_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app_read_body) - response = client.post( - "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART - ) + response = client.post("/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART) assert response.json() == {"some": "data", "second": "key pair"} @@ -481,7 +431,7 @@ def test_missing_boundary_parameter( "/", data=( # file - b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # type: ignore # noqa: E501 + b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # type: ignore b"Content-Type: text/plain\r\n\r\n" b"\r\n" ), @@ -513,16 +463,10 @@ def test_missing_name_parameter_on_content_disposition( b'Content-Disposition: form-data; ="field0"\r\n\r\n' b"value0\r\n" ), - headers={ - "Content-Type": ( - "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" - ) - }, + headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")}, ) assert res.status_code == 400 - assert ( - res.text == 'The Content-Disposition header field "name" must be provided.' - ) + assert res.text == 'The Content-Disposition header field "name" must be provided.' @pytest.mark.parametrize( @@ -540,9 +484,7 @@ def test_too_many_fields_raise( client = test_client_factory(app) fields = [] for i in range(1001): - fields.append( - "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n" - ) + fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n") data = "".join(fields).encode("utf-8") with expectation: res = client.post( @@ -569,11 +511,7 @@ def test_too_many_files_raise( client = test_client_factory(app) fields = [] for i in range(1001): - fields.append( - "--B\r\n" - f'Content-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n' - "\r\n" - ) + fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n' "\r\n") data = "".join(fields).encode("utf-8") with expectation: res = client.post( @@ -602,11 +540,7 @@ def test_too_many_files_single_field_raise( for i in range(1001): # This uses the same field name "N" for all files, equivalent to a # multifile upload form field - fields.append( - "--B\r\n" - f'Content-Disposition: form-data; name="N"; filename="F{i}";\r\n\r\n' - "\r\n" - ) + fields.append("--B\r\n" f'Content-Disposition: form-data; name="N"; filename="F{i}";\r\n\r\n' "\r\n") data = "".join(fields).encode("utf-8") with expectation: res = client.post( @@ -633,14 +567,8 @@ def test_too_many_files_and_fields_raise( client = test_client_factory(app) fields = [] for i in range(1001): - fields.append( - "--B\r\n" - f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' - "\r\n" - ) - fields.append( - "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n" - ) + fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n") + fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n") data = "".join(fields).encode("utf-8") with expectation: res = client.post( @@ -670,9 +598,7 @@ def test_max_fields_is_customizable_low_raises( client = test_client_factory(app) fields = [] for i in range(2): - fields.append( - "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n" - ) + fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n") data = "".join(fields).encode("utf-8") with expectation: res = client.post( @@ -702,11 +628,7 @@ def test_max_files_is_customizable_low_raises( client = test_client_factory(app) fields = [] for i in range(2): - fields.append( - "--B\r\n" - f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' - "\r\n" - ) + fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n") data = "".join(fields).encode("utf-8") with expectation: res = client.post( @@ -724,14 +646,8 @@ def test_max_fields_is_customizable_high( client = test_client_factory(make_app_max_parts(max_fields=2000, max_files=2000)) fields = [] for i in range(2000): - fields.append( - "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n" - ) - fields.append( - "--B\r\n" - f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' - "\r\n" - ) + fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n") + fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n") data = "".join(fields).encode("utf-8") data += b"--B--\r\n" res = client.post( diff --git a/tests/test_responses.py b/tests/test_responses.py index c63c92de5..ad1901ca5 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -118,9 +118,7 @@ async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9) generator = numbers(1, 5) - response = StreamingResponse( - generator, media_type="text/plain", background=cleanup_task - ) + response = StreamingResponse(generator, media_type="text/plain", background=cleanup_task) await response(scope, receive, send) assert filled_by_bg_task == "" @@ -236,9 +234,7 @@ async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9) async def app(scope: Scope, receive: Receive, send: Send) -> None: - response = FileResponse( - path=path, filename="example.png", background=cleanup_task - ) + response = FileResponse(path=path, filename="example.png", background=cleanup_task) await response(scope, receive, send) assert filled_by_bg_task == "" @@ -284,9 +280,7 @@ async def send(message: Message) -> None: await app({"type": "http", "method": "head"}, receive, send) -def test_file_response_set_media_type( - tmp_path: Path, test_client_factory: TestClientFactory -) -> None: +def test_file_response_set_media_type(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "xyz" path.write_bytes(b"") @@ -298,9 +292,7 @@ def test_file_response_set_media_type( assert response.headers["content-type"] == "image/jpeg" -def test_file_response_with_directory_raises_error( - tmp_path: Path, test_client_factory: TestClientFactory -) -> None: +def test_file_response_with_directory_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None: app = FileResponse(path=tmp_path, filename="example.png") client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: @@ -308,9 +300,7 @@ def test_file_response_with_directory_raises_error( assert "is not a file" in str(exc_info.value) -def test_file_response_with_missing_file_raises_error( - tmp_path: Path, test_client_factory: TestClientFactory -) -> None: +def test_file_response_with_missing_file_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "404.txt" app = FileResponse(path=path, filename="404.txt") client = test_client_factory(app) @@ -319,9 +309,7 @@ def test_file_response_with_missing_file_raises_error( assert "does not exist" in str(exc_info.value) -def test_file_response_with_chinese_filename( - tmp_path: Path, test_client_factory: TestClientFactory -) -> None: +def test_file_response_with_chinese_filename(tmp_path: Path, test_client_factory: TestClientFactory) -> None: content = b"file content" filename = "你好.txt" # probably "Hello.txt" in Chinese path = tmp_path / filename @@ -335,9 +323,7 @@ def test_file_response_with_chinese_filename( assert response.headers["content-disposition"] == expected_disposition -def test_file_response_with_inline_disposition( - tmp_path: Path, test_client_factory: TestClientFactory -) -> None: +def test_file_response_with_inline_disposition(tmp_path: Path, test_client_factory: TestClientFactory) -> None: content = b"file content" filename = "hello.txt" path = tmp_path / filename @@ -356,9 +342,7 @@ def test_file_response_with_method_warns(tmp_path: Path) -> None: FileResponse(path=tmp_path, filename="example.png", method="GET") -def test_set_cookie( - test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch -) -> None: +def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None: # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`. mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc) monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp()) @@ -382,8 +366,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = client.get("/") assert response.text == "Hello, world!" assert ( - response.headers["set-cookie"] - == "mycookie=myvalue; Domain=localhost; expires=Thu, 22 Jan 2037 12:00:10 GMT; " + response.headers["set-cookie"] == "mycookie=myvalue; Domain=localhost; expires=Thu, 22 Jan 2037 12:00:10 GMT; " "HttpOnly; Max-Age=10; Path=/; SameSite=none; Secure" ) @@ -403,9 +386,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: @pytest.mark.parametrize( "expires", [ - pytest.param( - dt.datetime(2037, 1, 22, 12, 0, 10, tzinfo=dt.timezone.utc), id="datetime" - ), + pytest.param(dt.datetime(2037, 1, 22, 12, 0, 10, tzinfo=dt.timezone.utc), id="datetime"), pytest.param("Thu, 22 Jan 2037 12:00:10 GMT", id="str"), pytest.param(10, id="int"), ], @@ -495,9 +476,7 @@ def test_response_do_not_add_redundant_charset( assert response.headers["content-type"] == "text/plain; charset=utf-8" -def test_file_response_known_size( - tmp_path: Path, test_client_factory: TestClientFactory -) -> None: +def test_file_response_known_size(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "xyz" content = b"" * 1000 path.write_bytes(content) @@ -518,9 +497,7 @@ def test_streaming_response_unknown_size( def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None: - app = StreamingResponse( - content=iter(["hello", "world"]), headers={"content-length": "10"} - ) + app = StreamingResponse(content=iter(["hello", "world"]), headers={"content-length": "10"}) client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == "10" diff --git a/tests/test_routing.py b/tests/test_routing.py index 1490723b4..9fa44def4 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -232,10 +232,7 @@ def test_route_converters(client: TestClient) -> None: response = client.get("/path-with-parentheses(7)") assert response.status_code == 200 assert response.json() == {"int": 7} - assert ( - app.url_path_for("path-with-parentheses", param=7) - == "/path-with-parentheses(7)" - ) + assert app.url_path_for("path-with-parentheses", param=7) == "/path-with-parentheses(7)" # Test float conversion response = client.get("/float/25.5") @@ -247,18 +244,14 @@ def test_route_converters(client: TestClient) -> None: response = client.get("/path/some/example") assert response.status_code == 200 assert response.json() == {"path": "some/example"} - assert ( - app.url_path_for("path-convertor", param="some/example") == "/path/some/example" - ) + assert app.url_path_for("path-convertor", param="some/example") == "/path/some/example" # Test UUID conversion response = client.get("/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a") assert response.status_code == 200 assert response.json() == {"uuid": "ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"} assert ( - app.url_path_for( - "uuid-convertor", param=uuid.UUID("ec38df32-ceda-4cfa-9b4a-1aeb94ad551a") - ) + app.url_path_for("uuid-convertor", param=uuid.UUID("ec38df32-ceda-4cfa-9b4a-1aeb94ad551a")) == "/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a" ) @@ -267,13 +260,9 @@ def test_url_path_for() -> None: assert app.url_path_for("homepage") == "/" assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie" assert app.url_path_for("websocket_endpoint") == "/ws" - with pytest.raises( - NoMatchFound, match='No route exists for name "broken" and params "".' - ): + with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "".'): assert app.url_path_for("broken") - with pytest.raises( - NoMatchFound, match='No route exists for name "broken" and params "key, key2".' - ): + with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "key, key2".'): assert app.url_path_for("broken", key="value", key2="value2") with pytest.raises(AssertionError): app.url_path_for("user", username="tom/christie") @@ -282,32 +271,21 @@ def test_url_path_for() -> None: def test_url_for() -> None: + assert app.url_path_for("homepage").make_absolute_url(base_url="https://example.org") == "https://example.org/" assert ( - app.url_path_for("homepage").make_absolute_url(base_url="https://example.org") - == "https://example.org/" - ) - assert ( - app.url_path_for("homepage").make_absolute_url( - base_url="https://example.org/root_path/" - ) + app.url_path_for("homepage").make_absolute_url(base_url="https://example.org/root_path/") == "https://example.org/root_path/" ) assert ( - app.url_path_for("user", username="tomchristie").make_absolute_url( - base_url="https://example.org" - ) + app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org") == "https://example.org/users/tomchristie" ) assert ( - app.url_path_for("user", username="tomchristie").make_absolute_url( - base_url="https://example.org/root_path/" - ) + app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org/root_path/") == "https://example.org/root_path/users/tomchristie" ) assert ( - app.url_path_for("websocket_endpoint").make_absolute_url( - base_url="https://example.org" - ) + app.url_path_for("websocket_endpoint").make_absolute_url(base_url="https://example.org") == "wss://example.org/ws" ) @@ -409,13 +387,8 @@ def test_reverse_mount_urls() -> None: users = Router([Route("/{username}", ok, name="user")]) mounted = Router([Mount("/{subpath}/users", users, name="users")]) - assert ( - mounted.url_path_for("users:user", subpath="test", username="tom") - == "/test/users/tom" - ) - assert ( - mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom" - ) + assert mounted.url_path_for("users:user", subpath="test", username="tom") == "/test/users/tom" + assert mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom" def test_mount_at_root(test_client_factory: TestClientFactory) -> None: @@ -472,9 +445,7 @@ def test_host_routing(test_client_factory: TestClientFactory) -> None: response = client.get("/") assert response.status_code == 200 - client = test_client_factory( - mixed_hosts_app, base_url="https://port.example.org:3600/" - ) + client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:3600/") response = client.get("/users") assert response.status_code == 404 @@ -489,31 +460,23 @@ def test_host_routing(test_client_factory: TestClientFactory) -> None: response = client.get("/") assert response.status_code == 200 - client = test_client_factory( - mixed_hosts_app, base_url="https://port.example.org:5600/" - ) + client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:5600/") response = client.get("/") assert response.status_code == 200 def test_host_reverse_urls() -> None: + assert mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever") == "https://www.example.org/" assert ( - mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever") - == "https://www.example.org/" - ) - assert ( - mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever") - == "https://www.example.org/users" + mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever") == "https://www.example.org/users" ) assert ( mixed_hosts_app.url_path_for("api:users").make_absolute_url("https://whatever") == "https://api.example.org/users" ) assert ( - mixed_hosts_app.url_path_for("port:homepage").make_absolute_url( - "https://whatever" - ) + mixed_hosts_app.url_path_for("port:homepage").make_absolute_url("https://whatever") == "https://port.example.org:3600/" ) @@ -523,9 +486,7 @@ async def subdomain_app(scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) -subdomain_router = Router( - routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")] -) +subdomain_router = Router(routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")]) def test_subdomain_routing(test_client_factory: TestClientFactory) -> None: @@ -538,9 +499,9 @@ def test_subdomain_routing(test_client_factory: TestClientFactory) -> None: def test_subdomain_reverse_urls() -> None: assert ( - subdomain_router.url_path_for( - "subdomains", subdomain="foo", path="/homepage" - ).make_absolute_url("https://whatever") + subdomain_router.url_path_for("subdomains", subdomain="foo", path="/homepage").make_absolute_url( + "https://whatever" + ) == "https://foo.example.org/homepage" ) @@ -566,9 +527,7 @@ async def echo_urls(request: Request) -> JSONResponse: def test_url_for_with_root_path(test_client_factory: TestClientFactory) -> None: app = Starlette(routes=echo_url_routes) - client = test_client_factory( - app, base_url="https://www.example.org/", root_path="/sub_path" - ) + client = test_client_factory(app, base_url="https://www.example.org/", root_path="/sub_path") response = client.get("/sub_path/") assert response.json() == { "index": "https://www.example.org/sub_path/", @@ -657,9 +616,7 @@ async def run_shutdown() -> None: nonlocal shutdown_complete shutdown_complete = True - with pytest.deprecated_call( - match="The on_startup and on_shutdown parameters are deprecated" - ): + with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"): app = Router( on_startup=[run_startup], on_shutdown=[run_shutdown], @@ -697,18 +654,11 @@ def run_shutdown() -> None: # pragma: no cover nonlocal shutdown_called shutdown_called = True - with pytest.deprecated_call( - match="The on_startup and on_shutdown parameters are deprecated" - ): + with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"): with pytest.warns( - UserWarning, - match=( - "The `lifespan` parameter cannot be used with `on_startup` or `on_shutdown`." # noqa: E501 - ), + UserWarning, match="The `lifespan` parameter cannot be used with `on_startup` or `on_shutdown`." ): - app = Router( - on_startup=[run_startup], on_shutdown=[run_shutdown], lifespan=lifespan - ) + app = Router(on_startup=[run_startup], on_shutdown=[run_shutdown], lifespan=lifespan) assert not lifespan_called assert not startup_called @@ -738,9 +688,7 @@ def run_shutdown() -> None: nonlocal shutdown_complete shutdown_complete = True - with pytest.deprecated_call( - match="The on_startup and on_shutdown parameters are deprecated" - ): + with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"): app = Router( on_startup=[run_startup], on_shutdown=[run_shutdown], @@ -775,9 +723,7 @@ async def no_state_wrapper(scope: Scope, receive: Receive, send: Send) -> None: del scope["state"] await app(scope, receive, send) - with pytest.raises( - RuntimeError, match='The server does not support "state" in the lifespan scope' - ): + with pytest.raises(RuntimeError, match='The server does not support "state" in the lifespan scope'): with test_client_factory(no_state_wrapper): raise AssertionError("Should not be called") # pragma: no cover @@ -834,9 +780,7 @@ def test_raise_on_startup(test_client_factory: TestClientFactory) -> None: def run_startup() -> None: raise RuntimeError() - with pytest.deprecated_call( - match="The on_startup and on_shutdown parameters are deprecated" - ): + with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"): router = Router(on_startup=[run_startup]) startup_failed = False @@ -859,9 +803,7 @@ def test_raise_on_shutdown(test_client_factory: TestClientFactory) -> None: def run_shutdown() -> None: raise RuntimeError() - with pytest.deprecated_call( - match="The on_startup and on_shutdown parameters are deprecated" - ): + with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"): app = Router(on_shutdown=[run_shutdown]) with pytest.raises(RuntimeError): @@ -934,9 +876,7 @@ def __call__(self, request: Request) -> None: ... # pragma: no cover pytest.param(lambda request: ..., "", id="lambda"), ], ) -def test_route_name( - endpoint: typing.Callable[..., Response], expected_name: str -) -> None: +def test_route_name(endpoint: typing.Callable[..., Response], expected_name: str) -> None: assert Route(path="/", endpoint=endpoint).name == expected_name @@ -1172,10 +1112,7 @@ async def modified_send(msg: Message) -> None: def test_route_repr() -> None: route = Route("/welcome", endpoint=homepage) - assert ( - repr(route) - == "Route(path='/welcome', name='homepage', methods=['GET', 'HEAD'])" - ) + assert repr(route) == "Route(path='/welcome', name='homepage', methods=['GET', 'HEAD'])" def test_route_repr_without_methods() -> None: @@ -1264,9 +1201,7 @@ async def echo_paths(request: Request, name: str) -> JSONResponse: ) -async def pure_asgi_echo_paths( - scope: Scope, receive: Receive, send: Send, name: str -) -> None: +async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str) -> None: data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]} content = json.dumps(data).encode("utf-8") await send( @@ -1304,9 +1239,7 @@ async def pure_asgi_echo_paths( def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None: app = Starlette(routes=echo_paths_routes) - client = test_client_factory( - app, base_url="https://www.example.org/", root_path="/root" - ) + client = test_client_factory(app, base_url="https://www.example.org/", root_path="/root") response = client.get("/root/path") assert response.status_code == 200 assert response.json() == { diff --git a/tests/test_schemas.py b/tests/test_schemas.py index f4a5b4ad9..3b321ca0b 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -7,9 +7,7 @@ from starlette.websockets import WebSocket from tests.types import TestClientFactory -schemas = SchemaGenerator( - {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}} -) +schemas = SchemaGenerator({"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}}) def ws(session: WebSocket) -> None: @@ -142,7 +140,7 @@ def test_schema_generation() -> None: "get": { "responses": { 200: { - "description": "A list of " "organisations.", + "description": "A list of organisations.", "examples": [{"name": "Foo Corp."}, {"name": "Acme Ltd."}], } } @@ -157,25 +155,13 @@ def test_schema_generation() -> None: }, }, "/regular-docstring-and-schema": { - "get": { - "responses": { - 200: {"description": "This is included in the schema."} - } - } + "get": {"responses": {200: {"description": "This is included in the schema."}}} }, "/subapp/subapp-endpoint": { - "get": { - "responses": { - 200: {"description": "This endpoint is part of a subapp."} - } - } + "get": {"responses": {200: {"description": "This endpoint is part of a subapp."}}} }, "/subapp2/subapp-endpoint": { - "get": { - "responses": { - 200: {"description": "This endpoint is part of a subapp."} - } - } + "get": {"responses": {200: {"description": "This endpoint is part of a subapp."}}} }, "/users": { "get": { @@ -186,11 +172,7 @@ def test_schema_generation() -> None: } } }, - "post": { - "responses": { - 200: {"description": "A user.", "examples": {"username": "tom"}} - } - }, + "post": {"responses": {200: {"description": "A user.", "examples": {"username": "tom"}}}}, }, "/users/{id}": { "get": { diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 65d71b97b..8beb3cd87 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -31,9 +31,7 @@ def test_staticfiles(tmpdir: Path, test_client_factory: TestClientFactory) -> No assert response.text == "" -def test_staticfiles_with_pathlib( - tmp_path: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_with_pathlib(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "example.txt" with open(path, "w") as file: file.write("") @@ -45,9 +43,7 @@ def test_staticfiles_with_pathlib( assert response.text == "" -def test_staticfiles_head_with_middleware( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_head_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None: """ see https://github.com/encode/starlette/pull/935 """ @@ -55,9 +51,7 @@ def test_staticfiles_head_with_middleware( with open(path, "w") as file: file.write("x" * 100) - async def does_nothing_middleware( - request: Request, call_next: RequestResponseEndpoint - ) -> Response: + async def does_nothing_middleware(request: Request, call_next: RequestResponseEndpoint) -> Response: response = await call_next(request) return response @@ -99,9 +93,7 @@ def test_staticfiles_post(tmpdir: Path, test_client_factory: TestClientFactory) assert response.text == "Method Not Allowed" -def test_staticfiles_with_directory_returns_404( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_with_directory_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") @@ -115,9 +107,7 @@ def test_staticfiles_with_directory_returns_404( assert response.text == "Not Found" -def test_staticfiles_with_missing_file_returns_404( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_with_missing_file_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") @@ -138,9 +128,7 @@ def test_staticfiles_instantiated_with_missing_directory(tmpdir: Path) -> None: assert "does not exist" in str(exc_info.value) -def test_staticfiles_configured_with_missing_directory( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_configured_with_missing_directory(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "no_such_directory") app = StaticFiles(directory=path, check_dir=False) client = test_client_factory(app) @@ -163,9 +151,7 @@ def test_staticfiles_configured_with_file_instead_of_directory( assert "is not a directory" in str(exc_info.value) -def test_staticfiles_config_check_occurs_only_once( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_config_check_occurs_only_once(tmpdir: Path, test_client_factory: TestClientFactory) -> None: app = StaticFiles(directory=tmpdir) client = test_client_factory(app) assert not app.config_checked @@ -199,9 +185,7 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir: Path) -> None: assert exc_info.value.detail == "Not Found" -def test_staticfiles_never_read_file_for_head_method( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_never_read_file_for_head_method(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") @@ -214,9 +198,7 @@ def test_staticfiles_never_read_file_for_head_method( assert response.headers["content-length"] == "14" -def test_staticfiles_304_with_etag_match( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_304_with_etag_match(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") @@ -229,9 +211,7 @@ def test_staticfiles_304_with_etag_match( second_resp = client.get("/example.txt", headers={"if-none-match": last_etag}) assert second_resp.status_code == 304 assert second_resp.content == b"" - second_resp = client.get( - "/example.txt", headers={"if-none-match": f'W/{last_etag}, "123"'} - ) + second_resp = client.get("/example.txt", headers={"if-none-match": f'W/{last_etag}, "123"'}) assert second_resp.status_code == 304 assert second_resp.content == b"" @@ -240,9 +220,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req( tmpdir: Path, test_client_factory: TestClientFactory ) -> None: path = os.path.join(tmpdir, "example.txt") - file_last_modified_time = time.mktime( - time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S") - ) + file_last_modified_time = time.mktime(time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S")) with open(path, "w") as file: file.write("") os.utime(path, (file_last_modified_time, file_last_modified_time)) @@ -250,22 +228,16 @@ def test_staticfiles_304_with_last_modified_compare_last_req( app = StaticFiles(directory=tmpdir) client = test_client_factory(app) # last modified less than last request, 304 - response = client.get( - "/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"} - ) + response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"}) assert response.status_code == 304 assert response.content == b"" # last modified greater than last request, 200 with content - response = client.get( - "/example.txt", headers={"If-Modified-Since": "Thu, 20 Feb 2012 15:30:19 GMT"} - ) + response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 20 Feb 2012 15:30:19 GMT"}) assert response.status_code == 200 assert response.content == b"" -def test_staticfiles_html_normal( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_html_normal(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "404.html") with open(path, "w") as file: file.write("

Custom not found page

") @@ -298,9 +270,7 @@ def test_staticfiles_html_normal( assert response.text == "

Custom not found page

" -def test_staticfiles_html_without_index( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_html_without_index(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "404.html") with open(path, "w") as file: file.write("

Custom not found page

") @@ -325,9 +295,7 @@ def test_staticfiles_html_without_index( assert response.text == "

Custom not found page

" -def test_staticfiles_html_without_404( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_html_without_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "dir") os.mkdir(path) path = os.path.join(path, "index.html") @@ -352,9 +320,7 @@ def test_staticfiles_html_without_404( assert exc_info.value.status_code == 404 -def test_staticfiles_html_only_files( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_html_only_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "hello.html") with open(path, "w") as file: file.write("

Hello

") @@ -381,9 +347,7 @@ def test_staticfiles_cache_invalidation_for_deleted_file_html_mode( with open(path_some, "w") as file: file.write("

some file

") - common_modified_time = time.mktime( - time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S") - ) + common_modified_time = time.mktime(time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S")) os.utime(path_404, (common_modified_time, common_modified_time)) os.utime(path_some, (common_modified_time, common_modified_time)) @@ -435,9 +399,7 @@ def test_staticfiles_with_invalid_dir_permissions_returns_401( tmp_path.chmod(original_mode) -def test_staticfiles_with_missing_dir_returns_404( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_with_missing_dir_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") @@ -451,9 +413,7 @@ def test_staticfiles_with_missing_dir_returns_404( assert response.text == "Not Found" -def test_staticfiles_access_file_as_dir_returns_404( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_access_file_as_dir_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") @@ -467,9 +427,7 @@ def test_staticfiles_access_file_as_dir_returns_404( assert response.text == "Not Found" -def test_staticfiles_filename_too_long( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_filename_too_long(tmpdir: Path, test_client_factory: TestClientFactory) -> None: routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) client = test_client_factory(app) @@ -503,9 +461,7 @@ def mock_timeout(*args: typing.Any, **kwargs: typing.Any) -> None: assert response.text == "Internal Server Error" -def test_staticfiles_follows_symlinks( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_follows_symlinks(tmpdir: Path, test_client_factory: TestClientFactory) -> None: statics_path = os.path.join(tmpdir, "statics") os.mkdir(statics_path) @@ -526,9 +482,7 @@ def test_staticfiles_follows_symlinks( assert response.text == "

Hello

" -def test_staticfiles_follows_symlink_directories( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_staticfiles_follows_symlink_directories(tmpdir: Path, test_client_factory: TestClientFactory) -> None: statics_path = os.path.join(tmpdir, "statics") statics_html_path = os.path.join(statics_path, "html") os.mkdir(statics_path) diff --git a/tests/test_status.py b/tests/test_status.py index 04719e87e..4852c06ef 100644 --- a/tests/test_status.py +++ b/tests/test_status.py @@ -8,13 +8,11 @@ ( ( "WS_1004_NO_STATUS_RCVD", - "'WS_1004_NO_STATUS_RCVD' is deprecated. " - "Use 'WS_1005_NO_STATUS_RCVD' instead.", + "'WS_1004_NO_STATUS_RCVD' is deprecated. Use 'WS_1005_NO_STATUS_RCVD' instead.", ), ( "WS_1005_ABNORMAL_CLOSURE", - "'WS_1005_ABNORMAL_CLOSURE' is deprecated. " - "Use 'WS_1006_ABNORMAL_CLOSURE' instead.", + "'WS_1005_ABNORMAL_CLOSURE' is deprecated. Use 'WS_1006_ABNORMAL_CLOSURE' instead.", ), ), ) diff --git a/tests/test_templates.py b/tests/test_templates.py index 8e344f331..6b2080c17 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -36,9 +36,7 @@ async def homepage(request: Request) -> Response: assert set(response.context.keys()) == {"request"} # type: ignore -def test_calls_context_processors( - tmp_path: Path, test_client_factory: TestClientFactory -) -> None: +def test_calls_context_processors(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "index.html" path.write_text("Hello {{ username }}") @@ -66,9 +64,7 @@ def hello_world_processor(request: Request) -> dict[str, str]: assert set(response.context.keys()) == {"request", "username"} # type: ignore -def test_template_with_middleware( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_template_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: file.write("Hello, world") @@ -77,9 +73,7 @@ async def homepage(request: Request) -> Response: return templates.TemplateResponse(request, "index.html") class CustomMiddleware(BaseHTTPMiddleware): - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: return await call_next(request) app = Starlette( @@ -96,9 +90,7 @@ async def dispatch( assert set(response.context.keys()) == {"request"} # type: ignore -def test_templates_with_directories( - tmp_path: Path, test_client_factory: TestClientFactory -) -> None: +def test_templates_with_directories(tmp_path: Path, test_client_factory: TestClientFactory) -> None: dir_a = tmp_path.resolve() / "a" dir_a.mkdir() template_a = dir_a / "template_a.html" @@ -134,16 +126,12 @@ async def page_b(request: Request) -> Response: def test_templates_require_directory_or_environment() -> None: - with pytest.raises( - AssertionError, match="either 'directory' or 'env' arguments must be passed" - ): + with pytest.raises(AssertionError, match="either 'directory' or 'env' arguments must be passed"): Jinja2Templates() # type: ignore[call-overload] def test_templates_require_directory_or_enviroment_not_both() -> None: - with pytest.raises( - AssertionError, match="either 'directory' or 'env' arguments must be passed" - ): + with pytest.raises(AssertionError, match="either 'directory' or 'env' arguments must be passed"): Jinja2Templates(directory="dir", env=jinja2.Environment()) @@ -157,9 +145,7 @@ def test_templates_with_directory(tmpdir: Path) -> None: assert template.render({}) == "Hello" -def test_templates_with_environment( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_templates_with_environment(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: file.write("Hello, world") @@ -185,9 +171,7 @@ def test_templates_with_environment_options_emit_warning(tmpdir: Path) -> None: Jinja2Templates(str(tmpdir), autoescape=True) -def test_templates_with_kwargs_only( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_templates_with_kwargs_only(tmpdir: Path, test_client_factory: TestClientFactory) -> None: # MAINTAINERS: remove after 1.0 path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: @@ -242,9 +226,7 @@ def test_templates_with_kwargs_only_warns_when_no_request_keyword( templates = Jinja2Templates(directory=str(tmpdir)) def page(request: Request) -> Response: - return templates.TemplateResponse( - name="index.html", context={"request": request} - ) + return templates.TemplateResponse(name="index.html", context={"request": request}) app = Starlette(routes=[Route("/", page)]) client = test_client_factory(app) @@ -297,9 +279,7 @@ def page(request: Request) -> Response: spy.assert_called() -def test_templates_when_first_argument_is_request( - tmpdir: Path, test_client_factory: TestClientFactory -) -> None: +def test_templates_when_first_argument_is_request(tmpdir: Path, test_client_factory: TestClientFactory) -> None: # MAINTAINERS: remove after 1.0 path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 77de3d976..92f16d336 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -88,9 +88,7 @@ def test_testclient_headers_behavior() -> None: assert client.headers.get("Authentication") == "Bearer 123" -def test_use_testclient_as_contextmanager( - test_client_factory: TestClientFactory, anyio_backend_name: str -) -> None: +def test_use_testclient_as_contextmanager(test_client_factory: TestClientFactory, anyio_backend_name: str) -> None: """ This test asserts a number of properties that are important for an app level task_group @@ -169,9 +167,7 @@ async def loop_id(request: Request) -> JSONResponse: def test_error_on_startup(test_client_factory: TestClientFactory) -> None: - with pytest.deprecated_call( - match="The on_startup and on_shutdown parameters are deprecated" - ): + with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"): startup_error_app = Starlette(on_startup=[startup]) with pytest.raises(RuntimeError): @@ -306,8 +302,7 @@ def homepage(request: Request) -> Response: marks=[ pytest.mark.xfail( sys.version_info < (3, 11), - reason="Fails due to domain handling in http.cookiejar module (see " - "#2152)", + reason="Fails due to domain handling in http.cookiejar module (see #2152)", ), ], ), @@ -316,9 +311,7 @@ def homepage(request: Request) -> Response: ("example.com", False), ], ) -def test_domain_restricted_cookies( - test_client_factory: TestClientFactory, domain: str, ok: bool -) -> None: +def test_domain_restricted_cookies(test_client_factory: TestClientFactory, domain: str, ok: bool) -> None: """ Test that test client discards domain restricted cookies which do not match the base_url of the testclient (`http://testserver` by default). diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 16d2d0f1f..7a9b9272a 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -270,8 +270,7 @@ async def receive() -> Message: async def send(message: Message) -> None: if message["type"] == "websocket.accept": return - # Simulate the exception the server would send to the application when the - # client disconnects. + # Simulate the exception the server would send to the application when the client disconnects. raise OSError with pytest.raises(WebSocketDisconnect) as ctx: @@ -334,19 +333,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: "headers": [(b"content-type", b"text/plain"), (b"foo", b"bar")], } ) - await websocket.send( - { - "type": "websocket.http.response.body", - "body": b"hard", - "more_body": True, - } - ) - await websocket.send( - { - "type": "websocket.http.response.body", - "body": b"body", - } - ) + await websocket.send({"type": "websocket.http.response.body", "body": b"hard", "more_body": True}) + await websocket.send({"type": "websocket.http.response.body", "body": b"body"}) client = test_client_factory(app) with pytest.raises(WebSocketDenialResponse) as exc: @@ -402,10 +390,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: client = test_client_factory(app) with pytest.raises( RuntimeError, - match=( - 'Expected ASGI message "websocket.http.response.body", but got ' - "'websocket.http.response.start'" - ), + match=("Expected ASGI message \"websocket.http.response.body\", but got 'websocket.http.response.start'"), ): with client.websocket_connect("/"): pass # pragma: no cover @@ -493,11 +478,7 @@ async def mock_receive() -> Message: # type: ignore async def mock_send(message: Message) -> None: ... # pragma: no cover - websocket = WebSocket( - {"type": "websocket", "path": "/abc/", "headers": []}, - receive=mock_receive, - send=mock_send, - ) + websocket = WebSocket({"type": "websocket", "path": "/abc/", "headers": []}, receive=mock_receive, send=mock_send) assert websocket["type"] == "websocket" assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []} assert len(websocket) == 3