Skip to content

Commit

Permalink
Handling sync messages and model parsing in create_mq_callback (#123)
Browse files Browse the repository at this point in the history
* - adding pydantic dependency to the project
- `create_mq_callback` can now be called without explicit invocation
- added `request_model` attribute to seamlessly validate and pass pydantic model into callback instead of plain dict
- added handling of synchronous messages that have 'routing_key' for response

* fixed test failures

* Added tests for the new logic

* Renamed references

---------

Co-authored-by: NeonKirill <kirill.grim@gmail.com>
  • Loading branch information
NeonKirill and kirgrim authored Feb 4, 2025
1 parent 83bc863 commit c4c9157
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 18 deletions.
62 changes: 60 additions & 2 deletions neon_mq_connector/utils/rabbit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,51 @@
import inspect

from functools import wraps
from typing import Optional, Type, Callable, Any, Tuple

import pika.channel

from ovos_utils.log import LOG
from pydantic import BaseModel, ValidationError

from neon_mq_connector.utils.network_utils import b64_to_dict


def create_mq_callback(include_callback_props: tuple = ('body',)):
def create_mq_callback(
callback: Optional[
Callable[
[
pika.channel.Channel,
pika.spec.Basic.Deliver,
pika.spec.BasicProperties,
bytes,
],
Any
]
] = None,
*,
include_callback_props: Tuple[str] = ('body',),
request_model: Optional[Type[BaseModel]] = None,
):
"""
Creates MQ callback method by filtering relevant MQ attributes. Use this
decorator to simplify creation of MQ callbacks.
Note that the consumer must have `auto_ack=True` specified at registration
if the decorated function does not accept `channel` and `method` kwargs that
are required to acknowledge a message.
:param callback: callable to wrap into this decorator
:param include_callback_props: tuple of `pika` callback arguments to include (defaults to ('body',))
:param request_model: pydantic request model to convert received body to
"""

if callback and callable(callback): # No arguments passed, used directly
return create_mq_callback(
include_callback_props=include_callback_props,
request_model=request_model,
)(callback)

if not include_callback_props:
include_callback_props = ()

Expand All @@ -63,14 +93,38 @@ def _parse_kwargs(*f_args) -> dict:
else:
raise TypeError(f'Invalid body received, expected: '
f'bytes string; got: {type(value)}')
if request_model:
callback_kwargs['body'] = request_model.model_validate(
obj=callback_kwargs['body'],
)
else:
callback_kwargs[mq_props[idx]] = value
return callback_kwargs

@wraps(f)
def wrapped_classmethod(self, *f_args):
try:
res = f(self, **_parse_kwargs(*f_args))
parsed_request_kwargs = _parse_kwargs(*f_args)
res = f(self, **parsed_request_kwargs)

body = parsed_request_kwargs.get('body') or {}
if isinstance(body, BaseModel):
body = body.model_dump()

routing_key = body.get('routing_key')
message_id = body.get('message_id')

if routing_key and res and isinstance(res, dict):
res.setdefault("context", {}).setdefault("mq", {}).setdefault("message_id", message_id)
self.send_message(
request_data=res,
vhost=res.pop('vhost', self.vhost),
queue=routing_key,
)
except ValidationError as val_err:
LOG.error(f'Validation error when parsing request data of {f.__name__} failed due to '
f'error={val_err}')
res = None
except Exception as ex:
LOG.error(f'Execution of {f.__name__} failed due to '
f'exception={ex}')
Expand All @@ -81,6 +135,10 @@ def wrapped_classmethod(self, *f_args):
def wrapped(*f_args):
try:
res = f(**_parse_kwargs(*f_args))
except ValidationError as val_err:
LOG.error(f'Validation error when parsing request data of {f.__name__} failed due to '
f'error={val_err}')
res = None
except Exception as ex:
LOG.error(f'Execution of {f.__name__} failed due to '
f'exception={ex}')
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pika~=1.2
ovos-config~=0.0,>=0.0.8
ovos-utils~=0.0,>=0.0.32
pydantic
82 changes: 66 additions & 16 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from threading import Thread

from pika.exceptions import ProbableAuthenticationError
from pydantic import BaseModel

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

Expand All @@ -46,14 +47,18 @@
wait_for_mq_startup
from neon_mq_connector.utils.client_utils import MQConnector, NeonMQHandler
from neon_mq_connector.utils.network_utils import dict_to_b64, b64_to_dict
from neon_mq_connector.utils.rabbit_utils import create_mq_callback

from .fixtures import rmq_instance

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
TEST_PATH = os.path.join(ROOT_DIR, "tests", "ccl_files")

INPUT_CHANNEL = str(time.time())
OUTPUT_CHANNEL = str(time.time())
RANDOM_STR = str(int(time.time()))

INPUT_CHANNEL_A = RANDOM_STR + '_a'
INPUT_CHANNEL_B = RANDOM_STR + '_b'
OUTPUT_CHANNEL = RANDOM_STR + '_output'

TEST_DICT = {b"section 1": {"key1": "val1",
"key2": "val2"},
Expand All @@ -70,26 +75,35 @@ def callback_on_failure():
return False


class MockRequestModel(BaseModel):
message_id: str
test: bool = True


class MqCallbackDecoratorClass:
from neon_mq_connector.utils.rabbit_utils import create_mq_callback
class_callback = Mock()

def __init__(self):
self.callback = Mock()

@create_mq_callback()
@create_mq_callback
def default_callback(self, body):
self.callback(body)

@create_mq_callback(())
@create_mq_callback(include_callback_props=())
def no_kwargs_callback(self, **kwargs):
self.callback(**kwargs)

@staticmethod
@create_mq_callback()
@create_mq_callback
def static_callback(body):
MqCallbackDecoratorClass.class_callback(body)

@create_mq_callback(request_model=MockRequestModel)
def callback_with_pydantic_model(self, **kwargs):
self.callback(**kwargs)


class SimpleMQConnector(MQConnector):
def __init__(self, config: dict, service_name: str, vhost: str):
Expand All @@ -110,6 +124,15 @@ def respond(channel, method, _, body):
properties=pika.BasicProperties(expiration='1000'))
channel.basic_ack(delivery_tag=method.delivery_tag)

@create_mq_callback
def respond_wrapped(self, body: dict):
return {
"message_id": body["message_id"],
"success": True,
"request_data": body["data"],
}



@pytest.mark.usefixtures("rmq_instance")
class TestClientUtils(unittest.TestCase):
Expand All @@ -130,9 +153,14 @@ def setUp(self) -> None:
vhost=vhost)
self.test_connector.register_consumer("neon_utils_test",
vhost,
INPUT_CHANNEL,
INPUT_CHANNEL_A,
self.test_connector.respond,
auto_ack=False)
self.test_connector.register_consumer("neon_utils_test_wrapped",
vhost,
INPUT_CHANNEL_B,
self.test_connector.respond_wrapped,
auto_ack=False)
self.test_connector.run_consumers()

