Skip to content

Commit

Permalink
Refactor test classes to match module order for better readability
Browse files Browse the repository at this point in the history
Add test coverage for `utils` submodule
Update `create_mq_callback` to support decorating both functions and methods (including static methods)
  • Loading branch information
NeonDaniel committed Dec 6, 2024
1 parent ded01ad commit d92fb39
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 73 deletions.
1 change: 1 addition & 0 deletions neon_mq_connector/utils/consumer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from ovos_utils.log import LOG


def default_error_handler(*args):
"""
Default handler for Consumer instances
Expand Down
34 changes: 27 additions & 7 deletions neon_mq_connector/utils/rabbit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import inspect
from inspect import ismethod
from functools import wraps

from ovos_utils.log import LOG
Expand All @@ -34,17 +35,16 @@


def create_mq_callback(include_callback_props: tuple = ('body',)):
""" Creates MQ callback method by filtering relevant MQ attributes """
"""
Creates MQ callback method by filtering relevant MQ attributes
"""

if not include_callback_props:
include_callback_props = ()

def wrapper(f):

@wraps(f)
def wrapped(self, *f_args):
def _parse_kwargs(*f_args) -> dict:
mq_props = ['channel', 'method', 'properties', 'body']

callback_kwargs = {}

for idx in range(len(mq_props)):
Expand All @@ -54,19 +54,39 @@ def wrapped(self, *f_args):
if value and isinstance(value, bytes):
dict_data = b64_to_dict(value)
callback_kwargs['body'] = dict_data
elif value and isinstance(value, dict):
callback_kwargs['body'] = value
else:
raise TypeError(f'Invalid body received, expected: '
f'bytes string; got: {type(value)}')
else:
callback_kwargs[mq_props[idx]] = value
return callback_kwargs

@wraps(f)
def wrapped_classmethod(self, *f_args):
try:
res = f(self, **_parse_kwargs(*f_args))
except Exception as ex:
LOG.error(f'Execution of {f.__name__} failed due to '
f'exception={ex}')
res = None
return res

@wraps(f)
def wrapped(*f_args):
try:
res = f(self, **callback_kwargs)
res = f(**_parse_kwargs(*f_args))
except Exception as ex:
LOG.error(f'Execution of {f.__name__} failed due to '
f'exception={ex}')
res = None
return res

# Use the appropriate wrapper for a class method vs a function
signature = inspect.signature(f).parameters
if 'self' in signature:
return wrapped_classmethod
return wrapped

return wrapper
232 changes: 166 additions & 66 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@
import sys
import time
import unittest
from unittest.mock import Mock

import pytest
import pika

from threading import Thread


sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from neon_mq_connector.utils import RepeatingTimer
Expand Down Expand Up @@ -66,7 +69,28 @@ def callback_on_failure():
return False


class TestMQConnector(MQConnector):
class MqCallbackDecoratorClass:
from neon_mq_connector.utils.rabbit_utils import create_mq_callback
class_callback = Mock()

def __init__(self):
self.callback = Mock()

@create_mq_callback()
def default_callback(self, body):
self.callback(body)

@create_mq_callback(())
def no_kwargs_callback(self, **kwargs):
self.callback(**kwargs)

@staticmethod
@create_mq_callback()
def static_callback(body):
MqCallbackDecoratorClass.class_callback(body)


class SimpleMQConnector(MQConnector):
def __init__(self, config: dict, service_name: str, vhost: str):
super().__init__(config, service_name)
self.vhost = vhost
Expand All @@ -86,69 +110,8 @@ def respond(channel, method, _, body):
channel.basic_ack(delivery_tag=method.delivery_tag)


class TestMQConnectorUtils(unittest.TestCase):
counter = 0

def repeating_method(self):
"""Simple method incrementing counter by one"""
self.counter += 1

@retry(num_retries=3, backoff_factor=0.1,
callback_on_exceeded=callback_on_failure, use_self=True)
def method_passing_on_nth_attempt(self, num_attempts: int = 3) -> bool:
"""
Simple method that is passing check only after n-th attempt
:param num_attempts: number of attempts before passing
"""
if self.counter < num_attempts - 1:
self.repeating_method()
raise AssertionError('Awaiting counter equal to 3')
return True

def test_01_get_timeout(self):
"""Tests of getting timeout with backoff factor applied"""
__backoff_factor, __number_of_retries = 0.1, 1
timeout = get_timeout(__backoff_factor, __number_of_retries)
self.assertEqual(timeout, 0.1)
__number_of_retries += 1
timeout = get_timeout(__backoff_factor, __number_of_retries)
self.assertEqual(timeout, 0.2)
__number_of_retries += 1
timeout = get_timeout(__backoff_factor, __number_of_retries)
self.assertEqual(timeout, 0.4)

def test_02_retry_succeed(self):
"""Testing retry decorator"""
outcome = self.method_passing_on_nth_attempt(num_attempts=3)
self.assertTrue(outcome)
self.assertEqual(2, self.counter)

def test_03_retry_failed(self):
"""Testing retry decorator"""
outcome = self.method_passing_on_nth_attempt(num_attempts=4)
self.assertFalse(outcome)
self.assertEqual(3, self.counter)

def test_repeating_timer(self):
"""Testing repeating timer thread"""
interval_timeout = 3
timer_thread = RepeatingTimer(interval=0.9,
function=self.repeating_method)
timer_thread.start()
time.sleep(interval_timeout)
timer_thread.cancel()
self.assertEqual(self.counter, 3)

def test_wait_for_mq_startup(self):
self.assertTrue(wait_for_mq_startup("mq.neonaiservices.com", 5672))
self.assertFalse(wait_for_mq_startup("www.neon.ai", 5672, 1))

def setUp(self) -> None:
self.counter = 0


@pytest.mark.usefixtures("rmq_instance")
class MqUtilTests(unittest.TestCase):
class TestClientUtils(unittest.TestCase):
test_connector = None

def setUp(self) -> None:
Expand All @@ -161,9 +124,9 @@ def setUp(self) -> None:
import neon_mq_connector.utils.client_utils
neon_mq_connector.utils.client_utils._default_mq_config = test_conf
vhost = "/neon_testing"
self.test_connector = TestMQConnector(config=test_conf,
service_name="mq_handler",
vhost=vhost)
self.test_connector = SimpleMQConnector(config=test_conf,
service_name="mq_handler",
vhost=vhost)
self.test_connector.register_consumer("neon_utils_test",
vhost,
INPUT_CHANNEL,
Expand Down Expand Up @@ -237,6 +200,66 @@ def test_send_mq_request_invalid_vhost(self):
send_mq_request("invalid_endpoint", {}, "test", "test", timeout=5)


class TestMQConnectionUtils(unittest.TestCase):
counter = 0

def setUp(self) -> None:
self.counter = 0

def repeating_method(self):
"""Simple method incrementing counter by one"""
self.counter += 1

@retry(num_retries=3, backoff_factor=0.1,
callback_on_exceeded=callback_on_failure, use_self=True)
def method_passing_on_nth_attempt(self, num_attempts: int = 3) -> bool:
"""
Simple method that is passing check only after n-th attempt
:param num_attempts: number of attempts before passing
"""
if self.counter < num_attempts - 1:
self.repeating_method()
raise AssertionError(f'Awaiting counter equal to {num_attempts}')
return True

def test_get_timeout(self):
"""Tests of getting timeout with backoff factor applied"""
__backoff_factor, __number_of_retries = 0.1, 1
timeout = get_timeout(__backoff_factor, __number_of_retries)
self.assertEqual(timeout, 0.1)
__number_of_retries += 1
timeout = get_timeout(__backoff_factor, __number_of_retries)
self.assertEqual(timeout, 0.2)
__number_of_retries += 1
timeout = get_timeout(__backoff_factor, __number_of_retries)
self.assertEqual(timeout, 0.4)

def test_retry(self):
"""Testing retry decorator"""
outcome = self.method_passing_on_nth_attempt(num_attempts=3)
self.assertTrue(outcome)
self.assertEqual(2, self.counter)

outcome = self.method_passing_on_nth_attempt(num_attempts=4)
self.assertFalse(outcome)
self.assertEqual(3, self.counter)

def test_wait_for_mq_startup(self):
self.assertTrue(wait_for_mq_startup("mq.neonaiservices.com", 5672))
self.assertFalse(wait_for_mq_startup("www.neon.ai", 5672, 1))


class TestConsumerUtils(unittest.TestCase):
def test_default_error_handler(self):
from neon_mq_connector.utils.consumer_utils import default_error_handler
with self.assertRaises(Exception):
default_error_handler()

with self.assertRaises(Exception) as e:
default_error_handler("error message")
self.assertEqual(str(e.exception), "error message")


class TestNetworkUtils(unittest.TestCase):
def test_dict_to_b64(self):
b64_str = dict_to_b64(TEST_DICT)
Expand All @@ -254,3 +277,80 @@ def test_check_port_is_open(self):
from neon_mq_connector.utils.network_utils import check_port_is_open
self.assertTrue(check_port_is_open("mq.neonaiservices.com", 5672))
self.assertFalse(check_port_is_open("www.neon.ai", 5672))


class TestRabbitUtils(unittest.TestCase):
def test_create_mq_callback(self):
from neon_mq_connector.utils.rabbit_utils import create_mq_callback
callback = Mock()
test_body = {"test": True}
valid_request = {"channel": Mock(), "method": Mock(),
"properties": Mock(),
"body": dict_to_b64(test_body)}

@create_mq_callback()
def default_handler_body(body: dict):
callback(body)

@create_mq_callback()
def default_handler_kwargs(**kwargs):
callback(**kwargs)

@create_mq_callback(('body', 'method'))
def extra_kwargs_handler(**kwargs):
callback(**kwargs)

@create_mq_callback(())
def no_kwargs_handler(**kwargs):
callback(**kwargs)

# Default handler
default_handler_body(*valid_request.values())
callback.assert_called_once_with(test_body)

# Handler accepts kwargs
default_handler_kwargs(*valid_request.values())
callback.assert_called_with(body=test_body)

# Handler accepts multiple kwargs
extra_kwargs_handler(*valid_request.values())
callback.assert_called_with(body=test_body,
method=valid_request['method'])

# Handler accepts no kwargs
no_kwargs_handler(*valid_request.values())
callback.assert_called_with()

test_handlers = MqCallbackDecoratorClass()
# Class handler with default args
test_handlers.default_callback(*valid_request.values())
test_handlers.callback.assert_called_once_with(test_body)

# Class handler with no kwargs
test_handlers.no_kwargs_callback(*valid_request.values())
test_handlers.callback.assert_called_with()

# Class staticmethod handler
test_handlers.static_callback(*valid_request.values())
test_handlers.class_callback.assert_called_once_with(test_body)


class TestThreadUtils(unittest.TestCase):
counter = 0

def setUp(self) -> None:
self.counter = 0

def repeating_method(self):
"""Simple method incrementing counter by one"""
self.counter += 1

def test_repeating_timer(self):
"""Testing repeating timer thread"""
interval_timeout = 3
timer_thread = RepeatingTimer(interval=0.9,
function=self.repeating_method)
timer_thread.start()
time.sleep(interval_timeout)
timer_thread.cancel()
self.assertEqual(self.counter, 3)

0 comments on commit d92fb39

Please sign in to comment.