diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java index 3201c700e..6808f647a 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/AbstractPipelineMessageListenerContainer.java @@ -241,7 +241,7 @@ protected TaskExecutor createTaskExecutor() { ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); int poolSize = getContainerOptions().getMaxConcurrentMessages() * this.messageSources.size(); executor.setMaxPoolSize(poolSize); - executor.setCorePoolSize(getContainerOptions().getMaxMessagesPerPoll()); + executor.setCorePoolSize(poolSize); // Necessary due to a small racing condition between releasing the permit and releasing the thread. executor.setQueueCapacity(poolSize); executor.setAllowCoreThreadTimeOut(true); diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java index 7bdc22745..b9ba0e1ea 100644 --- a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/integration/SqsIntegrationTests.java @@ -17,6 +17,7 @@ import static java.util.Collections.singletonMap; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import com.fasterxml.jackson.databind.ObjectMapper; import io.awspring.cloud.sqs.CompletableFutures; @@ -53,8 +54,10 @@ import java.util.Collections; import java.util.List; import java.util.UUID; +import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; @@ -116,6 +119,8 @@ class SqsIntegrationTests extends BaseSqsIntegrationTest { static final String MANUALLY_CREATE_FACTORY_QUEUE_NAME = "manually_create_factory_test_queue"; + static final String MAX_CONCURRENT_MESSAGES_QUEUE_NAME = "max_concurrent_messages_test_queue"; + static final String LOW_RESOURCE_FACTORY = "lowResourceFactory"; static final String MANUAL_ACK_FACTORY = "manualAcknowledgementFactory"; @@ -139,7 +144,8 @@ static void beforeTests() { createQueue(client, RESOLVES_PARAMETER_TYPES_QUEUE_NAME, singletonMap(QueueAttributeName.VISIBILITY_TIMEOUT, "20")), createQueue(client, MANUALLY_CREATE_CONTAINER_QUEUE_NAME), - createQueue(client, MANUALLY_CREATE_FACTORY_QUEUE_NAME)).join(); + createQueue(client, MANUALLY_CREATE_FACTORY_QUEUE_NAME), + createQueue(client, MAX_CONCURRENT_MESSAGES_QUEUE_NAME)).join(); } @Autowired @@ -275,6 +281,20 @@ void manuallyCreatesFactory() throws Exception { assertThat(latchContainer.manuallyCreatedFactorySinkLatch.await(10, TimeUnit.SECONDS)).isTrue(); } + @Test + void maxConcurrentMessages() { + List> messages1 = IntStream.range(0, 10) + .mapToObj(index -> "maxConcurrentMessages-payload-" + index) + .map(payload -> MessageBuilder.withPayload(payload).build()).collect(Collectors.toList()); + List> messages2 = IntStream.range(10, 20) + .mapToObj(index -> "maxConcurrentMessages-payload-" + index) + .map(payload -> MessageBuilder.withPayload(payload).build()).collect(Collectors.toList()); + sqsTemplate.sendManyAsync(MAX_CONCURRENT_MESSAGES_QUEUE_NAME, messages1); + sqsTemplate.sendManyAsync(MAX_CONCURRENT_MESSAGES_QUEUE_NAME, messages2); + logger.debug("Sent messages to queue {} with messages {} and {}", MAX_CONCURRENT_MESSAGES_QUEUE_NAME, messages1, messages2); + assertDoesNotThrow(() -> latchContainer.maxConcurrentMessagesBarrier.await(10, TimeUnit.SECONDS)); + } + static class ReceivesMessageListener { @Autowired @@ -399,6 +419,18 @@ void listen(Message message, MessageHeaders headers, Acknowledgement ack } } + static class MaxConcurrentMessagesListener { + + @Autowired + LatchContainer latchContainer; + + @SqsListener(queueNames = MAX_CONCURRENT_MESSAGES_QUEUE_NAME, maxMessagesPerPoll = "10", maxConcurrentMessages = "20", id = "max-concurrent-messages") + void listen(String message) throws BrokenBarrierException, InterruptedException { + logger.debug("Received message in Listener Method: " + message); + latchContainer.maxConcurrentMessagesBarrier.await(); + } + } + static class LatchContainer { final CountDownLatch receivesMessageLatch = new CountDownLatch(1); @@ -421,6 +453,7 @@ static class LatchContainer { final CountDownLatch acknowledgementCallbackSuccessLatch = new CountDownLatch(1); final CountDownLatch acknowledgementCallbackBatchLatch = new CountDownLatch(1); final CountDownLatch acknowledgementCallbackErrorLatch = new CountDownLatch(1); + final CyclicBarrier maxConcurrentMessagesBarrier = new CyclicBarrier(21); } @@ -612,6 +645,11 @@ ResolvesParameterTypesListener resolvesParameterTypesListener() { return new ResolvesParameterTypesListener(); } + @Bean + MaxConcurrentMessagesListener maxConcurrentMessagesListener() { + return new MaxConcurrentMessagesListener(); + } + @Bean SqsListenerConfigurer customizer() { return registrar -> {