@classmethod
Expand All @@ -143,7 +171,7 @@ def tearDownClass(cls) -> None:
def test_send_mq_request_valid(self):
from neon_mq_connector.utils.client_utils import send_mq_request
request = {"data": time.time()}
response = send_mq_request("/neon_testing", request, INPUT_CHANNEL)
response = send_mq_request("/neon_testing", request, INPUT_CHANNEL_A)
self.assertIsInstance(response, dict)
self.assertTrue(response["success"])
self.assertEqual(response["request_data"], request["data"])
Expand All @@ -152,7 +180,16 @@ def test_send_mq_request_spec_output_channel_valid(self):
from neon_mq_connector.utils.client_utils import send_mq_request
request = {"data": time.time()}
response = send_mq_request("/neon_testing", request,
INPUT_CHANNEL, OUTPUT_CHANNEL)
INPUT_CHANNEL_A, OUTPUT_CHANNEL)
self.assertIsInstance(response, dict)
self.assertTrue(response["success"])
self.assertEqual(response["request_data"], request["data"])

def test_send_mq_request_response_emit_handled_by_create_mq_request_decorator(self):
from neon_mq_connector.utils.client_utils import send_mq_request

request = {"data": time.time()}
response = send_mq_request("/neon_testing", request, INPUT_CHANNEL_B)
self.assertIsInstance(response, dict)
self.assertTrue(response["success"])
self.assertEqual(response["request_data"], request["data"])
Expand All @@ -164,7 +201,7 @@ def test_multiple_mq_requests(self):

def check_response(name: str):
request = {"data": time.time()}
response = send_mq_request("/neon_testing", request, INPUT_CHANNEL)
response = send_mq_request("/neon_testing", request, INPUT_CHANNEL_A)
self.assertIsInstance(response, dict)
if not isinstance(response, dict):
responses[name] = {'success': False,
Expand Down Expand Up @@ -343,27 +380,36 @@ def test_check_port_is_open(self):


class TestRabbitUtils(unittest.TestCase):

@staticmethod
def create_mock_request(body):
return {
"channel": Mock(),
"method": Mock(),
"properties": Mock(),
"body": dict_to_b64(body)
}

def test_create_mq_callback(self):
from neon_mq_connector.utils.rabbit_utils import create_mq_callback
callback = Mock()
test_body = {"test": True}
valid_request = {"channel": Mock(), "method": Mock(),
"properties": Mock(),
"body": dict_to_b64(test_body)}
valid_request = self.create_mock_request(body=test_body)
mock_model = MockRequestModel(message_id="test_id")

@create_mq_callback()
@create_mq_callback
def default_handler_body(body: dict):
callback(body)

@create_mq_callback()
@create_mq_callback
def default_handler_kwargs(**kwargs):
callback(**kwargs)

@create_mq_callback(('body', 'method'))
@create_mq_callback(include_callback_props=('body', 'method'))
def extra_kwargs_handler(**kwargs):
callback(**kwargs)

@create_mq_callback(())
@create_mq_callback(include_callback_props=())
def no_kwargs_handler(**kwargs):
callback(**kwargs)

Expand Down Expand Up @@ -397,6 +443,10 @@ def no_kwargs_handler(**kwargs):
test_handlers.static_callback(*valid_request.values())
test_handlers.class_callback.assert_called_once_with(test_body)

# Pydantic model handler
valid_model_request = self.create_mock_request(body=mock_model.model_dump())
test_handlers.callback_with_pydantic_model(*valid_model_request.values())
test_handlers.callback.assert_called_with(body=mock_model)

class TestThreadUtils(unittest.TestCase):
counter = 0
Expand Down

0 comments on commit c4c9157

Please sign in to comment.