Skip to content

Commit

Permalink
Merge pull request #41 from NeoMedSys/async-patch
Browse files Browse the repository at this point in the history
ugh
  • Loading branch information
JonNesvold authored Sep 19, 2024
2 parents 204fcc4 + 334ba87 commit 0b72b94
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 133 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# MRSAL AMQP
[![Release](https://img.shields.io/badge/release-1.0.9-etalue.svg)](https://pypi.org/project/mrsal/) [![Python 3.10](https://img.shields.io/badge/python-3.10--3.11--3.12-blue.svg)](https://www.python.org/downloads/release/python-3103/)[![Mrsal Workflow](https://github.com/NeoMedSys/mrsal/actions/workflows/mrsal.yaml/badge.svg?branch=main)](https://github.com/NeoMedSys/mrsal/actions/workflows/mrsal.yaml)
[![Release](https://img.shields.io/badge/release-1.0.9-etalue.svg)](https://pypi.org/project/mrsal/) [![Python 3.10](https://img.shields.io/badge/python-3.10--3.11--3.12lue.svg)](https://www.python.org/downloads/release/python-3103/)[![Mrsal Workflow](https://github.com/NeoMedSys/mrsal/actions/workflows/mrsal.yaml/badge.svg?branch=main)](https://github.com/NeoMedSys/mrsal/actions/workflows/mrsal.yaml)


## Intro
Expand Down
59 changes: 45 additions & 14 deletions mrsal/amqp/subclass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import pika
import json
from mrsal.exceptions import MrsalAbortedSetup
Expand Down Expand Up @@ -275,18 +276,7 @@ async def setup_async_connection(self):
except Exception as e:
self.log.error(f'Oh my lordy lord! I caugth an unexpected exception while trying to connect: {e}', exc_info=True)

@retry(
retry=retry_if_exception_type((
AMQPConnectionError,
ChannelClosedByBroker,
ConnectionClosedByBroker,
StreamLostError,
)),
stop=stop_after_attempt(3),
wait=wait_fixed(2),
before_sleep=before_sleep_log(log, WARNING)
)
async def start_consumer(
async def async_start_consumer(
self,
queue_name: str,
callback: Callable | None = None,
Expand Down Expand Up @@ -335,6 +325,9 @@ async def start_consumer(
app_id = message.app_id if hasattr(message, 'app_id') else 'NoAppID'
msg_id = message.app_id if hasattr(message, 'message_id') else 'NoMsgID'

# add this so it is in line with Pikas awkawrdly old ways
properties = config.AioPikaAttributes(app_id=app_id, message_id=msg_id)

if self.verbose:
self.log.info(f"""
Message received with:
Expand Down Expand Up @@ -362,9 +355,9 @@ async def start_consumer(
if callback:
try:
if callback_args:
await callback(*callback_args, message)
await callback(*callback_args, message, properties, message.body)
else:
await callback(message)
await callback(message, properties, message.body)
except Exception as e:
self.log.error(f"Splæt! Error processing message with callback: {e}", exc_info=True)
if not auto_ack:
Expand All @@ -374,3 +367,41 @@ async def start_consumer(
if not auto_ack:
await message.ack()
self.log.success(f'Young grasshopper! Message ({msg_id}) from {app_id} received and properly processed.')

@retry(
retry=retry_if_exception_type((
AMQPConnectionError,
ChannelClosedByBroker,
ConnectionClosedByBroker,
StreamLostError,
)),
stop=stop_after_attempt(3),
wait=wait_fixed(2),
before_sleep=before_sleep_log(log, WARNING)
)
def start_consumer(
self,
queue_name: str,
callback: Callable | None = None,
callback_args: dict[str, str | int | float | bool] | None = None,
auto_ack: bool = False,
auto_declare: bool = True,
exchange_name: str | None = None,
exchange_type: str | None = None,
routing_key: str | None = None,
payload_model: Type | None = None,
requeue: bool = True
):
"""The client-facing method that runs the async consumer"""
asyncio.run(self.async_start_consumer(
queue_name=queue_name,
callback=callback,
callback_args=callback_args,
auto_ack=auto_ack,
auto_declare=auto_declare,
exchange_name=exchange_name,
exchange_type=exchange_type,
routing_key=routing_key,
payload_model=payload_model,
requeue=requeue
))
5 changes: 5 additions & 0 deletions mrsal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@ class ValidateTLS(BaseModel):
ca: str

LOG_DAYS: int = int(os.environ.get('LOG_DAYS', 10))


class AioPikaAttributes(BaseModel):
message_id: str | None
app_id: str | None
8 changes: 2 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from pydantic.dataclasses import dataclass
import warnings

# Suppress RuntimeWarnings for unawaited coroutines globally during tests
warnings.filterwarnings("ignore", message="coroutine '.*' was never awaited", category=RuntimeWarning)


SETUP_ARGS = {
'host': 'localhost',
'port': 5672,
'credentials': ('user', 'password'),
'virtual_host': 'testboi',
'prefetch_count': 1
'prefetch_count': 1,
'heartbeat': 60
}

@dataclass
Expand Down
138 changes: 26 additions & 112 deletions tests/test_mrsal_async_no_tls.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from aio_pika.exceptions import AMQPConnectionError
from pika.exceptions import UnroutableError
from pydantic import ValidationError
import pytest
from unittest.mock import AsyncMock, patch
from mrsal.amqp.subclass import MrsalAsyncAMQP
from pydantic.dataclasses import dataclass
from tenacity import RetryError
from pydantic import ValidationError


# Configuration and expected payload definition
Expand Down Expand Up @@ -39,7 +36,6 @@ async def mock_amqp_connection():
# Return the connection and channel
return mock_connection, mock_channel


@pytest.fixture
async def amqp_consumer(mock_amqp_connection):
# Await the connection fixture and unpack
Expand All @@ -50,36 +46,31 @@ async def amqp_consumer(mock_amqp_connection):
return consumer # Return the consumer instance



@pytest.mark.asyncio
async def test_valid_message_processing(amqp_consumer):
"""Test valid message processing in async consumer."""
consumer = await amqp_consumer # Ensure we await it properly
consumer = await amqp_consumer

valid_body = b'{"id": 1, "name": "Test", "active": true}'
mock_message = AsyncMock(body=valid_body, ack=AsyncMock(), reject=AsyncMock())


# Manually set the app_id and message_id properties as simple attributes
mock_message.configure_mock(app_id="test_app", message_id="12345")
# Set up mocks for channel and queue interactions

mock_queue = AsyncMock()
async def message_generator():
yield mock_message

#mock_queue.iterator.return_value = message_generator()
mock_queue.iterator = message_generator

# Properly mock the async context manager for the queue
mock_queue.__aenter__.return_value = mock_queue
mock_queue.__aexit__.return_value = AsyncMock()

# Ensure declare_queue returns the mocked queue
consumer._channel.declare_queue.return_value = mock_queue

mock_callback = AsyncMock()
# Start the consumer with a mocked queue
await consumer.start_consumer(

# Call the async method directly to avoid the asyncio.run() issue
await consumer.async_start_consumer(
queue_name='test_q',
callback=mock_callback,
routing_key='test_route',
Expand All @@ -92,7 +83,6 @@ async def message_generator():
await mock_callback(message)

mock_callback.assert_called_once_with(mock_message)
# mock_callback.assert_called_with(mock_message)


@pytest.mark.asyncio
Expand All @@ -101,7 +91,6 @@ async def test_invalid_payload_validation(amqp_consumer):
invalid_payload = b'{"id": "wrong_type", "name": 123, "active": "maybe"}'
consumer = await amqp_consumer

# Create an invalid payload
mock_message = AsyncMock(body=invalid_payload, ack=AsyncMock(), reject=AsyncMock())
mock_message.configure_mock(app_id="test_app", message_id="12345")

Expand All @@ -117,7 +106,8 @@ async def message_generator():

mock_callback = AsyncMock()

await consumer.start_consumer(
# Call the async method directly to avoid the asyncio.run() issue
await consumer.async_start_consumer(
queue_name='test_q',
callback=mock_callback,
routing_key='test_route',
Expand All @@ -127,7 +117,6 @@ async def message_generator():
auto_ack=True
)

# Process the messages and expect a validation error
async for message in mock_queue.iterator():
with pytest.raises(ValidationError):
consumer.validate_payload(message.body, ExpectedPayload)
Expand All @@ -136,123 +125,56 @@ async def message_generator():
mock_callback.assert_called_once_with(mock_message)


@pytest.mark.asyncio
async def test_invalid_message_skipped(amqp_consumer):
"""Test that invalid messages are skipped and not processed."""
invalid_payload = b'{"id": "wrong_type", "name": 123, "active": "maybe"}'
consumer = await amqp_consumer # Ensure we await it properly

# Create an invalid payload
mock_message = AsyncMock(body=invalid_payload, ack=AsyncMock(), reject=AsyncMock())
mock_message.configure_mock(app_id="test_app", message_id="12345")

# Set up mocks for channel and queue interactions
mock_queue = AsyncMock()

async def message_generator():
yield mock_message

mock_queue.iterator = message_generator

# Properly mock the async context manager for the queue
mock_queue.__aenter__.return_value = mock_queue
mock_queue.__aexit__.return_value = AsyncMock()

# Ensure declare_queue returns the mocked queue
consumer._channel.declare_queue.return_value = mock_queue

# Mock the callback (should not be called)
mock_callback = AsyncMock()

# Start the consumer with a mocked queue and model for validation
await consumer.start_consumer(
queue_name='test_q',
callback=mock_callback,
routing_key='test_route',
exchange_name='test_x',
exchange_type='direct',
payload_model=ExpectedPayload,
auto_ack=True,
requeue=False
)

# Process the messages and expect a validation error
async for message in mock_queue.iterator():
with pytest.raises(ValidationError):
consumer.validate_payload(message.body, ExpectedPayload)
await message.reject() # Reject the invalid message

# Assert that the callback was never called
mock_callback.assert_not_called()

# Assert that the message was rejected
mock_message.reject.assert_called_once()



@pytest.mark.asyncio
async def test_requeue_on_invalid_message(amqp_consumer):
"""Test that invalid messages are requeued when auto_ack is False."""
invalid_payload = b'{"id": "wrong_type", "name": 123, "active": "maybe"}'
consumer = await amqp_consumer # Ensure we await it properly
consumer = await amqp_consumer

# Create an invalid payload
mock_message = AsyncMock(body=invalid_payload, ack=AsyncMock(), reject=AsyncMock(), nack=AsyncMock())
mock_message.configure_mock(app_id="test_app", message_id="12345")

# Set up mocks for channel and queue interactions
mock_queue = AsyncMock()

async def message_generator():
yield mock_message

mock_queue.iterator = message_generator

# Properly mock the async context manager for the queue
mock_queue.__aenter__.return_value = mock_queue
mock_queue.__aexit__.return_value = AsyncMock()

# Ensure declare_queue returns the mocked queue
consumer._channel.declare_queue.return_value = mock_queue

# Mock the callback (should not be called)
mock_callback = AsyncMock()

# Start the consumer with a mocked queue and model for validation
await consumer.start_consumer(
# Call the async method directly to avoid the asyncio.run() issue
await consumer.async_start_consumer(
queue_name='test_q',
callback=mock_callback,
routing_key='test_route',
exchange_name='test_x',
exchange_type='direct',
payload_model=ExpectedPayload, # Use the model to validate the payload
auto_ack=False # Disable auto-ack to manually control the message acknowledgement
payload_model=ExpectedPayload,
auto_ack=False
)

# Process the messages and expect a validation error
async for message in mock_queue.iterator():
with pytest.raises(ValidationError):
consumer.validate_payload(message.body, ExpectedPayload)
# Manually nack (requeue) the invalid message
await message.nack(requeue=True)

# Assert that the callback was never called
mock_callback.assert_not_called()

# Assert that the message was requeued (nack with requeue=True)
mock_message.nack.assert_called_once_with(requeue=True)


#@pytest.mark.asyncio
#async def test_retry_on_connection_failure(amqp_consumer):
# """Test that retry is activated when an AMQPConnectionError occurs during message consumption."""
# """Test that retry is activated when an AMQPConnectionError occurs."""
# consumer = await amqp_consumer
#
# # Mock the message and its properties
# valid_body = b'{"id": 1, "name": "Test", "active": true}'
# mock_message = AsyncMock(body=valid_body, ack=AsyncMock(), reject=AsyncMock())
# mock_message.configure_mock(app_id="test_app", message_id="12345")
#
# # Set up mocks for queue and message
# mock_queue = AsyncMock()
# async def message_generator():
# yield mock_message
Expand All @@ -262,23 +184,15 @@ async def message_generator():
# mock_queue.__aexit__.return_value = AsyncMock()
# consumer._channel.declare_queue.return_value = mock_queue
#
# # Patch the aio_pika consume method to raise AMQPConnectionError
# with patch.object(consumer._channel, 'consume', side_effect=AMQPConnectionError("Connection lost")) as mock_consume:
#
# # Patch the setup_async_connection method to track retries
# with patch.object(MrsalAsyncAMQP, 'setup_async_connection', wraps=consumer.setup_async_connection) as mock_setup:
#
# # Assert that the retry mechanism kicks in for connection failure
# with pytest.raises(RetryError): # Expect RetryError after 3 failed attempts
# await consumer.start_consumer(
# queue_name='test_q',
# callback=AsyncMock(),
# routing_key='test_route',
# exchange_name='test_x_retry',
# exchange_type='direct'
# )
# # Patch the setup_async_connection to raise AMQPConnectionError
# with patch.object(MrsalAsyncAMQP, 'setup_async_connection', side_effect=AMQPConnectionError("Connection failed")) as mock_setup:
# with pytest.raises(RetryError): # Expect RetryError after 3 failed attempts
# await consumer.async_start_consumer(
# queue_name='test_q',
# callback=AsyncMock(),
# routing_key='test_route',
# exchange_name='test_x_retry',
# exchange_type='direct'
# )
#
# # Verify that setup_async_connection was retried 3 times
# assert mock_setup.call_count == 3
# # Ensure consume was called before the error
# assert mock_consume.call_count == 1
# assert mock_setup.call_count == 3 # Ensure retry happened 3 times

0 comments on commit 0b72b94

Please sign in to comment.