From b6973526b10a52bdd81efb96733d7e16abf8824a Mon Sep 17 00:00:00 2001 From: Eduardo Klosowski Date: Thu, 5 Sep 2024 02:38:25 -0300 Subject: [PATCH] Configure redrive policy on "infra create_queues" command --- pyproject.toml | 7 + src/qldebugger/actions/infra.py | 39 +++++- tests/qldebugger/actions/test_infra.py | 172 ++++++++++++++++++++++++- 3 files changed, 210 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4dfcb04..4eb71a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers = [ python = "^3.8" boto3 = "^1.35" click = "^8.1" +graphlib_backport = {version = "^1.1", python = "<3.9"} pydantic = "^2.8" tomli = "^2.0" @@ -81,6 +82,12 @@ strict = true plugins = ["pydantic.mypy"] files = ["src/**/*.py", "tests/**/*.py"] +[[tool.mypy.overrides]] +module = [ + "graphlib.*", +] +ignore_missing_imports = true + [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/src/qldebugger/actions/infra.py b/src/qldebugger/actions/infra.py index 6f85b6a..e435d4e 100644 --- a/src/qldebugger/actions/infra.py +++ b/src/qldebugger/actions/infra.py @@ -1,10 +1,16 @@ +import json import logging +from typing import TYPE_CHECKING, Dict from botocore.exceptions import ClientError +from graphlib import TopologicalSorter from qldebugger.aws import get_account_id, get_client from qldebugger.config import get_config -from qldebugger.config.file_parser import ConfigSecretString +from qldebugger.config.file_parser import ConfigQueue, ConfigSecretString + +if TYPE_CHECKING: + from mypy_boto3_sqs.literals import QueueAttributeNameType logger = logging.getLogger(__name__) @@ -43,10 +49,33 @@ def create_topics() -> None: def create_queues() -> None: sqs = get_client('sqs') queues = get_config().queues - - for queue_name in queues: - logger.info('Creating %r queue...', queue_name) - sqs.create_queue(QueueName=queue_name) + order = TopologicalSorter( + { + name: {queue.redrive_policy.dead_letter_queue} if queue.redrive_policy else set() + for name, queue in queues.items() + } + ).static_order() + + for queue_name in order: + attributes: Dict['QueueAttributeNameType', str] = {} + if redrive_policy := queues.get(queue_name, ConfigQueue()).redrive_policy: + logger.debug('Checking dead letter queue (%r) for %r...', redrive_policy.dead_letter_queue, queue_name) + dead_letter_queue_attributes = sqs.get_queue_attributes( + QueueUrl=redrive_policy.dead_letter_queue, AttributeNames=['QueueArn'] + ) + attributes['RedrivePolicy'] = json.dumps( + { + 'deadLetterTargetArn': dead_letter_queue_attributes['Attributes']['QueueArn'], + 'maxReceiveCount': redrive_policy.max_receive_count, + } + ) + try: + queue_url = sqs.get_queue_url(QueueName=queue_name) + logger.info('Updating %r queue...', queue_name) + sqs.set_queue_attributes(QueueUrl=queue_url['QueueUrl'], Attributes=attributes) + except ClientError: + logger.info('Creating %r queue...', queue_name) + sqs.create_queue(QueueName=queue_name, Attributes=attributes) def subscribe_topics() -> None: diff --git a/tests/qldebugger/actions/test_infra.py b/tests/qldebugger/actions/test_infra.py index 1e64e0a..80e5851 100644 --- a/tests/qldebugger/actions/test_infra.py +++ b/tests/qldebugger/actions/test_infra.py @@ -1,11 +1,15 @@ +import json +from collections import OrderedDict from random import randint -from unittest.mock import Mock, patch +from typing import Dict +from unittest.mock import Mock, call, patch from botocore.exceptions import ClientError from qldebugger.actions.infra import create_queues, create_secrets, create_topics, subscribe_topics from qldebugger.config.file_parser import ( ConfigQueue, + ConfigQueueRedrivePolicy, ConfigSecretBinary, ConfigSecretString, ConfigTopic, @@ -105,17 +109,179 @@ def test_run(self, mock_get_config: Mock, mock_get_client: Mock) -> None: class TestCreateQueues: @patch('qldebugger.actions.infra.get_client') @patch('qldebugger.actions.infra.get_config') - def test_run(self, mock_get_config: Mock, mock_get_client: Mock) -> None: + def test_create_queues(self, mock_get_config: Mock, mock_get_client: Mock) -> None: queues_names = [randstr() for _ in range(randint(2, 5))] mock_get_config.return_value.queues = {queue_name: ConfigQueue() for queue_name in queues_names} + mock_get_client.return_value.get_queue_url.side_effect = ClientError({}, '') create_queues() mock_get_client.assert_called_once_with('sqs') assert mock_get_client.return_value.create_queue.call_count == len(queues_names) for queue_name in queues_names: - mock_get_client.return_value.create_queue.assert_any_call(QueueName=queue_name) + mock_get_client.return_value.create_queue.assert_any_call(QueueName=queue_name, Attributes={}) + + @patch('qldebugger.actions.infra.get_client') + @patch('qldebugger.actions.infra.get_config') + def test_update_queues(self, mock_get_config: Mock, mock_get_client: Mock) -> None: + queues_names = [randstr() for _ in range(randint(2, 5))] + host = randstr() + + mock_get_config.return_value.queues = {queue_name: ConfigQueue() for queue_name in queues_names} + mock_get_client.return_value.get_queue_url.side_effect = lambda QueueName: { # noqa: N803 + 'QueueUrl': f'http://{host}/{QueueName}' + } + + create_queues() + + mock_get_client.assert_called_once_with('sqs') + assert mock_get_client.return_value.set_queue_attributes.call_count == len(queues_names) + for queue_name in queues_names: + mock_get_client.return_value.set_queue_attributes.assert_any_call( + QueueUrl=f'http://{host}/{queue_name}', Attributes={} + ) + + @patch('qldebugger.actions.infra.get_client') + @patch('qldebugger.actions.infra.get_config') + def test_dead_letter_queue_should_created_first(self, mock_get_config: Mock, mock_get_client: Mock) -> None: + a_max_receive_count = randint(1, 10) + b_max_receive_count = randint(1, 10) + queues = OrderedDict( + [ + ( + 'b', + ConfigQueue( + redrive_policy=ConfigQueueRedrivePolicy( + dead_letter_queue='c', max_receive_count=b_max_receive_count + ) + ), + ), + ( + 'a', + ConfigQueue( + redrive_policy=ConfigQueueRedrivePolicy( + dead_letter_queue='b', max_receive_count=a_max_receive_count + ) + ), + ), + ('c', ConfigQueue()), + ] + ) + + mock_get_config.return_value.queues = queues + mock_get_client.return_value.get_queue_url.side_effect = ClientError({}, '') + mock_get_client.return_value.get_queue_attributes.side_effect = lambda QueueUrl, AttributeNames: { # noqa: N803 + 'Attributes': {'QueueArn': f'arn:aws:sqs:us-east-1:123456789012:{QueueUrl}'} + } + + create_queues() + + mock_get_client.assert_called_once_with('sqs') + assert mock_get_client.return_value.create_queue.call_args_list == [ + call(QueueName='c', Attributes={}), + call( + QueueName='b', + Attributes={ + 'RedrivePolicy': json.dumps( + { + 'deadLetterTargetArn': 'arn:aws:sqs:us-east-1:123456789012:c', + 'maxReceiveCount': b_max_receive_count, + } + ) + }, + ), + call( + QueueName='a', + Attributes={ + 'RedrivePolicy': json.dumps( + { + 'deadLetterTargetArn': 'arn:aws:sqs:us-east-1:123456789012:b', + 'maxReceiveCount': a_max_receive_count, + } + ) + }, + ), + ] + + @patch('qldebugger.actions.infra.get_client') + @patch('qldebugger.actions.infra.get_config') + def test_create_dead_letter_queue_not_in_queues(self, mock_get_config: Mock, mock_get_client: Mock) -> None: + queue_name = randstr() + dead_letter_queue = randstr() + max_receive_count = randint(1, 10) + + mock_get_config.return_value.queues = { + queue_name: ConfigQueue( + redrive_policy=ConfigQueueRedrivePolicy( + dead_letter_queue=dead_letter_queue, max_receive_count=max_receive_count + ) + ) + } + mock_get_client.return_value.get_queue_url.side_effect = ClientError({}, '') + mock_get_client.return_value.get_queue_attributes.side_effect = lambda QueueUrl, AttributeNames: { # noqa: N803 + 'Attributes': {'QueueArn': f'arn:aws:sqs:us-east-1:123456789012:{QueueUrl}'} + } + + create_queues() + + mock_get_client.assert_called_once_with('sqs') + assert mock_get_client.return_value.create_queue.call_args_list == [ + call(QueueName=dead_letter_queue, Attributes={}), + call( + QueueName=queue_name, + Attributes={ + 'RedrivePolicy': json.dumps( + { + 'deadLetterTargetArn': f'arn:aws:sqs:us-east-1:123456789012:{dead_letter_queue}', + 'maxReceiveCount': max_receive_count, + } + ) + }, + ), + ] + + @patch('qldebugger.actions.infra.get_client') + @patch('qldebugger.actions.infra.get_config') + def test_create_dead_letter_and_update_queue(self, mock_get_config: Mock, mock_get_client: Mock) -> None: + queue_name = randstr() + dead_letter_queue = randstr() + max_receive_count = randint(1, 10) + host = randstr() + + def get_queue_url(*, QueueName: str) -> Dict[str, str]: # noqa: N803 + if QueueName != queue_name: + raise ClientError({}, '') + return {'QueueUrl': f'http://{host}/{QueueName}'} + + mock_get_config.return_value.queues = { + queue_name: ConfigQueue( + redrive_policy=ConfigQueueRedrivePolicy( + dead_letter_queue=dead_letter_queue, max_receive_count=max_receive_count + ) + ), + dead_letter_queue: ConfigQueue(), + } + mock_get_client.return_value.get_queue_url.side_effect = get_queue_url + mock_get_client.return_value.get_queue_attributes.side_effect = lambda QueueUrl, AttributeNames: { # noqa: N803 + 'Attributes': {'QueueArn': f'arn:aws:sqs:us-east-1:123456789012:{QueueUrl}'} + } + + create_queues() + + mock_get_client.assert_called_once_with('sqs') + mock_get_client.return_value.create_queue.assert_called_once_with(QueueName=dead_letter_queue, Attributes={}) + mock_get_client.return_value.set_queue_attributes.assert_called_once_with( + QueueUrl=f'http://{host}/{queue_name}', + Attributes={ + 'RedrivePolicy': json.dumps( + { + 'deadLetterTargetArn': f'arn:aws:sqs:us-east-1:123456789012:{dead_letter_queue}', + 'maxReceiveCount': max_receive_count, + } + ) + }, + ) class TestSubscribeTopics: