Skip to content

Commit

Permalink
fix(WebSocket): properly handle HTTPStatus and HTTPError cases (#2150)
Browse files Browse the repository at this point in the history
* fix(websocket): properly handle HTTPStatus and HTTPError cases when using websockets

* test: fix failing tests

* test: add missing test coverage

* docs: add newsfragment

* chore: explicit raise for unsupported call of internal methods

* style: fix imports

* docs(towncrier): improve the bugfix newsfragment

---------

Co-authored-by: Vytautas Liuolia <vytautas.liuolia@gmail.com>
  • Loading branch information
CaselIT and vytas7 committed Nov 12, 2023
1 parent 553e10c commit 4381638
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 15 deletions.
6 changes: 6 additions & 0 deletions docs/_newsfragments/2146.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

:ref:`WebSocket <ws>` implementation has been fixed to properly handle
:class:`~falcon.HTTPError` and :class:`~falcon.HTTPStatus` exceptions raised by
custom :func:`error handlers <falcon.asgi.App.add_error_handler>`.
The WebSocket connection is now correctly closed with an appropriate code
instead of bubbling up an unhandled error to the application server.
40 changes: 29 additions & 11 deletions falcon/asgi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .request import Request
from .response import Response
from .structures import SSEvent
from .ws import http_status_to_ws_code
from .ws import WebSocket
from .ws import WebSocketOptions

Expand Down Expand Up @@ -1027,30 +1028,47 @@ def _prepare_middleware(self, middleware=None, independent_middleware=False):
asgi=True,
)

async def _http_status_handler(self, req, resp, status, params):
self._compose_status_response(req, resp, status)
async def _http_status_handler(self, req, resp, status, params, ws=None):
if resp:
self._compose_status_response(req, resp, status)
elif ws:
code = http_status_to_ws_code(status.status)
falcon._logger.error(
'[FALCON] HTTPStatus %s raised while handling WebSocket. '
'Closing with code %s',
status,
code,
)
await ws.close(code)
else:
raise NotImplementedError('resp or ws expected')

async def _http_error_handler(self, req, resp, error, params, ws=None):
if resp:
self._compose_error_response(req, resp, error)

if ws:
elif ws:
# NOTE(vytas): error.status_code is not yet in this backport.
# code = http_status_to_ws_code(error.status_code)
code = http_status_to_ws_code(http_status_to_code(error.status))
falcon._logger.error(
'[FALCON] WebSocket handshake rejected due to raised HTTP error: %s',
'[FALCON] HTTPError %s raised while handling WebSocket. '
'Closing with code %s',
error,
code,
)

code = 3000 + falcon.util.http_status_to_code(error.status)
await ws.close(code)
else:
raise NotImplementedError('resp or ws expected')

async def _python_error_handler(self, req, resp, error, params, ws=None):
falcon._logger.error('[FALCON] Unhandled exception in ASGI app', exc_info=error)

if resp:
self._compose_error_response(req, resp, falcon.HTTPInternalServerError())

if ws:
elif ws:
await self._ws_cleanup_on_error(ws)
else:
raise NotImplementedError('resp or ws expected')

