Skip to content

Commit

Permalink
fix coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Aug 21, 2024
1 parent 0679356 commit c50731d
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 61 deletions.
12 changes: 8 additions & 4 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@


@typing.overload
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]:
...


@typing.overload
def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]: ...
def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]:
...


def is_async_callable(obj: typing.Any) -> typing.Any:
Expand All @@ -47,11 +49,13 @@ def is_async_callable(obj: typing.Any) -> typing.Any:

class AwaitableOrContextManager(
typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]
): ...
):
...


class SupportsAsyncClose(typing.Protocol):
async def close(self) -> None: ... # pragma: no cover
async def close(self) -> None:
... # pragma: no cover


SupportsAsyncCloseType = typing.TypeVar(
Expand Down
17 changes: 10 additions & 7 deletions starlette/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,29 @@ def __init__(
self.file_values = self._read_file(env_file)

@typing.overload
def __call__(self, key: str, *, default: None) -> str | None: ...
def __call__(self, key: str, *, default: None) -> str | None:
...

@typing.overload
def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ...
def __call__(self, key: str, cast: type[T], default: T = ...) -> T:
...

@typing.overload
def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ...
def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str:
...

@typing.overload
def __call__(
self,
key: str,
cast: typing.Callable[[typing.Any], T] = ...,
default: typing.Any = ...,
) -> T: ...
) -> T:
...

@typing.overload
def __call__(
self, key: str, cast: type[str] = ..., default: T = ...
) -> T | str: ...
def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str:
...

def __call__(
self,
Expand Down
10 changes: 4 additions & 6 deletions starlette/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@


class _MiddlewareClass(Protocol[P]):
def __init__(
self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs
) -> None: ... # pragma: no cover
def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None:
... # pragma: no cover

async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> None: ... # pragma: no cover
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
... # pragma: no cover


class Middleware:
Expand Down
9 changes: 6 additions & 3 deletions starlette/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def __init__(
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
| None = None,
**env_options: typing.Any,
) -> None: ...
) -> None:
...

@typing.overload
def __init__(
Expand All @@ -80,7 +81,8 @@ def __init__(
env: jinja2.Environment,
context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
| None = None,
) -> None: ...
) -> None:
...

def __init__(
self,
Expand Down Expand Up @@ -148,7 +150,8 @@ def TemplateResponse(
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> _TemplateResponse: ...
) -> _TemplateResponse:
...

@typing.overload
def TemplateResponse(
Expand Down
27 changes: 6 additions & 21 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,7 @@ async def passthrough(
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}
raise NotImplementedError('Should not be called!') # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand Down Expand Up @@ -346,12 +340,7 @@ async def passthrough(
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
await response_complete.wait()
return {"type": "http.disconnect"}
raise NotImplementedError('Should not be called!') # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand Down Expand Up @@ -425,13 +414,7 @@ async def passthrough(
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}
raise NotImplementedError('Should not be called!') # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand Down Expand Up @@ -779,7 +762,9 @@ async def rcv() -> AsyncGenerator[Message, None]:
yield {"type": "http.request", "body": b"1", "more_body": True}
yield {"type": "http.request", "body": b"2", "more_body": True}
yield {"type": "http.request", "body": b"3"}
await anyio.sleep(float("inf"))
raise AssertionError( # pragma: no cover
"Should not be called, no need to poll for disconnect"
)

sent: list[Message] = []

Expand Down
33 changes: 22 additions & 11 deletions tests/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@


def test_async_func() -> None:
async def async_func() -> None: ... # pragma: no cover
async def async_func() -> None:
... # pragma: no cover

def func() -> None: ... # pragma: no cover
def func() -> None:
... # pragma: no cover

assert is_async_callable(async_func)
assert not is_async_callable(func)


def test_async_partial() -> None:
async def async_func(a: Any, b: Any) -> None: ... # pragma: no cover
async def async_func(a: Any, b: Any) -> None:
... # pragma: no cover

def func(a: Any, b: Any) -> None: ... # pragma: no cover
def func(a: Any, b: Any) -> None:
... # pragma: no cover

partial = functools.partial(async_func, 1)
assert is_async_callable(partial)
Expand All @@ -27,21 +31,25 @@ def func(a: Any, b: Any) -> None: ... # pragma: no cover

def test_async_method() -> None:
class Async:
async def method(self) -> None: ... # pragma: no cover
async def method(self) -> None:
... # pragma: no cover

class Sync:
def method(self) -> None: ... # pragma: no cover
def method(self) -> None:
... # pragma: no cover

assert is_async_callable(Async().method)
assert not is_async_callable(Sync().method)


def test_async_object_call() -> None:
class Async:
async def __call__(self) -> None: ... # pragma: no cover
async def __call__(self) -> None:
... # pragma: no cover

class Sync:
def __call__(self) -> None: ... # pragma: no cover
def __call__(self) -> None:
... # pragma: no cover

assert is_async_callable(Async())
assert not is_async_callable(Sync())
Expand All @@ -53,14 +61,16 @@ async def __call__(
self,
a: Any,
b: Any,
) -> None: ... # pragma: no cover
) -> None:
... # pragma: no cover

class Sync:
def __call__(
self,
a: Any,
b: Any,
) -> None: ... # pragma: no cover
) -> None:
... # pragma: no cover

