diff --git a/src/main/java/org/springframework/integration/aws/outbound/KplMessageHandler.java b/src/main/java/org/springframework/integration/aws/outbound/KplMessageHandler.java index 1b2fe6b8..ea9e19ab 100644 --- a/src/main/java/org/springframework/integration/aws/outbound/KplMessageHandler.java +++ b/src/main/java/org/springframework/integration/aws/outbound/KplMessageHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2019-2024 the original author or authors. + * Copyright 2019-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,7 @@ import org.springframework.expression.Expression; import org.springframework.expression.common.LiteralExpression; import org.springframework.integration.aws.support.AwsHeaders; +import org.springframework.integration.aws.support.KplBackpressureException; import org.springframework.integration.aws.support.UserRecordResponse; import org.springframework.integration.expression.ValueExpression; import org.springframework.integration.handler.AbstractMessageHandler; @@ -63,11 +64,16 @@ import org.springframework.util.StringUtils; /** - * The {@link AbstractMessageHandler} implementation for the Amazon Kinesis Producer - * Library {@code putRecord(s)}. + * The {@link AbstractMessageHandler} implementation for the Amazon Kinesis Producer Library {@code putRecord(s)}. + *

+ * The {@link KplBackpressureException} is thrown when backpressure handling is enabled and buffer is at max capacity. + * This exception can be handled with + * {@link org.springframework.integration.handler.advice.AbstractRequestHandlerAdvice}. + *

* * @author Arnaud Lecollaire * @author Artem Bilan + * @author Siddharth Jain * * @since 2.2 * @@ -99,6 +105,8 @@ public class KplMessageHandler extends AbstractAwsMessageHandler implement private volatile ScheduledFuture flushFuture; + private long backPressureThreshold = 0; + public KplMessageHandler(KinesisProducer kinesisProducer) { Assert.notNull(kinesisProducer, "'kinesisProducer' must not be null."); this.kinesisProducer = kinesisProducer; @@ -115,6 +123,19 @@ public void setConverter(Converter converter) { setMessageConverter(new ConvertingFromMessageConverter(converter)); } + /** + * Configure maximum records in flight for handling backpressure. By default, backpressure handling is not enabled. + * When backpressure handling is enabled and number of records in flight exceeds the threshold, a + * {@link KplBackpressureException} would be thrown. + * @param backPressureThreshold Set a value greater than 0 to enable backpressure handling. + * @since 3.0.9 + */ + public void setBackPressureThreshold(long backPressureThreshold) { + Assert.isTrue(backPressureThreshold >= 0, + "'backPressureThreshold must be greater than or equal to 0."); + this.backPressureThreshold = backPressureThreshold; + } + /** * Configure a {@link MessageConverter} for converting payload to {@code byte[]} for Kinesis record. * @param messageConverter the {@link MessageConverter} to use. @@ -368,6 +389,14 @@ private void setGlueSchemaIntoUserRecordIfAny(UserRecord userRecord, Message } private CompletableFuture handleUserRecord(UserRecord userRecord) { + if (this.backPressureThreshold > 0) { + var numberOfRecordsInFlight = this.kinesisProducer.getOutstandingRecordsCount(); + if (numberOfRecordsInFlight > this.backPressureThreshold) { + throw new KplBackpressureException("Cannot send record to Kinesis since buffer is at max capacity.", + userRecord); + } + } + ListenableFuture recordResult = this.kinesisProducer.addUserRecord(userRecord); return listenableFutureToCompletableFuture(recordResult) .thenApply(UserRecordResponse::new); @@ -403,7 +432,8 @@ private PutRecordRequest buildPutRecordRequest(Message message) { if (!StringUtils.hasText(partitionKey) && this.partitionKeyExpression != null) { partitionKey = this.partitionKeyExpression.getValue(getEvaluationContext(), message, String.class); } - Assert.state(partitionKey != null, "'partitionKey' must not be null for sending a Kinesis record. " + Assert.state(partitionKey != null, + "'partitionKey' must not be null for sending a Kinesis record." + "Consider configuring this handler with a 'partitionKey'( or 'partitionKeyExpression') " + "or supply an 'aws_partitionKey' message header."); diff --git a/src/main/java/org/springframework/integration/aws/support/KplBackpressureException.java b/src/main/java/org/springframework/integration/aws/support/KplBackpressureException.java new file mode 100644 index 00000000..f13c5641 --- /dev/null +++ b/src/main/java/org/springframework/integration/aws/support/KplBackpressureException.java @@ -0,0 +1,47 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.aws.support; + +import com.amazonaws.services.kinesis.producer.UserRecord; + +/** + * An exception triggered from {@link org.springframework.integration.aws.outbound.KplMessageHandler} while sending + * records to Kinesis when maximum number of records in flight exceeds the backpressure threshold. + * + * @author Siddharth Jain + * + * @since 3.0.9 + */ +public class KplBackpressureException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + private final UserRecord userRecord; + + public KplBackpressureException(String message, UserRecord userRecord) { + super(message); + this.userRecord = userRecord; + } + + /** + * Get the {@link UserRecord} related. + * @return {@link UserRecord} linked while sending the record to Kinesis. + */ + public UserRecord getUserRecord() { + return this.userRecord; + } +} diff --git a/src/test/java/org/springframework/integration/aws/outbound/KplMessageHandlerTests.java b/src/test/java/org/springframework/integration/aws/outbound/KplMessageHandlerTests.java new file mode 100644 index 00000000..b9449ae5 --- /dev/null +++ b/src/test/java/org/springframework/integration/aws/outbound/KplMessageHandlerTests.java @@ -0,0 +1,184 @@ +/* + * Copyright 2019-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.aws.outbound; + +import com.amazonaws.services.kinesis.producer.KinesisProducer; +import com.amazonaws.services.kinesis.producer.UserRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.integration.annotation.ServiceActivator; +import org.springframework.integration.aws.support.AwsHeaders; +import org.springframework.integration.aws.support.KplBackpressureException; +import org.springframework.integration.config.EnableIntegration; +import org.springframework.integration.handler.advice.RequestHandlerRetryAdvice; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessageHandlingException; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.clearInvocations; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** The class contains test cases for KplMessageHandler. + * + * @author Siddharth Jain + * + * @since 3.0.9 + */ +@SpringJUnitConfig +@DirtiesContext +public class KplMessageHandlerTests { + + @Autowired + protected KinesisProducer kinesisProducer; + + @Autowired + protected MessageChannel kinesisSendChannel; + + @Autowired + protected KplMessageHandler kplMessageHandler; + + @Test + @SuppressWarnings("unchecked") + void kplMessageHandlerWithRawPayloadBackpressureDisabledSuccess() { + given(this.kinesisProducer.addUserRecord(any(UserRecord.class))) + .willReturn(mock()); + final Message message = MessageBuilder + .withPayload("someMessage") + .setHeader(AwsHeaders.PARTITION_KEY, "somePartitionKey") + .setHeader(AwsHeaders.SEQUENCE_NUMBER, "10") + .setHeader("someHeaderKey", "someHeaderValue") + .build(); + + + ArgumentCaptor userRecordRequestArgumentCaptor = ArgumentCaptor + .forClass(UserRecord.class); + this.kplMessageHandler.setBackPressureThreshold(0); + this.kinesisSendChannel.send(message); + verify(this.kinesisProducer).addUserRecord(userRecordRequestArgumentCaptor.capture()); + verify(this.kinesisProducer, Mockito.never()).getOutstandingRecordsCount(); + UserRecord userRecord = userRecordRequestArgumentCaptor.getValue(); + assertThat(userRecord.getStreamName()).isEqualTo("someStream"); + assertThat(userRecord.getPartitionKey()).isEqualTo("somePartitionKey"); + assertThat(userRecord.getExplicitHashKey()).isNull(); + } + + @Test + @SuppressWarnings("unchecked") + void kplMessageHandlerWithRawPayloadBackpressureEnabledCapacityAvailable() { + given(this.kinesisProducer.addUserRecord(any(UserRecord.class))) + .willReturn(mock()); + this.kplMessageHandler.setBackPressureThreshold(2); + given(this.kinesisProducer.getOutstandingRecordsCount()) + .willReturn(1); + final Message message = MessageBuilder + .withPayload("someMessage") + .setHeader(AwsHeaders.PARTITION_KEY, "somePartitionKey") + .setHeader(AwsHeaders.SEQUENCE_NUMBER, "10") + .setHeader("someHeaderKey", "someHeaderValue") + .build(); + + + ArgumentCaptor userRecordRequestArgumentCaptor = ArgumentCaptor + .forClass(UserRecord.class); + + this.kinesisSendChannel.send(message); + verify(this.kinesisProducer).addUserRecord(userRecordRequestArgumentCaptor.capture()); + verify(this.kinesisProducer).getOutstandingRecordsCount(); + UserRecord userRecord = userRecordRequestArgumentCaptor.getValue(); + assertThat(userRecord.getStreamName()).isEqualTo("someStream"); + assertThat(userRecord.getPartitionKey()).isEqualTo("somePartitionKey"); + assertThat(userRecord.getExplicitHashKey()).isNull(); + } + + @Test + @SuppressWarnings("unchecked") + void kplMessageHandlerWithRawPayloadBackpressureEnabledCapacityInsufficient() { + given(this.kinesisProducer.addUserRecord(any(UserRecord.class))) + .willReturn(mock()); + this.kplMessageHandler.setBackPressureThreshold(2); + given(this.kinesisProducer.getOutstandingRecordsCount()) + .willReturn(5); + final Message message = MessageBuilder + .withPayload("someMessage") + .setHeader(AwsHeaders.PARTITION_KEY, "somePartitionKey") + .setHeader(AwsHeaders.SEQUENCE_NUMBER, "10") + .setHeader("someHeaderKey", "someHeaderValue") + .build(); + + assertThatExceptionOfType(RuntimeException.class) + .isThrownBy(() -> this.kinesisSendChannel.send(message)) + .withCauseInstanceOf(MessageHandlingException.class) + .withRootCauseExactlyInstanceOf(KplBackpressureException.class) + .withStackTraceContaining("Cannot send record to Kinesis since buffer is at max capacity."); + + verify(this.kinesisProducer, Mockito.never()).addUserRecord(any(UserRecord.class)); + verify(this.kinesisProducer).getOutstandingRecordsCount(); + } + + @AfterEach + public void tearDown() { + clearInvocations(this.kinesisProducer); + } + + @Configuration + @EnableIntegration + public static class ContextConfiguration { + + @Bean + public KinesisProducer kinesisProducer() { + return mock(); + } + + @Bean + public RequestHandlerRetryAdvice retryAdvice() { + RequestHandlerRetryAdvice requestHandlerRetryAdvice = new RequestHandlerRetryAdvice(); + requestHandlerRetryAdvice.setRetryTemplate(RetryTemplate.builder() + .retryOn(KplBackpressureException.class) + .exponentialBackoff(100, 2.0, 1000) + .maxAttempts(3) + .build()); + return requestHandlerRetryAdvice; + } + + @Bean + @ServiceActivator(inputChannel = "kinesisSendChannel", adviceChain = {"retryAdvice"}) + public MessageHandler kplMessageHandler(KinesisProducer kinesisProducer) { + KplMessageHandler kplMessageHandler = new KplMessageHandler(kinesisProducer); + kplMessageHandler.setAsync(true); + kplMessageHandler.setStream("someStream"); + return kplMessageHandler; + } + + } + +}