From d92fb39b6d652d52b892ae52e1977e8ce6205a42 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Thu, 5 Dec 2024 16:55:11 -0800 Subject: [PATCH] Refactor test classes to match module order for better readability Add test coverage for `utils` submodule Update `create_mq_callback` to support decorating both functions and methods (including static methods) --- neon_mq_connector/utils/consumer_utils.py | 1 + neon_mq_connector/utils/rabbit_utils.py | 34 +++- tests/test_utils.py | 232 ++++++++++++++++------ 3 files changed, 194 insertions(+), 73 deletions(-) diff --git a/neon_mq_connector/utils/consumer_utils.py b/neon_mq_connector/utils/consumer_utils.py index f65d3e6..070a5c2 100644 --- a/neon_mq_connector/utils/consumer_utils.py +++ b/neon_mq_connector/utils/consumer_utils.py @@ -29,6 +29,7 @@ from ovos_utils.log import LOG + def default_error_handler(*args): """ Default handler for Consumer instances diff --git a/neon_mq_connector/utils/rabbit_utils.py b/neon_mq_connector/utils/rabbit_utils.py index 00ec73b..baf24f8 100644 --- a/neon_mq_connector/utils/rabbit_utils.py +++ b/neon_mq_connector/utils/rabbit_utils.py @@ -25,7 +25,8 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +import inspect +from inspect import ismethod from functools import wraps from ovos_utils.log import LOG @@ -34,17 +35,16 @@ def create_mq_callback(include_callback_props: tuple = ('body',)): - """ Creates MQ callback method by filtering relevant MQ attributes """ + """ + Creates MQ callback method by filtering relevant MQ attributes + """ if not include_callback_props: include_callback_props = () def wrapper(f): - - @wraps(f) - def wrapped(self, *f_args): + def _parse_kwargs(*f_args) -> dict: mq_props = ['channel', 'method', 'properties', 'body'] - callback_kwargs = {} for idx in range(len(mq_props)): @@ -54,19 +54,39 @@ def wrapped(self, *f_args): if value and isinstance(value, bytes): dict_data = b64_to_dict(value) callback_kwargs['body'] = dict_data + elif value and isinstance(value, dict): + callback_kwargs['body'] = value else: raise TypeError(f'Invalid body received, expected: ' f'bytes string; got: {type(value)}') else: callback_kwargs[mq_props[idx]] = value + return callback_kwargs + + @wraps(f) + def wrapped_classmethod(self, *f_args): + try: + res = f(self, **_parse_kwargs(*f_args)) + except Exception as ex: + LOG.error(f'Execution of {f.__name__} failed due to ' + f'exception={ex}') + res = None + return res + + @wraps(f) + def wrapped(*f_args): try: - res = f(self, **callback_kwargs) + res = f(**_parse_kwargs(*f_args)) except Exception as ex: LOG.error(f'Execution of {f.__name__} failed due to ' f'exception={ex}') res = None return res + # Use the appropriate wrapper for a class method vs a function + signature = inspect.signature(f).parameters + if 'self' in signature: + return wrapped_classmethod return wrapped return wrapper diff --git a/tests/test_utils.py b/tests/test_utils.py index 9cf464b..744cd65 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -30,11 +30,14 @@ import sys import time import unittest +from unittest.mock import Mock + import pytest import pika from threading import Thread + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from neon_mq_connector.utils import RepeatingTimer @@ -66,7 +69,28 @@ def callback_on_failure(): return False -class TestMQConnector(MQConnector): +class MqCallbackDecoratorClass: + from neon_mq_connector.utils.rabbit_utils import create_mq_callback + class_callback = Mock() + + def __init__(self): + self.callback = Mock() + + @create_mq_callback() + def default_callback(self, body): + self.callback(body) + + @create_mq_callback(()) + def no_kwargs_callback(self, **kwargs): + self.callback(**kwargs) + + @staticmethod + @create_mq_callback() + def static_callback(body): + MqCallbackDecoratorClass.class_callback(body) + + +class SimpleMQConnector(MQConnector): def __init__(self, config: dict, service_name: str, vhost: str): super().__init__(config, service_name) self.vhost = vhost @@ -86,69 +110,8 @@ def respond(channel, method, _, body): channel.basic_ack(delivery_tag=method.delivery_tag) -class TestMQConnectorUtils(unittest.TestCase): - counter = 0 - - def repeating_method(self): - """Simple method incrementing counter by one""" - self.counter += 1 - - @retry(num_retries=3, backoff_factor=0.1, - callback_on_exceeded=callback_on_failure, use_self=True) - def method_passing_on_nth_attempt(self, num_attempts: int = 3) -> bool: - """ - Simple method that is passing check only after n-th attempt - :param num_attempts: number of attempts before passing - """ - if self.counter < num_attempts - 1: - self.repeating_method() - raise AssertionError('Awaiting counter equal to 3') - return True - - def test_01_get_timeout(self): - """Tests of getting timeout with backoff factor applied""" - __backoff_factor, __number_of_retries = 0.1, 1 - timeout = get_timeout(__backoff_factor, __number_of_retries) - self.assertEqual(timeout, 0.1) - __number_of_retries += 1 - timeout = get_timeout(__backoff_factor, __number_of_retries) - self.assertEqual(timeout, 0.2) - __number_of_retries += 1 - timeout = get_timeout(__backoff_factor, __number_of_retries) - self.assertEqual(timeout, 0.4) - - def test_02_retry_succeed(self): - """Testing retry decorator""" - outcome = self.method_passing_on_nth_attempt(num_attempts=3) - self.assertTrue(outcome) - self.assertEqual(2, self.counter) - - def test_03_retry_failed(self): - """Testing retry decorator""" - outcome = self.method_passing_on_nth_attempt(num_attempts=4) - self.assertFalse(outcome) - self.assertEqual(3, self.counter) - - def test_repeating_timer(self): - """Testing repeating timer thread""" - interval_timeout = 3 - timer_thread = RepeatingTimer(interval=0.9, - function=self.repeating_method) - timer_thread.start() - time.sleep(interval_timeout) - timer_thread.cancel() - self.assertEqual(self.counter, 3) - - def test_wait_for_mq_startup(self): - self.assertTrue(wait_for_mq_startup("mq.neonaiservices.com", 5672)) - self.assertFalse(wait_for_mq_startup("www.neon.ai", 5672, 1)) - - def setUp(self) -> None: - self.counter = 0 - - @pytest.mark.usefixtures("rmq_instance") -class MqUtilTests(unittest.TestCase): +class TestClientUtils(unittest.TestCase): test_connector = None def setUp(self) -> None: @@ -161,9 +124,9 @@ def setUp(self) -> None: import neon_mq_connector.utils.client_utils neon_mq_connector.utils.client_utils._default_mq_config = test_conf vhost = "/neon_testing" - self.test_connector = TestMQConnector(config=test_conf, - service_name="mq_handler", - vhost=vhost) + self.test_connector = SimpleMQConnector(config=test_conf, + service_name="mq_handler", + vhost=vhost) self.test_connector.register_consumer("neon_utils_test", vhost, INPUT_CHANNEL, @@ -237,6 +200,66 @@ def test_send_mq_request_invalid_vhost(self): send_mq_request("invalid_endpoint", {}, "test", "test", timeout=5) +class TestMQConnectionUtils(unittest.TestCase): + counter = 0 + + def setUp(self) -> None: + self.counter = 0 + + def repeating_method(self): + """Simple method incrementing counter by one""" + self.counter += 1 + + @retry(num_retries=3, backoff_factor=0.1, + callback_on_exceeded=callback_on_failure, use_self=True) + def method_passing_on_nth_attempt(self, num_attempts: int = 3) -> bool: + """ + Simple method that is passing check only after n-th attempt + :param num_attempts: number of attempts before passing + """ + if self.counter < num_attempts - 1: + self.repeating_method() + raise AssertionError(f'Awaiting counter equal to {num_attempts}') + return True + + def test_get_timeout(self): + """Tests of getting timeout with backoff factor applied""" + __backoff_factor, __number_of_retries = 0.1, 1 + timeout = get_timeout(__backoff_factor, __number_of_retries) + self.assertEqual(timeout, 0.1) + __number_of_retries += 1 + timeout = get_timeout(__backoff_factor, __number_of_retries) + self.assertEqual(timeout, 0.2) + __number_of_retries += 1 + timeout = get_timeout(__backoff_factor, __number_of_retries) + self.assertEqual(timeout, 0.4) + + def test_retry(self): + """Testing retry decorator""" + outcome = self.method_passing_on_nth_attempt(num_attempts=3) + self.assertTrue(outcome) + self.assertEqual(2, self.counter) + + outcome = self.method_passing_on_nth_attempt(num_attempts=4) + self.assertFalse(outcome) + self.assertEqual(3, self.counter) + + def test_wait_for_mq_startup(self): + self.assertTrue(wait_for_mq_startup("mq.neonaiservices.com", 5672)) + self.assertFalse(wait_for_mq_startup("www.neon.ai", 5672, 1)) + + +class TestConsumerUtils(unittest.TestCase): + def test_default_error_handler(self): + from neon_mq_connector.utils.consumer_utils import default_error_handler + with self.assertRaises(Exception): + default_error_handler() + + with self.assertRaises(Exception) as e: + default_error_handler("error message") + self.assertEqual(str(e.exception), "error message") + + class TestNetworkUtils(unittest.TestCase): def test_dict_to_b64(self): b64_str = dict_to_b64(TEST_DICT) @@ -254,3 +277,80 @@ def test_check_port_is_open(self): from neon_mq_connector.utils.network_utils import check_port_is_open self.assertTrue(check_port_is_open("mq.neonaiservices.com", 5672)) self.assertFalse(check_port_is_open("www.neon.ai", 5672)) + + +class TestRabbitUtils(unittest.TestCase): + def test_create_mq_callback(self): + from neon_mq_connector.utils.rabbit_utils import create_mq_callback + callback = Mock() + test_body = {"test": True} + valid_request = {"channel": Mock(), "method": Mock(), + "properties": Mock(), + "body": dict_to_b64(test_body)} + + @create_mq_callback() + def default_handler_body(body: dict): + callback(body) + + @create_mq_callback() + def default_handler_kwargs(**kwargs): + callback(**kwargs) + + @create_mq_callback(('body', 'method')) + def extra_kwargs_handler(**kwargs): + callback(**kwargs) + + @create_mq_callback(()) + def no_kwargs_handler(**kwargs): + callback(**kwargs) + + # Default handler + default_handler_body(*valid_request.values()) + callback.assert_called_once_with(test_body) + + # Handler accepts kwargs + default_handler_kwargs(*valid_request.values()) + callback.assert_called_with(body=test_body) + + # Handler accepts multiple kwargs + extra_kwargs_handler(*valid_request.values()) + callback.assert_called_with(body=test_body, + method=valid_request['method']) + + # Handler accepts no kwargs + no_kwargs_handler(*valid_request.values()) + callback.assert_called_with() + + test_handlers = MqCallbackDecoratorClass() + # Class handler with default args + test_handlers.default_callback(*valid_request.values()) + test_handlers.callback.assert_called_once_with(test_body) + + # Class handler with no kwargs + test_handlers.no_kwargs_callback(*valid_request.values()) + test_handlers.callback.assert_called_with() + + # Class staticmethod handler + test_handlers.static_callback(*valid_request.values()) + test_handlers.class_callback.assert_called_once_with(test_body) + + +class TestThreadUtils(unittest.TestCase): + counter = 0 + + def setUp(self) -> None: + self.counter = 0 + + def repeating_method(self): + """Simple method incrementing counter by one""" + self.counter += 1 + + def test_repeating_timer(self): + """Testing repeating timer thread""" + interval_timeout = 3 + timer_thread = RepeatingTimer(interval=0.9, + function=self.repeating_method) + timer_thread.start() + time.sleep(interval_timeout) + timer_thread.cancel() + self.assertEqual(self.counter, 3)