partial = functools.partial(Async(), 1)
assert is_async_callable(partial)
Expand All @@ -73,7 +83,8 @@ def test_async_nested_partial() -> None:
async def async_func(
a: Any,
b: Any,
) -> None: ... # pragma: no cover
) -> None:
... # pragma: no cover

partial = functools.partial(async_func, b=2)
nested_partial = functools.partial(partial, a=1)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ def test_decorator_deprecations() -> None:

async def middleware(
request: Request, call_next: RequestResponseEndpoint
) -> None: ... # pragma: no cover
) -> None:
... # pragma: no cover

app.middleware("http")(middleware)
assert len(record) == 1
Expand Down Expand Up @@ -492,7 +493,8 @@ async def middleware(
)
) as record:

async def startup() -> None: ... # pragma: no cover
async def startup() -> None:
... # pragma: no cover

app.on_event("startup")(startup)
assert len(record) == 1
Expand Down
15 changes: 10 additions & 5 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,15 +908,19 @@ def test_duplicated_param_names() -> None:


class Endpoint:
async def my_method(self, request: Request) -> None: ... # pragma: no cover
async def my_method(self, request: Request) -> None:
... # pragma: no cover

@classmethod
async def my_classmethod(cls, request: Request) -> None: ... # pragma: no cover
async def my_classmethod(cls, request: Request) -> None:
... # pragma: no cover

@staticmethod
async def my_staticmethod(request: Request) -> None: ... # pragma: no cover
async def my_staticmethod(request: Request) -> None:
... # pragma: no cover

def __call__(self, request: Request) -> None: ... # pragma: no cover
def __call__(self, request: Request) -> None:
... # pragma: no cover


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1249,7 +1253,8 @@ def test_decorator_deprecations() -> None:

with pytest.deprecated_call():

async def startup() -> None: ... # pragma: nocover
async def startup() -> None:
... # pragma: nocover

router.on_event("startup")(startup)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,8 @@ def test_websocket_scope_interface() -> None:
async def mock_receive() -> Message: # type: ignore
... # pragma: no cover

async def mock_send(message: Message) -> None: ... # pragma: no cover
async def mock_send(message: Message) -> None:
... # pragma: no cover

websocket = WebSocket(
{"type": "websocket", "path": "/abc/", "headers": []},
Expand Down
3 changes: 2 additions & 1 deletion tests/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def __call__(
cookies: httpx._types.CookieTypes | None = None,
headers: dict[str, str] | None = None,
follow_redirects: bool = True,
) -> TestClient: ...
) -> TestClient:
...
else: # pragma: no cover

class TestClientFactory:
Expand Down

0 comments on commit c50731d

Please sign in to comment.