diff --git a/examples/endpoint_request_methods.py b/examples/endpoint_request_methods.py new file mode 100644 index 00000000..5987c123 --- /dev/null +++ b/examples/endpoint_request_methods.py @@ -0,0 +1,33 @@ +import responder + +api = responder.API() + + +@api.route("/{greeting}") +async def greet(req, resp, *, greeting): # any request method. + resp.text = f"{greeting}, world!" + + +@api.route("/me/{greeting}", methods=["POST"]) +async def greet_me(req, resp, *, greeting): + resp.text = f"POST - {greeting}, world!" + + +@api.route("/class/{greeting}") +class GreetingResource: + def on_get(self, req, resp, *, greeting): + resp.text = f"GET class - {greeting}, world!" + resp.headers.update({"X-Life": "42"}) + resp.status_code = api.status_codes.HTTP_201 + + def on_post(self, req, resp, *, greeting): + resp.text = f"POST class - {greeting}, world!" + resp.headers.update({"X-Life": "42"}) + + def on_request(self, req, resp, *, greeting): # any request method. + resp.text = f"any class - {greeting}, world!" + resp.headers.update({"X-Life": "42"}) + + +if __name__ == "__main__": + api.run() diff --git a/responder/api.py b/responder/api.py index 6b826448..a0f4e44c 100644 --- a/responder/api.py +++ b/responder/api.py @@ -189,6 +189,7 @@ def add_route( check_existing=True, websocket=False, before_request=False, + methods=(), ): """Adds a route to the API. @@ -212,6 +213,7 @@ def add_route( websocket=websocket, before_request=before_request, check_existing=check_existing, + methods=methods, ) async def _static_response(self, req, resp): diff --git a/responder/routes.py b/responder/routes.py index e5fa2d2d..9593febe 100644 --- a/responder/routes.py +++ b/responder/routes.py @@ -56,11 +56,12 @@ async def __call__(self, scope, receive, send): class Route(BaseRoute): - def __init__(self, route, endpoint, *, before_request=False): + def __init__(self, route, endpoint, *, before_request=False, methods=()): assert route.startswith("/"), "Route path must start with '/'" self.route = route self.endpoint = endpoint self.before_request = before_request + self.methods = methods self.path_re, self.param_convertors = compile_path(route) @@ -123,6 +124,10 @@ async def __call__(self, scope, receive, send): if on_request is None: raise HTTPException(status_code=status_codes.HTTP_405) from None else: + if self.methods and request.method not in [ + method.lower() for method in self.methods + ]: + raise HTTPException(status_code=status_codes.HTTP_405) from None views.append(self.endpoint) for view in views: @@ -225,11 +230,13 @@ def add_route( websocket=False, before_request=False, check_existing=False, + methods=(), ): """Adds a route to the router. :param route: A string representation of the route :param endpoint: The endpoint for the route -- can be callable, or class. :param default: If ``True``, all unknown requests will route to this view. + :param methods: A list of supported request methods for this endpoint. """ if before_request: if websocket: @@ -249,7 +256,7 @@ def add_route( if websocket: route = WebSocketRoute(route, endpoint) else: - route = Route(route, endpoint) + route = Route(route, endpoint, methods=methods) self.routes.append(route) diff --git a/tests/test_responder.py b/tests/test_responder.py index 6fe7a590..86880fe6 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -1024,7 +1024,7 @@ class Item(BaseModel): resp_mock = mocker.MagicMock() - @api.route("/create") + @api.route("/create", methods=["POST"]) @api.trust(Item) async def create_item(req, resp, *, data): resp.text = "created" @@ -1052,7 +1052,7 @@ class ItemSchema(Schema): resp_mock = mocker.MagicMock() - @api.route("/create") + @api.route("/create", methods=["POST"]) @api.trust(ItemSchema) async def create_item(req, resp, *, data): resp.text = "created" @@ -1072,3 +1072,68 @@ async def create_item(req, resp, *, data): response = api.requests.post(api.url_for(create_item), json=data) assert response.status_code == api.status_codes.HTTP_400 assert "error" in response.text + + +def test_endpoint_request_methods(api): + @api.route("/{greeting}") + async def greet(req, resp, *, greeting): # defaults to get. + resp.text = f"{greeting}, world!" + + @api.route("/me/{greeting}", methods=["POST"]) + async def greet_me(req, resp, *, greeting): + resp.text = f"POST - {greeting}, world!" + + @api.route("/no/{greeting}") + class NoGreeting: + pass + + @api.route("/class/{greeting}") + class GreetingResource: + def on_get(self, req, resp, *, greeting): + resp.text = f"GET class - {greeting}, world!" + resp.headers.update({"X-Life": "41"}) + resp.status_code = api.status_codes.HTTP_201 + + def on_post(self, req, resp, *, greeting): + resp.text = f"POST class - {greeting}, world!" + resp.headers.update({"X-Life": "42"}) + + def on_request(self, req, resp, *, greeting): # any request method. + resp.text = f"any class - {greeting}, world!" + resp.headers.update({"X-Life": "43"}) + + resp = api.requests.get("http://;/Hello") + assert resp.status_code == api.status_codes.HTTP_200 + assert resp.text == "Hello, world!" + + resp = api.requests.post("http://;/Hello") + assert resp.status_code == api.status_codes.HTTP_200 + assert resp.text == "Hello, world!" + + resp = api.requests.get("http://;/me/Hey") + assert resp.status_code == api.status_codes.HTTP_405 + + resp = api.requests.post("http://;/me/Hey") + assert resp.status_code == api.status_codes.HTTP_200 + assert resp.text == "POST - Hey, world!" + + resp = api.requests.get("http://;/no/Hello") + assert resp.status_code == api.status_codes.HTTP_405 + + resp = api.requests.post("http://;/no/Hello") + assert resp.status_code == api.status_codes.HTTP_405 + + resp = api.requests.get("http://;/class/Hi") + assert resp.text == "GET class - Hi, world!" + assert resp.headers["X-Life"] == "41" + assert resp.status_code == api.status_codes.HTTP_201 + + resp = api.requests.post("http://;/class/Hi") + assert resp.text == "POST class - Hi, world!" + assert resp.headers["X-Life"] == "42" + assert resp.status_code == api.status_codes.HTTP_200 + + resp = api.requests.put("http://;/class/Hi") + assert resp.text == "any class - Hi, world!" + assert resp.headers["X-Life"] == "43" + assert resp.status_code == api.status_codes.HTTP_200