Skip to content

Commit

Permalink
Configure redrive policy on "infra create_queues" command
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardoklosowski committed Sep 5, 2024
1 parent 067b1e2 commit b697352
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 8 deletions.
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ classifiers = [
python = "^3.8"
boto3 = "^1.35"
click = "^8.1"
graphlib_backport = {version = "^1.1", python = "<3.9"}
pydantic = "^2.8"
tomli = "^2.0"

Expand Down Expand Up @@ -81,6 +82,12 @@ strict = true
plugins = ["pydantic.mypy"]
files = ["src/**/*.py", "tests/**/*.py"]

[[tool.mypy.overrides]]
module = [
"graphlib.*",
]
ignore_missing_imports = true

[tool.pytest.ini_options]
testpaths = ["tests"]

Expand Down
39 changes: 34 additions & 5 deletions src/qldebugger/actions/infra.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import json
import logging
from typing import TYPE_CHECKING, Dict

from botocore.exceptions import ClientError
from graphlib import TopologicalSorter

from qldebugger.aws import get_account_id, get_client
from qldebugger.config import get_config
from qldebugger.config.file_parser import ConfigSecretString
from qldebugger.config.file_parser import ConfigQueue, ConfigSecretString

if TYPE_CHECKING:
from mypy_boto3_sqs.literals import QueueAttributeNameType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,10 +49,33 @@ def create_topics() -> None:
def create_queues() -> None:
sqs = get_client('sqs')
queues = get_config().queues

for queue_name in queues:
logger.info('Creating %r queue...', queue_name)
sqs.create_queue(QueueName=queue_name)
order = TopologicalSorter(
{
name: {queue.redrive_policy.dead_letter_queue} if queue.redrive_policy else set()
for name, queue in queues.items()
}
).static_order()

for queue_name in order:
attributes: Dict['QueueAttributeNameType', str] = {}
if redrive_policy := queues.get(queue_name, ConfigQueue()).redrive_policy:
logger.debug('Checking dead letter queue (%r) for %r...', redrive_policy.dead_letter_queue, queue_name)
dead_letter_queue_attributes = sqs.get_queue_attributes(
QueueUrl=redrive_policy.dead_letter_queue, AttributeNames=['QueueArn']
)
attributes['RedrivePolicy'] = json.dumps(
{
'deadLetterTargetArn': dead_letter_queue_attributes['Attributes']['QueueArn'],
'maxReceiveCount': redrive_policy.max_receive_count,
}
)
try:
queue_url = sqs.get_queue_url(QueueName=queue_name)
logger.info('Updating %r queue...', queue_name)
sqs.set_queue_attributes(QueueUrl=queue_url['QueueUrl'], Attributes=attributes)
except ClientError:
logger.info('Creating %r queue...', queue_name)
sqs.create_queue(QueueName=queue_name, Attributes=attributes)


def subscribe_topics() -> None:
Expand Down
172 changes: 169 additions & 3 deletions tests/qldebugger/actions/test_infra.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import json
from collections import OrderedDict
from random import randint
from unittest.mock import Mock, patch
from typing import Dict
from unittest.mock import Mock, call, patch

from botocore.exceptions import ClientError

from qldebugger.actions.infra import create_queues, create_secrets, create_topics, subscribe_topics
from qldebugger.config.file_parser import (
ConfigQueue,
ConfigQueueRedrivePolicy,
ConfigSecretBinary,
ConfigSecretString,
ConfigTopic,
Expand Down Expand Up @@ -105,17 +109,179 @@ def test_run(self, mock_get_config: Mock, mock_get_client: Mock) -> None:
class TestCreateQueues:
@patch('qldebugger.actions.infra.get_client')
@patch('qldebugger.actions.infra.get_config')
def test_run(self, mock_get_config: Mock, mock_get_client: Mock) -> None:
def test_create_queues(self, mock_get_config: Mock, mock_get_client: Mock) -> None:
queues_names = [randstr() for _ in range(randint(2, 5))]

mock_get_config.return_value.queues = {queue_name: ConfigQueue() for queue_name in queues_names}
mock_get_client.return_value.get_queue_url.side_effect = ClientError({}, '')

create_queues()

mock_get_client.assert_called_once_with('sqs')
assert mock_get_client.return_value.create_queue.call_count == len(queues_names)
for queue_name in queues_names:
mock_get_client.return_value.create_queue.assert_any_call(QueueName=queue_name)
mock_get_client.return_value.create_queue.assert_any_call(QueueName=queue_name, Attributes={})

@patch('qldebugger.actions.infra.get_client')
@patch('qldebugger.actions.infra.get_config')
def test_update_queues(self, mock_get_config: Mock, mock_get_client: Mock) -> None:
queues_names = [randstr() for _ in range(randint(2, 5))]
host = randstr()

mock_get_config.return_value.queues = {queue_name: ConfigQueue() for queue_name in queues_names}
mock_get_client.return_value.get_queue_url.side_effect = lambda QueueName: { # noqa: N803
'QueueUrl': f'http://{host}/{QueueName}'
}

create_queues()

mock_get_client.assert_called_once_with('sqs')
assert mock_get_client.return_value.set_queue_attributes.call_count == len(queues_names)
for queue_name in queues_names:
mock_get_client.return_value.set_queue_attributes.assert_any_call(
QueueUrl=f'http://{host}/{queue_name}', Attributes={}
)

@patch('qldebugger.actions.infra.get_client')
@patch('qldebugger.actions.infra.get_config')
def test_dead_letter_queue_should_created_first(self, mock_get_config: Mock, mock_get_client: Mock) -> None:
a_max_receive_count = randint(1, 10)
b_max_receive_count = randint(1, 10)
queues = OrderedDict(
[
(
'b',
ConfigQueue(
redrive_policy=ConfigQueueRedrivePolicy(
dead_letter_queue='c', max_receive_count=b_max_receive_count
)
),
),
(
'a',
ConfigQueue(
redrive_policy=ConfigQueueRedrivePolicy(
dead_letter_queue='b', max_receive_count=a_max_receive_count
)
),
),
('c', ConfigQueue()),
]
)

mock_get_config.return_value.queues = queues
mock_get_client.return_value.get_queue_url.side_effect = ClientError({}, '')
mock_get_client.return_value.get_queue_attributes.side_effect = lambda QueueUrl, AttributeNames: { # noqa: N803
'Attributes': {'QueueArn': f'arn:aws:sqs:us-east-1:123456789012:{QueueUrl}'}
}

create_queues()

mock_get_client.assert_called_once_with('sqs')
assert mock_get_client.return_value.create_queue.call_args_list == [
call(QueueName='c', Attributes={}),
call(
QueueName='b',
Attributes={
'RedrivePolicy': json.dumps(
{
'deadLetterTargetArn': 'arn:aws:sqs:us-east-1:123456789012:c',
'maxReceiveCount': b_max_receive_count,
}
)
},
),
call(
QueueName='a',
Attributes={
'RedrivePolicy': json.dumps(
{
'deadLetterTargetArn': 'arn:aws:sqs:us-east-1:123456789012:b',
'maxReceiveCount': a_max_receive_count,
}
)
},
),
]

@patch('qldebugger.actions.infra.get_client')
@patch('qldebugger.actions.infra.get_config')
def test_create_dead_letter_queue_not_in_queues(self, mock_get_config: Mock, mock_get_client: Mock) -> None:
queue_name = randstr()
dead_letter_queue = randstr()
max_receive_count = randint(1, 10)

mock_get_config.return_value.queues = {
queue_name: ConfigQueue(
redrive_policy=ConfigQueueRedrivePolicy(
dead_letter_queue=dead_letter_queue, max_receive_count=max_receive_count
)
)
}
mock_get_client.return_value.get_queue_url.side_effect = ClientError({}, '')
mock_get_client.return_value.get_queue_attributes.side_effect = lambda QueueUrl, AttributeNames: { # noqa: N803
'Attributes': {'QueueArn': f'arn:aws:sqs:us-east-1:123456789012:{QueueUrl}'}
}

create_queues()

mock_get_client.assert_called_once_with('sqs')
assert mock_get_client.return_value.create_queue.call_args_list == [
call(QueueName=dead_letter_queue, Attributes={}),
call(
QueueName=queue_name,
Attributes={
'RedrivePolicy': json.dumps(
{
'deadLetterTargetArn': f'arn:aws:sqs:us-east-1:123456789012:{dead_letter_queue}',
'maxReceiveCount': max_receive_count,
}
)
},
),
]

@patch('qldebugger.actions.infra.get_client')
@patch('qldebugger.actions.infra.get_config')
def test_create_dead_letter_and_update_queue(self, mock_get_config: Mock, mock_get_client: Mock) -> None:
queue_name = randstr()
dead_letter_queue = randstr()
max_receive_count = randint(1, 10)
host = randstr()

def get_queue_url(*, QueueName: str) -> Dict[str, str]: # noqa: N803
if QueueName != queue_name:
raise ClientError({}, '')
return {'QueueUrl': f'http://{host}/{QueueName}'}

mock_get_config.return_value.queues = {
queue_name: ConfigQueue(
redrive_policy=ConfigQueueRedrivePolicy(
dead_letter_queue=dead_letter_queue, max_receive_count=max_receive_count
)
),
dead_letter_queue: ConfigQueue(),
}
mock_get_client.return_value.get_queue_url.side_effect = get_queue_url
mock_get_client.return_value.get_queue_attributes.side_effect = lambda QueueUrl, AttributeNames: { # noqa: N803
'Attributes': {'QueueArn': f'arn:aws:sqs:us-east-1:123456789012:{QueueUrl}'}
}

create_queues()

mock_get_client.assert_called_once_with('sqs')
mock_get_client.return_value.create_queue.assert_called_once_with(QueueName=dead_letter_queue, Attributes={})
mock_get_client.return_value.set_queue_attributes.assert_called_once_with(
QueueUrl=f'http://{host}/{queue_name}',
Attributes={
'RedrivePolicy': json.dumps(
{
'deadLetterTargetArn': f'arn:aws:sqs:us-east-1:123456789012:{dead_letter_queue}',
'maxReceiveCount': max_receive_count,
}
)
},
)


class TestSubscribeTopics:
Expand Down

0 comments on commit b697352

Please sign in to comment.