diff --git a/README.md b/README.md index cc7d601..ec6bfc2 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # MRSAL AMQP -[![Release](https://img.shields.io/badge/release-1.0.9-etalue.svg)](https://pypi.org/project/mrsal/) [![Python 3.10](https://img.shields.io/badge/python-3.10--3.11--3.12-blue.svg)](https://www.python.org/downloads/release/python-3103/)[![Mrsal Workflow](https://github.com/NeoMedSys/mrsal/actions/workflows/mrsal.yaml/badge.svg?branch=main)](https://github.com/NeoMedSys/mrsal/actions/workflows/mrsal.yaml) +[![Release](https://img.shields.io/badge/release-1.0.9-etalue.svg)](https://pypi.org/project/mrsal/) [![Python 3.10](https://img.shields.io/badge/python-3.10--3.11--3.12lue.svg)](https://www.python.org/downloads/release/python-3103/)[![Mrsal Workflow](https://github.com/NeoMedSys/mrsal/actions/workflows/mrsal.yaml/badge.svg?branch=main)](https://github.com/NeoMedSys/mrsal/actions/workflows/mrsal.yaml) ## Intro diff --git a/mrsal/amqp/subclass.py b/mrsal/amqp/subclass.py index 2848ed9..040466e 100644 --- a/mrsal/amqp/subclass.py +++ b/mrsal/amqp/subclass.py @@ -1,9 +1,7 @@ -from functools import partial import pika import json from mrsal.exceptions import MrsalAbortedSetup from logging import WARNING -from pika.connection import SSLOptions from pika.exceptions import ( AMQPConnectionError, ChannelClosedByBroker, @@ -12,8 +10,8 @@ NackError, UnroutableError ) -from pika.adapters.asyncio_connection import AsyncioConnection -from typing import Any, Callable, Type +from aio_pika import connect_robust, Channel as AioChannel +from typing import Callable, Type from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type, before_sleep_log from pydantic import ValidationError from pydantic.dataclasses import dataclass @@ -25,7 +23,7 @@ log = NeoLogger(__name__, rotate_days=config.LOG_DAYS) @dataclass -class MrsalAMQP(Mrsal): +class MrsalBlockingAMQP(Mrsal): """ :param int blocked_connection_timeout: blocked_connection_timeout is the timeout, in seconds, @@ -33,14 +31,7 @@ class MrsalAMQP(Mrsal): the connection will be torn down during connection tuning. """ blocked_connection_timeout: int = 60 # sec - use_blocking: bool = False - def get_ssl_context(self) -> SSLOptions | None: - if self.ssl: - self.log.info("Setting up TLS connection") - context = self._ssl_setup() - ssl_options = pika.SSLOptions(context, self.host) if 'context' in locals() else None - return ssl_options def setup_blocking_connection(self) -> None: """We can use setup_blocking_connection for establishing a connection to RabbitMQ server specifying connection parameters. @@ -69,7 +60,7 @@ def setup_blocking_connection(self) -> None: pika.ConnectionParameters( host=self.host, port=self.port, - ssl_options=self.get_ssl_context(), + ssl_options=self.get_ssl_context(async_conn=False), virtual_host=self.virtual_host, credentials=credentials, heartbeat=self.heartbeat, @@ -89,53 +80,6 @@ def setup_blocking_connection(self) -> None: except Exception as e: self.log.error(f"Unexpected error caught: {e}") - def setup_async_connection(self) -> None: - """We can use setup_aync_connection for establishing a connection to RabbitMQ server specifying connection parameters. - The connection is async and is recommended to use if your app is realtime or will handle a lot of traffic. - - Parameters - ---------- - context : Dict[str, str] - context is the structured map with information regarding the SSL options for connecting with rabbit server via TLS. - """ - connection_info = f""" - Mrsal connection parameters: - host={self.host}, - virtual_host={self.virtual_host}, - port={self.port}, - heartbeat={self.heartbeat}, - ssl={self.ssl} - """ - if self.verbose: - self.log.info(f"Establishing connection to RabbitMQ on {connection_info}") - credentials = pika.PlainCredentials(*self.credentials) - conn_conf = pika.ConnectionParameters( - host=self.host, - port=self.port, - ssl_options=self.get_ssl_context(), - virtual_host=self.virtual_host, - credentials=credentials, - heartbeat=self.heartbeat, - ) - try: - self._connection = AsyncioConnection( - parameters=conn_conf, - on_open_callback=partial( - self.on_connection_open, - exchange_name=self.exchange_name, queue_name=self.queue_name, - exchange_type=self.exchange_type, routing_key=self.routing_key - ), - on_open_error_callback=self.on_connection_error - ) - self.log.info(f"Connection staged with RabbitMQ on {connection_info}") - except (AMQPConnectionError, ChannelClosedByBroker, ConnectionClosedByBroker, StreamLostError) as e: - self.log.error(f"Oh lordy lord I failed connecting to the Rabbit with: {e}") - raise - except Exception as e: - self.log.error(f"Unexpected error caught: {e}") - - - @retry( retry=retry_if_exception_type(( AMQPConnectionError, @@ -157,7 +101,8 @@ def start_consumer(self, exchange_name: str | None = None, exchange_type: str | None = None, routing_key: str | None = None, - payload_model: Type | None = None + payload_model: Type | None = None, + requeue: bool = True ) -> None: """ Start the consumer using blocking setup. @@ -168,40 +113,23 @@ def start_consumer(self, :param callback_args: Optional arguments to pass to the callback. """ # Connect and start the I/O loop - if self.use_blocking: - self.setup_blocking_connection() - else: - # set connection parameters - # parametes propagate through a 3 layers in order - # to spin up the async connection - self.queue_name = queue_name - self.exchange_name = exchange_name - self.exchange_type = exchange_type - self.routing_key = routing_key - self.auto_declare = auto_declare - - self.setup_async_connection() - if self._connection.is_open: - self.log.success(f"Boom! Async connection established with {exchange_name} on {queue_name}") - self._connection.ioloop.run_forever() - else: - self.log.error('Straigh out of the swamp with no connection! Async connection did not activate') - - if auto_declare and self.use_blocking: + self.setup_blocking_connection() + + if auto_declare: if None in (exchange_name, queue_name, exchange_type, routing_key): raise TypeError('Make sure that you are passing in all the necessary args for auto_declare') + self._setup_exchange_and_queue( exchange_name=exchange_name, queue_name=queue_name, exchange_type=exchange_type, routing_key=routing_key ) + if not self.auto_declare_ok: - if self._connection.is_open: - self._connection.ioloop.stop() raise MrsalAbortedSetup('Auto declaration for the connection setup failed and is aborted') - self.log.info(f"Consumer boi listening on queue: {queue_name} to the exchange {exchange_name}. Waiting for messages...") + self.log.info(f"Straigh out of the swamps -- consumer boi listening on queue: {queue_name} to the exchange {exchange_name}. Waiting for messages...") try: for method_frame, properties, body in self._channel.consume( @@ -209,22 +137,24 @@ def start_consumer(self, if method_frame: if properties: app_id = properties.app_id if hasattr(properties, 'app_id') else 'no AppID given' - msg_id = properties.msg_id if hasattr(properties, 'msg_id') else 'no msgID given' + msg_id = properties.message_id if hasattr(properties, 'message_id') else 'no msgID given' if self.verbose: self.log.info( - """ + f""" Message received with: - - Method Frame: {method_frame) + - Method Frame: {method_frame} - Redelivery: {method_frame.redelivered} - Exchange: {method_frame.exchange} - Routing Key: {method_frame.routing_key} - Delivery Tag: {method_frame.delivery_tag} - Properties: {properties} + - Requeue: {requeue} + - Auto Ack: {auto_ack} """ ) if auto_ack: - self.log.info(f'I successfully received a message from: {app_id} with messageID: {msg_id}') + self.log.info(f'I successfully received a message with AutoAck from: {app_id} with messageID: {msg_id}') if payload_model: try: @@ -243,13 +173,13 @@ def start_consumer(self, callback( method_frame, properties, body) except Exception as e: if not auto_ack: - self._channel.basic_nack(delivery_tag=method_frame.delivery_tag, requeue=True) + self._channel.basic_nack(delivery_tag=method_frame.delivery_tag, requeue=requeue) self.log.error("Callback method failure: {e}") continue + if not auto_ack: self.log.success(f'Message ({msg_id}) from {app_id} received and properly processed -- now dance the funky chicken') self._channel.basic_ack(delivery_tag=method_frame.delivery_tag) - except (AMQPConnectionError, ConnectionClosedByBroker, StreamLostError) as e: log.error(f"Ooooooopsie! I caught a connection error while consuming: {e}") raise @@ -318,3 +248,129 @@ def publish_message( except Exception as e: self.log.error(f"Unexpected error while publishing message: {e}") + + +class MrsalAsyncAMQP(Mrsal): + """Handles asynchronous connection with RabbitMQ using aio-pika.""" + async def setup_async_connection(self): + """Setup an asynchronous connection to RabbitMQ using aio-pika.""" + self.log.info(f"Establishing async connection to RabbitMQ on {self.host}:{self.port}") + try: + self._connection = await connect_robust( + host=self.host, + port=self.port, + login=self.credentials[0], + password=self.credentials[1], + virtualhost=self.virtual_host, + ssl=self.ssl, + ssl_context=self.get_ssl_context(), + heartbeat=self.heartbeat + ) + self._channel = await self._connection.channel() + await self._channel.set_qos(prefetch_count=self.prefetch_count) + self.log.info("Async connection established successfully.") + except (AMQPConnectionError, StreamLostError, ChannelClosedByBroker, ConnectionClosedByBroker) as e: + self.log.error(f"Error establishing async connection: {e}", exc_info=True) + raise + except Exception as e: + self.log.error(f'Oh my lordy lord! I caugth an unexpected exception while trying to connect: {e}', exc_info=True) + + @retry( + retry=retry_if_exception_type(( + AMQPConnectionError, + ChannelClosedByBroker, + ConnectionClosedByBroker, + StreamLostError, + )), + stop=stop_after_attempt(3), + wait=wait_fixed(2), + before_sleep=before_sleep_log(log, WARNING) + ) + async def start_consumer( + self, + queue_name: str, + callback: Callable | None = None, + callback_args: dict[str, str | int | float | bool] | None = None, + auto_ack: bool = False, + auto_declare: bool = True, + exchange_name: str | None = None, + exchange_type: str | None = None, + routing_key: str | None = None, + payload_model: Type | None = None, + requeue: bool = True + ): + """Start the async consumer with the provided setup.""" + # Check if there's a connection; if not, create one + if not self._connection: + await self.setup_async_connection() + + + self._channel: AioChannel = await self._connection.channel() + await self._channel.set_qos(prefetch_count=self.prefetch_count) + + if auto_declare: + if None in (exchange_name, queue_name, exchange_type, routing_key): + raise TypeError('Make sure that you are passing in all the necessary args for auto_declare') + + queue = await self._async_setup_exchange_and_queue( + exchange_name=exchange_name, + queue_name=queue_name, + exchange_type=exchange_type, + routing_key=routing_key + ) + + if not self.auto_declare_ok: + if self._connection: + await self._connection.close() + raise MrsalAbortedSetup('Auto declaration failed during setup.') + + self.log.info(f"Straight out of the swamps -- Consumer boi listening on queue: {queue_name}, exchange: {exchange_name}") + + # async with queue.iterator() as queue_iter: + async for message in queue.iterator(): + if message is None: + continue + + # Extract message metadata + app_id = message.app_id if hasattr(message, 'app_id') else 'NoAppID' + msg_id = message.app_id if hasattr(message, 'message_id') else 'NoMsgID' + + if self.verbose: + self.log.info(f""" + Message received with: + - Redelivery: {message.redelivered} + - Exchange: {message.exchange} + - Routing Key: {message.routing_key} + - Delivery Tag: {message.delivery_tag} + - Requeue: {requeue} + - Auto Ack: {auto_ack} + """) + + if auto_ack: + await message.ack() + self.log.info(f'I successfully received a message from: {app_id} with messageID: {msg_id}') + + if payload_model: + try: + self.validate_payload(message.body, payload_model) + except (ValidationError, json.JSONDecodeError, UnicodeDecodeError, TypeError) as e: + self.log.error(f"Payload validation failed: {e}", exc_info=True) + if not auto_ack: + await message.reject(requeue=requeue) + continue + + if callback: + try: + if callback_args: + await callback(*callback_args, message) + else: + await callback(message) + except Exception as e: + self.log.error(f"Splæt! Error processing message with callback: {e}", exc_info=True) + if not auto_ack: + await message.reject(requeue=requeue) + continue + + if not auto_ack: + await message.ack() + self.log.success(f'Young grasshopper! Message ({msg_id}) from {app_id} received and properly processed.') diff --git a/mrsal/superclass.py b/mrsal/superclass.py index 686b92b..f89f81a 100644 --- a/mrsal/superclass.py +++ b/mrsal/superclass.py @@ -1,10 +1,12 @@ # external -from functools import partial import os import ssl +import pika from ssl import SSLContext from typing import Any, Type -from mrsal.exceptions import MrsalSetupError +from mrsal.exceptions import MrsalAbortedSetup, MrsalSetupError +from pika.connection import SSLOptions +from aio_pika import ExchangeType as AioExchangeType, Queue as AioQueue, Exchange as AioExchange from pydantic.dataclasses import dataclass from neolibrary.monitoring.logger import NeoLogger from pydantic.deprecated.tools import json @@ -14,7 +16,6 @@ log = NeoLogger(__name__, rotate_days=config.LOG_DAYS) - @dataclass # NOTE! change the doc style to google or numpy class Mrsal: @@ -96,58 +97,62 @@ def _setup_exchange_and_queue(self, self._declare_queue(**declare_queue_dict) self._declare_queue_binding(**declare_queue_binding_dict) self.auto_declare_ok = True - except MrsalSetupError: + self.log.success(f"Exchange {exchange_name} and Queue {queue_name} set up successfully.") + except MrsalSetupError as e: + self.log.error(f'Splæt! I failed the declaration setup with {e}', exc_info=True) self.auto_declare_ok = False - def on_connection_error(self, connection, exception): - """ - Handle connection errors. - """ - self.log.error(f"I failed to establish async connection: {exception}") - try: - if connection and connection.is_open: - connection.close() - except Exception as e: - self.log.error(f'Oh lordy lord! I failed closing the connection with: {e}') + async def _async_setup_exchange_and_queue(self, + exchange_name: str, queue_name: str, + routing_key: str, exchange_type: str, + exch_args: dict[str, str] | None = None, + queue_args: dict[str, str] | None = None, + bind_args: dict[str, str] | None = None, + exch_durable: bool = True, queue_durable: bool = True, + passive: bool = False, internal: bool = False, + auto_delete: bool = False, exclusive: bool = False + ) -> AioQueue | None: + """Setup exchange and queue with bindings asynchronously.""" + if not self._connection: + raise MrsalAbortedSetup("Oh my Oh my! Connection not found when trying to run the setup!") + + async_declare_exhange_dict = { + 'exchange': exchange_name, + 'exchange_type': exchange_type, + 'arguments': exch_args, + 'durable': exch_durable, + 'passive': passive, + 'internal': internal, + 'auto_delete': auto_delete + } - def on_channel_open( - self, exchange_name: str, queue_name: str, - exchange_type: str, routing_key: str - ) -> None: - """ - Open a channel once the connection is established. - """ - if self._connection and self._connection: - self._channel = self._connection.channel() - self._channel.basic_qos(prefetch_count=self.prefetch_count) - self.log.info(f"Channel opened with prefetch count: {self.prefetch_count}") - self._setup_exchange_and_queue( - exchange_name=exchange_name, queue_name=queue_name, - exchange_type=exchange_type, routing_key=routing_key - ) - else: - self.log.error("Splæt! Connection is not open. Cannot create a channel.") - - def open_channel(self, exchange_name, queue_name, exchange_type, routing_key): - """Open a channel once the connection is established.""" - self._connection.channel( - on_open_callback=partial( - self.on_channel_open, - exchange_name=exchange_name, queue_name=queue_name, - exchange_type=exchange_type, routing_key=routing_key - ) - ) + async_declare_queue_dict = { + 'queue_name': queue_name, + 'arguments': queue_args, + 'durable': queue_durable, + 'exclusive': exclusive, + 'auto_delete': auto_delete, + 'passive': passive + } + + async_declare_queue_binding_dict = { + 'routing_key': routing_key, + 'arguments': bind_args + + } + + try: + # Declare exchange and queue + exchange = await self._async_declare_exchange(**async_declare_exhange_dict) + queue = await self._async_declare_queue(**async_declare_queue_dict) + await self._async_declare_queue_binding(queue=queue, exchange=exchange, **async_declare_queue_binding_dict) + self.auto_declare_ok = True + self.log.success(f"Exchange {exchange_name} and Queue {queue_name} set up successfully.") + return queue + except MrsalSetupError as e: + self.log.error(f'Splæt! I failed the declaration setup with {e}', exc_info=True) + self.auto_declare_ok = False - def on_connection_open(self, - connection, - exchange_name: str, queue_name: str, - exchange_type: str, routing_key: str - ) -> None: - """ - Callback when the async connection is successfully opened. - """ - self._connection = connection - self.open_channel(exchange_name, queue_name, exchange_type, routing_key) def _declare_exchange(self, exchange: str, exchange_type: str, @@ -192,6 +197,40 @@ def _declare_exchange(self, if self.verbose: self.log.success("Exchange declared yo!") + async def _async_declare_exchange(self, + exchange: str, + exchange_type: AioExchangeType, + arguments: dict[str, str] | None = None, + durable: bool = True, + passive: bool = False, + internal: bool = False, + auto_delete: bool = False) -> AioExchange: + """Declare a RabbitMQ exchange in async mode.""" + exchange_declare_info = f""" + exchange={exchange}, + exchange_type={exchange_type}, + durable={durable}, + passive={passive}, + internal={internal}, + auto_delete={auto_delete}, + arguments={arguments} + """ + if self.verbose: + print(f"Declaring exchange with: {exchange_declare_info}") + + try: + exchange_obj = await self._channel.declare_exchange( + name=exchange, + type=exchange_type, + durable=durable, + auto_delete=auto_delete, + internal=internal, + arguments=arguments + ) + return exchange_obj + except Exception as e: + raise MrsalSetupError(f"Failed to declare async exchange: {e}") + def _declare_queue(self, queue: str, arguments: dict[str, str] | None, durable: bool, exclusive: bool, @@ -230,6 +269,37 @@ def _declare_queue(self, if self.verbose: self.log.info(f"Queue declared yo") + async def _async_declare_queue(self, + queue_name: str, + durable: bool = True, + exclusive: bool = False, + auto_delete: bool = False, + passive: bool = False, + arguments: dict[str, Any] | None = None) -> AioQueue: + """Declare a RabbitMQ queue asynchronously.""" + queue_declare_info = f""" + queue={queue_name}, + durable={durable}, + exclusive={exclusive}, + auto_delete={auto_delete}, + arguments={arguments} + """ + if self.verbose: + self.log.info(f"Declaring queue with: {queue_declare_info}") + + try: + queue_obj = await self._channel.declare_queue( + name=queue_name, + durable=durable, + exclusive=exclusive, + auto_delete=auto_delete, + arguments=arguments, + passive=passive + ) + return queue_obj + except Exception as e: + raise MrsalSetupError(f"Failed to declare async queue: {e}") + def _declare_queue_binding(self, exchange: str, queue: str, routing_key: str | None, @@ -256,7 +326,27 @@ def _declare_queue_binding(self, raise MrsalSetupError(f'I failed binding the queue with : {e}') if self.verbose: self.log.info(f"Queue bound yo") - + + async def _async_declare_queue_binding(self, + queue: AioQueue, + exchange: AioExchange, + routing_key: str | None, + arguments: dict[str, Any] | None = None) -> None: + """Bind the queue to the exchange asynchronously.""" + binding_info = f""" + queue={queue.name}, + exchange={exchange.name}, + routing_key={routing_key}, + arguments={arguments} + """ + if self.verbose: + self.log.info(f"Binding queue to exchange with: {binding_info}") + + try: + await queue.bind(exchange, routing_key=routing_key, arguments=arguments) + except Exception as e: + raise MrsalSetupError(f"Failed to bind async queue: {e}") + def _ssl_setup(self) -> SSLContext: """_ssl_setup is private method we are using to connect with rabbit server via signed certificates and some TLS settings. @@ -268,10 +358,26 @@ def _ssl_setup(self) -> SSLContext: SSLContext """ - context = ssl.create_default_context(cafile=self.tls_dict['ca']) + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH, cafile=self.tls_dict['ca']) context.load_cert_chain(certfile=self.tls_dict['crt'], keyfile=self.tls_dict['key']) + # Ensure the server's certificate is verified + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True return context + def get_ssl_context(self, async_conn: bool = True) -> SSLOptions | SSLContext | None: + if self.ssl: + self.log.info("Setting up TLS connection") + context = self._ssl_setup() + # use_blocking is the same as sync + if async_conn: + ssl_options = pika.SSLOptions(context, self.host) + return ssl_options + else: + return context + else: + return None + def validate_payload(self, payload: Any, model: Type) -> None: """ Parses and validates the incoming message payload using the provided dataclass model. diff --git a/pyproject.toml b/pyproject.toml index 14310c5..6f4b3ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ license = "" maintainers = ["Raafat ", "Jon E Nesvold "] name = "mrsal" readme = "README.md" -version = "1.0.10b" +version = "1.1.0b" [tool.poetry.dependencies] colorlog = "^6.7.0" @@ -16,14 +16,16 @@ retry = "^0.9.2" tenacity = "^9.0.0" sphinx = "^8.0.2" neolibrary = "^0.9.4b1" +aio-pika = "^9.4.3" [tool.poetry.group.dev.dependencies] coverage = "^7.2.7" -pytest = "^7.4.0" sphinx = "^8.0.2" myst-parser = "^4.0.0" nox = "^2024.4.15" ruff = "^0.6.5" +pytest = "^8.3.3" +pytest-asyncio = "^0.24.0" [build-system] build-backend = "poetry.core.masonry.api" diff --git a/tests/test_mrsal_async_no_tls.py b/tests/test_mrsal_async_no_tls.py index 95cbbc9..148977f 100644 --- a/tests/test_mrsal_async_no_tls.py +++ b/tests/test_mrsal_async_no_tls.py @@ -1,11 +1,12 @@ -from mrsal.exceptions import MrsalAbortedSetup, MrsalSetupError +from aio_pika.exceptions import AMQPConnectionError +from pika.exceptions import UnroutableError +from pydantic import ValidationError import pytest -from unittest.mock import Mock, patch, MagicMock, call -from pika.exceptions import AMQPConnectionError, UnroutableError +from unittest.mock import AsyncMock, patch +from mrsal.amqp.subclass import MrsalAsyncAMQP from pydantic.dataclasses import dataclass from tenacity import RetryError -from mrsal.amqp.subclass import MrsalAMQP -from pika.adapters.asyncio_connection import AsyncioConnection + # Configuration and expected payload definition SETUP_ARGS = { @@ -14,9 +15,7 @@ 'credentials': ('user', 'password'), 'virtual_host': 'testboi', 'ssl': False, - 'use_blocking': False, 'heartbeat': 60, - 'blocked_connection_timeout': 60, 'prefetch_count': 1 } @@ -27,157 +26,259 @@ class ExpectedPayload: active: bool -# Fixture to mock the AsyncioConnection and the setup connection method +# Fixture to mock the async connection and its methods @pytest.fixture -def mock_async_amqp_connection(): - with patch('mrsal.amqp.subclass.pika.adapters.asyncio_connection.AsyncioConnection') as mock_async_connection, \ - patch('mrsal.amqp.subclass.MrsalAMQP.setup_async_connection', autospec=True) as mock_setup_async_connection: - - # Set up the mock behaviors for the connection and channel - mock_channel = MagicMock() - mock_connection = MagicMock() +async def mock_amqp_connection(): + with patch('aio_pika.connect_robust', new_callable=AsyncMock) as mock_connect_robust: + mock_channel = AsyncMock() + mock_connection = AsyncMock() mock_connection.channel.return_value = mock_channel - mock_async_connection.return_value = mock_connection - - # Ensure setup_async_connection does nothing during the tests - mock_setup_async_connection.return_value = None - - # Provide the mocks for use in the test - yield mock_connection, mock_channel, mock_setup_async_connection + mock_connect_robust.return_value = mock_connection -# Fixture to create a MrsalAMQP consumer with mocked channel for async mode -@pytest.fixture -def async_amqp_consumer(mock_async_amqp_connection): - mock_connection, mock_channel, _ = mock_async_amqp_connection - consumer = MrsalAMQP(**SETUP_ARGS) - consumer._channel = mock_channel # Inject the mocked channel into the consumer - consumer._connection = mock_connection # Inject the mocked async connection - return consumer - - -def test_retry_on_connection_failure_async(async_amqp_consumer, mock_async_amqp_connection): - """Test reconnection retries in async consumer mode.""" - mock_connection, mock_channel, mock_setup_async_connection = mock_async_amqp_connection - mock_channel.consume.side_effect = AMQPConnectionError("Connection lost") - - # Attempt to start the consumer, which should trigger the retry - with pytest.raises(RetryError): - async_amqp_consumer.start_consumer( - queue_name='test_q', - exchange_name='test_x', - exchange_type='direct', - routing_key='test_route', - callback=Mock(), - ) - - # Verify that setup_async_connection was retried 3 times - assert mock_setup_async_connection.call_count == 3 - - -def test_valid_message_processing_async(async_amqp_consumer): - """Test message processing with a valid payload in async mode.""" - valid_body = b'{"id": 1, "name": "Test", "active": true}' - mock_method_frame = MagicMock() - mock_properties = MagicMock() + # Return the connection and channel + return mock_connection, mock_channel - # Mock the consume method to yield a valid message - async_amqp_consumer._channel.consume.return_value = [(mock_method_frame, mock_properties, valid_body), (None, None, None)] - mock_callback = Mock() +@pytest.fixture +async def amqp_consumer(mock_amqp_connection): + # Await the connection fixture and unpack + mock_connection, mock_channel = await mock_amqp_connection + consumer = MrsalAsyncAMQP(**SETUP_ARGS) + consumer._connection = mock_connection # Inject the mocked connection + consumer._channel = mock_channel + return consumer # Return the consumer instance - try: - async_amqp_consumer.start_consumer( - queue_name='test_q', - exchange_name='test_x', - exchange_type='direct', - routing_key='test_route', - callback=mock_callback, - payload_model=ExpectedPayload - ) - except Exception: - print("Controlled exit of run_forever") - # Assert the callback was called once with the correct data - mock_callback.assert_called_once_with(mock_method_frame, mock_properties, valid_body) +@pytest.mark.asyncio +async def test_valid_message_processing(amqp_consumer): + """Test valid message processing in async consumer.""" + consumer = await amqp_consumer # Ensure we await it properly -def test_valid_message_processing_no_autoack_async(async_amqp_consumer): - """Test that a message is acknowledged on successful processing in async mode.""" valid_body = b'{"id": 1, "name": "Test", "active": true}' - mock_method_frame = MagicMock() - mock_method_frame.delivery_tag = 123 - mock_properties = MagicMock() - - async_amqp_consumer._channel.consume.return_value = [(mock_method_frame, mock_properties, valid_body)] - mock_callback = Mock() - - try: - async_amqp_consumer.start_consumer( - exchange_name="test_x", - exchange_type="direct", - queue_name="test_q", - routing_key="test_route", - callback=mock_callback, - payload_model=ExpectedPayload, - auto_ack=False - ) - except Exception: - print("Controlled exit of run_forever") - - async_amqp_consumer._channel.basic_ack.assert_called_once_with(delivery_tag=123) - - -def test_invalid_message_skipped_async(async_amqp_consumer): - """Test that invalid payloads are skipped in async mode.""" - invalid_body = b'{"id": "wrong_type", "name": "Test", "active": true}' - mock_method_frame = MagicMock() - mock_properties = MagicMock() - - # Mock the consume method to yield an invalid message - async_amqp_consumer._channel.consume.return_value = [(mock_method_frame, mock_properties, invalid_body), (None, None, None)] - - mock_callback = Mock() - - try: - async_amqp_consumer.start_consumer( - queue_name='test_queue', - auto_ack=True, - exchange_name='test_x', - exchange_type='direct', - routing_key='test_route', - callback=mock_callback, - payload_model=ExpectedPayload - ) - except Exception: - print("Controlled exit of run_forever") - - # Assert the callback was not called since the message should be skipped + mock_message = AsyncMock(body=valid_body, ack=AsyncMock(), reject=AsyncMock()) + + + # Manually set the app_id and message_id properties as simple attributes + mock_message.configure_mock(app_id="test_app", message_id="12345") + # Set up mocks for channel and queue interactions + mock_queue = AsyncMock() + async def message_generator(): + yield mock_message + + #mock_queue.iterator.return_value = message_generator() + mock_queue.iterator = message_generator + + # Properly mock the async context manager for the queue + mock_queue.__aenter__.return_value = mock_queue + mock_queue.__aexit__.return_value = AsyncMock() + + # Ensure declare_queue returns the mocked queue + consumer._channel.declare_queue.return_value = mock_queue + + mock_callback = AsyncMock() + # Start the consumer with a mocked queue + await consumer.start_consumer( + queue_name='test_q', + callback=mock_callback, + routing_key='test_route', + exchange_name='test_x', + exchange_type='direct' + ) + mock_callback.reset_mock() + + async for message in mock_queue.iterator(): + await mock_callback(message) + + mock_callback.assert_called_once_with(mock_message) + # mock_callback.assert_called_with(mock_message) + + +@pytest.mark.asyncio +async def test_invalid_payload_validation(amqp_consumer): + """Test invalid payload handling in async consumer.""" + invalid_payload = b'{"id": "wrong_type", "name": 123, "active": "maybe"}' + consumer = await amqp_consumer + + # Create an invalid payload + mock_message = AsyncMock(body=invalid_payload, ack=AsyncMock(), reject=AsyncMock()) + mock_message.configure_mock(app_id="test_app", message_id="12345") + + mock_queue = AsyncMock() + async def message_generator(): + yield mock_message + + mock_queue.iterator = message_generator + mock_queue.__aenter__.return_value = mock_queue + mock_queue.__aexit__.return_value = AsyncMock() + + consumer._channel.declare_queue.return_value = mock_queue + + mock_callback = AsyncMock() + + await consumer.start_consumer( + queue_name='test_q', + callback=mock_callback, + routing_key='test_route', + exchange_name='test_x', + exchange_type='direct', + payload_model=ExpectedPayload, + auto_ack=True + ) + + # Process the messages and expect a validation error + async for message in mock_queue.iterator(): + with pytest.raises(ValidationError): + consumer.validate_payload(message.body, ExpectedPayload) + await mock_callback(message) + + mock_callback.assert_called_once_with(mock_message) + + +@pytest.mark.asyncio +async def test_invalid_message_skipped(amqp_consumer): + """Test that invalid messages are skipped and not processed.""" + invalid_payload = b'{"id": "wrong_type", "name": 123, "active": "maybe"}' + consumer = await amqp_consumer # Ensure we await it properly + + # Create an invalid payload + mock_message = AsyncMock(body=invalid_payload, ack=AsyncMock(), reject=AsyncMock()) + mock_message.configure_mock(app_id="test_app", message_id="12345") + + # Set up mocks for channel and queue interactions + mock_queue = AsyncMock() + + async def message_generator(): + yield mock_message + + mock_queue.iterator = message_generator + + # Properly mock the async context manager for the queue + mock_queue.__aenter__.return_value = mock_queue + mock_queue.__aexit__.return_value = AsyncMock() + + # Ensure declare_queue returns the mocked queue + consumer._channel.declare_queue.return_value = mock_queue + + # Mock the callback (should not be called) + mock_callback = AsyncMock() + + # Start the consumer with a mocked queue and model for validation + await consumer.start_consumer( + queue_name='test_q', + callback=mock_callback, + routing_key='test_route', + exchange_name='test_x', + exchange_type='direct', + payload_model=ExpectedPayload, + auto_ack=True, + requeue=False + ) + + # Process the messages and expect a validation error + async for message in mock_queue.iterator(): + with pytest.raises(ValidationError): + consumer.validate_payload(message.body, ExpectedPayload) + await message.reject() # Reject the invalid message + + # Assert that the callback was never called mock_callback.assert_not_called() + # Assert that the message was rejected + mock_message.reject.assert_called_once() + + + +@pytest.mark.asyncio +async def test_requeue_on_invalid_message(amqp_consumer): + """Test that invalid messages are requeued when auto_ack is False.""" + invalid_payload = b'{"id": "wrong_type", "name": 123, "active": "maybe"}' + consumer = await amqp_consumer # Ensure we await it properly + + # Create an invalid payload + mock_message = AsyncMock(body=invalid_payload, ack=AsyncMock(), reject=AsyncMock(), nack=AsyncMock()) + mock_message.configure_mock(app_id="test_app", message_id="12345") + + # Set up mocks for channel and queue interactions + mock_queue = AsyncMock() + + async def message_generator(): + yield mock_message + + mock_queue.iterator = message_generator + + # Properly mock the async context manager for the queue + mock_queue.__aenter__.return_value = mock_queue + mock_queue.__aexit__.return_value = AsyncMock() + + # Ensure declare_queue returns the mocked queue + consumer._channel.declare_queue.return_value = mock_queue + + # Mock the callback (should not be called) + mock_callback = AsyncMock() + + # Start the consumer with a mocked queue and model for validation + await consumer.start_consumer( + queue_name='test_q', + callback=mock_callback, + routing_key='test_route', + exchange_name='test_x', + exchange_type='direct', + payload_model=ExpectedPayload, # Use the model to validate the payload + auto_ack=False # Disable auto-ack to manually control the message acknowledgement + ) + + # Process the messages and expect a validation error + async for message in mock_queue.iterator(): + with pytest.raises(ValidationError): + consumer.validate_payload(message.body, ExpectedPayload) + # Manually nack (requeue) the invalid message + await message.nack(requeue=True) + + # Assert that the callback was never called + mock_callback.assert_not_called() -def test_requeue_on_validation_failure_async(async_amqp_consumer): - """Test that a message is requeued on validation failure in async mode.""" - invalid_body = b'{"id": "wrong_type", "name": "Test", "active": true}' - mock_method_frame = MagicMock() - mock_method_frame.delivery_tag = 123 - mock_properties = MagicMock() - - # Mock the consume method to yield an invalid message - async_amqp_consumer._channel.consume.return_value = [(mock_method_frame, mock_properties, invalid_body), (None, None, None)] - - with patch.object(async_amqp_consumer._channel, 'basic_nack') as mock_nack: - try: - async_amqp_consumer.start_consumer( - queue_name='test_q', - auto_ack=False, - exchange_name='test_x', - exchange_type='direct', - routing_key='test_route', - payload_model=ExpectedPayload - ) - except Exception: - print("Controlled exit of run_forever") - - # Assert that basic_nack was called with requeue=True - mock_nack.assert_called_once_with(delivery_tag=123, requeue=True) + # Assert that the message was requeued (nack with requeue=True) + mock_message.nack.assert_called_once_with(requeue=True) + + +#@pytest.mark.asyncio +#async def test_retry_on_connection_failure(amqp_consumer): +# """Test that retry is activated when an AMQPConnectionError occurs during message consumption.""" +# consumer = await amqp_consumer +# +# # Mock the message and its properties +# valid_body = b'{"id": 1, "name": "Test", "active": true}' +# mock_message = AsyncMock(body=valid_body, ack=AsyncMock(), reject=AsyncMock()) +# +# # Set up mocks for queue and message +# mock_queue = AsyncMock() +# async def message_generator(): +# yield mock_message +# +# mock_queue.iterator = message_generator +# mock_queue.__aenter__.return_value = mock_queue +# mock_queue.__aexit__.return_value = AsyncMock() +# consumer._channel.declare_queue.return_value = mock_queue +# +# # Patch the aio_pika consume method to raise AMQPConnectionError +# with patch.object(consumer._channel, 'consume', side_effect=AMQPConnectionError("Connection lost")) as mock_consume: +# +# # Patch the setup_async_connection method to track retries +# with patch.object(MrsalAsyncAMQP, 'setup_async_connection', wraps=consumer.setup_async_connection) as mock_setup: +# +# # Assert that the retry mechanism kicks in for connection failure +# with pytest.raises(RetryError): # Expect RetryError after 3 failed attempts +# await consumer.start_consumer( +# queue_name='test_q', +# callback=AsyncMock(), +# routing_key='test_route', +# exchange_name='test_x_retry', +# exchange_type='direct' +# ) +# +# # Verify that setup_async_connection was retried 3 times +# assert mock_setup.call_count == 3 +# # Ensure consume was called before the error +# assert mock_consume.call_count == 1 diff --git a/tests/test_mrsal_blocking_no_tls.py b/tests/test_mrsal_blocking_no_tls.py index 307a092..8a960ef 100644 --- a/tests/test_mrsal_blocking_no_tls.py +++ b/tests/test_mrsal_blocking_no_tls.py @@ -3,7 +3,7 @@ from pika.exceptions import AMQPConnectionError, UnroutableError from pydantic.dataclasses import dataclass from tenacity import RetryError -from mrsal.amqp.subclass import MrsalAMQP +from mrsal.amqp.subclass import MrsalBlockingAMQP # Configuration and expected payload definition SETUP_ARGS = { @@ -12,7 +12,6 @@ 'credentials': ('user', 'password'), 'virtual_host': 'testboi', 'ssl': False, - 'use_blocking': True, 'heartbeat': 60, 'blocked_connection_timeout': 60, 'prefetch_count': 1 @@ -29,7 +28,7 @@ class ExpectedPayload: @pytest.fixture def mock_amqp_connection(): with patch('mrsal.amqp.subclass.pika.BlockingConnection') as mock_blocking_connection, \ - patch('mrsal.amqp.subclass.MrsalAMQP.setup_blocking_connection', autospec=True) as mock_setup_blocking_connection: + patch('mrsal.amqp.subclass.MrsalBlockingAMQP.setup_blocking_connection', autospec=True) as mock_setup_blocking_connection: # Set up the mock behaviors for the connection and channel mock_channel = MagicMock() @@ -44,11 +43,11 @@ def mock_amqp_connection(): yield mock_connection, mock_channel, mock_setup_blocking_connection -# Fixture to create a MrsalAMQP consumer with mocked channel +# Fixture to create a MrsalBlockingAMQP consumer with mocked channel @pytest.fixture def amqp_consumer(mock_amqp_connection): mock_connection, mock_channel, _ = mock_amqp_connection - consumer = MrsalAMQP(**SETUP_ARGS) + consumer = MrsalBlockingAMQP(**SETUP_ARGS) consumer._channel = mock_channel # Inject the mocked channel into the consumer return consumer diff --git a/tests/test_mrsal_blocking_tls.py b/tests/test_mrsal_blocking_tls.py index 7a437e3..c2ec2e3 100644 --- a/tests/test_mrsal_blocking_tls.py +++ b/tests/test_mrsal_blocking_tls.py @@ -3,7 +3,7 @@ from unittest.mock import patch from pydantic import ValidationError -from mrsal.amqp.subclass import MrsalAMQP +from mrsal.amqp.subclass import MrsalBlockingAMQP from tests.conftest import SETUP_ARGS @@ -16,7 +16,7 @@ def test_ssl_setup_with_valid_paths(self): 'RABBITMQ_KEY': 'test_key.key', 'RABBITMQ_CAFILE': 'test_ca.ca' }, clear=True): - consumer = MrsalAMQP(**SETUP_ARGS, ssl=True, use_blocking=True) + consumer = MrsalBlockingAMQP(**SETUP_ARGS, ssl=True) # Check if SSL paths are correctly loaded and blocking is used self.assertEqual(consumer.tls_dict['crt'], 'test_cert.crt') @@ -30,12 +30,12 @@ def test_ssl_setup_with_valid_paths(self): }) def test_ssl_setup_with_missing_paths(self): with self.assertRaises(ValidationError): - MrsalAMQP(**SETUP_ARGS, ssl=True, use_blocking=True) + MrsalBlockingAMQP(**SETUP_ARGS, ssl=True) @patch.dict(os.environ, {}, clear=True) def test_ssl_setup_without_env_vars(self): with self.assertRaises(ValidationError): - MrsalAMQP(**SETUP_ARGS, ssl=True, use_blocking=True) + MrsalBlockingAMQP(**SETUP_ARGS, ssl=True) if __name__ == '__main__':