async def _ws_disconnected_error_handler(self, req, resp, error, params, ws):
falcon._logger.debug(
Expand Down Expand Up @@ -1095,9 +1113,9 @@ async def _handle_exception(self, req, resp, ex, params, ws=None):
await err_handler(req, resp, ex, params, **kwargs)

except HTTPStatus as status:
self._compose_status_response(req, resp, status)
await self._http_status_handler(req, resp, status, params, ws=ws)
except HTTPError as error:
self._compose_error_response(req, resp, error)
await self._http_error_handler(req, resp, error, params, ws=ws)

return True

Expand Down
5 changes: 5 additions & 0 deletions falcon/asgi/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,3 +701,8 @@ async def _pump(self):
if self._pop_message_waiter is not None:
self._pop_message_waiter.set_result(None)
self._pop_message_waiter = None


def http_status_to_ws_code(http_status: int) -> int:
"""Convert the provided http status to a websocket close code by adding 3000."""
return http_status + 3000
8 changes: 5 additions & 3 deletions falcon/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,18 +456,20 @@ def code_to_http_status(status):
if isinstance(status, http.HTTPStatus):
return '{} {}'.format(status.value, status.phrase)

if isinstance(status, str):
# NOTE(kgriffs): If it is a str but does not have a space, assume it is
# just the number by itself.
if isinstance(status, str) and ' ' in status:
return status

if isinstance(status, bytes):
return status.decode()

try:
code = int(status)
if not 100 <= code <= 999:
raise ValueError('{} is not a valid status code'.format(status))
except (ValueError, TypeError):
raise ValueError('{!r} is not a valid status code'.format(status))
if not 100 <= code <= 999:
raise ValueError('{} is not a valid status code'.format(status))

try:
# NOTE(kgriffs): We do this instead of using http.HTTPStatus since
Expand Down
32 changes: 32 additions & 0 deletions tests/asgi/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# misc test for 100% coverage

from unittest.mock import MagicMock

import pytest

from falcon.asgi import App
from falcon.http_error import HTTPError
from falcon.http_status import HTTPStatus


@pytest.mark.asyncio
async def test_http_status_not_impl():
app = App()
with pytest.raises(NotImplementedError):
await app._http_status_handler(MagicMock(), None, HTTPStatus(200), {}, None)


@pytest.mark.asyncio
async def test_http_error_not_impl():
app = App()
with pytest.raises(NotImplementedError):
await app._http_error_handler(MagicMock(), None, HTTPError(400), {}, None)


@pytest.mark.asyncio
async def test_python_error_not_impl():
app = App()
with pytest.raises(NotImplementedError):
await app._python_error_handler(
MagicMock(), None, ValueError('error'), {}, None
)
123 changes: 123 additions & 0 deletions tests/asgi/test_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,3 +1093,126 @@ def test_msgpack_missing():

with pytest.raises(RuntimeError):
handler.deserialize(b'{}')


@pytest.mark.asyncio
@pytest.mark.parametrize('status', [200, 500, 422, 400])
@pytest.mark.parametrize('thing', [falcon.HTTPStatus, falcon.HTTPError])
@pytest.mark.parametrize('accept', [True, False])
async def test_ws_http_error_or_status_response(conductor, status, thing, accept):
class Resource:
async def on_websocket(self, req, ws):
if accept:
await ws.accept()
raise thing(status)

conductor.app.add_route('/', Resource())
exp_code = 3000 + status

async with conductor as c:
if accept:
async with c.simulate_ws() as ws:
assert ws.closed
assert ws.close_code == exp_code
else:
with pytest.raises(falcon.WebSocketDisconnected) as err:
async with c.simulate_ws():
pass
assert err.value.code == exp_code


@pytest.mark.asyncio
@pytest.mark.parametrize('status', [200, 500, 422, 400])
@pytest.mark.parametrize(
'thing',
[
falcon.HTTPStatus,
falcon.HTTPError,
],
)
@pytest.mark.parametrize('place', ['request', 'resource'])
async def test_ws_http_error_or_status_middleware(conductor, status, thing, place):
called = False

class Resource:
async def on_websocket(self, req, ws):
nonlocal called
called = True

class Middleware:
async def process_request_ws(self, req, ws):
if place == 'request':
raise thing(status)

async def process_resource_ws(self, req, ws, res, params):
if place == 'resource':
raise thing(status)

conductor.app.add_route('/', Resource())
conductor.app.add_middleware(Middleware())
exp_code = 3000 + status

async with conductor as c:
with pytest.raises(falcon.WebSocketDisconnected) as err:
async with c.simulate_ws():
pass
assert err.value.code == exp_code
assert not called


class FooBarError(Exception):
pass


@pytest.mark.asyncio
@pytest.mark.parametrize('status', [200, 500, 422, 400])
@pytest.mark.parametrize('thing', [falcon.HTTPStatus, falcon.HTTPError])
@pytest.mark.parametrize(
'place', ['request', 'resource', 'ws_before_accept', 'ws_after_accept']
)
@pytest.mark.parametrize('handler_has_ws', [True, False])
async def test_ws_http_error_or_status_error_handler(
conductor, status, thing, place, handler_has_ws
):
class Resource:
async def on_websocket(self, req, ws):
if place == 'ws_before_accept':
raise FooBarError
await ws.accept()
if place == 'ws_after_accept':
raise FooBarError

class Middleware:
async def process_request_ws(self, req, ws):
if place == 'request':
raise FooBarError

async def process_resource_ws(self, req, ws, res, params):
if place == 'resource':
raise FooBarError

if handler_has_ws:

async def handle_foobar(req, resp, ex, param, ws=None): # type: ignore
raise thing(status)

else:

async def handle_foobar(req, resp, ex, param): # type: ignore
raise thing(status)

conductor.app.add_route('/', Resource())
conductor.app.add_middleware(Middleware())
conductor.app.add_error_handler(FooBarError, handle_foobar)
exp_code = 3000 + status

async with conductor as c:
if place == 'ws_after_accept':
async with c.simulate_ws() as ws:
assert ws.closed
assert ws.close_code == exp_code
else:
with pytest.raises(falcon.WebSocketDisconnected) as err:
async with c.simulate_ws():
pass
assert err.value.code == exp_code
18 changes: 17 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,23 @@ def test_get_http_status(self):
def test_code_to_http_status(self, v_in, v_out):
assert falcon.code_to_http_status(v_in) == v_out

@pytest.mark.parametrize('v', [0, 13, 99, 1000, 1337.01, -99, -404.3, -404, -404.3])
@pytest.mark.parametrize(
'v',
[
0,
13,
99,
1000,
1337.01,
-99,
-404.3,
-404,
-404.3,
'Successful',
'Failed',
None,
],
)
def test_code_to_http_status_value_error(self, v):
with pytest.raises(ValueError):
falcon.code_to_http_status(v)
Expand Down

0 comments on commit 4381638

Please sign in to comment.