diff --git a/tests/test_networks/test_brokers/test_handlers/test_handlers.py b/tests/test_networks/test_brokers/test_handlers/test_handlers.py index a1bbfa95..daa873df 100644 --- a/tests/test_networks/test_brokers/test_handlers/test_handlers.py +++ b/tests/test_networks/test_brokers/test_handlers/test_handlers.py @@ -33,6 +33,7 @@ PostgresAsyncTestCase, ) from minos.networks import ( + REQUEST_HEADERS_CONTEXT_VAR, REQUEST_USER_CONTEXT_VAR, BrokerHandler, BrokerHandlerEntry, @@ -91,6 +92,7 @@ def setUp(self) -> None: identifier=self.identifier, user=self.user, reply_topic="UpdateTicket", + headers={"foo": "bar"}, ) async def asyncSetUp(self): @@ -227,7 +229,7 @@ async def test_dispatch(self): identifier=self.message.identifier, status=BrokerMessageStatus.SUCCESS, user=self.user, - headers=dict(), + headers={"foo": "bar"}, ) ], send_mock.call_args_list, @@ -294,20 +296,20 @@ async def test_dispatch_one(self): async def test_get_callback(self): fn = self.handler.get_callback(_Cls._fn) - self.assertEqual((FakeModel("foo"), BrokerMessageStatus.SUCCESS, dict()), await fn(self.message)) + self.assertEqual((FakeModel("foo"), BrokerMessageStatus.SUCCESS, {"foo": "bar"}), await fn(self.message)) async def test_get_callback_none(self): fn = self.handler.get_callback(_Cls._fn_none) - self.assertEqual((None, BrokerMessageStatus.SUCCESS, dict()), await fn(self.message)) + self.assertEqual((None, BrokerMessageStatus.SUCCESS, {"foo": "bar"}), await fn(self.message)) async def test_get_callback_raises_response(self): fn = self.handler.get_callback(_Cls._fn_raises_response) - expected = (repr(BrokerResponseException("foo")), BrokerMessageStatus.ERROR, dict()) + expected = (repr(BrokerResponseException("foo")), BrokerMessageStatus.ERROR, {"foo": "bar"}) self.assertEqual(expected, await fn(self.message)) async def test_get_callback_raises_exception(self): fn = self.handler.get_callback(_Cls._fn_raises_exception) - expected = (repr(ValueError()), BrokerMessageStatus.SYSTEM_ERROR, dict()) + expected = (repr(ValueError()), BrokerMessageStatus.SYSTEM_ERROR, {"foo": "bar"}) self.assertEqual(expected, await fn(self.message)) async def test_get_callback_with_user(self): @@ -322,6 +324,18 @@ async def _fn(request) -> None: self.assertEqual(1, mock.call_count) + async def test_get_callback_with_headers(self): + async def _fn(request) -> None: + self.assertEqual({"foo": "bar"}, request.raw.headers) + REQUEST_HEADERS_CONTEXT_VAR.get()["bar"] = "foo" + + mock = AsyncMock(side_effect=_fn) + + handler = self.handler.get_callback(mock) + _, _, observed = await handler(self.message) + + self.assertEqual({"foo": "bar", "bar": "foo"}, observed) + async def test_dispatch_without_sorting(self): observed = list()