Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sandbox: Enable mypy type checker #560

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ markers = [
]
xfail_strict = true

[tool.mypy]
packages = [
"responder",
]
exclude = [
]
check_untyped_defs = true
explicit_package_bases = true
ignore_missing_imports = true
implicit_optional = true
install_types = true
namespace_packages = true
non_interactive = true

[tool.poe.tasks]

check = [
Expand Down Expand Up @@ -98,7 +112,7 @@ lint = [
{ cmd = "ruff format --check ." },
{ cmd = "ruff check ." },
{ cmd = "validate-pyproject pyproject.toml" },
# { cmd = "mypy" },
{ cmd = "mypy" },
]

release = [
Expand Down
11 changes: 8 additions & 3 deletions responder/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,16 @@ async def _static_response(self, req, resp):
with open(index, "r") as f:
resp.html = f.read()
else:
resp.status_code = status_codes.HTTP_404
resp.status_code = status_codes.HTTP_404 # type: ignore[attr-defined]
resp.text = "Not found."

def redirect(
self, resp, location, *, set_text=True, status_code=status_codes.HTTP_301
self,
resp,
location,
*,
set_text=True,
status_code=status_codes.HTTP_301, # type: ignore[attr-defined]
):
"""
Redirects a given response to a given location.
Expand Down Expand Up @@ -365,7 +370,7 @@ def serve(self, *, address=None, port=None, debug=False, **options):
port = 5042

def spawn():
uvicorn.run(self, host=address, port=port, debug=debug, **options)
uvicorn.run(self, host=address, port=port, **options)

spawn()

Expand Down
2 changes: 1 addition & 1 deletion responder/ext/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,6 @@ def docs_response(self, req, resp):
resp.html = self.docs

def schema_response(self, req, resp):
resp.status_code = status_codes.HTTP_200
resp.status_code = status_codes.HTTP_200 # type: ignore[attr-defined]
resp.headers["Content-Type"] = "application/x-yaml"
resp.content = self.openapi
41 changes: 31 additions & 10 deletions responder/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

from .statics import DEFAULT_ENCODING
from .status_codes import HTTP_301
from .status_codes import HTTP_301 # type: ignore[attr-defined]


class QueryDict(dict):
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(self, scope, receive, api=None, formats=None):
self.api = api
self._content = None

headers = CaseInsensitiveDict()
headers: CaseInsensitiveDict = CaseInsensitiveDict()
for key, value in self._starlette.headers.items():
headers[key] = value

Expand Down Expand Up @@ -150,7 +150,7 @@ def cookies(self):
cookies = RequestsCookieJar()
cookie_header = self.headers.get("Cookie", "")

bc = SimpleCookie(cookie_header)
bc: SimpleCookie = SimpleCookie(cookie_header)
for key, morsel in bc.items():
cookies[key] = morsel.value

Expand Down Expand Up @@ -239,9 +239,20 @@ async def media(self, format: t.Union[str, t.Callable] = None): # noqa: A001, A
format = "yaml" if "yaml" in self.mimetype or "" else "json" # noqa: A001
format = "form" if "form" in self.mimetype or "" else format # noqa: A001

if format in self.formats:
return await self.formats[format](self)
return await format(self)
formatter: t.Callable
if isinstance(format, str):
try:
formatter = self.formats[format]
except KeyError as ex:
raise ValueError(f"Unable to process data in '{format}' format") from ex

elif callable(format):
formatter = format

else:
raise TypeError(f"Invalid 'format' argument: {format}")

return await formatter(self)


def content_setter(mimetype):
Expand Down Expand Up @@ -275,7 +286,8 @@ class Response:

def __init__(self, req, *, formats):
self.req = req
self.status_code = None #: The HTTP Status Code to use for the Response.
#: The HTTP Status Code to use for the Response.
self.status_code: t.Union[int, None] = None
self.content = None #: A bytes representation of the response body.
self.mimetype = None
self.encoding = DEFAULT_ENCODING
Expand All @@ -285,7 +297,7 @@ def __init__(self, req, *, formats):
self.headers = {} #: A Python dictionary of ``{key: value}``,
#: representing the headers of the response.
self.formats = formats
self.cookies = SimpleCookie() #: The cookies set in the Response
self.cookies: SimpleCookie = SimpleCookie() #: The cookies set in the Response
self.session = (
req.session
) #: The cookie-based session data, in dict form, to add to the Response.
Expand Down Expand Up @@ -365,16 +377,25 @@ async def __call__(self, scope, receive, send):
if self.headers:
headers.update(self.headers)

response_cls: t.Union[
t.Type[StarletteResponse], t.Type[StarletteStreamingResponse]
]
if self._stream is not None:
response_cls = StarletteStreamingResponse
else:
response_cls = StarletteResponse

response = response_cls(body, status_code=self.status_code, headers=headers)
response = response_cls(body, status_code=self.status_code_safe, headers=headers)
self._prepare_cookies(response)

await response(scope, receive, send)

@property
def ok(self):
return 200 <= self.status_code < 300
return 200 <= self.status_code_safe < 300

@property
def status_code_safe(self) -> int:
if self.status_code is None:
raise RuntimeError("HTTP status code has not been defined")
return self.status_code
14 changes: 8 additions & 6 deletions responder/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import inspect
import re
import traceback
import typing as t
from collections import defaultdict

from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.middleware.wsgi import WSGIMiddleware
from starlette.types import ASGIApp
from starlette.websockets import WebSocket, WebSocketClose

from . import status_codes
Expand Down Expand Up @@ -121,7 +123,7 @@ async def __call__(self, scope, receive, send):
views.append(view)
except AttributeError as ex:
if on_request is None:
raise HTTPException(status_code=status_codes.HTTP_405) from ex
raise HTTPException(status_code=status_codes.HTTP_405) from ex # type: ignore[attr-defined]
else:
views.append(self.endpoint)

Expand All @@ -135,7 +137,7 @@ async def __call__(self, scope, receive, send):
await run_in_threadpool(view, request, response, **path_params)

if response.status_code is None:
response.status_code = status_codes.HTTP_200
response.status_code = status_codes.HTTP_200 # type: ignore[attr-defined]

await response(scope, receive, send)

Expand Down Expand Up @@ -207,7 +209,7 @@ class Router:
def __init__(self, routes=None, default_response=None, before_requests=None):
self.routes = [] if routes is None else list(routes)
# [TODO] Make its own router
self.apps = {}
self.apps: t.Dict[str, ASGIApp] = {}
self.default_endpoint = (
self.default_response if default_response is None else default_response
)
Expand Down Expand Up @@ -255,7 +257,7 @@ def add_route(

def mount(self, route, app):
"""Mounts ASGI / WSGI applications at a given route"""
self.apps.update(route, app)
self.apps.update({route: app})

def add_event_handler(self, event_type, handler):
assert event_type in (
Expand Down Expand Up @@ -287,14 +289,14 @@ def url_for(self, endpoint, **params):
async def default_response(self, scope, receive, send):
if scope["type"] == "websocket":
websocket_close = WebSocketClose()
await websocket_close(receive, send)
await websocket_close(scope, receive, send)
return

# FIXME: Please review!
request = Request(scope, receive)
response = Response(request, formats=get_formats()) # noqa: F841

raise HTTPException(status_code=status_codes.HTTP_404)
raise HTTPException(status_code=status_codes.HTTP_404) # type: ignore[attr-defined]

def _resolve_route(self, scope):
for route in self.routes:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def run(self):
],
"graphql": ["graphene"],
"release": ["build", "twine"],
"test": ["pytest", "pytest-cov", "pytest-mock", "flask"],
"test": ["flask", "mypy", "pytest", "pytest-cov", "pytest-mock"],
},
include_package_data=True,
license="Apache 2.0",
Expand Down
Loading