diff --git a/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/annotations/DataPrepperPlugin.java b/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/annotations/DataPrepperPlugin.java index c9345385dc..d94c0d8c19 100644 --- a/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/annotations/DataPrepperPlugin.java +++ b/data-prepper-api/src/main/java/org/opensearch/dataprepper/model/annotations/DataPrepperPlugin.java @@ -33,6 +33,8 @@ public @interface DataPrepperPlugin { String DEFAULT_DEPRECATED_NAME = ""; + String DEFAULT_ALTERNATE_NAME = ""; + /** * * @return Name of the plugin which should be unique for the type @@ -46,6 +48,12 @@ */ String deprecatedName() default DEFAULT_DEPRECATED_NAME; + /** + * + * @return Alternate name of the plugin which should be unique for the type + */ + String[] alternateNames() default {}; + /** * The class type for this plugin. * diff --git a/data-prepper-core/src/main/java/org/opensearch/dataprepper/DataPrepper.java b/data-prepper-core/src/main/java/org/opensearch/dataprepper/DataPrepper.java index 19c0822ce9..b2dd7a5541 100644 --- a/data-prepper-core/src/main/java/org/opensearch/dataprepper/DataPrepper.java +++ b/data-prepper-core/src/main/java/org/opensearch/dataprepper/DataPrepper.java @@ -97,17 +97,21 @@ public void shutdown() { shutdownServers(); } + private void shutdownPipelines() { + shutdownPipelines(DataPrepperShutdownOptions.defaultOptions()); + } + /** * Triggers the shutdown of all configured valid pipelines. */ - public void shutdownPipelines() { + public void shutdownPipelines(final DataPrepperShutdownOptions shutdownOptions) { transformationPipelines.forEach((name, pipeline) -> { pipeline.removeShutdownObserver(pipelinesObserver); }); for (final Pipeline pipeline : transformationPipelines.values()) { LOG.info("Shutting down pipeline: {}", pipeline.getName()); - pipeline.shutdown(); + pipeline.shutdown(shutdownOptions); } } @@ -127,11 +131,12 @@ public void shutdownServers() { * * @param pipeline name of the pipeline */ - public void shutdownPipelines(final String pipeline) { + public void shutdownPipeline(final String pipeline) { if (transformationPipelines.containsKey(pipeline)) { transformationPipelines.get(pipeline).shutdown(); } } + public PluginFactory getPluginFactory() { return pluginFactory; } diff --git a/data-prepper-core/src/main/java/org/opensearch/dataprepper/DataPrepperShutdownOptions.java b/data-prepper-core/src/main/java/org/opensearch/dataprepper/DataPrepperShutdownOptions.java new file mode 100644 index 0000000000..ea3edbf4f5 --- /dev/null +++ b/data-prepper-core/src/main/java/org/opensearch/dataprepper/DataPrepperShutdownOptions.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper; + +import java.time.Duration; + +public class DataPrepperShutdownOptions { + private final Duration bufferReadTimeout; + private final Duration bufferDrainTimeout; + + public static DataPrepperShutdownOptions defaultOptions() { + return new DataPrepperShutdownOptions(builder()); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Duration bufferReadTimeout; + private Duration bufferDrainTimeout; + + private Builder() { + } + + public Builder withBufferReadTimeout(final Duration bufferReadTimeout) { + this.bufferReadTimeout = bufferReadTimeout; + return this; + } + + public Builder withBufferDrainTimeout(final Duration bufferDrainTimeout) { + this.bufferDrainTimeout = bufferDrainTimeout; + return this; + } + + public DataPrepperShutdownOptions build() { + return new DataPrepperShutdownOptions(this); + } + } + + private DataPrepperShutdownOptions(final Builder builder) { + this.bufferReadTimeout = builder.bufferReadTimeout; + this.bufferDrainTimeout = builder.bufferDrainTimeout; + + if(bufferReadTimeout != null && bufferDrainTimeout != null) { + if (bufferReadTimeout.compareTo(bufferDrainTimeout) > 0) { + throw new IllegalArgumentException("Buffer read timeout cannot be greater than buffer drain timeout"); + } + } + } + + public Duration getBufferReadTimeout() { + return bufferReadTimeout; + } + + public Duration getBufferDrainTimeout() { + return bufferDrainTimeout; + } +} diff --git a/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/Pipeline.java b/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/Pipeline.java index 29bb69db46..de22876041 100644 --- a/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/Pipeline.java +++ b/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/Pipeline.java @@ -6,6 +6,7 @@ package org.opensearch.dataprepper.pipeline; import com.google.common.base.Preconditions; +import org.opensearch.dataprepper.DataPrepperShutdownOptions; import org.opensearch.dataprepper.acknowledgements.InactiveAcknowledgementSetManager; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.buffer.Buffer; @@ -41,7 +42,6 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import java.util.stream.Collectors; @@ -55,7 +55,7 @@ public class Pipeline { private static final Logger LOG = LoggerFactory.getLogger(Pipeline.class); private static final int SINK_LOGGING_FREQUENCY = (int) Duration.ofSeconds(60).toMillis(); - private volatile AtomicBoolean stopRequested; + private final PipelineShutdown pipelineShutdown; private final String name; private final Source source; @@ -137,7 +137,7 @@ public Pipeline( this.sinkExecutorService = PipelineThreadPoolExecutor.newFixedThreadPool(processorThreads, new PipelineThreadFactory(format("%s-sink-worker", name)), this); - stopRequested = new AtomicBoolean(false); + this.pipelineShutdown = new PipelineShutdown(buffer); } AcknowledgementSetManager getAcknowledgementSetManager() { @@ -176,7 +176,11 @@ public Collection getSinks() { } public boolean isStopRequested() { - return stopRequested.get(); + return pipelineShutdown.isStopRequested(); + } + + public boolean isForceStopReadingBuffers() { + return pipelineShutdown.isForceStopReadingBuffers(); } public Duration getPeerForwarderDrainTimeout() { @@ -267,6 +271,10 @@ public void execute() { } } + public synchronized void shutdown() { + shutdown(DataPrepperShutdownOptions.defaultOptions()); + } + /** * Initiates shutdown of the pipeline by: * 1. Stopping the source to prevent new items from being consumed @@ -276,19 +284,20 @@ public void execute() { * 5. Shutting down processors and sinks * 6. Stopping the sink ExecutorService */ - public synchronized void shutdown() { + public synchronized void shutdown(final DataPrepperShutdownOptions dataPrepperShutdownOptions) { LOG.info("Pipeline [{}] - Received shutdown signal with buffer drain timeout {}, processor shutdown timeout {}, " + "and sink shutdown timeout {}. Initiating the shutdown process", name, buffer.getDrainTimeout(), processorShutdownTimeout, sinkShutdownTimeout); try { source.stop(); - stopRequested.set(true); } catch (Exception ex) { LOG.error("Pipeline [{}] - Encountered exception while stopping the source, " + "proceeding with termination of process workers", name, ex); } - shutdownExecutorService(processorExecutorService, buffer.getDrainTimeout().toMillis() + processorShutdownTimeout.toMillis(), "processor"); + pipelineShutdown.shutdown(dataPrepperShutdownOptions); + + shutdownExecutorService(processorExecutorService, pipelineShutdown.getBufferDrainTimeout().plus(processorShutdownTimeout), "processor"); processorSets.forEach(processorSet -> processorSet.forEach(Processor::shutdown)); buffer.shutdown(); @@ -297,7 +306,7 @@ public synchronized void shutdown() { .map(DataFlowComponent::getComponent) .forEach(Sink::shutdown); - shutdownExecutorService(sinkExecutorService, sinkShutdownTimeout.toMillis(), "sink"); + shutdownExecutorService(sinkExecutorService, sinkShutdownTimeout, "sink"); LOG.info("Pipeline [{}] - Pipeline fully shutdown.", name); @@ -312,13 +321,13 @@ public void removeShutdownObserver(final PipelineObserver pipelineObserver) { observers.remove(pipelineObserver); } - private void shutdownExecutorService(final ExecutorService executorService, final long timeoutForTerminationInMillis, final String workerName) { + private void shutdownExecutorService(final ExecutorService executorService, final Duration timeoutForTermination, final String workerName) { LOG.info("Pipeline [{}] - Shutting down {} process workers.", name, workerName); executorService.shutdown(); try { - if (!executorService.awaitTermination(timeoutForTerminationInMillis, TimeUnit.MILLISECONDS)) { - LOG.warn("Pipeline [{}] - Workers did not terminate in time, forcing termination of {} workers.", name, workerName); + if (!executorService.awaitTermination(timeoutForTermination.toMillis(), TimeUnit.MILLISECONDS)) { + LOG.warn("Pipeline [{}] - Workers did not terminate in {}, forcing termination of {} workers.", name, timeoutForTermination, workerName); executorService.shutdownNow(); } } catch (InterruptedException ex) { diff --git a/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/PipelineShutdown.java b/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/PipelineShutdown.java new file mode 100644 index 0000000000..f3731e9d67 --- /dev/null +++ b/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/PipelineShutdown.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.pipeline; + +import org.opensearch.dataprepper.DataPrepperShutdownOptions; +import org.opensearch.dataprepper.model.buffer.Buffer; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; + +class PipelineShutdown { + private final AtomicBoolean stopRequested = new AtomicBoolean(false); + private final Duration bufferDrainTimeout; + private final Clock clock; + private Instant shutdownRequestedAt; + private Instant forceStopReadingBuffersAt; + private Duration bufferDrainTimeoutOverride; + + PipelineShutdown(final Buffer buffer) { + this(buffer, Clock.systemDefaultZone()); + } + + PipelineShutdown(final Buffer buffer, final Clock clock) { + bufferDrainTimeout = Objects.requireNonNull(buffer.getDrainTimeout()); + this.clock = clock; + } + + public void shutdown(final DataPrepperShutdownOptions dataPrepperShutdownOptions) { + final boolean stopPreviouslyRequested = stopRequested.get(); + if(stopPreviouslyRequested) { + return; + } + + stopRequested.set(true); + shutdownRequestedAt = now(); + + final Duration bufferReadTimeout = dataPrepperShutdownOptions.getBufferReadTimeout(); + if(bufferReadTimeout != null) { + forceStopReadingBuffersAt = shutdownRequestedAt.plus(bufferReadTimeout); + } + + final Duration bufferDrainTimeoutOverride = dataPrepperShutdownOptions.getBufferDrainTimeout(); + if(bufferDrainTimeoutOverride != null) { + this.bufferDrainTimeoutOverride = bufferDrainTimeoutOverride; + } + } + + boolean isStopRequested() { + return stopRequested.get(); + } + + boolean isForceStopReadingBuffers() { + return forceStopReadingBuffersAt != null && now().isAfter(forceStopReadingBuffersAt); + } + + public Duration getBufferDrainTimeout() { + return bufferDrainTimeoutOverride != null ? + bufferDrainTimeoutOverride : bufferDrainTimeout; + } + + private Instant now() { + return Instant.ofEpochMilli(clock.millis()); + } +} diff --git a/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/ProcessWorker.java b/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/ProcessWorker.java index b5538dfe73..8117848f9a 100644 --- a/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/ProcessWorker.java +++ b/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/ProcessWorker.java @@ -60,37 +60,41 @@ public void run() { while (!pipeline.isStopRequested()) { doRun(); } - LOG.info("Processor shutdown phase 1 complete."); + executeShutdownProcess(); + } catch (final Exception e) { + LOG.error("Encountered exception during pipeline {} processing", pipeline.getName(), e); + } + } - // Phase 2 - execute until buffers are empty - LOG.info("Beginning processor shutdown phase 2, iterating until buffers empty."); - while (!readBuffer.isEmpty()) { - doRun(); - } - LOG.info("Processor shutdown phase 2 complete."); + private void executeShutdownProcess() { + LOG.info("Processor shutdown phase 1 complete."); - // Phase 3 - execute until peer forwarder drain period expires (best effort to process all peer forwarder data) - final long drainTimeoutExpiration = System.currentTimeMillis() + pipeline.getPeerForwarderDrainTimeout().toMillis(); - LOG.info("Beginning processor shutdown phase 3, iterating until {}.", drainTimeoutExpiration); - while (System.currentTimeMillis() < drainTimeoutExpiration) { - doRun(); - } - LOG.info("Processor shutdown phase 3 complete."); + // Phase 2 - execute until buffers are empty + LOG.info("Beginning processor shutdown phase 2, iterating until buffers empty."); + while (!isBufferReadyForShutdown()) { + doRun(); + } + LOG.info("Processor shutdown phase 2 complete."); - // Phase 4 - prepare processors for shutdown - LOG.info("Beginning processor shutdown phase 4, preparing processors for shutdown."); - processors.forEach(Processor::prepareForShutdown); - LOG.info("Processor shutdown phase 4 complete."); + // Phase 3 - execute until peer forwarder drain period expires (best effort to process all peer forwarder data) + final long drainTimeoutExpiration = System.currentTimeMillis() + pipeline.getPeerForwarderDrainTimeout().toMillis(); + LOG.info("Beginning processor shutdown phase 3, iterating until {}.", drainTimeoutExpiration); + while (System.currentTimeMillis() < drainTimeoutExpiration) { + doRun(); + } + LOG.info("Processor shutdown phase 3 complete."); - // Phase 5 - execute until processors are ready to shutdown - LOG.info("Beginning processor shutdown phase 5, iterating until processors are ready to shutdown."); - while (!areComponentsReadyForShutdown()) { - doRun(); - } - LOG.info("Processor shutdown phase 5 complete."); - } catch (final Exception e) { - LOG.error("Encountered exception during pipeline {} processing", pipeline.getName(), e); + // Phase 4 - prepare processors for shutdown + LOG.info("Beginning processor shutdown phase 4, preparing processors for shutdown."); + processors.forEach(Processor::prepareForShutdown); + LOG.info("Processor shutdown phase 4 complete."); + + // Phase 5 - execute until processors are ready to shutdown + LOG.info("Beginning processor shutdown phase 5, iterating until processors are ready to shutdown."); + while (!areComponentsReadyForShutdown()) { + doRun(); } + LOG.info("Processor shutdown phase 5 complete."); } private void processAcknowledgements(List inputEvents, Collection> outputRecords) { @@ -153,11 +157,19 @@ private void doRun() { } private boolean areComponentsReadyForShutdown() { - return readBuffer.isEmpty() && processors.stream() + return isBufferReadyForShutdown() && processors.stream() .map(Processor::isReadyForShutdown) .allMatch(result -> result == true); } + private boolean isBufferReadyForShutdown() { + final boolean isBufferEmpty = readBuffer.isEmpty(); + final boolean forceStopReadingBuffers = pipeline.isForceStopReadingBuffers(); + final boolean isBufferReadyForShutdown = isBufferEmpty || forceStopReadingBuffers; + LOG.debug("isBufferReadyForShutdown={}, isBufferEmpty={}, forceStopReadingBuffers={}", isBufferReadyForShutdown, isBufferEmpty, forceStopReadingBuffers); + return isBufferReadyForShutdown; + } + /** * TODO Add isolator pattern - Fail if one of the Sink fails [isolator Pattern] * Uses the pipeline method to publish to sinks, waits for each of the sink result to be true before attempting to diff --git a/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/server/ShutdownHandler.java b/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/server/ShutdownHandler.java index 08449e0b21..e3da8fc51d 100644 --- a/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/server/ShutdownHandler.java +++ b/data-prepper-core/src/main/java/org/opensearch/dataprepper/pipeline/server/ShutdownHandler.java @@ -7,16 +7,23 @@ import com.sun.net.httpserver.HttpExchange; import com.sun.net.httpserver.HttpHandler; +import org.apache.http.NameValuePair; +import org.apache.http.client.utils.URLEncodedUtils; import org.opensearch.dataprepper.DataPrepper; +import org.opensearch.dataprepper.DataPrepperShutdownOptions; +import org.opensearch.dataprepper.pipeline.parser.DataPrepperDurationParser; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.ws.rs.HttpMethod; import java.io.IOException; import java.net.HttpURLConnection; +import java.net.URI; +import java.nio.charset.Charset; +import java.util.List; /** - * HttpHandler to handle requests to shut down the data prepper instance + * HttpHandler to handle requests to shut down the Data Prepper instance */ public class ShutdownHandler implements HttpHandler { private final DataPrepper dataPrepper; @@ -40,7 +47,8 @@ public void handle(final HttpExchange exchange) throws IOException { LOG.info("Received HTTP shutdown request to shutdown Data Prepper. Shutdown pipelines and server. User-Agent='{}'", exchange.getRequestHeaders().getFirst("User-Agent")); } - dataPrepper.shutdownPipelines(); + final DataPrepperShutdownOptions dataPrepperShutdownOptions = mapShutdownOptions(exchange.getRequestURI()); + dataPrepper.shutdownPipelines(dataPrepperShutdownOptions); exchange.sendResponseHeaders(HttpURLConnection.HTTP_OK, 0); } catch (final Exception e) { LOG.error("Caught exception shutting down data prepper", e); @@ -50,4 +58,28 @@ public void handle(final HttpExchange exchange) throws IOException { dataPrepper.shutdownServers(); } } + + private DataPrepperShutdownOptions mapShutdownOptions(final URI requestURI) { + final List queryParams = URLEncodedUtils.parse(requestURI, Charset.defaultCharset()); + + DataPrepperShutdownOptions.Builder shutdownOptionsBuilder + = DataPrepperShutdownOptions.builder(); + + for (final NameValuePair queryParam : queryParams) { + final String value = queryParam.getValue(); + switch(queryParam.getName()) { + case "bufferReadTimeout": + shutdownOptionsBuilder = + shutdownOptionsBuilder.withBufferReadTimeout(DataPrepperDurationParser.parse(value)); + break; + case "bufferDrainTimeout": + shutdownOptionsBuilder = + shutdownOptionsBuilder.withBufferDrainTimeout(DataPrepperDurationParser.parse(value)); + break; + default: + LOG.warn("Unrecognized query parameter '{}'", queryParam.getName()); + } + } + return shutdownOptionsBuilder.build(); + } } diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/DataPrepperShutdownOptionsTest.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/DataPrepperShutdownOptionsTest.java new file mode 100644 index 0000000000..42ea27e97b --- /dev/null +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/DataPrepperShutdownOptionsTest.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.Random; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class DataPrepperShutdownOptionsTest { + private Random random; + + @BeforeEach + void setUp() { + random = new Random(); + } + + @Test + void defaultOptions_returns_correct_defaults() { + final DataPrepperShutdownOptions options = DataPrepperShutdownOptions.defaultOptions(); + + assertThat(options.getBufferDrainTimeout(), nullValue()); + assertThat(options.getBufferReadTimeout(), nullValue()); + } + + @Test + void builder_returns_valid_builder() { + final DataPrepperShutdownOptions.Builder builder = DataPrepperShutdownOptions.builder(); + + assertThat(builder, notNullValue()); + } + + @Test + void build_throws_if_bufferReadTimeout_is_greater_than_bufferDrainTimeout() { + final Duration bufferDrainTimeout = Duration.ofSeconds(random.nextInt(20)); + final Duration bufferReadTimeout = bufferDrainTimeout.plus(1, ChronoUnit.MILLIS); + final DataPrepperShutdownOptions.Builder builder = DataPrepperShutdownOptions.builder() + .withBufferDrainTimeout(bufferDrainTimeout) + .withBufferReadTimeout(bufferReadTimeout); + assertThrows(IllegalArgumentException.class, builder::build); + } + + @Test + void build_creates_new_options_with_bufferReadTimeout_equal_to_bufferDrainTimeout() { + final Duration timeout = Duration.ofSeconds(random.nextInt(20)); + final DataPrepperShutdownOptions dataPrepperShutdownOptions = DataPrepperShutdownOptions.builder() + .withBufferDrainTimeout(timeout) + .withBufferReadTimeout(timeout) + .build(); + + + assertThat(dataPrepperShutdownOptions, notNullValue()); + assertThat(dataPrepperShutdownOptions.getBufferReadTimeout(), equalTo(timeout)); + assertThat(dataPrepperShutdownOptions.getBufferDrainTimeout(), equalTo(timeout)); + } + + @Test + void build_creates_new_options_with_bufferReadTimeout_less_than_bufferDrainTimeout() { + final Duration bufferReadTimeout = Duration.ofSeconds(random.nextInt(20)); + final Duration bufferDrainTimeout = Duration.ofSeconds(random.nextInt(20)).plus(bufferReadTimeout); + final DataPrepperShutdownOptions dataPrepperShutdownOptions = DataPrepperShutdownOptions.builder() + .withBufferDrainTimeout(bufferDrainTimeout) + .withBufferReadTimeout(bufferReadTimeout) + .build(); + + + assertThat(dataPrepperShutdownOptions, notNullValue()); + assertThat(dataPrepperShutdownOptions.getBufferReadTimeout(), equalTo(bufferReadTimeout)); + assertThat(dataPrepperShutdownOptions.getBufferDrainTimeout(), equalTo(bufferDrainTimeout)); + } +} \ No newline at end of file diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/DataPrepperTests.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/DataPrepperTests.java index 670d9664c6..3332be605f 100644 --- a/data-prepper-core/src/test/java/org/opensearch/dataprepper/DataPrepperTests.java +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/DataPrepperTests.java @@ -111,8 +111,9 @@ public void testGivenValidPipelineParserWhenExecuteThenAllPipelinesExecuteAndSer @Test public void testDataPrepperShutdown() throws NoSuchFieldException, IllegalAccessException { - createObjectUnderTest().shutdownPipelines(); - verify(pipeline).shutdown(); + final DataPrepperShutdownOptions dataPrepperShutdownOptions = mock(DataPrepperShutdownOptions.class); + createObjectUnderTest().shutdownPipelines(dataPrepperShutdownOptions); + verify(pipeline).shutdown(dataPrepperShutdownOptions); } @Test @@ -120,14 +121,14 @@ public void testDataPrepperShutdownPipeline() throws NoSuchFieldException, Illeg final Pipeline randomPipeline = mock(Pipeline.class); lenient().when(randomPipeline.isReady()).thenReturn(true); parseConfigurationFixture.put("Random Pipeline", randomPipeline); - createObjectUnderTest().shutdownPipelines("Random Pipeline"); + createObjectUnderTest().shutdownPipeline("Random Pipeline"); verify(randomPipeline).shutdown(); } @Test public void testDataPrepperShutdownNonExistentPipelineWithoutException() throws NoSuchFieldException, IllegalAccessException { - createObjectUnderTest().shutdownPipelines("Missing Pipeline"); + createObjectUnderTest().shutdownPipeline("Missing Pipeline"); } @Test diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/PipelineShutdownTest.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/PipelineShutdownTest.java new file mode 100644 index 0000000000..36ac4aa3d1 --- /dev/null +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/PipelineShutdownTest.java @@ -0,0 +1,186 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.pipeline; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.DataPrepperShutdownOptions; +import org.opensearch.dataprepper.model.buffer.Buffer; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Random; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class PipelineShutdownTest { + @Mock + private Buffer buffer; + + @Mock + private Clock clock; + + @Mock + private DataPrepperShutdownOptions dataPrepperShutdownOptions; + + private Duration bufferDrainTimeout; + private Random random; + + @BeforeEach + void setUp() { + random = new Random(); + bufferDrainTimeout = Duration.ofSeconds(random.nextInt(100) + 1_000); + + when(buffer.getDrainTimeout()).thenReturn(bufferDrainTimeout); + } + + private PipelineShutdown createObjectUnderTest() { + return new PipelineShutdown(buffer, clock); + } + + @Test + void constructor_throws_if_drainTimeout_is_null() { + reset(buffer); + when(buffer.getDrainTimeout()).thenReturn(null); + assertThrows(NullPointerException.class, this::createObjectUnderTest); + } + + @Test + void isStopRequested_returns_false() { + assertThat(createObjectUnderTest().isStopRequested(), equalTo(false)); + } + + @Test + void isForceStopReadingBuffers_returns_false() { + assertThat(createObjectUnderTest().isForceStopReadingBuffers(), equalTo(false)); + } + + @Test + void isStopRequested_returns_true_after_shutdown() { + final PipelineShutdown objectUnderTest = createObjectUnderTest(); + when(clock.millis()).thenReturn(Clock.systemUTC().millis()); + objectUnderTest.shutdown(dataPrepperShutdownOptions); + assertThat(objectUnderTest.isStopRequested(), equalTo(true)); + } + + @Test + void isStopRequested_returns_true_after_multiple_shutdown_calls() { + final PipelineShutdown objectUnderTest = createObjectUnderTest(); + when(clock.millis()).thenReturn(Clock.systemUTC().millis()); + for (int i = 0; i < 10; i++) { + objectUnderTest.shutdown(dataPrepperShutdownOptions); + } + assertThat(objectUnderTest.isStopRequested(), equalTo(true)); + } + + @Test + void isForceStopReadingBuffers_returns_false_after_shutdown_if_getBufferReadTimeout_is_null() { + final PipelineShutdown objectUnderTest = createObjectUnderTest(); + + when(dataPrepperShutdownOptions.getBufferReadTimeout()).thenReturn(null); + objectUnderTest.shutdown(dataPrepperShutdownOptions); + + assertThat(objectUnderTest.isForceStopReadingBuffers(), equalTo(false)); + } + + @Test + void isForceStopReadingBuffers_returns_false_after_shutdown_if_time_is_before_shutdown_plus_getBufferReadTimeout() { + final PipelineShutdown objectUnderTest = createObjectUnderTest(); + + when(dataPrepperShutdownOptions.getBufferReadTimeout()).thenReturn(Duration.ofSeconds(1)); + final Instant baseTime = Instant.now(); + when(clock.millis()) + .thenReturn(baseTime.toEpochMilli()); + + objectUnderTest.shutdown(dataPrepperShutdownOptions); + + assertThat(objectUnderTest.isForceStopReadingBuffers(), equalTo(false)); + } + + @Test + void isForceStopReadingBuffers_returns_true_after_shutdown_if_time_is_after_shutdown_plus_getBufferReadTimeout() { + final PipelineShutdown objectUnderTest = createObjectUnderTest(); + + when(dataPrepperShutdownOptions.getBufferReadTimeout()).thenReturn(Duration.ofSeconds(1)); + final Instant baseTime = Instant.now(); + when(clock.millis()) + .thenReturn(baseTime.toEpochMilli()) + .thenReturn(baseTime.plusSeconds(2).toEpochMilli()); + + objectUnderTest.shutdown(dataPrepperShutdownOptions); + + assertThat(objectUnderTest.isForceStopReadingBuffers(), equalTo(true)); + } + + @Test + void isForceStopReadingBuffers_returns_true_if_shutdown_is_called_multiple_times() { + final PipelineShutdown objectUnderTest = createObjectUnderTest(); + + when(dataPrepperShutdownOptions.getBufferReadTimeout()) + .thenReturn(Duration.ofSeconds(1)) + .thenReturn(Duration.ofSeconds(5)); + final Instant baseTime = Instant.now(); + when(clock.millis()) + .thenReturn(baseTime.toEpochMilli()) + .thenReturn(baseTime.plusSeconds(2).toEpochMilli()); + + objectUnderTest.shutdown(dataPrepperShutdownOptions); + objectUnderTest.shutdown(dataPrepperShutdownOptions); + + assertThat(objectUnderTest.isForceStopReadingBuffers(), equalTo(true)); + } + + @Test + void isForceStopReadingBuffers_returns_true_if_shutdown_is_called_in_between_isForceStopReadingBuffers_calls() { + final PipelineShutdown objectUnderTest = createObjectUnderTest(); + + when(dataPrepperShutdownOptions.getBufferReadTimeout()) + .thenReturn(Duration.ofSeconds(1)) + .thenReturn(Duration.ofSeconds(5)); + final Instant baseTime = Instant.now(); + when(clock.millis()) + .thenReturn(baseTime.toEpochMilli()) + .thenReturn(baseTime.plusSeconds(2).toEpochMilli()); + + objectUnderTest.shutdown(dataPrepperShutdownOptions); + assertThat(objectUnderTest.isForceStopReadingBuffers(), equalTo(true)); + + objectUnderTest.shutdown(dataPrepperShutdownOptions); + assertThat(objectUnderTest.isForceStopReadingBuffers(), equalTo(true)); + } + + + @Test + void getBufferDrainTimeout_returns_buffer_getDrainTimeout_if_shutdown_not_called() { + assertThat(createObjectUnderTest().getBufferDrainTimeout(), equalTo(bufferDrainTimeout)); + } + + @Test + void getBufferDrainTimeout_returns_buffer_getDrainTimeout_if_shutdown_called_without_bufferDrainTimeout() { + final PipelineShutdown objectUnderTest = createObjectUnderTest(); + when(dataPrepperShutdownOptions.getBufferDrainTimeout()).thenReturn(null); + objectUnderTest.shutdown(dataPrepperShutdownOptions); + assertThat(objectUnderTest.getBufferDrainTimeout(), equalTo(bufferDrainTimeout)); + } + + @Test + void getBufferDrainTimeout_returns_buffer_shutdownOptions_bufferDrainTimeout_if_provided() { + final PipelineShutdown objectUnderTest = createObjectUnderTest(); + Duration bufferDrainTimeoutFromOptions = Duration.ofSeconds(random.nextInt(100) + 100); + when(dataPrepperShutdownOptions.getBufferDrainTimeout()).thenReturn(bufferDrainTimeoutFromOptions); + objectUnderTest.shutdown(dataPrepperShutdownOptions); + assertThat(objectUnderTest.getBufferDrainTimeout(), equalTo(bufferDrainTimeoutFromOptions)); + } +} \ No newline at end of file diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/PipelineTests.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/PipelineTests.java index c2e0ad769f..66300969c0 100644 --- a/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/PipelineTests.java +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/PipelineTests.java @@ -9,6 +9,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.DataPrepperShutdownOptions; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.buffer.Buffer; @@ -93,9 +94,9 @@ void setup() { eventFactory = mock(EventFactory.class); acknowledgementSetManager = mock(AcknowledgementSetManager.class); sourceCoordinatorFactory = mock(SourceCoordinatorFactory.class); - processorShutdownTimeout = Duration.ofSeconds(Math.abs(new Random().nextInt(10))); - sinkShutdownTimeout = Duration.ofSeconds(Math.abs(new Random().nextInt(10))); - peerForwarderDrainTimeout = Duration.ofSeconds(Math.abs(new Random().nextInt(10))); + processorShutdownTimeout = Duration.ofMillis(Math.abs(new Random().nextInt(10))); + sinkShutdownTimeout = Duration.ofMillis(Math.abs(new Random().nextInt(10))); + peerForwarderDrainTimeout = Duration.ofMillis(Math.abs(new Random().nextInt(10))); } @AfterEach @@ -620,4 +621,33 @@ void shutdown_does_not_call_removed_PipelineObservers() { testPipeline.shutdown(); verifyNoInteractions(pipelineObserver); } + + @Test + void isForceStopReadingBuffers_returns_false_if_not_in_shutdown() { + final Source> testSource = new TestSource(); + final DataFlowComponent sinkDataFlowComponent = mock(DataFlowComponent.class); + final TestSink testSink = new TestSink(); + when(sinkDataFlowComponent.getComponent()).thenReturn(testSink); + testPipeline = new Pipeline(TEST_PIPELINE_NAME, testSource, new BlockingBuffer(TEST_PIPELINE_NAME), + Collections.emptyList(), Collections.singletonList(sinkDataFlowComponent), router, + eventFactory, acknowledgementSetManager, sourceCoordinatorFactory, TEST_PROCESSOR_THREADS, TEST_READ_BATCH_TIMEOUT, + processorShutdownTimeout, sinkShutdownTimeout, peerForwarderDrainTimeout); + assertThat(testPipeline.isForceStopReadingBuffers(), equalTo(false)); + } + + @Test + void isForceStopReadingBuffers_returns_true_if_bufferReadTimeout_is_exceeded() throws InterruptedException { + final Source> testSource = new TestSource(); + final DataFlowComponent sinkDataFlowComponent = mock(DataFlowComponent.class); + final TestSink testSink = new TestSink(); + when(sinkDataFlowComponent.getComponent()).thenReturn(testSink); + testPipeline = new Pipeline(TEST_PIPELINE_NAME, testSource, new BlockingBuffer(TEST_PIPELINE_NAME), + Collections.emptyList(), Collections.singletonList(sinkDataFlowComponent), router, + eventFactory, acknowledgementSetManager, sourceCoordinatorFactory, TEST_PROCESSOR_THREADS, TEST_READ_BATCH_TIMEOUT, + processorShutdownTimeout, sinkShutdownTimeout, peerForwarderDrainTimeout); + + testPipeline.shutdown(DataPrepperShutdownOptions.builder().withBufferReadTimeout(Duration.ofMillis(1)).build()); + Thread.sleep(2); + assertThat(testPipeline.isForceStopReadingBuffers(), is(true)); + } } diff --git a/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/server/ShutdownHandlerTest.java b/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/server/ShutdownHandlerTest.java index 19f1e839e1..0d36d05b1d 100644 --- a/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/server/ShutdownHandlerTest.java +++ b/data-prepper-core/src/test/java/org/opensearch/dataprepper/pipeline/server/ShutdownHandlerTest.java @@ -8,20 +8,29 @@ import com.sun.net.httpserver.Headers; import com.sun.net.httpserver.HttpExchange; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.DataPrepper; +import org.opensearch.dataprepper.DataPrepperShutdownOptions; import javax.ws.rs.HttpMethod; import java.io.IOException; import java.io.OutputStream; import java.net.HttpURLConnection; +import java.net.URI; +import java.time.Duration; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.lenient; @@ -51,23 +60,6 @@ public void beforeEach() { .thenReturn(new Headers()); } - @Test - public void testWhenShutdownWithPostRequestThenResponseWritten() throws IOException { - when(exchange.getRequestMethod()) - .thenReturn(HttpMethod.POST); - - shutdownHandler.handle(exchange); - - verify(dataPrepper, times(1)) - .shutdownPipelines(); - verify(exchange, times(1)) - .sendResponseHeaders(eq(HttpURLConnection.HTTP_OK), eq(0L)); - verify(responseBody, times(1)) - .close(); - verify(dataPrepper, times(1)) - .shutdownServers(); - } - @ParameterizedTest @ValueSource(strings = { HttpMethod.DELETE, HttpMethod.GET, HttpMethod.PATCH, HttpMethod.PUT }) public void testWhenShutdownWithProhibitedHttpMethodThenErrorResponseWritten(String httpMethod) throws IOException { @@ -82,17 +74,91 @@ public void testWhenShutdownWithProhibitedHttpMethodThenErrorResponseWritten(Str .close(); } - @Test - public void testHandleException() throws IOException { - when(exchange.getRequestMethod()) - .thenReturn(HttpMethod.POST); - doThrow(RuntimeException.class).when(dataPrepper).shutdownPipelines(); - - shutdownHandler.handle(exchange); + @Nested + class WithoutQueryParameters { + @BeforeEach + void setUp() { + when(exchange.getRequestURI()).thenReturn(URI.create("/shutdown")); + } + + @Test + public void testWhenShutdownWithPostRequestThenResponseWritten() throws IOException { + when(exchange.getRequestMethod()) + .thenReturn(HttpMethod.POST); + + shutdownHandler.handle(exchange); + + ArgumentCaptor shutdownOptionsArgumentCaptor = ArgumentCaptor.forClass(DataPrepperShutdownOptions.class); + verify(dataPrepper, times(1)) + .shutdownPipelines(shutdownOptionsArgumentCaptor.capture()); + verify(exchange, times(1)) + .sendResponseHeaders(eq(HttpURLConnection.HTTP_OK), eq(0L)); + verify(responseBody, times(1)) + .close(); + verify(dataPrepper, times(1)) + .shutdownServers(); + + DataPrepperShutdownOptions actualShutdownOptions = shutdownOptionsArgumentCaptor.getValue(); + assertThat(actualShutdownOptions.getBufferDrainTimeout(), nullValue()); + assertThat(actualShutdownOptions.getBufferReadTimeout(), nullValue()); + } + + @Test + public void testHandleException() throws IOException { + when(exchange.getRequestMethod()) + .thenReturn(HttpMethod.POST); + doThrow(RuntimeException.class).when(dataPrepper).shutdownPipelines(any(DataPrepperShutdownOptions.class)); + + shutdownHandler.handle(exchange); + + verify(exchange, times(1)) + .sendResponseHeaders(eq(HttpURLConnection.HTTP_INTERNAL_ERROR), eq(0L)); + verify(responseBody, times(1)) + .close(); + } + } - verify(exchange, times(1)) - .sendResponseHeaders(eq(HttpURLConnection.HTTP_INTERNAL_ERROR), eq(0L)); - verify(responseBody, times(1)) - .close(); + @Nested + class WithoutShutdownQueryParameters { + @BeforeEach + void setUp() { + when(exchange.getRequestURI()).thenReturn(URI.create("/shutdown?bufferReadTimeout=1500ms&bufferDrainTimeout=20s")); + } + + @Test + public void testWhenShutdownWithPostRequestThenResponseWritten() throws IOException { + when(exchange.getRequestMethod()) + .thenReturn(HttpMethod.POST); + + shutdownHandler.handle(exchange); + + final ArgumentCaptor shutdownOptionsArgumentCaptor = ArgumentCaptor.forClass(DataPrepperShutdownOptions.class); + verify(dataPrepper, times(1)) + .shutdownPipelines(shutdownOptionsArgumentCaptor.capture()); + verify(exchange, times(1)) + .sendResponseHeaders(eq(HttpURLConnection.HTTP_OK), eq(0L)); + verify(responseBody, times(1)) + .close(); + verify(dataPrepper, times(1)) + .shutdownServers(); + + final DataPrepperShutdownOptions actualShutdownOptions = shutdownOptionsArgumentCaptor.getValue(); + assertThat(actualShutdownOptions.getBufferDrainTimeout(), equalTo(Duration.ofSeconds(20))); + assertThat(actualShutdownOptions.getBufferReadTimeout(), equalTo(Duration.ofMillis(1500))); + } + + @Test + public void testHandleException() throws IOException { + when(exchange.getRequestMethod()) + .thenReturn(HttpMethod.POST); + doThrow(RuntimeException.class).when(dataPrepper).shutdownPipelines(any(DataPrepperShutdownOptions.class)); + + shutdownHandler.handle(exchange); + + verify(exchange, times(1)) + .sendResponseHeaders(eq(HttpURLConnection.HTTP_INTERNAL_ERROR), eq(0L)); + verify(responseBody, times(1)) + .close(); + } } } diff --git a/data-prepper-pipeline-parser/src/main/java/org/opensearch/dataprepper/pipeline/parser/DataPrepperDurationDeserializer.java b/data-prepper-pipeline-parser/src/main/java/org/opensearch/dataprepper/pipeline/parser/DataPrepperDurationDeserializer.java index 5005eb9f96..d6ae65e2b0 100644 --- a/data-prepper-pipeline-parser/src/main/java/org/opensearch/dataprepper/pipeline/parser/DataPrepperDurationDeserializer.java +++ b/data-prepper-pipeline-parser/src/main/java/org/opensearch/dataprepper/pipeline/parser/DataPrepperDurationDeserializer.java @@ -11,9 +11,6 @@ import java.io.IOException; import java.time.Duration; -import java.time.format.DateTimeParseException; -import java.util.regex.Matcher; -import java.util.regex.Pattern; /** * This deserializer is used for configurations that use a {@link Duration} type when deserialized by Jackson @@ -24,54 +21,17 @@ */ public class DataPrepperDurationDeserializer extends StdDeserializer { - private static final String SIMPLE_DURATION_REGEX = "^(0|[1-9]\\d*)(s|ms)$"; - private static final Pattern SIMPLE_DURATION_PATTERN = Pattern.compile(SIMPLE_DURATION_REGEX); - public DataPrepperDurationDeserializer() { this(null); } - protected DataPrepperDurationDeserializer(Class vc) { + protected DataPrepperDurationDeserializer(final Class vc) { super(vc); } @Override - public Duration deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + public Duration deserialize(final JsonParser p, final DeserializationContext ctxt) throws IOException { final String durationString = p.getValueAsString(); - Duration duration; - - try { - duration = Duration.parse(durationString); - } catch (final DateTimeParseException e) { - duration = parseSimpleDuration(durationString); - if (duration == null) { - throw new IllegalArgumentException("Durations must use either ISO 8601 notation or simple notations for seconds (60s) or milliseconds (100ms). Whitespace is ignored."); - } - } - - return duration; - } - - private Duration parseSimpleDuration(final String durationString) throws IllegalArgumentException { - final String durationStringNoSpaces = durationString.replaceAll("\\s", ""); - final Matcher matcher = SIMPLE_DURATION_PATTERN.matcher(durationStringNoSpaces); - if (!matcher.find()) { - return null; - } - - final long durationNumber = Long.parseLong(matcher.group(1)); - final String durationUnit = matcher.group(2); - - return getDurationFromUnitAndNumber(durationNumber, durationUnit); - } - - private Duration getDurationFromUnitAndNumber(final long durationNumber, final String durationUnit) { - switch (durationUnit) { - case "s": - return Duration.ofSeconds(durationNumber); - case "ms": - return Duration.ofMillis(durationNumber); - } - return null; + return DataPrepperDurationParser.parse(durationString); } } diff --git a/data-prepper-pipeline-parser/src/main/java/org/opensearch/dataprepper/pipeline/parser/DataPrepperDurationParser.java b/data-prepper-pipeline-parser/src/main/java/org/opensearch/dataprepper/pipeline/parser/DataPrepperDurationParser.java new file mode 100644 index 0000000000..e758278cf7 --- /dev/null +++ b/data-prepper-pipeline-parser/src/main/java/org/opensearch/dataprepper/pipeline/parser/DataPrepperDurationParser.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.pipeline.parser; + +import java.time.Duration; +import java.time.format.DateTimeParseException; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Parses strings into {@link Duration} supporting the Data Prepper duration format. + * It supports ISO 8601 notation ("PT20.345S", "PT15M", etc.) and simple durations for + * seconds (60s) and milliseconds (100ms). It does not support combining the units for simple durations ("60s 100ms" is not allowed). + * Whitespace is ignored and leading zeroes are not allowed. + * @since 2.10 + */ +public class DataPrepperDurationParser { + private static final String SIMPLE_DURATION_REGEX = "^(0|[1-9]\\d*)(s|ms)$"; + private static final Pattern SIMPLE_DURATION_PATTERN = Pattern.compile(SIMPLE_DURATION_REGEX); + + public static Duration parse(final String durationString) { + try { + return Duration.parse(durationString); + } catch (final DateTimeParseException e) { + final Duration duration = parseSimpleDuration(durationString); + if (duration == null) { + throw new IllegalArgumentException("Durations must use either ISO 8601 notation or simple notations for seconds (60s) or milliseconds (100ms). Whitespace is ignored."); + } + return duration; + } + } + + private static Duration parseSimpleDuration(final String durationString) throws IllegalArgumentException { + final String durationStringNoSpaces = durationString.replaceAll("\\s", ""); + final Matcher matcher = SIMPLE_DURATION_PATTERN.matcher(durationStringNoSpaces); + if (!matcher.find()) { + return null; + } + + final long durationNumber = Long.parseLong(matcher.group(1)); + final String durationUnit = matcher.group(2); + + return getDurationFromUnitAndNumber(durationNumber, durationUnit); + } + + private static Duration getDurationFromUnitAndNumber(final long durationNumber, final String durationUnit) { + switch (durationUnit) { + case "s": + return Duration.ofSeconds(durationNumber); + case "ms": + return Duration.ofMillis(durationNumber); + } + return null; + } + +} diff --git a/data-prepper-pipeline-parser/src/test/java/org/opensearch/dataprepper/pipeline/parser/DataPrepperDurationParserTest.java b/data-prepper-pipeline-parser/src/test/java/org/opensearch/dataprepper/pipeline/parser/DataPrepperDurationParserTest.java new file mode 100644 index 0000000000..4913e9c545 --- /dev/null +++ b/data-prepper-pipeline-parser/src/test/java/org/opensearch/dataprepper/pipeline/parser/DataPrepperDurationParserTest.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.pipeline.parser; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.time.Duration; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class DataPrepperDurationParserTest { + + @ParameterizedTest + @ValueSource(strings = {"6s1s", "60ms 100s", "20.345s", "-1s", "06s", "100m", "100sm", "100"}) + void invalidDurationStringsThrowIllegalArgumentException(final String durationString) { + assertThrows(IllegalArgumentException.class, () -> DataPrepperDurationParser.parse(durationString)); + } + + @Test + void ISO_8601_duration_string_returns_correct_duration() { + final String durationString = "PT15M"; + final Duration result = DataPrepperDurationParser.parse(durationString); + assertThat(result, equalTo(Duration.ofMinutes(15))); + } + + @ParameterizedTest + @ValueSource(strings = {"0s", "0ms"}) + void simple_duration_strings_of_0_return_correct_duration(final String durationString) { + final Duration result = DataPrepperDurationParser.parse(durationString); + + assertThat(result, equalTo(Duration.ofSeconds(0))); + } + + @ParameterizedTest + @ValueSource(strings = {"60s", "60000ms", "60 s", "60000 ms", " 60 s "}) + void simple_duration_strings_of_60_seconds_return_correct_duration(final String durationString) { + final Duration result = DataPrepperDurationParser.parse(durationString); + + assertThat(result, equalTo(Duration.ofSeconds(60))); + } + + @ParameterizedTest + @ValueSource(strings = {"5s", "5000ms", "5 s", "5000 ms", " 5 s "}) + void simple_duration_strings_of_5_seconds_return_correct_duration(final String durationString) { + final Duration result = DataPrepperDurationParser.parse(durationString); + + assertThat(result, equalTo(Duration.ofSeconds(5))); + } +} \ No newline at end of file diff --git a/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/ClasspathPluginProvider.java b/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/ClasspathPluginProvider.java index f5217ef8f5..764c83f4db 100644 --- a/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/ClasspathPluginProvider.java +++ b/data-prepper-plugin-framework/src/main/java/org/opensearch/dataprepper/plugin/ClasspathPluginProvider.java @@ -15,8 +15,11 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Function; +import java.util.function.Predicate; import java.util.stream.Collectors; +import static org.opensearch.dataprepper.model.annotations.DataPrepperPlugin.DEFAULT_ALTERNATE_NAME; import static org.opensearch.dataprepper.model.annotations.DataPrepperPlugin.DEFAULT_DEPRECATED_NAME; /** @@ -69,30 +72,34 @@ private Map, Class>> scanForPlugins() { final Map, Class>> pluginsMap = new HashMap<>(dataPrepperPluginClasses.size()); for (final Class concretePluginClass : dataPrepperPluginClasses) { - final DataPrepperPlugin dataPrepperPluginAnnotation = concretePluginClass.getAnnotation(DataPrepperPlugin.class); - final String pluginName = dataPrepperPluginAnnotation.name(); - final Class supportedType = dataPrepperPluginAnnotation.pluginType(); - - final Map, Class> supportTypeToPluginTypeMap = - pluginsMap.computeIfAbsent(pluginName, k -> new HashMap<>()); - supportTypeToPluginTypeMap.put(supportedType, concretePluginClass); - - addOptionalDeprecatedPluginName(pluginsMap, concretePluginClass); + // plugin name + addPossiblePluginName(pluginsMap, concretePluginClass, DataPrepperPlugin::name, name -> true); + // deprecated plugin name + addPossiblePluginName(pluginsMap, concretePluginClass, DataPrepperPlugin::deprecatedName, + deprecatedPluginName -> !deprecatedPluginName.equals(DEFAULT_DEPRECATED_NAME)); + // alternate plugin names + for (final String alternateName: concretePluginClass.getAnnotation(DataPrepperPlugin.class).alternateNames()) { + addPossiblePluginName(pluginsMap, concretePluginClass, DataPrepperPlugin -> alternateName, + alternatePluginName -> !alternatePluginName.equals(DEFAULT_ALTERNATE_NAME)); + } } return pluginsMap; } - private void addOptionalDeprecatedPluginName( + private void addPossiblePluginName( final Map, Class>> pluginsMap, - final Class concretePluginClass) { + final Class concretePluginClass, + final Function possiblePluginNameFunction, + final Predicate possiblePluginNamePredicate + ) { final DataPrepperPlugin dataPrepperPluginAnnotation = concretePluginClass.getAnnotation(DataPrepperPlugin.class); - final String deprecatedPluginName = dataPrepperPluginAnnotation.deprecatedName(); + final String possiblePluginName = possiblePluginNameFunction.apply(dataPrepperPluginAnnotation); final Class supportedType = dataPrepperPluginAnnotation.pluginType(); - if (!deprecatedPluginName.equals(DEFAULT_DEPRECATED_NAME)) { + if (possiblePluginNamePredicate.test(possiblePluginName)) { final Map, Class> supportTypeToPluginTypeMap = - pluginsMap.computeIfAbsent(deprecatedPluginName, k -> new HashMap<>()); + pluginsMap.computeIfAbsent(possiblePluginName, k -> new HashMap<>()); supportTypeToPluginTypeMap.put(supportedType, concretePluginClass); } } diff --git a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/ClasspathPluginProviderTest.java b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/ClasspathPluginProviderTest.java index 88763164dd..6cda169636 100644 --- a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/ClasspathPluginProviderTest.java +++ b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/ClasspathPluginProviderTest.java @@ -8,6 +8,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; import org.opensearch.dataprepper.model.sink.Sink; import org.opensearch.dataprepper.model.source.Source; @@ -105,6 +107,29 @@ void findPlugin_should_return_plugin_if_found_for_deprecated_name_and_type_using assertThat(optionalPlugin.get(), equalTo(TestSource.class)); } + @Test + void findPlugin_should_return_empty_for_default_alternate_name() { + given(reflections.getTypesAnnotatedWith(DataPrepperPlugin.class)) + .willReturn(new HashSet<>(List.of(TestSource.class))); + + final Optional> optionalPlugin = createObjectUnderTest() + .findPluginClass(Source.class, UUID.randomUUID().toString()); + assertThat(optionalPlugin, notNullValue()); + assertThat(optionalPlugin.isPresent(), equalTo(false)); + } + + @ParameterizedTest + @ValueSource(strings = {"test_source_alternate_name1", "test_source_alternate_name2"}) + void findPlugin_should_return_plugin_if_found_for_alternate_name_and_type_using_pluginType(final String alternateSourceName) { + given(reflections.getTypesAnnotatedWith(DataPrepperPlugin.class)) + .willReturn(new HashSet<>(List.of(TestSource.class))); + + final Optional> optionalPlugin = createObjectUnderTest().findPluginClass(Source.class, alternateSourceName); + assertThat(optionalPlugin, notNullValue()); + assertThat(optionalPlugin.isPresent(), equalTo(true)); + assertThat(optionalPlugin.get(), equalTo(TestSource.class)); + } + @Nested class WithPredefinedPlugins { diff --git a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/DefaultPluginFactoryTest.java b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/DefaultPluginFactoryTest.java index 64b053a924..495d003bb3 100644 --- a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/DefaultPluginFactoryTest.java +++ b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugin/DefaultPluginFactoryTest.java @@ -401,4 +401,33 @@ void loadPlugin_should_create_a_new_instance_of_the_first_plugin_found_with_corr verify(beanFactoryProvider).get(); } } + + @Nested + class WithAlternatePluginName { + private static final String TEST_SINK_ALTERNATE_NAME = "test_sink_alternate_name"; + private Class expectedPluginClass; + + @BeforeEach + void setUp() { + expectedPluginClass = TestSink.class; + given(pluginSetting.getName()).willReturn(TEST_SINK_ALTERNATE_NAME); + + given(firstPluginProvider.findPluginClass(baseClass, TEST_SINK_ALTERNATE_NAME)) + .willReturn(Optional.of(expectedPluginClass)); + } + + @Test + void loadPlugin_should_create_a_new_instance_of_the_first_plugin_found_with_correct_name_and_alternate_name() { + final TestSink expectedInstance = mock(TestSink.class); + final Object convertedConfiguration = mock(Object.class); + given(pluginConfigurationConverter.convert(PluginSetting.class, pluginSetting)) + .willReturn(convertedConfiguration); + given(pluginCreator.newPluginInstance(eq(expectedPluginClass), any(ComponentPluginArgumentsContext.class), eq(TEST_SINK_ALTERNATE_NAME))) + .willReturn(expectedInstance); + + assertThat(createObjectUnderTest().loadPlugin(baseClass, pluginSetting), equalTo(expectedInstance)); + MatcherAssert.assertThat(expectedInstance.getClass().getAnnotation(DataPrepperPlugin.class).alternateNames(), equalTo(new String[]{TEST_SINK_ALTERNATE_NAME})); + verify(beanFactoryProvider).get(); + } + } } diff --git a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugins/test/TestSink.java b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugins/test/TestSink.java index fc54428ba2..1e2742a0ba 100644 --- a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugins/test/TestSink.java +++ b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugins/test/TestSink.java @@ -16,7 +16,7 @@ import java.util.stream.Collectors; import java.time.Instant; -@DataPrepperPlugin(name = "test_sink", deprecatedName = "test_sink_deprecated_name", pluginType = Sink.class) +@DataPrepperPlugin(name = "test_sink", alternateNames = "test_sink_alternate_name", deprecatedName = "test_sink_deprecated_name", pluginType = Sink.class) public class TestSink implements Sink> { public boolean isShutdown = false; private final List> collectedRecords; diff --git a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugins/test/TestSource.java b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugins/test/TestSource.java index 9a7192a370..2ad29f2650 100644 --- a/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugins/test/TestSource.java +++ b/data-prepper-plugin-framework/src/test/java/org/opensearch/dataprepper/plugins/test/TestSource.java @@ -16,7 +16,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -@DataPrepperPlugin(name = "test_source", deprecatedName = "test_source_deprecated_name", pluginType = Source.class) +@DataPrepperPlugin(name = "test_source", alternateNames = { "test_source_alternate_name1", "test_source_alternate_name2" }, deprecatedName = "test_source_deprecated_name", pluginType = Source.class) public class TestSource implements Source> { public static final List> TEST_DATA = Stream.of("TEST") .map(Record::new).collect(Collectors.toList()); diff --git a/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/codec/Codec.java b/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/codec/Codec.java index 49b28b9aa2..e81254abf4 100644 --- a/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/codec/Codec.java +++ b/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/codec/Codec.java @@ -24,31 +24,25 @@ public interface Codec { T parse(HttpData httpData) throws IOException; /** - * Serializes parsed data back into a UTF-8 string. + * Validates the content of the HTTP request. + * + * @param content The content of the original HTTP request + * @throws IOException A failure validating data. + */ + void validate(HttpData content) throws IOException; + + /* + * Serializes the HttpData and split into multiple bodies based on splitLength. + *

+ * The serialized bodies are passed to the serializedBodyConsumer. *

* This API will split into multiple bodies based on splitLength. Note that if a single * item is larger than this, it will be output and exceed that length. * - * @param parsedData The parsed data + * @param content The content of the original HTTP request * @param serializedBodyConsumer A {@link Consumer} to accept each serialized body * @param splitLength The length at which to split serialized bodies. * @throws IOException A failure writing data. */ - void serialize(final T parsedData, - final Consumer serializedBodyConsumer, - final int splitLength) throws IOException; - - - /** - * Serializes parsed data back into a UTF-8 string. - *

- * This API will not split the data into chunks. - * - * @param parsedData The parsed data - * @param serializedBodyConsumer A {@link Consumer} to accept the serialized body - * @throws IOException A failure writing data. - */ - default void serialize(final T parsedData, final Consumer serializedBodyConsumer) throws IOException { - serialize(parsedData, serializedBodyConsumer, Integer.MAX_VALUE); - } + void serializeSplit(HttpData content, Consumer serializedBodyConsumer, int splitLength) throws IOException; } diff --git a/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/codec/JsonCodec.java b/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/codec/JsonCodec.java index 378af9c2d9..6306366816 100644 --- a/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/codec/JsonCodec.java +++ b/data-prepper-plugins/http-source-common/src/main/java/org/opensearch/dataprepper/http/codec/JsonCodec.java @@ -8,16 +8,17 @@ import com.fasterxml.jackson.core.JsonEncoding; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.CountingOutputStream; import com.linecorp.armeria.common.HttpData; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InputStream; import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -33,6 +34,7 @@ public class JsonCodec implements Codec> { private static final TypeReference>> LIST_OF_MAP_TYPE_REFERENCE = new TypeReference>>() { }; + private static final JsonFactory JSON_FACTORY = new JsonFactory(); @Override @@ -48,38 +50,56 @@ public List parse(final HttpData httpData) throws IOException { return jsonList; } - public void serialize(final List jsonList, - final Consumer serializedBodyConsumer, - final int splitLength) throws IOException { - if (splitLength < 0) - throw new IllegalArgumentException("The splitLength must be greater than or equal to 0."); + @Override + public void validate(final HttpData content) throws IOException { + mapper.readValue(content.toInputStream(), + LIST_OF_MAP_TYPE_REFERENCE); + } + @Override + public void serializeSplit(final HttpData content, final Consumer serializedBodyConsumer, final int splitLength) throws IOException { + final InputStream contentInputStream = content.toInputStream(); if (splitLength == 0) { - performSerialization(jsonList, serializedBodyConsumer, Integer.MAX_VALUE); + performSerialization(contentInputStream, serializedBodyConsumer, Integer.MAX_VALUE); } else { - performSerialization(jsonList, serializedBodyConsumer, splitLength); + performSerialization(contentInputStream, serializedBodyConsumer, splitLength); } } - private void performSerialization(final List jsonList, + + private void performSerialization(final InputStream inputStream, final Consumer serializedBodyConsumer, final int splitLength) throws IOException { - JsonArrayWriter jsonArrayWriter = new JsonArrayWriter(splitLength, serializedBodyConsumer); + try (final JsonParser jsonParser = JSON_FACTORY.createParser(inputStream)) { + if (jsonParser.nextToken() != JsonToken.START_ARRAY) { + throw new RuntimeException("Input is not a valid JSON array."); + } + + JsonArrayWriter jsonArrayWriter = new JsonArrayWriter(splitLength, serializedBodyConsumer); - for (final String individualJsonLine : jsonList) { - if (jsonArrayWriter.willExceedByWriting(individualJsonLine)) { - jsonArrayWriter.close(); + while (jsonParser.nextToken() != JsonToken.END_ARRAY) { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + final JsonGenerator objectJsonGenerator = JSON_FACTORY + .createGenerator(outputStream, JsonEncoding.UTF8); + objectJsonGenerator.copyCurrentStructure(jsonParser); + objectJsonGenerator.close(); - jsonArrayWriter = new JsonArrayWriter(splitLength, serializedBodyConsumer); + if (jsonArrayWriter.willExceedByWriting(outputStream)) { + jsonArrayWriter.close(); + + jsonArrayWriter = new JsonArrayWriter(splitLength, serializedBodyConsumer); + + } + jsonArrayWriter.write(outputStream); } - jsonArrayWriter.write(individualJsonLine); - } - jsonArrayWriter.close(); + jsonArrayWriter.close(); + } } + private static class JsonArrayWriter { private static final JsonFactory JSON_FACTORY = new JsonFactory().setCodec(mapper); private static final int BUFFER_SIZE = 16 * 1024; @@ -100,15 +120,15 @@ private static class JsonArrayWriter { generator.writeStartArray(); } - boolean willExceedByWriting(final String individualJsonLine) { - final int lengthToWrite = individualJsonLine.getBytes(StandardCharsets.UTF_8).length; + boolean willExceedByWriting(final ByteArrayOutputStream byteArrayOutputStream) { + final int lengthToWrite = byteArrayOutputStream.size(); final long lengthOfDataWritten = countingOutputStream.getCount(); return lengthToWrite + lengthOfDataWritten + NECESSARY_CHARACTERS_TO_WRITE.length() > splitLength; } - void write(final String individualJsonLine) throws IOException { - final JsonNode jsonNode = mapper.readTree(individualJsonLine); - generator.writeTree(jsonNode); + void write(final ByteArrayOutputStream individualJsonLine) throws IOException { + final String jsonLineString = individualJsonLine.toString(Charset.defaultCharset()); + generator.writeRawValue(jsonLineString); generator.flush(); hasItem = true; } @@ -126,5 +146,4 @@ void close() throws IOException { outputStream.close(); } } - } diff --git a/data-prepper-plugins/http-source-common/src/test/java/org/opensearch/dataprepper/http/codec/JsonCodecTest.java b/data-prepper-plugins/http-source-common/src/test/java/org/opensearch/dataprepper/http/codec/JsonCodecTest.java index b58f9e6cde..ca2a483eee 100644 --- a/data-prepper-plugins/http-source-common/src/test/java/org/opensearch/dataprepper/http/codec/JsonCodecTest.java +++ b/data-prepper-plugins/http-source-common/src/test/java/org/opensearch/dataprepper/http/codec/JsonCodecTest.java @@ -38,19 +38,6 @@ class JsonCodecTest { private static final HttpData GOOD_TEST_DATA = HttpData.ofUtf8("[{\"a\":\"b\"}, {\"c\":\"d\"}]"); private static final HttpData GOOD_LARGE_TEST_DATA = HttpData.ofUtf8("[{\"a1\":\"b1\"}, {\"a2\":\"b2\"}, {\"a3\":\"b3\"}, {\"a4\":\"b4\"}, {\"a5\":\"b5\"}]"); private static final HttpData GOOD_LARGE_TEST_DATA_UNICODE = HttpData.ofUtf8("[{\"ὊὊὊ1\":\"ὊὊὊ1\"}, {\"ὊὊὊ2\":\"ὊὊὊ2\"}, {\"a3\":\"b3\"}, {\"ὊὊὊ4\":\"ὊὊὊ4\"}]"); - public static final List JSON_BODIES_LIST = List.of( - "{\"a1\":\"b1\"}", - "{\"a2\":\"b2\"}", - "{\"a3\":\"b3\"}", - "{\"a4\":\"b4\"}", - "{\"a5\":\"b5\"}" - ); - public static final List JSON_BODIES_UNICODE_MIXED_LIST = List.of( - "{\"ὊὊὊ1\":\"ὊὊὊ1\"}", - "{\"ὊὊὊ2\":\"ὊὊὊ2\"}", - "{\"a3\":\"b3\"}", - "{\"ὊὊὊ4\":\"ὊὊὊ4\"}" - ); private final HttpData badTestDataJsonLine = HttpData.ofUtf8("{\"a\":\"b\"}"); private final HttpData badTestDataMultiJsonLines = HttpData.ofUtf8("{\"a\":\"b\"}{\"c\":\"d\"}"); private final HttpData badTestDataNonJson = HttpData.ofUtf8("non json content"); @@ -84,16 +71,16 @@ public void testParseSuccessWithMaxSize() throws IOException { @ParameterizedTest @ValueSource(ints = {-1, -2, Integer.MIN_VALUE}) - void serialize_with_invalid_splitLength(final int splitLength) { + void serializeSplit_with_invalid_splitLength(final int splitLength) { final Consumer serializedBodyConsumer = mock(Consumer.class); - assertThrows(IllegalArgumentException.class, () -> objectUnderTest.serialize(JSON_BODIES_LIST, serializedBodyConsumer, splitLength)); + assertThrows(IllegalArgumentException.class, () -> objectUnderTest.serializeSplit(GOOD_LARGE_TEST_DATA, serializedBodyConsumer, splitLength)); } @ParameterizedTest @ValueSource(ints = {1, 2, 24}) - void serialize_with_split_length_leading_to_groups_of_one(final int splitLength) throws IOException { + void serializeSplit_with_split_length_leading_to_groups_of_one(final int splitLength) throws IOException { final Consumer serializedBodyConsumer = mock(Consumer.class); - objectUnderTest.serialize(JSON_BODIES_LIST, serializedBodyConsumer, splitLength); + objectUnderTest.serializeSplit(GOOD_LARGE_TEST_DATA, serializedBodyConsumer, splitLength); final ArgumentCaptor actualSerializedBodyCaptor = ArgumentCaptor.forClass(String.class); verify(serializedBodyConsumer, times(5)).accept(actualSerializedBodyCaptor.capture()); @@ -109,9 +96,9 @@ void serialize_with_split_length_leading_to_groups_of_one(final int splitLength) @ParameterizedTest @ValueSource(ints = {25, 30, 36}) - void serialize_with_split_length_leading_to_groups_of_two(final int splitLength) throws IOException { + void serializeSplit_with_split_length_leading_to_groups_of_two(final int splitLength) throws IOException { final Consumer serializedBodyConsumer = mock(Consumer.class); - objectUnderTest.serialize(JSON_BODIES_LIST, serializedBodyConsumer, splitLength); + objectUnderTest.serializeSplit(GOOD_LARGE_TEST_DATA, serializedBodyConsumer, splitLength); final ArgumentCaptor actualSerializedBodyCaptor = ArgumentCaptor.forClass(String.class); verify(serializedBodyConsumer, times(3)).accept(actualSerializedBodyCaptor.capture()); @@ -129,9 +116,9 @@ void serialize_with_split_length_leading_to_groups_of_two(final int splitLength) @ParameterizedTest @ValueSource(ints = {37, 48}) - void serialize_with_split_length_leading_to_groups_up_to_three(final int splitLength) throws IOException { + void serializeSplit_with_split_length_leading_to_groups_up_to_three(final int splitLength) throws IOException { final Consumer serializedBodyConsumer = mock(Consumer.class); - objectUnderTest.serialize(JSON_BODIES_LIST, serializedBodyConsumer, splitLength); + objectUnderTest.serializeSplit(GOOD_LARGE_TEST_DATA, serializedBodyConsumer, splitLength); final ArgumentCaptor actualSerializedBodyCaptor = ArgumentCaptor.forClass(String.class); verify(serializedBodyConsumer, times(2)).accept(actualSerializedBodyCaptor.capture()); @@ -147,9 +134,9 @@ void serialize_with_split_length_leading_to_groups_up_to_three(final int splitLe @ParameterizedTest @ValueSource(ints = {0, Integer.MAX_VALUE}) - void serialize_with_split_size_that_does_not_split(final int splitLength) throws IOException { + void serializeSplit_with_split_size_that_does_not_split(final int splitLength) throws IOException { final Consumer serializedBodyConsumer = mock(Consumer.class); - objectUnderTest.serialize(JSON_BODIES_LIST, serializedBodyConsumer, splitLength); + objectUnderTest.serializeSplit(GOOD_LARGE_TEST_DATA, serializedBodyConsumer, splitLength); final ArgumentCaptor actualSerializedBodyCaptor = ArgumentCaptor.forClass(String.class); verify(serializedBodyConsumer, times(1)).accept(actualSerializedBodyCaptor.capture()); @@ -160,9 +147,9 @@ void serialize_with_split_size_that_does_not_split(final int splitLength) throws @ParameterizedTest @ValueSource(ints = {58, 68}) - void serialize_with_split_length_unicode(final int splitLength) throws IOException { + void serializeSplit_with_split_length_unicode(final int splitLength) throws IOException { final Consumer serializedBodyConsumer = mock(Consumer.class); - objectUnderTest.serialize(JSON_BODIES_UNICODE_MIXED_LIST, serializedBodyConsumer, splitLength); + objectUnderTest.serializeSplit(GOOD_LARGE_TEST_DATA_UNICODE, serializedBodyConsumer, splitLength); final ArgumentCaptor actualSerializedBodyCaptor = ArgumentCaptor.forClass(String.class); verify(serializedBodyConsumer, times(2)).accept(actualSerializedBodyCaptor.capture()); @@ -178,17 +165,19 @@ void serialize_with_split_length_unicode(final int splitLength) throws IOExcepti @ParameterizedTest @ArgumentsSource(GoodTestData.class) - void parse_and_serialize_symmetry(final HttpData httpData) throws IOException { - final List parsedList = objectUnderTest.parse(httpData); - + void serializeSplit_and_parse_symmetry(final HttpData httpData) throws IOException { + final List parsedFromOriginal = objectUnderTest.parse(httpData); final Consumer serializedBodyConsumer = mock(Consumer.class); - objectUnderTest.serialize(parsedList, serializedBodyConsumer); + objectUnderTest.serializeSplit(httpData, serializedBodyConsumer, Integer.MAX_VALUE); final ArgumentCaptor actualSerializedBodyCaptor = ArgumentCaptor.forClass(String.class); verify(serializedBodyConsumer, times(1)).accept(actualSerializedBodyCaptor.capture()); final String actualString = actualSerializedBodyCaptor.getValue(); final String expectedJsonString = httpData.toStringUtf8().replace(" ", ""); assertThat(actualString, equalTo(expectedJsonString)); + + final List parsedFromRewritten = objectUnderTest.parse(HttpData.ofUtf8(actualString)); + assertThat(parsedFromRewritten, equalTo(parsedFromOriginal)); } @@ -196,10 +185,8 @@ void parse_and_serialize_symmetry(final HttpData httpData) throws IOException { @ArgumentsSource(JsonArrayWithKnownFirstArgumentsProvider.class) public void parse_should_return_lists_smaller_than_provided_length( final String inputJsonArray, final String knownFirstPart, final int maxSize, final List> expectedChunks, final List exceedsMaxSize) throws IOException { - List individualJsonLines = objectUnderTest.parse(HttpData.ofUtf8(inputJsonArray)); - Consumer serializedBodyConsumer = mock(Consumer.class); - objectUnderTest.serialize(individualJsonLines, serializedBodyConsumer, maxSize); + objectUnderTest.serializeSplit(HttpData.ofUtf8(inputJsonArray), serializedBodyConsumer, maxSize); ArgumentCaptor actualSerializedBodyCaptor = ArgumentCaptor.forClass(String.class); verify(serializedBodyConsumer, times(expectedChunks.size())).accept(actualSerializedBodyCaptor.capture()); @@ -301,4 +288,26 @@ public Stream provideArguments(ExtensionContext extensionCo ); } } + + + @ParameterizedTest + @ArgumentsSource(GoodTestData.class) + void validate_with_known_good_Json(final HttpData httpData) throws IOException { + objectUnderTest.validate(httpData); + } + + @Test + void validate_with_valid_JSON_but_not_array_should_throw() { + assertThrows(IOException.class, () -> objectUnderTest.validate(badTestDataJsonLine)); + } + + @Test + void validate_with_multiline_JSON_should_throw() { + assertThrows(IOException.class, () -> objectUnderTest.validate(badTestDataMultiJsonLines)); + } + + @Test + void validate_with_invalid_JSON_should_throw() { + assertThrows(IOException.class, () -> objectUnderTest.validate(badTestDataNonJson)); + } } diff --git a/data-prepper-plugins/http-source/build.gradle b/data-prepper-plugins/http-source/build.gradle index 7d54d5f177..2d5c5ceceb 100644 --- a/data-prepper-plugins/http-source/build.gradle +++ b/data-prepper-plugins/http-source/build.gradle @@ -5,6 +5,7 @@ plugins { id 'java' + id 'me.champeau.jmh' version '0.7.2' } dependencies { diff --git a/data-prepper-plugins/http-source/src/jmh/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPServiceMeasure.java b/data-prepper-plugins/http-source/src/jmh/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPServiceMeasure.java new file mode 100644 index 0000000000..40d5b7c3e2 --- /dev/null +++ b/data-prepper-plugins/http-source/src/jmh/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPServiceMeasure.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.loghttp; + +import com.linecorp.armeria.common.AggregatedHttpRequest; +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.server.ServiceRequestContext; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.Buffer; + +import java.io.IOException; +import java.time.Duration; +import java.util.Optional; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +public class LogHTTPServiceMeasure { + + @State(Scope.Benchmark) + public static class BenchmarkState { + private HttpData httpData; + private Buffer buffer; + private LogHTTPService logHTTPService; + private ServiceRequestContext serviceRequestContext; + private RequestHeaders requestHeaders; + + @Setup + public void setUp() throws IOException { + byte[] jsonContent = new TestGenerator().createJson(10 * 1024 * 1024); + httpData = HttpData.ofUtf8(new String(jsonContent)); + + buffer = mock(Buffer.class, withSettings().stubOnly()); + when(buffer.isByteBuffer()).thenReturn(true); + when(buffer.getMaxRequestSize()).thenReturn(Optional.of(512 * 1024)); + when(buffer.getOptimalRequestSize()).thenReturn(Optional.of(256 * 1024)); + + serviceRequestContext = mock(ServiceRequestContext.class); + logHTTPService = new LogHTTPService((int) Duration.ofSeconds(10).toMillis(), buffer, PluginMetrics.fromPrefix("testing")); + + requestHeaders = RequestHeaders.builder() + .method(HttpMethod.POST) + .path("/test") + .build(); + } + } + + @Benchmark + @BenchmarkMode(Mode.Throughput) + @Warmup(iterations = 1) + @Measurement(iterations = 5, time = 10) + public HttpResponse measure_doPost(BenchmarkState benchmarkState) throws Exception { + AggregatedHttpRequest aggregatedHttpRequest = AggregatedHttpRequest.of(benchmarkState.requestHeaders, benchmarkState.httpData); + return benchmarkState.logHTTPService.doPost(benchmarkState.serviceRequestContext, aggregatedHttpRequest); + } +} diff --git a/data-prepper-plugins/http-source/src/jmh/java/org/opensearch/dataprepper/plugins/source/loghttp/TestGenerator.java b/data-prepper-plugins/http-source/src/jmh/java/org/opensearch/dataprepper/plugins/source/loghttp/TestGenerator.java new file mode 100644 index 0000000000..4c09ef99d7 --- /dev/null +++ b/data-prepper-plugins/http-source/src/jmh/java/org/opensearch/dataprepper/plugins/source/loghttp/TestGenerator.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.loghttp; + +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; +import org.apache.commons.io.output.CountingOutputStream; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.time.Instant; +import java.util.Random; +import java.util.UUID; + +public class TestGenerator { + private final Random random = new Random(); + + public byte[] createJson(final int roughMaximumSize) throws IOException { + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(roughMaximumSize); + writeLog(roughMaximumSize, byteArrayOutputStream); + return byteArrayOutputStream.toByteArray(); + } + + private void writeLog(final int roughMaximumSize, final OutputStream fileOutputStream) throws IOException { + try (final CountingOutputStream countingOutputStream = new CountingOutputStream(fileOutputStream)) { + writeJson(roughMaximumSize, countingOutputStream); + } + } + + private void writeJson(final int roughMaximumSize, final CountingOutputStream countingOutputStream) throws IOException { + final JsonFactory jsonFactory = new JsonFactory(); + final JsonGenerator jsonGenerator = jsonFactory + .createGenerator(countingOutputStream, JsonEncoding.UTF8); + + jsonGenerator.writeStartArray(); + + while (countingOutputStream.getCount() < roughMaximumSize) { + writeSingleRecord(jsonGenerator); + jsonGenerator.flush(); // Need to flush the JsonGenerator in order to get the bytes to write to the counting output stream + } + + jsonGenerator.writeEndArray(); + jsonGenerator.close(); + + countingOutputStream.flush(); + } + + private void writeSingleRecord(final JsonGenerator jsonGenerator) throws IOException { + final StringBuilder logStringBuilder = new StringBuilder(); + logStringBuilder.append(Instant.now()); + logStringBuilder.append(" "); + logStringBuilder.append(UUID.randomUUID()); + logStringBuilder.append(" "); + logStringBuilder.append(UUID.randomUUID()); + logStringBuilder.append(" "); + logStringBuilder.append(random.nextInt(100_000)); + + jsonGenerator.writeStartObject(); + + jsonGenerator.writeStringField("log", logStringBuilder.toString()); + + jsonGenerator.writeEndObject(); + } +} diff --git a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSource.java b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSource.java index cea9e252f6..d56946f334 100644 --- a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSource.java +++ b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSource.java @@ -141,7 +141,7 @@ public void start(final Buffer> buffer) { final String httpSourcePath = sourceConfig.getPath().replace(PIPELINE_NAME_PLACEHOLDER, pipelineName); sb.decorator(httpSourcePath, ThrottlingService.newDecorator(logThrottlingStrategy, logThrottlingRejectHandler)); - final LogHTTPService logHTTPService = new LogHTTPService(sourceConfig.getBufferTimeoutInMillis(), buffer, byteDecoder, pluginMetrics); + final LogHTTPService logHTTPService = new LogHTTPService(sourceConfig.getBufferTimeoutInMillis(), buffer, pluginMetrics); if (CompressionOption.NONE.equals(sourceConfig.getCompression())) { sb.annotatedService(httpSourcePath, logHTTPService, httpRequestExceptionHandler); diff --git a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPService.java b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPService.java index 2163f322ef..c2bd344fc5 100644 --- a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPService.java +++ b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPService.java @@ -11,7 +11,6 @@ import org.opensearch.dataprepper.model.log.JacksonLog; import org.opensearch.dataprepper.model.log.Log; import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.model.codec.ByteDecoder; import com.linecorp.armeria.common.AggregatedHttpRequest; import com.linecorp.armeria.common.HttpData; import com.linecorp.armeria.common.HttpResponse; @@ -61,7 +60,6 @@ public class LogHTTPService { public LogHTTPService(final int bufferWriteTimeoutInMillis, final Buffer> buffer, - final ByteDecoder decoder, final PluginMetrics pluginMetrics) { this.buffer = buffer; this.bufferWriteTimeoutInMillis = bufferWriteTimeoutInMillis; @@ -89,33 +87,47 @@ public HttpResponse doPost(final ServiceRequestContext serviceRequestContext, fi HttpResponse processRequest(final AggregatedHttpRequest aggregatedHttpRequest) throws Exception { final HttpData content = aggregatedHttpRequest.content(); - final List jsonList; - try { - jsonList = jsonCodec.parse(content); - } catch (IOException e) { - LOG.error("Failed to parse the request of size {} due to: {}", content.length(), e.getMessage()); - throw new IOException("Bad request data format. Needs to be json array.", e.getCause()); - } - try { - if (buffer.isByteBuffer()) { - if (bufferMaxRequestLength != null && bufferOptimalRequestLength != null && content.array().length > bufferOptimalRequestLength) { - jsonCodec.serialize(jsonList, this::writeChunkedBody, bufferOptimalRequestLength - SERIALIZATION_OVERHEAD); - } else { - // jsonList is ignored in this path but parse() was done to make - // sure that the data is in the expected json format + if (buffer.isByteBuffer()) { + if (bufferMaxRequestLength != null && bufferOptimalRequestLength != null && content.array().length > bufferOptimalRequestLength) { + jsonCodec.serializeSplit(content, this::writeChunkedBody, bufferOptimalRequestLength - SERIALIZATION_OVERHEAD); + } else { + try { + jsonCodec.validate(content); + } catch (IOException e) { + LOG.error("Failed to parse the request of size {} due to: {}", content.length(), e.getMessage()); + throw new IOException("Bad request data format. Needs to be json array.", e.getCause()); + } + + try { buffer.writeBytes(content.array(), null, bufferWriteTimeoutInMillis); + } catch (Exception e) { + LOG.error("Failed to write the request of size {} due to: {}", content.length(), e.getMessage()); + throw e; } - } else { - final List> records = jsonList.stream() - .map(this::buildRecordLog) - .collect(Collectors.toList()); + } + } else { + final List jsonList; + + try { + jsonList = jsonCodec.parse(content); + } catch (IOException e) { + LOG.error("Failed to parse the request of size {} due to: {}", content.length(), e.getMessage()); + throw new IOException("Bad request data format. Needs to be json array.", e.getCause()); + } + + final List> records = jsonList.stream() + .map(this::buildRecordLog) + .collect(Collectors.toList()); + + try { buffer.writeAll(records, bufferWriteTimeoutInMillis); + } catch (Exception e) { + LOG.error("Failed to write the request of size {} due to: {}", content.length(), e.getMessage()); + throw e; } - } catch (Exception e) { - LOG.error("Failed to write the request of size {} due to: {}", content.length(), e.getMessage()); - throw e; } + successRequestsCounter.increment(); return HttpResponse.of(HttpStatus.OK); } diff --git a/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPServiceTest.java b/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPServiceTest.java index 2e9b802f32..5e2ffe2c7d 100644 --- a/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPServiceTest.java +++ b/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPServiceTest.java @@ -6,6 +6,7 @@ package org.opensearch.dataprepper.plugins.source.loghttp; import com.linecorp.armeria.server.ServiceRequestContext; +import org.junit.jupiter.api.Nested; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.buffer.SizeOverflowException; @@ -49,8 +50,12 @@ import java.io.ByteArrayInputStream; import java.io.InputStream; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -107,7 +112,7 @@ public void setUp() throws Exception { ); Buffer> blockingBuffer = new BlockingBuffer<>(TEST_BUFFER_CAPACITY, 8, "test-pipeline"); - logHTTPService = new LogHTTPService(TEST_TIMEOUT_IN_MILLIS, blockingBuffer, null, pluginMetrics); + logHTTPService = new LogHTTPService(TEST_TIMEOUT_IN_MILLIS, blockingBuffer, pluginMetrics); } @Test @@ -182,71 +187,186 @@ public void testHTTPRequestTimeout() throws Exception { verify(requestProcessDuration, times(2)).recordCallable(ArgumentMatchers.>any()); } - @Test - public void testChunking() throws Exception { - byteBuffer = mock(Buffer.class); - when(byteBuffer.isByteBuffer()).thenReturn(true); - when(byteBuffer.getMaxRequestSize()).thenReturn(Optional.of(4*1024*1024)); - when(byteBuffer.getOptimalRequestSize()).thenReturn(Optional.of(1024*1024)); - - logHTTPService = new LogHTTPService(TEST_TIMEOUT_IN_MILLIS, byteBuffer, null, pluginMetrics); - AggregatedHttpRequest aggregatedHttpRequest = mock(AggregatedHttpRequest.class); - HttpData httpData = mock(HttpData.class); - // Test small json data - String testString ="{\"key1\":\"value1\"},{\"key2\":\"value2\"},{\"key3\":\"value3\"},{\"key4\":\"value4\"},{\"key5\":\"value5\"}"; - String exampleString = "[ " + testString + "]"; - when(httpData.array()).thenReturn(exampleString.getBytes()); - InputStream stream = new ByteArrayInputStream(exampleString.getBytes(StandardCharsets.UTF_8)); - when(httpData.toInputStream()).thenReturn(stream); - - when(aggregatedHttpRequest.content()).thenReturn(httpData); - logHTTPService.processRequest(aggregatedHttpRequest); - verify(byteBuffer, times(1)).writeBytes(any(), (String)isNull(), any(Integer.class)); - - // Test more than 1MB json data - StringBuilder sb = new StringBuilder(1024*1024+10240); - for (int i =0; i < 12500; i++) { - sb.append(testString); - if (i+1 != 12500) - sb.append(","); + @Nested + class ChunkingCapableBuffer { + private String testString = "{\"key1\":\"value1\"},{\"key2\":\"value2\"},{\"key3\":\"value3\"},{\"key4\":\"value4\"},{\"key5\":\"value5\"}"; + private AggregatedHttpRequest aggregatedHttpRequest; + private HttpData httpData; + private String largeTestString; + + @BeforeEach + void setUp() { + byteBuffer = mock(Buffer.class); + when(byteBuffer.isByteBuffer()).thenReturn(true); + when(byteBuffer.getMaxRequestSize()).thenReturn(Optional.of(4 * 1024 * 1024)); + when(byteBuffer.getOptimalRequestSize()).thenReturn(Optional.of(1024 * 1024)); + + aggregatedHttpRequest = mock(AggregatedHttpRequest.class); + httpData = mock(HttpData.class); + + logHTTPService = new LogHTTPService(TEST_TIMEOUT_IN_MILLIS, byteBuffer, pluginMetrics); + + StringBuilder sb = new StringBuilder(1024 * 1024 + 10240); + for (int i = 0; i < 12500; i++) { + sb.append(testString); + if (i + 1 != 12500) + sb.append(","); + } + largeTestString = sb.toString(); + } + + @Test + void invalid_JSON_returns() { + // Test small json data + String exampleString = testString; + when(httpData.array()).thenReturn(exampleString.getBytes()); + InputStream stream = new ByteArrayInputStream(exampleString.getBytes(StandardCharsets.UTF_8)); + when(httpData.toInputStream()).thenReturn(stream); + + when(aggregatedHttpRequest.content()).thenReturn(httpData); + IOException actualException = assertThrows(IOException.class, () -> logHTTPService.processRequest(aggregatedHttpRequest)); + assertThat(actualException.getMessage(), containsString("Needs to be json array")); + } + + @Test + void too_small_to_chunk() throws Exception { + // Test small json data + String exampleString = "[ " + testString + "]"; + when(httpData.array()).thenReturn(exampleString.getBytes()); + InputStream stream = new ByteArrayInputStream(exampleString.getBytes(StandardCharsets.UTF_8)); + when(httpData.toInputStream()).thenReturn(stream); + + when(aggregatedHttpRequest.content()).thenReturn(httpData); + logHTTPService.processRequest(aggregatedHttpRequest); + verify(byteBuffer, times(1)).writeBytes(any(), (String) isNull(), eq(TEST_TIMEOUT_IN_MILLIS)); + } + + @Test + void chunking_with_1mb() throws Exception { + // Test more than 1MB json data + String exampleString = "[" + largeTestString + "]"; + when(httpData.array()).thenReturn(exampleString.getBytes()); + ByteArrayInputStream stream = new ByteArrayInputStream(exampleString.getBytes(StandardCharsets.UTF_8)); + when(httpData.toInputStream()).thenReturn(stream); + + when(aggregatedHttpRequest.content()).thenReturn(httpData); + logHTTPService.processRequest(aggregatedHttpRequest); + verify(byteBuffer, times(2)).writeBytes(any(), anyString(), eq(TEST_TIMEOUT_IN_MILLIS)); + } + + @Test + void chunking_with_4mb() throws Exception { + // Test more than 4MB json data + String exampleString = "[" + largeTestString + "," + largeTestString + "," + largeTestString + "," + largeTestString + "]"; + when(httpData.array()).thenReturn(exampleString.getBytes()); + ByteArrayInputStream stream = new ByteArrayInputStream(exampleString.getBytes(StandardCharsets.UTF_8)); + when(httpData.toInputStream()).thenReturn(stream); + + when(aggregatedHttpRequest.content()).thenReturn(httpData); + logHTTPService.processRequest(aggregatedHttpRequest); + verify(byteBuffer, times(5)).writeBytes(any(), anyString(), eq(TEST_TIMEOUT_IN_MILLIS)); + } + + @Test + void chunking_with_4mb_single_json_object_should_throw() { + String exampleString; + // Test more than 4MB single json object, should throw exception + int length = 3 * 1024 * 1024; + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append('A'); + } + String value = sb.toString(); + exampleString = "[{\"key\":\"" + value + "\"}]"; + + when(httpData.array()).thenReturn(exampleString.getBytes()); + ByteArrayInputStream stream = new ByteArrayInputStream(exampleString.getBytes(StandardCharsets.UTF_8)); + when(httpData.toInputStream()).thenReturn(stream); + + when(aggregatedHttpRequest.content()).thenReturn(httpData); + assertThrows(RuntimeException.class, () -> logHTTPService.processRequest(aggregatedHttpRequest)); } - String largeTestString = sb.toString(); - exampleString = "[" + largeTestString + "]"; - when(httpData.array()).thenReturn(exampleString.getBytes()); - stream = new ByteArrayInputStream(exampleString.getBytes(StandardCharsets.UTF_8)); - when(httpData.toInputStream()).thenReturn(stream); - - when(aggregatedHttpRequest.content()).thenReturn(httpData); - logHTTPService.processRequest(aggregatedHttpRequest); - verify(byteBuffer, times(2)).writeBytes(any(), anyString(), any(Integer.class)); - // Test more than 4MB json data - exampleString = "[" + largeTestString + "," + largeTestString + ","+largeTestString +","+largeTestString+"]"; - when(httpData.array()).thenReturn(exampleString.getBytes()); - stream = new ByteArrayInputStream(exampleString.getBytes(StandardCharsets.UTF_8)); - when(httpData.toInputStream()).thenReturn(stream); - - when(aggregatedHttpRequest.content()).thenReturn(httpData); - logHTTPService.processRequest(aggregatedHttpRequest); - verify(byteBuffer, times(7)).writeBytes(any(), anyString(), any(Integer.class)); - - // Test more than 4MB single json object, should throw exception - int length = 3*1024*1024; - sb = new StringBuilder(length); - for (int i = 0; i < length; i++) { - sb.append('A'); + } + + @Nested + class NonChunkingByteBuffer { + private String testString = "{\"key1\":\"value1\"},{\"key2\":\"value2\"},{\"key3\":\"value3\"},{\"key4\":\"value4\"},{\"key5\":\"value5\"}"; + private AggregatedHttpRequest aggregatedHttpRequest; + private HttpData httpData; + private String largeTestString; + + @BeforeEach + void setUp() { + byteBuffer = mock(Buffer.class); + when(byteBuffer.isByteBuffer()).thenReturn(true); + + aggregatedHttpRequest = mock(AggregatedHttpRequest.class); + httpData = mock(HttpData.class); + + logHTTPService = new LogHTTPService(TEST_TIMEOUT_IN_MILLIS, byteBuffer, pluginMetrics); + + StringBuilder sb = new StringBuilder(1024 * 1024 + 10240); + for (int i = 0; i < 12500; i++) { + sb.append(testString); + if (i + 1 != 12500) + sb.append(","); + } + largeTestString = sb.toString(); } - String value = sb.toString(); - exampleString = "[{\"key\":\""+value+"\"}]"; - when(httpData.array()).thenReturn(exampleString.getBytes()); - stream = new ByteArrayInputStream(exampleString.getBytes(StandardCharsets.UTF_8)); - when(httpData.toInputStream()).thenReturn(stream); - when(aggregatedHttpRequest.content()).thenReturn(httpData); - assertThrows(RuntimeException.class, () -> logHTTPService.processRequest(aggregatedHttpRequest)); + @Test + void invalid_JSON_returns() { + // Test small json data + String exampleString = testString; + InputStream stream = new ByteArrayInputStream(exampleString.getBytes(StandardCharsets.UTF_8)); + when(httpData.toInputStream()).thenReturn(stream); + + when(aggregatedHttpRequest.content()).thenReturn(httpData); + IOException actualException = assertThrows(IOException.class, () -> logHTTPService.processRequest(aggregatedHttpRequest)); + assertThat(actualException.getMessage(), containsString("Needs to be json array")); + } + + @Test + void chunking_with_1mb() throws Exception { + // Test more than 1MB json data + String exampleString = "[" + largeTestString + "]"; + byte[] bytes = exampleString.getBytes(); + when(httpData.array()).thenReturn(bytes); + ByteArrayInputStream stream = new ByteArrayInputStream(exampleString.getBytes(StandardCharsets.UTF_8)); + when(httpData.toInputStream()).thenReturn(stream); + + when(aggregatedHttpRequest.content()).thenReturn(httpData); + logHTTPService.processRequest(aggregatedHttpRequest); + ArgumentCaptor byteContentCaptor = ArgumentCaptor.forClass(byte[].class); + verify(byteBuffer).writeBytes(byteContentCaptor.capture(), isNull(), eq(TEST_TIMEOUT_IN_MILLIS)); + + final byte[] actualBytesWritten = byteContentCaptor.getValue(); + assertThat(actualBytesWritten.length, equalTo(bytes.length)); + } + + @Test + void chunking_with_4mb() throws Exception { + // Test more than 4MB json data + String exampleString = "[" + largeTestString + "," + largeTestString + "," + largeTestString + "," + largeTestString + "]"; + byte[] bytes = exampleString.getBytes(); + when(httpData.array()).thenReturn(bytes); + ByteArrayInputStream stream = new ByteArrayInputStream(exampleString.getBytes(StandardCharsets.UTF_8)); + when(httpData.toInputStream()).thenReturn(stream); + + when(aggregatedHttpRequest.content()).thenReturn(httpData); + logHTTPService.processRequest(aggregatedHttpRequest); + ArgumentCaptor byteContentCaptor = ArgumentCaptor.forClass(byte[].class); + verify(byteBuffer).writeBytes(byteContentCaptor.capture(), isNull(), eq(TEST_TIMEOUT_IN_MILLIS)); + + final byte[] actualBytesWritten = byteContentCaptor.getValue(); + assertThat(actualBytesWritten.length, equalTo(bytes.length)); + assertThat(actualBytesWritten, equalTo(bytes)); + } } + private AggregatedHttpRequest generateRandomValidHTTPRequest(int numJson) throws JsonProcessingException, ExecutionException, InterruptedException { RequestHeaders requestHeaders = RequestHeaders.builder() diff --git a/data-prepper-plugins/kinesis-source/build.gradle b/data-prepper-plugins/kinesis-source/build.gradle new file mode 100644 index 0000000000..c4a0614e36 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/build.gradle @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +plugins { + id 'java' +} + +dependencies { + implementation project(':data-prepper-api') + implementation project(':data-prepper-plugins:common') + implementation project(path: ':data-prepper-plugins:buffer-common') + implementation project(path: ':data-prepper-plugins:aws-plugin-api') + implementation 'com.fasterxml.jackson.core:jackson-core' + implementation 'io.micrometer:micrometer-core' + implementation 'software.amazon.kinesis:amazon-kinesis-client:2.6.0' + compileOnly 'org.projectlombok:lombok:1.18.20' + annotationProcessor 'org.projectlombok:lombok:1.18.20' + + testImplementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml' + testImplementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310' + testImplementation project(':data-prepper-test-common') + testImplementation project(':data-prepper-test-event') + testImplementation project(':data-prepper-core') + testImplementation project(':data-prepper-plugin-framework') + testImplementation project(':data-prepper-pipeline-parser') + testImplementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml' + testImplementation project(':data-prepper-plugins:parse-json-processor') + testImplementation project(':data-prepper-plugins:newline-codecs') +} + +jacocoTestCoverageVerification { + dependsOn jacocoTestReport + violationRules { + rule { //in addition to core projects rule + limit { + minimum = 1.0 + } + } + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfig.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfig.java new file mode 100644 index 0000000000..68981c5cba --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfig.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.extension; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; + +@Getter +public class KinesisLeaseConfig { + @JsonProperty("lease_coordination") + private KinesisLeaseCoordinationTableConfig leaseCoordinationTable; +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigExtension.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigExtension.java new file mode 100644 index 0000000000..2cca52e565 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigExtension.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.extension; + +import org.opensearch.dataprepper.model.annotations.DataPrepperExtensionPlugin; +import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; +import org.opensearch.dataprepper.model.plugin.ExtensionPlugin; +import org.opensearch.dataprepper.model.plugin.ExtensionPoints; + +@DataPrepperExtensionPlugin(modelType = KinesisLeaseConfig.class, rootKeyJsonPath = "/kinesis", allowInPipelineConfigurations = true) +public class KinesisLeaseConfigExtension implements ExtensionPlugin { + + private KinesisLeaseConfigSupplier kinesisLeaseConfigSupplier; + @DataPrepperPluginConstructor + public KinesisLeaseConfigExtension(final KinesisLeaseConfig kinesisLeaseConfig) { + this.kinesisLeaseConfigSupplier = new KinesisLeaseConfigSupplier(kinesisLeaseConfig); + } + + @Override + public void apply(final ExtensionPoints extensionPoints) { + extensionPoints.addExtensionProvider(new KinesisLeaseConfigProvider(this.kinesisLeaseConfigSupplier)); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigProvider.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigProvider.java new file mode 100644 index 0000000000..9140ca9e92 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigProvider.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.extension; + +import org.opensearch.dataprepper.model.plugin.ExtensionProvider; + +import java.util.Optional; + +class KinesisLeaseConfigProvider implements ExtensionProvider { + private final KinesisLeaseConfigSupplier kinesisLeaseConfigSupplier; + + public KinesisLeaseConfigProvider(final KinesisLeaseConfigSupplier kinesisLeaseConfigSupplier) { + this.kinesisLeaseConfigSupplier = kinesisLeaseConfigSupplier; + } + + @Override + public Optional provideInstance(Context context) { + return Optional.of(this.kinesisLeaseConfigSupplier); + } + + @Override + public Class supportedClass() { + return KinesisLeaseConfigSupplier.class; + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigSupplier.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigSupplier.java new file mode 100644 index 0000000000..6c00e40405 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigSupplier.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.extension; + +import java.util.Optional; + +public class KinesisLeaseConfigSupplier { + + private KinesisLeaseConfig kinesisLeaseConfig; + + public KinesisLeaseConfigSupplier(final KinesisLeaseConfig kinesisLeaseConfig) { + this.kinesisLeaseConfig = kinesisLeaseConfig; + } + + public Optional getKinesisExtensionLeaseConfig() { + return Optional.ofNullable(kinesisLeaseConfig); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseCoordinationTableConfig.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseCoordinationTableConfig.java new file mode 100644 index 0000000000..d497f01369 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseCoordinationTableConfig.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.extension; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; +import lombok.NonNull; +import software.amazon.awssdk.regions.Region; + +@Getter +public class KinesisLeaseCoordinationTableConfig { + + @JsonProperty("table_name") + @NonNull + private String tableName; + + @JsonProperty("region") + @NonNull + private String region; + + public Region getAwsRegion() { + return Region.of(region); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/HostNameWorkerIdentifierGenerator.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/HostNameWorkerIdentifierGenerator.java new file mode 100644 index 0000000000..61383304d0 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/HostNameWorkerIdentifierGenerator.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source; + +import java.net.InetAddress; +import java.net.UnknownHostException; + +/** + * Generate a unique ID to represent a consumer application instance. + */ +public class HostNameWorkerIdentifierGenerator implements WorkerIdentifierGenerator { + + private static final String hostName; + + static { + try { + hostName = InetAddress.getLocalHost().getHostName(); + } catch (final UnknownHostException e) { + throw new RuntimeException(e); + } + } + + + /** + * @return Default to use host name. + */ + @Override + public String generate() { + return hostName; + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisClientFactory.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisClientFactory.java new file mode 100644 index 0000000000..8f3bac38aa --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisClientFactory.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source; + +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.AwsAuthenticationConfig; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.kinesis.common.KinesisClientUtil; + +public class KinesisClientFactory { + private final AwsCredentialsProvider awsCredentialsProvider; + private final AwsCredentialsProvider defaultCredentialsProvider; + private final AwsAuthenticationConfig awsAuthenticationConfig; + + public KinesisClientFactory(final AwsCredentialsSupplier awsCredentialsSupplier, + final AwsAuthenticationConfig awsAuthenticationConfig) { + awsCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder() + .withRegion(awsAuthenticationConfig.getAwsRegion()) + .withStsRoleArn(awsAuthenticationConfig.getAwsStsRoleArn()) + .withStsExternalId(awsAuthenticationConfig.getAwsStsExternalId()) + .withStsHeaderOverrides(awsAuthenticationConfig.getAwsStsHeaderOverrides()) + .build()); + defaultCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.defaultOptions()); + this.awsAuthenticationConfig = awsAuthenticationConfig; + } + + public DynamoDbAsyncClient buildDynamoDBClient(Region region) { + return DynamoDbAsyncClient.builder() + .credentialsProvider(defaultCredentialsProvider) + .region(region) + .build(); + } + + public KinesisAsyncClient buildKinesisAsyncClient(Region region) { + return KinesisClientUtil.createKinesisAsyncClient( + KinesisAsyncClient.builder() + .credentialsProvider(awsCredentialsProvider) + .region(region) + ); + } + + public CloudWatchAsyncClient buildCloudWatchAsyncClient(Region region) { + return CloudWatchAsyncClient.builder() + .credentialsProvider(defaultCredentialsProvider) + .region(region) + .build(); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisMultiStreamTracker.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisMultiStreamTracker.java new file mode 100644 index 0000000000..638751f17e --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisMultiStreamTracker.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source; + +import com.amazonaws.arn.Arn; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisStreamConfig; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamResponse; +import software.amazon.awssdk.services.kinesis.model.StreamDescription; +import software.amazon.kinesis.common.InitialPositionInStreamExtended; +import software.amazon.kinesis.common.StreamConfig; +import software.amazon.kinesis.common.StreamIdentifier; +import software.amazon.kinesis.processor.FormerStreamsLeasesDeletionStrategy; +import software.amazon.kinesis.processor.MultiStreamTracker; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + + +public class KinesisMultiStreamTracker implements MultiStreamTracker { + private static final String COLON = ":"; + + private final KinesisAsyncClient kinesisClient; + private final KinesisSourceConfig sourceConfig; + private final String applicationName; + + public KinesisMultiStreamTracker(KinesisAsyncClient kinesisClient, final KinesisSourceConfig sourceConfig, final String applicationName) { + this.kinesisClient = kinesisClient; + this.sourceConfig = sourceConfig; + this.applicationName = applicationName; + } + + @Override + public List streamConfigList() { + List streamConfigList = new ArrayList<>(); + for (KinesisStreamConfig kinesisStreamConfig : sourceConfig.getStreams()) { + StreamConfig streamConfig = getStreamConfig(kinesisStreamConfig); + streamConfigList.add(streamConfig); + } + return streamConfigList; + } + + private StreamConfig getStreamConfig(KinesisStreamConfig kinesisStreamConfig) { + StreamIdentifier sourceStreamIdentifier = getStreamIdentifier(kinesisStreamConfig); + return new StreamConfig(sourceStreamIdentifier, + InitialPositionInStreamExtended.newInitialPosition(kinesisStreamConfig.getInitialPosition())); + } + + private StreamIdentifier getStreamIdentifier(KinesisStreamConfig kinesisStreamConfig) { + DescribeStreamRequest describeStreamRequest = DescribeStreamRequest.builder() + .streamName(kinesisStreamConfig.getName()) + .build(); + DescribeStreamResponse describeStreamResponse = kinesisClient.describeStream(describeStreamRequest).join(); + String streamIdentifierString = getStreamIdentifierString(describeStreamResponse.streamDescription()); + return StreamIdentifier.multiStreamInstance(streamIdentifierString); + } + + private String getStreamIdentifierString(StreamDescription streamDescription) { + String accountId = Arn.fromString(streamDescription.streamARN()).getAccountId(); + long creationEpochSecond = streamDescription.streamCreationTimestamp().getEpochSecond(); + return String.join(COLON, accountId, streamDescription.streamName(), String.valueOf(creationEpochSecond)); + } + + /** + * Setting the deletion policy as autodetect and release shard lease with a wait time of 10 sec + */ + @Override + public FormerStreamsLeasesDeletionStrategy formerStreamsLeasesDeletionStrategy() { + return new FormerStreamsLeasesDeletionStrategy.AutoDetectionAndDeferredDeletionStrategy() { + @Override + public Duration waitPeriodToDeleteFormerStreams() { + return Duration.ofSeconds(10); + } + }; + + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisService.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisService.java new file mode 100644 index 0000000000..4ed15833f6 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisService.java @@ -0,0 +1,169 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source; + +import lombok.Setter; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.InvalidPluginConfigurationException; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.kinesis.extension.KinesisLeaseConfig; +import org.opensearch.dataprepper.plugins.kinesis.extension.KinesisLeaseConfigSupplier; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.ConsumerStrategy; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.processor.KinesisShardRecordProcessorFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.model.BillingMode; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.kinesis.common.ConfigsBuilder; +import software.amazon.kinesis.coordinator.Scheduler; +import software.amazon.kinesis.processor.ShardRecordProcessorFactory; +import software.amazon.kinesis.retrieval.polling.PollingConfig; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class KinesisService { + private static final Logger LOG = LoggerFactory.getLogger(KinesisService.class); + private static final int GRACEFUL_SHUTDOWN_WAIT_INTERVAL_SECONDS = 20; + + private final PluginMetrics pluginMetrics; + private final PluginFactory pluginFactory; + + private final String applicationName; + private final String tableName; + private final String kclMetricsNamespaceName; + private final String pipelineName; + private final AcknowledgementSetManager acknowledgementSetManager; + private final KinesisSourceConfig kinesisSourceConfig; + private final KinesisAsyncClient kinesisClient; + private final DynamoDbAsyncClient dynamoDbClient; + private final CloudWatchAsyncClient cloudWatchClient; + private final WorkerIdentifierGenerator workerIdentifierGenerator; + private final InputCodec codec; + + @Setter + private Scheduler scheduler; + private final ExecutorService executorService; + + public KinesisService(final KinesisSourceConfig kinesisSourceConfig, + final KinesisClientFactory kinesisClientFactory, + final PluginMetrics pluginMetrics, + final PluginFactory pluginFactory, + final PipelineDescription pipelineDescription, + final AcknowledgementSetManager acknowledgementSetManager, + final KinesisLeaseConfigSupplier kinesisLeaseConfigSupplier, + final WorkerIdentifierGenerator workerIdentifierGenerator + ){ + this.kinesisSourceConfig = kinesisSourceConfig; + this.pluginMetrics = pluginMetrics; + this.pluginFactory = pluginFactory; + this.acknowledgementSetManager = acknowledgementSetManager; + if (kinesisLeaseConfigSupplier.getKinesisExtensionLeaseConfig().isEmpty()) { + throw new IllegalStateException("Lease Coordination table should be provided!"); + } + KinesisLeaseConfig kinesisLeaseConfig = + kinesisLeaseConfigSupplier.getKinesisExtensionLeaseConfig().get(); + this.tableName = kinesisLeaseConfig.getLeaseCoordinationTable().getTableName(); + this.kclMetricsNamespaceName = this.tableName; + this.dynamoDbClient = kinesisClientFactory.buildDynamoDBClient(kinesisLeaseConfig.getLeaseCoordinationTable().getAwsRegion()); + this.kinesisClient = kinesisClientFactory.buildKinesisAsyncClient(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsRegion()); + this.cloudWatchClient = kinesisClientFactory.buildCloudWatchAsyncClient(kinesisLeaseConfig.getLeaseCoordinationTable().getAwsRegion()); + this.pipelineName = pipelineDescription.getPipelineName(); + this.applicationName = pipelineName; + this.workerIdentifierGenerator = workerIdentifierGenerator; + this.executorService = Executors.newFixedThreadPool(1); + final PluginModel codecConfiguration = kinesisSourceConfig.getCodec(); + final PluginSetting codecPluginSettings = new PluginSetting(codecConfiguration.getPluginName(), codecConfiguration.getPluginSettings()); + this.codec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSettings); + } + + public void start(final Buffer> buffer) { + if (buffer == null) { + throw new IllegalStateException("Buffer provided is null."); + } + + if (kinesisSourceConfig.getStreams() == null || kinesisSourceConfig.getStreams().isEmpty()) { + throw new InvalidPluginConfigurationException("No Kinesis streams provided."); + } + + scheduler = getScheduler(buffer); + executorService.execute(scheduler); + } + + public void shutDown() { + LOG.info("Stop request received for Kinesis Source"); + + Future gracefulShutdownFuture = scheduler.startGracefulShutdown(); + LOG.info("Waiting up to {} seconds for shutdown to complete.", GRACEFUL_SHUTDOWN_WAIT_INTERVAL_SECONDS); + try { + gracefulShutdownFuture.get(GRACEFUL_SHUTDOWN_WAIT_INTERVAL_SECONDS, TimeUnit.SECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException ex) { + LOG.error("Exception while executing kinesis consumer graceful shutdown, doing force shutdown", ex); + scheduler.shutdown(); + } + LOG.info("Completed, shutting down now."); + } + + public Scheduler getScheduler(final Buffer> buffer) { + if (scheduler == null) { + return createScheduler(buffer); + } + return scheduler; + } + + public Scheduler createScheduler(final Buffer> buffer) { + final ShardRecordProcessorFactory processorFactory = new KinesisShardRecordProcessorFactory( + buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, codec); + + ConfigsBuilder configsBuilder = + new ConfigsBuilder( + new KinesisMultiStreamTracker(kinesisClient, kinesisSourceConfig, applicationName), + applicationName, kinesisClient, dynamoDbClient, cloudWatchClient, + workerIdentifierGenerator.generate(), processorFactory + ) + .tableName(tableName) + .namespace(kclMetricsNamespaceName); + + ConsumerStrategy consumerStrategy = kinesisSourceConfig.getConsumerStrategy(); + if (consumerStrategy == ConsumerStrategy.POLLING) { + configsBuilder.retrievalConfig().retrievalSpecificConfig( + new PollingConfig(kinesisClient) + .maxRecords(kinesisSourceConfig.getPollingConfig().getMaxPollingRecords()) + .idleTimeBetweenReadsInMillis( + kinesisSourceConfig.getPollingConfig().getIdleTimeBetweenReads().toMillis())); + } + + return new Scheduler( + configsBuilder.checkpointConfig(), + configsBuilder.coordinatorConfig(), + configsBuilder.leaseManagementConfig().billingMode(BillingMode.PAY_PER_REQUEST), + configsBuilder.lifecycleConfig(), + configsBuilder.metricsConfig(), + configsBuilder.processorConfig(), + configsBuilder.retrievalConfig() + ); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisSource.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisSource.java new file mode 100644 index 0000000000..220d19cac8 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisSource.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source; + +import lombok.Setter; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; +import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.source.Source; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.kinesis.extension.KinesisLeaseConfigSupplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@DataPrepperPlugin(name = "kinesis", pluginType = Source.class, pluginConfigurationType = KinesisSourceConfig.class) +public class KinesisSource implements Source> { + private static final Logger LOG = LoggerFactory.getLogger(KinesisSource.class); + private final KinesisSourceConfig kinesisSourceConfig; + private final KinesisLeaseConfigSupplier kinesisLeaseConfigSupplier; + + @Setter + private KinesisService kinesisService; + + @DataPrepperPluginConstructor + public KinesisSource(final KinesisSourceConfig kinesisSourceConfig, + final PluginMetrics pluginMetrics, + final PluginFactory pluginFactory, + final PipelineDescription pipelineDescription, + final AwsCredentialsSupplier awsCredentialsSupplier, + final AcknowledgementSetManager acknowledgementSetManager, + final KinesisLeaseConfigSupplier kinesisLeaseConfigSupplier) { + this.kinesisSourceConfig = kinesisSourceConfig; + this.kinesisLeaseConfigSupplier = kinesisLeaseConfigSupplier; + KinesisClientFactory kinesisClientFactory = new KinesisClientFactory(awsCredentialsSupplier, kinesisSourceConfig.getAwsAuthenticationConfig()); + this.kinesisService = new KinesisService(kinesisSourceConfig, kinesisClientFactory, pluginMetrics, pluginFactory, + pipelineDescription, acknowledgementSetManager, kinesisLeaseConfigSupplier, new HostNameWorkerIdentifierGenerator()); + } + @Override + public void start(final Buffer> buffer) { + if (buffer == null) { + throw new IllegalStateException("Buffer provided is null"); + } + + kinesisService.start(buffer); + } + + @Override + public void stop() { + kinesisService.shutDown(); + } + + @Override + public boolean areAcknowledgementsEnabled() { + return kinesisSourceConfig.isAcknowledgments(); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/WorkerIdentifierGenerator.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/WorkerIdentifierGenerator.java new file mode 100644 index 0000000000..75bad8761a --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/WorkerIdentifierGenerator.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source; + +public interface WorkerIdentifierGenerator { + + String generate(); +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/AwsAuthenticationConfig.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/AwsAuthenticationConfig.java new file mode 100644 index 0000000000..6a98f70c3b --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/AwsAuthenticationConfig.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.configuration; + +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.constraints.Size; +import lombok.Getter; +import software.amazon.awssdk.regions.Region; + +import java.util.Map; + +public class AwsAuthenticationConfig { + private static final String AWS_IAM_ROLE = "role"; + private static final String AWS_IAM = "iam"; + + @JsonProperty("region") + @Size(min = 1, message = "Region cannot be empty string") + private String awsRegion; + + @Getter + @JsonProperty("sts_role_arn") + @Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters") + private String awsStsRoleArn; + + @Getter + @JsonProperty("sts_external_id") + @Size(min = 2, max = 1224, message = "awsStsExternalId length should be between 2 and 1224 characters") + private String awsStsExternalId; + + @Getter + @JsonProperty("sts_header_overrides") + @Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override") + private Map awsStsHeaderOverrides; + + public Region getAwsRegion() { + return awsRegion != null ? Region.of(awsRegion) : null; + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/ConsumerStrategy.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/ConsumerStrategy.java new file mode 100644 index 0000000000..05fc88f62a --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/ConsumerStrategy.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.configuration; + +import com.fasterxml.jackson.annotation.JsonValue; + +/** + * @see Enhanced Consumers + */ + +public enum ConsumerStrategy { + + POLLING("polling"), + + ENHANCED_FAN_OUT("fan-out"); + + private final String value; + + ConsumerStrategy(String value) { + this.value = value; + } + + @JsonValue + public String getValue() { + return value; + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/InitialPositionInStreamConfig.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/InitialPositionInStreamConfig.java new file mode 100644 index 0000000000..37019cc9af --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/InitialPositionInStreamConfig.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.configuration; + +import lombok.Getter; +import software.amazon.kinesis.common.InitialPositionInStream; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; + +@Getter +public enum InitialPositionInStreamConfig { + LATEST("latest", InitialPositionInStream.LATEST), + EARLIEST("earliest", InitialPositionInStream.TRIM_HORIZON); + + private final String position; + + private final InitialPositionInStream positionInStream; + + InitialPositionInStreamConfig(final String position, final InitialPositionInStream positionInStream) { + this.position = position; + this.positionInStream = positionInStream; + } + + private static final Map POSITIONS_MAP = Arrays.stream(InitialPositionInStreamConfig.values()) + .collect(Collectors.toMap( + value -> value.position, + value -> value + )); + + public static InitialPositionInStreamConfig fromPositionValue(final String position) { + return POSITIONS_MAP.get(position.toLowerCase()); + } + + public String toString() { + return this.position; + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisSourceConfig.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisSourceConfig.java new file mode 100644 index 0000000000..1414229813 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisSourceConfig.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; +import lombok.Getter; +import org.opensearch.dataprepper.model.configuration.PluginModel; + +import java.time.Duration; +import java.util.List; + +public class KinesisSourceConfig { + static final Duration DEFAULT_TIME_OUT_IN_MILLIS = Duration.ofMillis(1000); + static final int DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE = 100; + static final Duration DEFAULT_SHARD_ACKNOWLEDGEMENT_TIMEOUT = Duration.ofMinutes(10); + + @Getter + @JsonProperty("streams") + @NotNull + @Valid + @Size(min = 1, max = 4, message = "Provide 1-4 streams to read from.") + private List streams; + + @Getter + @JsonProperty("aws") + @NotNull + @Valid + private AwsAuthenticationConfig awsAuthenticationConfig; + + @Getter + @JsonProperty("buffer_timeout") + private Duration bufferTimeout = DEFAULT_TIME_OUT_IN_MILLIS; + + @Getter + @JsonProperty("records_to_accumulate") + private int numberOfRecordsToAccumulate = DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE; + + @JsonProperty("acknowledgments") + @Getter + private boolean acknowledgments = false; + + @Getter + @JsonProperty("consumer_strategy") + private ConsumerStrategy consumerStrategy = ConsumerStrategy.ENHANCED_FAN_OUT; + + @Getter + @JsonProperty("polling") + private KinesisStreamPollingConfig pollingConfig; + + @Getter + @NotNull + @JsonProperty("codec") + private PluginModel codec; + + @JsonProperty("shard_acknowledgment_timeout") + private Duration shardAcknowledgmentTimeout = DEFAULT_SHARD_ACKNOWLEDGEMENT_TIMEOUT; + + public Duration getShardAcknowledgmentTimeout() { + return shardAcknowledgmentTimeout; + } +} + + + diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisStreamConfig.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisStreamConfig.java new file mode 100644 index 0000000000..b26732e357 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisStreamConfig.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import lombok.Getter; +import software.amazon.kinesis.common.InitialPositionInStream; + +import java.time.Duration; + +@Getter +public class KinesisStreamConfig { + // Checkpointing interval + private static final Duration MINIMAL_CHECKPOINT_INTERVAL = Duration.ofMillis(2 * 60 * 1000); // 2 minute + private static final boolean DEFAULT_ENABLE_CHECKPOINT = false; + + @JsonProperty("stream_name") + @NotNull + @Valid + private String name; + + @JsonProperty("initial_position") + private InitialPositionInStreamConfig initialPosition = InitialPositionInStreamConfig.LATEST; + + @JsonProperty("checkpoint_interval") + private Duration checkPointInterval = MINIMAL_CHECKPOINT_INTERVAL; + + public InitialPositionInStream getInitialPosition() { + return initialPosition.getPositionInStream(); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisStreamPollingConfig.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisStreamPollingConfig.java new file mode 100644 index 0000000000..cd7b7a59f6 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisStreamPollingConfig.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; + +import java.time.Duration; + +public class KinesisStreamPollingConfig { + private static final int DEFAULT_MAX_RECORDS = 10000; + private static final Duration IDLE_TIME_BETWEEN_READS = Duration.ofMillis(250); + @Getter + @JsonProperty("max_polling_records") + private int maxPollingRecords = DEFAULT_MAX_RECORDS; + + @Getter + @JsonProperty("idle_time_between_reads") + private Duration idleTimeBetweenReads = IDLE_TIME_BETWEEN_READS; + +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/converter/KinesisRecordConverter.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/converter/KinesisRecordConverter.java new file mode 100644 index 0000000000..5a70b95c10 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/converter/KinesisRecordConverter.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.converter; + +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import software.amazon.kinesis.retrieval.KinesisClientRecord; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +public class KinesisRecordConverter { + + private final InputCodec codec; + + public KinesisRecordConverter(final InputCodec codec) { + this.codec = codec; + } + + public List> convert(List kinesisClientRecords) throws IOException { + List> records = new ArrayList<>(); + for (KinesisClientRecord record : kinesisClientRecords) { + processRecord(record, records::add); + } + return records; + } + + private void processRecord(KinesisClientRecord record, Consumer> eventConsumer) throws IOException { + // Read bytebuffer + byte[] arr = new byte[record.data().remaining()]; + record.data().get(arr); + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arr); + codec.parse(byteArrayInputStream, eventConsumer); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisCheckpointerRecord.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisCheckpointerRecord.java new file mode 100644 index 0000000000..b891de2bd0 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisCheckpointerRecord.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.processor; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import software.amazon.kinesis.processor.RecordProcessorCheckpointer; +import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; + +@Builder +@Getter +@Setter +public class KinesisCheckpointerRecord { + private RecordProcessorCheckpointer checkpointer; + private ExtendedSequenceNumber extendedSequenceNumber; + private boolean readyToCheckpoint; +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisCheckpointerTracker.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisCheckpointerTracker.java new file mode 100644 index 0000000000..8fb7c5ec6c --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisCheckpointerTracker.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.processor; + +import software.amazon.kinesis.processor.RecordProcessorCheckpointer; +import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class KinesisCheckpointerTracker { + private final Map checkpointerRecordList = new LinkedHashMap<>(); + + public synchronized void addRecordForCheckpoint(final ExtendedSequenceNumber extendedSequenceNumber, + final RecordProcessorCheckpointer checkpointer) { + checkpointerRecordList.put(extendedSequenceNumber, KinesisCheckpointerRecord.builder() + .extendedSequenceNumber(extendedSequenceNumber) + .checkpointer(checkpointer) + .readyToCheckpoint(false) + .build()); + } + + public synchronized void markSequenceNumberForCheckpoint(final ExtendedSequenceNumber extendedSequenceNumber) { + if (!checkpointerRecordList.containsKey(extendedSequenceNumber)) { + throw new IllegalArgumentException("checkpointer not available"); + } + checkpointerRecordList.get(extendedSequenceNumber).setReadyToCheckpoint(true); + } + + public synchronized Optional popLatestReadyToCheckpointRecord() { + Optional kinesisCheckpointerRecordOptional = Optional.empty(); + List toRemoveRecords = new ArrayList<>(); + + for (Map.Entry entry: checkpointerRecordList.entrySet()) { + KinesisCheckpointerRecord kinesisCheckpointerRecord = entry.getValue(); + + // Break out of the loop on the first record which is not ready for checkpoint + if (!kinesisCheckpointerRecord.isReadyToCheckpoint()) { + break; + } + + kinesisCheckpointerRecordOptional = Optional.of(kinesisCheckpointerRecord); + toRemoveRecords.add(entry.getKey()); + } + + //Cleanup the ones which are already marked for checkpoint + for (ExtendedSequenceNumber extendedSequenceNumber: toRemoveRecords) { + checkpointerRecordList.remove(extendedSequenceNumber); + } + + return kinesisCheckpointerRecordOptional; + } + + public synchronized int size() { + return checkpointerRecordList.size(); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisRecordProcessor.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisRecordProcessor.java new file mode 100644 index 0000000000..6df0760ca3 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisRecordProcessor.java @@ -0,0 +1,270 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.processor; + +import com.google.common.annotations.VisibleForTesting; +import io.micrometer.core.instrument.Counter; +import org.opensearch.dataprepper.buffer.common.BufferAccumulator; +import org.opensearch.dataprepper.common.concurrent.BackgroundThreadFactory; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.EventMetadata; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisStreamConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.converter.KinesisRecordConverter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.kinesis.common.StreamIdentifier; +import software.amazon.kinesis.exceptions.InvalidStateException; +import software.amazon.kinesis.exceptions.ShutdownException; +import software.amazon.kinesis.exceptions.ThrottlingException; +import software.amazon.kinesis.lifecycle.events.InitializationInput; +import software.amazon.kinesis.lifecycle.events.LeaseLostInput; +import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; +import software.amazon.kinesis.lifecycle.events.ShardEndedInput; +import software.amazon.kinesis.lifecycle.events.ShutdownRequestedInput; +import software.amazon.kinesis.processor.RecordProcessorCheckpointer; +import software.amazon.kinesis.processor.ShardRecordProcessor; +import software.amazon.kinesis.retrieval.KinesisClientRecord; +import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; + +import java.time.Duration; +import java.util.List; +import java.util.ListIterator; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; + +public class KinesisRecordProcessor implements ShardRecordProcessor { + private static final Logger LOG = LoggerFactory.getLogger(KinesisRecordProcessor.class); + + private static final int DEFAULT_MONITOR_WAIT_TIME_MS = 15_000; + private static final Duration ACKNOWLEDGEMENT_SET_TIMEOUT = Duration.ofSeconds(20); + + private final StreamIdentifier streamIdentifier; + private final KinesisStreamConfig kinesisStreamConfig; + private final Duration checkpointInterval; + private final KinesisSourceConfig kinesisSourceConfig; + private final BufferAccumulator> bufferAccumulator; + private final KinesisRecordConverter kinesisRecordConverter; + private final KinesisCheckpointerTracker kinesisCheckpointerTracker; + private final ExecutorService executorService; + private String kinesisShardId; + private long lastCheckpointTimeInMillis; + private final int bufferTimeoutMillis; + private final AcknowledgementSetManager acknowledgementSetManager; + + private final Counter acknowledgementSetSuccesses; + private final Counter acknowledgementSetFailures; + private final Counter recordsProcessed; + private final Counter recordProcessingErrors; + private final Counter checkpointFailures; + public static final String ACKNOWLEDGEMENT_SET_SUCCESS_METRIC_NAME = "acknowledgementSetSuccesses"; + public static final String ACKNOWLEDGEMENT_SET_FAILURES_METRIC_NAME = "acknowledgementSetFailures"; + public static final String KINESIS_RECORD_PROCESSED = "recordProcessed"; + public static final String KINESIS_RECORD_PROCESSING_ERRORS = "recordProcessingErrors"; + public static final String KINESIS_CHECKPOINT_FAILURES = "checkpointFailures"; + public static final String KINESIS_STREAM_TAG_KEY = "stream"; + private AtomicBoolean isStopRequested; + + public KinesisRecordProcessor(final BufferAccumulator> bufferAccumulator, + final KinesisSourceConfig kinesisSourceConfig, + final AcknowledgementSetManager acknowledgementSetManager, + final PluginMetrics pluginMetrics, + final KinesisRecordConverter kinesisRecordConverter, + final KinesisCheckpointerTracker kinesisCheckpointerTracker, + final StreamIdentifier streamIdentifier) { + this.bufferTimeoutMillis = (int) kinesisSourceConfig.getBufferTimeout().toMillis(); + this.streamIdentifier = streamIdentifier; + this.kinesisSourceConfig = kinesisSourceConfig; + this.kinesisStreamConfig = getStreamConfig(kinesisSourceConfig); + this.kinesisRecordConverter = kinesisRecordConverter; + this.acknowledgementSetManager = acknowledgementSetManager; + this.acknowledgementSetSuccesses = pluginMetrics.counterWithTags(ACKNOWLEDGEMENT_SET_SUCCESS_METRIC_NAME, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName()); + this.acknowledgementSetFailures = pluginMetrics.counterWithTags(ACKNOWLEDGEMENT_SET_FAILURES_METRIC_NAME, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName()); + this.recordsProcessed = pluginMetrics.counterWithTags(KINESIS_RECORD_PROCESSED, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName()); + this.recordProcessingErrors = pluginMetrics.counterWithTags(KINESIS_RECORD_PROCESSING_ERRORS, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName()); + this.checkpointFailures = pluginMetrics.counterWithTags(KINESIS_CHECKPOINT_FAILURES, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName()); + this.checkpointInterval = kinesisStreamConfig.getCheckPointInterval(); + this.bufferAccumulator = bufferAccumulator; + this.kinesisCheckpointerTracker = kinesisCheckpointerTracker; + this.executorService = Executors.newSingleThreadExecutor(BackgroundThreadFactory.defaultExecutorThreadFactory("kinesis-ack-monitor")); + this.isStopRequested = new AtomicBoolean(false); + } + + private KinesisStreamConfig getStreamConfig(final KinesisSourceConfig kinesisSourceConfig) { + return kinesisSourceConfig.getStreams().stream().filter(streamConfig -> streamConfig.getName().equals(streamIdentifier.streamName())).findAny().get(); + } + + @Override + public void initialize(InitializationInput initializationInput) { + // Called once when the processor is initialized. + kinesisShardId = initializationInput.shardId(); + String kinesisStreamName = streamIdentifier.streamName(); + LOG.info("Initialize Processor for stream: {}, shard: {}", kinesisStreamName, kinesisShardId); + lastCheckpointTimeInMillis = System.currentTimeMillis(); + + if (kinesisSourceConfig.isAcknowledgments()) { + executorService.submit(() -> monitorCheckpoint(executorService)); + } + } + + private void monitorCheckpoint(final ExecutorService executorService) { + while (!isStopRequested.get()) { + if (System.currentTimeMillis() - lastCheckpointTimeInMillis >= checkpointInterval.toMillis()) { + doCheckpoint(); + } + try { + Thread.sleep(DEFAULT_MONITOR_WAIT_TIME_MS); + } catch (InterruptedException ex) { + break; + } + } + executorService.shutdown(); + } + + private AcknowledgementSet createAcknowledgmentSet(final ProcessRecordsInput processRecordsInput, + final ExtendedSequenceNumber extendedSequenceNumber) { + return acknowledgementSetManager.create((result) -> { + String kinesisStreamName = streamIdentifier.streamName(); + if (result) { + acknowledgementSetSuccesses.increment(); + kinesisCheckpointerTracker.markSequenceNumberForCheckpoint(extendedSequenceNumber); + LOG.debug("acknowledgements received for stream: {}, shardId: {}", kinesisStreamName, kinesisShardId); + } else { + acknowledgementSetFailures.increment(); + LOG.debug("acknowledgements received with false for stream: {}, shardId: {}", kinesisStreamName, kinesisShardId); + } + }, ACKNOWLEDGEMENT_SET_TIMEOUT); + } + + @Override + public void processRecords(ProcessRecordsInput processRecordsInput) { + try { + Optional acknowledgementSetOpt = Optional.empty(); + boolean acknowledgementsEnabled = kinesisSourceConfig.isAcknowledgments(); + ExtendedSequenceNumber extendedSequenceNumber = getLatestSequenceNumberFromInput(processRecordsInput); + if (acknowledgementsEnabled) { + acknowledgementSetOpt = Optional.of(createAcknowledgmentSet(processRecordsInput, extendedSequenceNumber)); + } + + // Track the records for checkpoint purpose + kinesisCheckpointerTracker.addRecordForCheckpoint(extendedSequenceNumber, processRecordsInput.checkpointer()); + List> records = kinesisRecordConverter.convert(processRecordsInput.records()); + + int eventCount = 0; + for (Record record: records) { + Event event = record.getData(); + acknowledgementSetOpt.ifPresent(acknowledgementSet -> acknowledgementSet.add(event)); + EventMetadata eventMetadata = event.getMetadata(); + eventMetadata.setAttribute(MetadataKeyAttributes.KINESIS_STREAM_NAME_METADATA_ATTRIBUTE, + streamIdentifier.streamName().toLowerCase()); + bufferAccumulator.add(record); + eventCount++; + } + + // Flush buffer at the end + bufferAccumulator.flush(); + recordsProcessed.increment(eventCount); + + // If acks are not enabled, mark the sequence number for checkpoint + if (!acknowledgementsEnabled) { + kinesisCheckpointerTracker.markSequenceNumberForCheckpoint(extendedSequenceNumber); + } + + LOG.debug("Number of Records {} written for stream: {}, shardId: {} to buffer: {}", eventCount, streamIdentifier.streamName(), kinesisShardId, records.size()); + + acknowledgementSetOpt.ifPresent(AcknowledgementSet::complete); + + // Checkpoint for shard + if (!acknowledgementsEnabled && (System.currentTimeMillis() - lastCheckpointTimeInMillis >= checkpointInterval.toMillis())) { + doCheckpoint(); + } + } catch (Exception ex) { + recordProcessingErrors.increment(); + LOG.error("Failed writing shard data to buffer: ", ex); + } + } + + @Override + public void leaseLost(LeaseLostInput leaseLostInput) { + LOG.debug("Lease Lost"); + } + + @Override + public void shardEnded(ShardEndedInput shardEndedInput) { + String kinesisStream = streamIdentifier.streamName(); + LOG.debug("Reached shard end, checkpointing for stream: {}, shardId: {}", kinesisStream, kinesisShardId); + checkpoint(shardEndedInput.checkpointer()); + } + + @Override + public void shutdownRequested(ShutdownRequestedInput shutdownRequestedInput) { + String kinesisStream = streamIdentifier.streamName(); + isStopRequested.set(true); + LOG.debug("Scheduler is shutting down, checkpointing for stream: {}, shardId: {}", kinesisStream, kinesisShardId); + checkpoint(shutdownRequestedInput.checkpointer()); + } + + @VisibleForTesting + public void checkpoint(RecordProcessorCheckpointer checkpointer, String sequenceNumber, long subSequenceNumber) { + try { + String kinesisStream = streamIdentifier.streamName(); + LOG.debug("Checkpoint for stream: {}, shardId: {}, sequence: {}, subsequence: {}", kinesisStream, kinesisShardId, sequenceNumber, subSequenceNumber); + checkpointer.checkpoint(sequenceNumber, subSequenceNumber); + } catch (ShutdownException | ThrottlingException | InvalidStateException ex) { + LOG.debug("Caught exception at checkpoint, skipping checkpoint.", ex); + checkpointFailures.increment(); + } + } + + private void doCheckpoint() { + LOG.debug("Regular checkpointing for shard {}", kinesisShardId); + Optional kinesisCheckpointerRecordOptional = kinesisCheckpointerTracker.popLatestReadyToCheckpointRecord(); + if (kinesisCheckpointerRecordOptional.isPresent()) { + ExtendedSequenceNumber lastExtendedSequenceNumber = kinesisCheckpointerRecordOptional.get().getExtendedSequenceNumber(); + RecordProcessorCheckpointer recordProcessorCheckpointer = kinesisCheckpointerRecordOptional.get().getCheckpointer(); + checkpoint(recordProcessorCheckpointer, lastExtendedSequenceNumber.sequenceNumber(), lastExtendedSequenceNumber.subSequenceNumber()); + lastCheckpointTimeInMillis = System.currentTimeMillis(); + } + } + + private void checkpoint(RecordProcessorCheckpointer checkpointer) { + try { + String kinesisStream = streamIdentifier.streamName(); + LOG.debug("Checkpoint for stream: {}, shardId: {}", kinesisStream, kinesisShardId); + checkpointer.checkpoint(); + } catch (ShutdownException | ThrottlingException | InvalidStateException ex) { + LOG.debug("Caught exception at checkpoint, skipping checkpoint.", ex); + checkpointFailures.increment(); + } + } + + private ExtendedSequenceNumber getLatestSequenceNumberFromInput(final ProcessRecordsInput processRecordsInput) { + ListIterator recordIterator = processRecordsInput.records().listIterator(); + ExtendedSequenceNumber largestExtendedSequenceNumber = null; + while (recordIterator.hasNext()) { + KinesisClientRecord record = recordIterator.next(); + ExtendedSequenceNumber extendedSequenceNumber = + new ExtendedSequenceNumber(record.sequenceNumber(), record.subSequenceNumber()); + + if (largestExtendedSequenceNumber == null + || largestExtendedSequenceNumber.compareTo(extendedSequenceNumber) < 0) { + largestExtendedSequenceNumber = extendedSequenceNumber; + } + } + return largestExtendedSequenceNumber; + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisShardRecordProcessorFactory.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisShardRecordProcessorFactory.java new file mode 100644 index 0000000000..ff9943a41d --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisShardRecordProcessorFactory.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.processor; + +import org.opensearch.dataprepper.buffer.common.BufferAccumulator; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.converter.KinesisRecordConverter; +import software.amazon.kinesis.common.StreamIdentifier; +import software.amazon.kinesis.processor.ShardRecordProcessor; +import software.amazon.kinesis.processor.ShardRecordProcessorFactory; + +public class KinesisShardRecordProcessorFactory implements ShardRecordProcessorFactory { + + private final Buffer> buffer; + private final KinesisSourceConfig kinesisSourceConfig; + private final AcknowledgementSetManager acknowledgementSetManager; + private final PluginMetrics pluginMetrics; + private final KinesisRecordConverter kinesisRecordConverter; + + public KinesisShardRecordProcessorFactory(Buffer> buffer, + KinesisSourceConfig kinesisSourceConfig, + final AcknowledgementSetManager acknowledgementSetManager, + final PluginMetrics pluginMetrics, + final InputCodec codec) { + this.kinesisSourceConfig = kinesisSourceConfig; + this.buffer = buffer; + this.acknowledgementSetManager = acknowledgementSetManager; + this.pluginMetrics = pluginMetrics; + this.kinesisRecordConverter = new KinesisRecordConverter(codec); + } + + @Override + public ShardRecordProcessor shardRecordProcessor() { + throw new UnsupportedOperationException("Use the method with stream details!"); + } + + @Override + public ShardRecordProcessor shardRecordProcessor(StreamIdentifier streamIdentifier) { + BufferAccumulator> bufferAccumulator = BufferAccumulator.create(buffer, + kinesisSourceConfig.getNumberOfRecordsToAccumulate(), kinesisSourceConfig.getBufferTimeout()); + KinesisCheckpointerTracker kinesisCheckpointerTracker = new KinesisCheckpointerTracker(); + return new KinesisRecordProcessor(bufferAccumulator, kinesisSourceConfig, acknowledgementSetManager, + pluginMetrics, kinesisRecordConverter, kinesisCheckpointerTracker, streamIdentifier); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/MetadataKeyAttributes.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/MetadataKeyAttributes.java new file mode 100644 index 0000000000..e2debba54e --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/MetadataKeyAttributes.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.processor; + +public class MetadataKeyAttributes { + static final String KINESIS_STREAM_NAME_METADATA_ATTRIBUTE = "kinesis_stream_name"; +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigExtensionTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigExtensionTest.java new file mode 100644 index 0000000000..852baab195 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigExtensionTest.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.extension; + +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.dataprepper.model.plugin.ExtensionPoints; +import org.opensearch.dataprepper.model.plugin.ExtensionProvider; + +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +public class KinesisLeaseConfigExtensionTest { + @Mock + private ExtensionPoints extensionPoints; + + @Mock + private KinesisLeaseConfig kinesisLeaseConfig; + + private KinesisLeaseConfigExtension createObjectUnderTest() { + return new KinesisLeaseConfigExtension(kinesisLeaseConfig); + } + + @Test + void applyShouldAddExtensionProvider() { + extensionPoints = mock(ExtensionPoints.class); + createObjectUnderTest().apply(extensionPoints); + final ArgumentCaptor extensionProviderArgumentCaptor = + ArgumentCaptor.forClass(ExtensionProvider.class); + + verify(extensionPoints).addExtensionProvider(extensionProviderArgumentCaptor.capture()); + + final ExtensionProvider actualExtensionProvider = extensionProviderArgumentCaptor.getValue(); + + assertThat(actualExtensionProvider, instanceOf(KinesisLeaseConfigProvider.class)); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigProviderTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigProviderTest.java new file mode 100644 index 0000000000..1fa17f5f42 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigProviderTest.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.extension; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.model.plugin.ExtensionProvider; + +import java.util.Optional; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.CoreMatchers.sameInstance; +import static org.hamcrest.MatcherAssert.assertThat; + +@ExtendWith(MockitoExtension.class) +public class KinesisLeaseConfigProviderTest { + @Mock + private KinesisLeaseConfigSupplier kinesisLeaseConfigSupplier; + + @Mock + private ExtensionProvider.Context context; + + private KinesisLeaseConfigProvider createObjectUnderTest() { + return new KinesisLeaseConfigProvider(kinesisLeaseConfigSupplier); + } + + @Test + void supportedClassReturnsKinesisSourceConfigSupplier() { + assertThat(createObjectUnderTest().supportedClass(), equalTo(KinesisLeaseConfigSupplier.class)); + } + + @Test + void provideInstanceReturnsKinesisSourceConfigSupplierFromConstructor() { + final KinesisLeaseConfigProvider objectUnderTest = createObjectUnderTest(); + + final Optional optionalKinesisSourceConfigSupplier = objectUnderTest.provideInstance(context); + assertThat(optionalKinesisSourceConfigSupplier, notNullValue()); + assertThat(optionalKinesisSourceConfigSupplier.isPresent(), equalTo(true)); + assertThat(optionalKinesisSourceConfigSupplier.get(), equalTo(kinesisLeaseConfigSupplier)); + + final Optional anotherOptionalKinesisSourceConfigSupplier = objectUnderTest.provideInstance(context); + assertThat(anotherOptionalKinesisSourceConfigSupplier, notNullValue()); + assertThat(anotherOptionalKinesisSourceConfigSupplier.isPresent(), equalTo(true)); + assertThat(anotherOptionalKinesisSourceConfigSupplier.get(), sameInstance(optionalKinesisSourceConfigSupplier.get())); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigSupplierTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigSupplierTest.java new file mode 100644 index 0000000000..4cfc323ed5 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigSupplierTest.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.extension; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.Optional; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class KinesisLeaseConfigSupplierTest { + private static final String LEASE_COORDINATION_TABLE = "lease-table"; + @Mock + KinesisLeaseConfig kinesisLeaseConfig; + + @Mock + KinesisLeaseCoordinationTableConfig kinesisLeaseCoordinationTableConfig; + + private KinesisLeaseConfigSupplier createObjectUnderTest() { + return new KinesisLeaseConfigSupplier(kinesisLeaseConfig); + } + + @Test + void testGetters() { + when(kinesisLeaseConfig.getLeaseCoordinationTable()).thenReturn(kinesisLeaseCoordinationTableConfig); + when(kinesisLeaseCoordinationTableConfig.getTableName()).thenReturn(LEASE_COORDINATION_TABLE); + when(kinesisLeaseCoordinationTableConfig.getRegion()).thenReturn("us-east-1"); + KinesisLeaseConfigSupplier kinesisLeaseConfigSupplier = createObjectUnderTest(); + assertThat(kinesisLeaseConfigSupplier.getKinesisExtensionLeaseConfig().get().getLeaseCoordinationTable().getTableName(), equalTo(LEASE_COORDINATION_TABLE)); + assertThat(kinesisLeaseConfigSupplier.getKinesisExtensionLeaseConfig().get().getLeaseCoordinationTable().getRegion(), equalTo("us-east-1")); + } + + @Test + void testGettersWithNullTableConfig() { + when(kinesisLeaseConfig.getLeaseCoordinationTable()).thenReturn(null); + KinesisLeaseConfigSupplier defaultKinesisLeaseConfigSupplier = createObjectUnderTest(); + assertThat(defaultKinesisLeaseConfigSupplier.getKinesisExtensionLeaseConfig().get().getLeaseCoordinationTable(), equalTo(null)); + + } + + @Test + void testGettersWithNullConfig() { + KinesisLeaseConfigSupplier kinesisLeaseConfigSupplier = new KinesisLeaseConfigSupplier(null); + assertThat(kinesisLeaseConfigSupplier.getKinesisExtensionLeaseConfig(), equalTo(Optional.empty())); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigTest.java new file mode 100644 index 0000000000..30194a9659 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/extension/KinesisLeaseConfigTest.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.extension; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.module.SimpleModule; +import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.model.types.ByteCount; +import org.opensearch.dataprepper.parser.model.DataPrepperConfiguration; +import org.opensearch.dataprepper.pipeline.parser.ByteCountDeserializer; +import org.opensearch.dataprepper.pipeline.parser.DataPrepperDurationDeserializer; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import software.amazon.awssdk.regions.Region; + +import java.io.File; +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.time.Duration; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class KinesisLeaseConfigTest { + private static SimpleModule simpleModule = new SimpleModule() + .addDeserializer(Duration.class, new DataPrepperDurationDeserializer()) + .addDeserializer(ByteCount.class, new ByteCountDeserializer()); + private static ObjectMapper OBJECT_MAPPER = new ObjectMapper(new YAMLFactory()).registerModule(simpleModule); + + private KinesisLeaseConfig makeConfig(String filePath) throws IOException { + final File configurationFile = new File(filePath); + final DataPrepperConfiguration dataPrepperConfiguration = OBJECT_MAPPER.readValue(configurationFile, DataPrepperConfiguration.class); + assertThat(dataPrepperConfiguration, notNullValue()); + assertThat(dataPrepperConfiguration.getPipelineExtensions(), notNullValue()); + final Map kinesisLeaseConfigMap = + (Map) dataPrepperConfiguration.getPipelineExtensions().getExtensionMap().get("kinesis"); + String json = OBJECT_MAPPER.writeValueAsString(kinesisLeaseConfigMap); + Reader reader = new StringReader(json); + return OBJECT_MAPPER.readValue(reader, KinesisLeaseConfig.class); + } + + + @Test + void testConfigWithTestExtension() throws IOException { + final KinesisLeaseConfig kinesisLeaseConfig = makeConfig( + "src/test/resources/simple_pipeline_with_extensions.yaml"); + + assertNotNull(kinesisLeaseConfig.getLeaseCoordinationTable()); + assertEquals(kinesisLeaseConfig.getLeaseCoordinationTable().getTableName(), "kinesis-pipeline-kcl"); + assertEquals(kinesisLeaseConfig.getLeaseCoordinationTable().getRegion(), "us-east-1"); + assertEquals(kinesisLeaseConfig.getLeaseCoordinationTable().getAwsRegion(), Region.US_EAST_1); + } + +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisClientFactoryTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisClientFactoryTest.java new file mode 100644 index 0000000000..f476754eb9 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisClientFactoryTest.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source; + +import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.AwsAuthenticationConfig; +import org.opensearch.dataprepper.test.helper.ReflectivelySetField; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; + +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KinesisClientFactoryTest { + private Region region = Region.US_EAST_1; + private String roleArn; + private Map stsHeader; + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Test + void testCreateClient() throws NoSuchFieldException, IllegalAccessException { + roleArn = "arn:aws:iam::278936200144:role/test-role"; + stsHeader= new HashMap<>(); + stsHeader.put(UUID.randomUUID().toString(),UUID.randomUUID().toString()); + awsCredentialsSupplier = mock(AwsCredentialsSupplier.class); + + AwsAuthenticationConfig awsAuthenticationOptionsConfig = new AwsAuthenticationConfig(); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsRegion", "us-east-1"); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsStsRoleArn", roleArn); + + AwsCredentialsProvider defaultCredentialsProvider = mock(AwsCredentialsProvider.class); + when(awsCredentialsSupplier.getProvider(eq(AwsCredentialsOptions.defaultOptions()))).thenReturn(defaultCredentialsProvider); + + KinesisClientFactory clientFactory = new KinesisClientFactory(awsCredentialsSupplier, awsAuthenticationOptionsConfig); + + final DynamoDbAsyncClient dynamoDbAsyncClient = clientFactory.buildDynamoDBClient(Region.US_EAST_1); + assertNotNull(dynamoDbAsyncClient); + + final KinesisAsyncClient kinesisAsyncClient = clientFactory.buildKinesisAsyncClient(Region.US_EAST_1); + assertNotNull(kinesisAsyncClient); + + final CloudWatchAsyncClient cloudWatchAsyncClient = clientFactory.buildCloudWatchAsyncClient(Region.US_EAST_1); + assertNotNull(cloudWatchAsyncClient); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisMultiStreamTrackerTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisMultiStreamTrackerTest.java new file mode 100644 index 0000000000..edf23b8033 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisMultiStreamTrackerTest.java @@ -0,0 +1,153 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisStreamConfig; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamResponse; +import software.amazon.awssdk.services.kinesis.model.StreamDescription; +import software.amazon.kinesis.common.InitialPositionInStream; +import software.amazon.kinesis.common.InitialPositionInStreamExtended; +import software.amazon.kinesis.common.StreamConfig; +import software.amazon.kinesis.common.StreamIdentifier; +import software.amazon.kinesis.processor.FormerStreamsLeasesDeletionStrategy; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KinesisMultiStreamTrackerTest { + private static final String APPLICATION_NAME = "multi-stream-application"; + private static final String awsAccountId = "1234"; + private static final String streamArnFormat = "arn:aws:kinesis:us-east-1:%s:stream/%s"; + private static final Instant streamCreationTime = Instant.now(); + private static final List STREAMS_LIST = ImmutableList.of("stream-1", "stream-2", "stream-3"); + + private KinesisMultiStreamTracker kinesisMultiStreamTracker; + @Mock + private KinesisAsyncClient kinesisClient; + private List streamConfigList; + + private Map streamConfigMap; + + @Mock + KinesisSourceConfig kinesisSourceConfig; + + @BeforeEach + public void setUp() { + MockitoAnnotations.openMocks(this); + List kinesisStreamConfigs = new ArrayList<>(); + streamConfigMap = new HashMap<>(); + STREAMS_LIST.forEach(stream -> { + KinesisStreamConfig kinesisStreamConfig = mock(KinesisStreamConfig.class); + when(kinesisStreamConfig.getName()).thenReturn(stream); + when(kinesisStreamConfig.getInitialPosition()).thenReturn(InitialPositionInStream.LATEST); + + StreamDescription streamDescription = StreamDescription.builder() + .streamARN(String.format(streamArnFormat, awsAccountId, stream)) + .streamCreationTimestamp(streamCreationTime) + .streamName(stream) + .build(); + + DescribeStreamRequest describeStreamRequest = DescribeStreamRequest.builder() + .streamName(stream) + .build(); + + DescribeStreamResponse describeStreamResponse = DescribeStreamResponse.builder() + .streamDescription(streamDescription) + .build(); + + when(kinesisClient.describeStream(describeStreamRequest)).thenReturn(CompletableFuture.completedFuture(describeStreamResponse)); + kinesisStreamConfigs.add(kinesisStreamConfig); + + streamConfigMap.put(stream, kinesisStreamConfig); + }); + + when(kinesisSourceConfig.getStreams()).thenReturn(kinesisStreamConfigs); + kinesisMultiStreamTracker = new KinesisMultiStreamTracker(kinesisClient, kinesisSourceConfig, APPLICATION_NAME); + } + + @Test + public void testStreamConfigList() { + streamConfigList = kinesisMultiStreamTracker.streamConfigList(); + assertEquals(kinesisSourceConfig.getStreams().size(), streamConfigList.size()); + + int totalStreams = streamConfigList.size(); + for (int i=0; i kinesisStreamConfigs = new ArrayList<>(); + streamConfigMap = new HashMap<>(); + STREAMS_LIST.forEach(stream -> { + KinesisStreamConfig kinesisStreamConfig = mock(KinesisStreamConfig.class); + when(kinesisStreamConfig.getName()).thenReturn(stream); + when(kinesisStreamConfig.getInitialPosition()).thenReturn(InitialPositionInStream.LATEST); + + DescribeStreamRequest describeStreamRequest = DescribeStreamRequest.builder() + .streamName(stream) + .build(); + + when(kinesisClient.describeStream(describeStreamRequest)).thenThrow(new RuntimeException()); + kinesisStreamConfigs.add(kinesisStreamConfig); + + streamConfigMap.put(stream, kinesisStreamConfig); + }); + + when(kinesisSourceConfig.getStreams()).thenReturn(kinesisStreamConfigs); + kinesisMultiStreamTracker = new KinesisMultiStreamTracker(kinesisClient, kinesisSourceConfig, APPLICATION_NAME); + + assertThrows(RuntimeException.class, () -> kinesisMultiStreamTracker.streamConfigList()); + } + + @Test + public void formerStreamsLeasesDeletionStrategy() { + + FormerStreamsLeasesDeletionStrategy formerStreamsLeasesDeletionStrategy = + kinesisMultiStreamTracker.formerStreamsLeasesDeletionStrategy(); + + Duration duration = formerStreamsLeasesDeletionStrategy.waitPeriodToDeleteFormerStreams(); + + Assertions.assertTrue(formerStreamsLeasesDeletionStrategy instanceof + FormerStreamsLeasesDeletionStrategy.AutoDetectionAndDeferredDeletionStrategy); + assertEquals(10, duration.getSeconds()); + } + + private StreamIdentifier getStreamIdentifier(final String streamName) { + return StreamIdentifier.multiStreamInstance(String.join(":", awsAccountId, streamName, String.valueOf(streamCreationTime.getEpochSecond()))); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisServiceTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisServiceTest.java new file mode 100644 index 0000000000..12986d9969 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisServiceTest.java @@ -0,0 +1,353 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.InvalidPluginConfigurationException; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.kinesis.extension.KinesisLeaseConfig; +import org.opensearch.dataprepper.plugins.kinesis.extension.KinesisLeaseConfigSupplier; +import org.opensearch.dataprepper.plugins.kinesis.extension.KinesisLeaseCoordinationTableConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.AwsAuthenticationConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.ConsumerStrategy; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisStreamConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisStreamPollingConfig; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisServiceClientConfiguration; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamResponse; +import software.amazon.awssdk.services.kinesis.model.StreamDescription; +import software.amazon.kinesis.common.InitialPositionInStream; +import software.amazon.kinesis.coordinator.Scheduler; +import software.amazon.kinesis.metrics.MetricsLevel; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class KinesisServiceTest { + private final String PIPELINE_NAME = "kinesis-pipeline-test"; + private final String streamId = "stream-1"; + private static final String codec_plugin_name = "json"; + + private static final Duration CHECKPOINT_INTERVAL = Duration.ofMillis(0); + private static final int NUMBER_OF_RECORDS_TO_ACCUMULATE = 10; + private static final int DEFAULT_MAX_RECORDS = 10000; + private static final int IDLE_TIME_BETWEEN_READS_IN_MILLIS = 250; + private static final String awsAccountId = "123456789012"; + private static final String streamArnFormat = "arn:aws:kinesis:us-east-1:%s:stream/%s"; + private static final Instant streamCreationTime = Instant.now(); + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private PluginFactory pluginFactory; + + @Mock + private KinesisSourceConfig kinesisSourceConfig; + + @Mock + private KinesisStreamConfig kinesisStreamConfig; + + @Mock + private KinesisStreamPollingConfig kinesisStreamPollingConfig; + + @Mock + private AwsAuthenticationConfig awsAuthenticationConfig; + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + + @Mock + private PipelineDescription pipelineDescription; + + @Mock + private KinesisClientFactory kinesisClientFactory; + + @Mock + private KinesisAsyncClient kinesisClient; + + @Mock + private DynamoDbAsyncClient dynamoDbClient; + + @Mock + private CloudWatchAsyncClient cloudWatchClient; + + @Mock + Buffer> buffer; + + @Mock + private Scheduler scheduler; + + @Mock + KinesisLeaseConfigSupplier kinesisLeaseConfigSupplier; + + @Mock + KinesisLeaseConfig kinesisLeaseConfig; + + @Mock + KinesisLeaseCoordinationTableConfig kinesisLeaseCoordinationTableConfig; + + @Mock + WorkerIdentifierGenerator workerIdentifierGenerator; + + @BeforeEach + void setup() { + awsAuthenticationConfig = mock(AwsAuthenticationConfig.class); + kinesisSourceConfig = mock(KinesisSourceConfig.class); + kinesisStreamConfig = mock(KinesisStreamConfig.class); + kinesisStreamPollingConfig = mock(KinesisStreamPollingConfig.class); + kinesisClient = mock(KinesisAsyncClient.class); + dynamoDbClient = mock(DynamoDbAsyncClient.class); + cloudWatchClient = mock(CloudWatchAsyncClient.class); + kinesisClientFactory = mock(KinesisClientFactory.class); + scheduler = mock(Scheduler.class); + pipelineDescription = mock(PipelineDescription.class); + buffer = mock(Buffer.class); + kinesisLeaseConfigSupplier = mock(KinesisLeaseConfigSupplier.class); + kinesisLeaseConfig = mock(KinesisLeaseConfig.class); + workerIdentifierGenerator = mock(WorkerIdentifierGenerator.class); + kinesisLeaseCoordinationTableConfig = mock(KinesisLeaseCoordinationTableConfig.class); + when(kinesisLeaseConfig.getLeaseCoordinationTable()).thenReturn(kinesisLeaseCoordinationTableConfig); + when(kinesisLeaseCoordinationTableConfig.getTableName()).thenReturn("kinesis-lease-table"); + when(kinesisLeaseCoordinationTableConfig.getRegion()).thenReturn("us-east-1"); + when(kinesisLeaseCoordinationTableConfig.getAwsRegion()).thenReturn(Region.US_EAST_1); + when(kinesisLeaseConfigSupplier.getKinesisExtensionLeaseConfig()).thenReturn(Optional.ofNullable(kinesisLeaseConfig)); + + when(awsAuthenticationConfig.getAwsRegion()).thenReturn(Region.of("us-west-2")); + when(awsAuthenticationConfig.getAwsStsRoleArn()).thenReturn(UUID.randomUUID().toString()); + when(awsAuthenticationConfig.getAwsStsExternalId()).thenReturn(UUID.randomUUID().toString()); + final Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + when(awsAuthenticationConfig.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides); + StreamDescription streamDescription = StreamDescription.builder() + .streamARN(String.format(streamArnFormat, awsAccountId, streamId)) + .streamCreationTimestamp(streamCreationTime) + .streamName(streamId) + .build(); + + DescribeStreamRequest describeStreamRequest = DescribeStreamRequest.builder() + .streamName(streamId) + .build(); + + DescribeStreamResponse describeStreamResponse = DescribeStreamResponse.builder() + .streamDescription(streamDescription) + .build(); + + when(kinesisClient.describeStream(describeStreamRequest)).thenReturn(CompletableFuture.completedFuture(describeStreamResponse)); + + when(kinesisSourceConfig.getAwsAuthenticationConfig()).thenReturn(awsAuthenticationConfig); + when(kinesisStreamConfig.getName()).thenReturn(streamId); + when(kinesisStreamConfig.getCheckPointInterval()).thenReturn(CHECKPOINT_INTERVAL); + when(kinesisStreamConfig.getInitialPosition()).thenReturn(InitialPositionInStream.LATEST); + when(kinesisSourceConfig.getConsumerStrategy()).thenReturn(ConsumerStrategy.ENHANCED_FAN_OUT); + when(kinesisSourceConfig.getPollingConfig()).thenReturn(kinesisStreamPollingConfig); + when(kinesisStreamPollingConfig.getMaxPollingRecords()).thenReturn(DEFAULT_MAX_RECORDS); + when(kinesisStreamPollingConfig.getIdleTimeBetweenReads()).thenReturn(Duration.ofMillis(IDLE_TIME_BETWEEN_READS_IN_MILLIS)); + + List streamConfigs = new ArrayList<>(); + streamConfigs.add(kinesisStreamConfig); + when(kinesisSourceConfig.getStreams()).thenReturn(streamConfigs); + when(kinesisSourceConfig.getNumberOfRecordsToAccumulate()).thenReturn(NUMBER_OF_RECORDS_TO_ACCUMULATE); + + PluginModel pluginModel = mock(PluginModel.class); + when(pluginModel.getPluginName()).thenReturn(codec_plugin_name); + when(pluginModel.getPluginSettings()).thenReturn(Collections.emptyMap()); + when(kinesisSourceConfig.getCodec()).thenReturn(pluginModel); + + pluginFactory = mock(PluginFactory.class); + InputCodec codec = mock(InputCodec.class); + when(pluginFactory.loadPlugin(eq(InputCodec.class), any())).thenReturn(codec); + + when(kinesisClientFactory.buildDynamoDBClient(kinesisLeaseCoordinationTableConfig.getAwsRegion())).thenReturn(dynamoDbClient); + when(kinesisClientFactory.buildKinesisAsyncClient(awsAuthenticationConfig.getAwsRegion())).thenReturn(kinesisClient); + when(kinesisClientFactory.buildCloudWatchAsyncClient(kinesisLeaseCoordinationTableConfig.getAwsRegion())).thenReturn(cloudWatchClient); + when(kinesisClient.serviceClientConfiguration()).thenReturn(KinesisServiceClientConfiguration.builder().region(Region.US_EAST_1).build()); + when(scheduler.startGracefulShutdown()).thenReturn(CompletableFuture.completedFuture(true)); + when(pipelineDescription.getPipelineName()).thenReturn(PIPELINE_NAME); + when(workerIdentifierGenerator.generate()).thenReturn(UUID.randomUUID().toString()); + } + + public KinesisService createObjectUnderTest() { + return new KinesisService(kinesisSourceConfig, kinesisClientFactory, pluginMetrics, pluginFactory, + pipelineDescription, acknowledgementSetManager, kinesisLeaseConfigSupplier, workerIdentifierGenerator); + } + + @Test + void testServiceStart() { + KinesisService kinesisService = createObjectUnderTest(); + kinesisService.start(buffer); + assertNotNull(kinesisService.getScheduler(buffer)); + verify(workerIdentifierGenerator, times(1)).generate(); + } + + @Test + void testServiceThrowsWhenLeaseConfigIsInvalid() { + when(kinesisLeaseConfigSupplier.getKinesisExtensionLeaseConfig()).thenReturn(Optional.empty()); + assertThrows(IllegalStateException.class, () -> new KinesisService(kinesisSourceConfig, kinesisClientFactory, pluginMetrics, pluginFactory, + pipelineDescription, acknowledgementSetManager, kinesisLeaseConfigSupplier, workerIdentifierGenerator)); + } + + @Test + void testCreateScheduler() { + KinesisService kinesisService = new KinesisService(kinesisSourceConfig, kinesisClientFactory, pluginMetrics, pluginFactory, + pipelineDescription, acknowledgementSetManager, kinesisLeaseConfigSupplier, workerIdentifierGenerator); + Scheduler schedulerObjectUnderTest = kinesisService.createScheduler(buffer); + + assertNotNull(schedulerObjectUnderTest); + assertNotNull(schedulerObjectUnderTest.checkpointConfig()); + assertNotNull(schedulerObjectUnderTest.leaseManagementConfig()); + assertSame(schedulerObjectUnderTest.leaseManagementConfig().initialPositionInStream().getInitialPositionInStream(), InitialPositionInStream.TRIM_HORIZON); + assertNotNull(schedulerObjectUnderTest.lifecycleConfig()); + assertNotNull(schedulerObjectUnderTest.metricsConfig()); + assertSame(schedulerObjectUnderTest.metricsConfig().metricsLevel(), MetricsLevel.DETAILED); + assertNotNull(schedulerObjectUnderTest.processorConfig()); + assertNotNull(schedulerObjectUnderTest.retrievalConfig()); + verify(workerIdentifierGenerator, times(1)).generate(); + } + + @Test + void testCreateSchedulerWithPollingStrategy() { + when(kinesisSourceConfig.getConsumerStrategy()).thenReturn(ConsumerStrategy.POLLING); + KinesisService kinesisService = new KinesisService(kinesisSourceConfig, kinesisClientFactory, pluginMetrics, pluginFactory, + pipelineDescription, acknowledgementSetManager, kinesisLeaseConfigSupplier, workerIdentifierGenerator); + Scheduler schedulerObjectUnderTest = kinesisService.createScheduler(buffer); + + assertNotNull(schedulerObjectUnderTest); + assertNotNull(schedulerObjectUnderTest.checkpointConfig()); + assertNotNull(schedulerObjectUnderTest.leaseManagementConfig()); + assertSame(schedulerObjectUnderTest.leaseManagementConfig().initialPositionInStream().getInitialPositionInStream(), InitialPositionInStream.TRIM_HORIZON); + assertNotNull(schedulerObjectUnderTest.lifecycleConfig()); + assertNotNull(schedulerObjectUnderTest.metricsConfig()); + assertSame(schedulerObjectUnderTest.metricsConfig().metricsLevel(), MetricsLevel.DETAILED); + assertNotNull(schedulerObjectUnderTest.processorConfig()); + assertNotNull(schedulerObjectUnderTest.retrievalConfig()); + verify(workerIdentifierGenerator, times(1)).generate(); + } + + + @Test + void testServiceStartNullBufferThrows() { + KinesisService kinesisService = createObjectUnderTest(); + assertThrows(IllegalStateException.class, () -> kinesisService.start(null)); + + verify(scheduler, times(0)).run(); + } + + @Test + void testServiceStartNullStreams() { + when(kinesisSourceConfig.getStreams()).thenReturn(null); + + KinesisService kinesisService = createObjectUnderTest(); + assertThrows(InvalidPluginConfigurationException.class, () -> kinesisService.start(buffer)); + + verify(scheduler, times(0)).run(); + } + + @Test + void testServiceStartEmptyStreams() { + when(kinesisSourceConfig.getStreams()).thenReturn(new ArrayList<>()); + + KinesisService kinesisService = createObjectUnderTest(); + assertThrows(InvalidPluginConfigurationException.class, () -> kinesisService.start(buffer)); + + verify(scheduler, times(0)).run(); + } + + @Test + public void testShutdownGraceful() { + KinesisService kinesisService = createObjectUnderTest(); + kinesisService.setScheduler(scheduler); + kinesisService.shutDown(); + + verify(scheduler).startGracefulShutdown(); + verify(scheduler, times(0)).shutdown(); + } + + @Test + public void testShutdownGracefulThrowInterruptedException() { + KinesisService kinesisService = createObjectUnderTest(); + + when(scheduler.startGracefulShutdown()).thenReturn(CompletableFuture.failedFuture(new InterruptedException())); + kinesisService.setScheduler(scheduler); + assertDoesNotThrow(kinesisService::shutDown); + + verify(scheduler).startGracefulShutdown(); + verify(scheduler, times(1)).shutdown(); + } + + @Test + public void testShutdownGracefulThrowTimeoutException() { + KinesisService kinesisService = createObjectUnderTest(); + kinesisService.setScheduler(scheduler); + when(scheduler.startGracefulShutdown()).thenReturn(CompletableFuture.failedFuture(new TimeoutException())); + assertDoesNotThrow(kinesisService::shutDown); + + verify(scheduler).startGracefulShutdown(); + verify(scheduler, times(1)).shutdown(); + } + + @Test + public void testShutdownGracefulThrowExecutionException() { + KinesisService kinesisService = createObjectUnderTest(); + kinesisService.setScheduler(scheduler); + when(scheduler.startGracefulShutdown()).thenReturn(CompletableFuture.failedFuture(new ExecutionException(new Throwable()))); + assertDoesNotThrow(kinesisService::shutDown); + + verify(scheduler).startGracefulShutdown(); + verify(scheduler, times(1)).shutdown(); + } + + @Test + public void testShutdownExecutorServiceInterruptedException() { + when(scheduler.startGracefulShutdown()).thenReturn(CompletableFuture.failedFuture(new InterruptedException())); + + KinesisService kinesisService = createObjectUnderTest(); + kinesisService.setScheduler(scheduler); + kinesisService.shutDown(); + + verify(scheduler).startGracefulShutdown(); + verify(scheduler).shutdown(); + } + +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisSourceTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisSourceTest.java new file mode 100644 index 0000000000..fad335dd63 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/KinesisSourceTest.java @@ -0,0 +1,190 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.kinesis.extension.KinesisLeaseConfig; +import org.opensearch.dataprepper.plugins.kinesis.extension.KinesisLeaseConfigSupplier; +import org.opensearch.dataprepper.plugins.kinesis.extension.KinesisLeaseCoordinationTableConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.AwsAuthenticationConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisStreamConfig; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class KinesisSourceTest { + private final String PIPELINE_NAME = "kinesis-pipeline-test"; + private final String streamId = "stream-1"; + private static final String codec_plugin_name = "json"; + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private PluginFactory pluginFactory; + + @Mock + private KinesisSourceConfig kinesisSourceConfig; + + @Mock + private AwsAuthenticationConfig awsAuthenticationConfig; + + private KinesisSource source; + + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Mock + private PipelineDescription pipelineDescription; + + @Mock KinesisService kinesisService; + + @Mock + KinesisLeaseConfigSupplier kinesisLeaseConfigSupplier; + + @Mock + KinesisLeaseConfig kinesisLeaseConfig; + + @Mock + KinesisLeaseCoordinationTableConfig kinesisLeaseCoordinationTableConfig; + + @BeforeEach + void setup() { + pluginMetrics = mock(PluginMetrics.class); + pluginFactory = mock(PluginFactory.class); + kinesisSourceConfig = mock(KinesisSourceConfig.class); + this.pipelineDescription = mock(PipelineDescription.class); + awsCredentialsSupplier = mock(AwsCredentialsSupplier.class); + awsAuthenticationConfig = mock(AwsAuthenticationConfig.class); + acknowledgementSetManager = mock(AcknowledgementSetManager.class); + kinesisService = mock(KinesisService.class); + + PluginModel pluginModel = mock(PluginModel.class); + when(pluginModel.getPluginName()).thenReturn(codec_plugin_name); + when(pluginModel.getPluginSettings()).thenReturn(Collections.emptyMap()); + when(kinesisSourceConfig.getCodec()).thenReturn(pluginModel); + + pluginFactory = mock(PluginFactory.class); + InputCodec codec = mock(InputCodec.class); + when(pluginFactory.loadPlugin(eq(InputCodec.class), any())).thenReturn(codec); + + kinesisLeaseConfigSupplier = mock(KinesisLeaseConfigSupplier.class); + kinesisLeaseConfig = mock(KinesisLeaseConfig.class); + kinesisLeaseCoordinationTableConfig = mock(KinesisLeaseCoordinationTableConfig.class); + when(kinesisLeaseConfig.getLeaseCoordinationTable()).thenReturn(kinesisLeaseCoordinationTableConfig); + when(kinesisLeaseCoordinationTableConfig.getTableName()).thenReturn("table-name"); + when(kinesisLeaseCoordinationTableConfig.getRegion()).thenReturn("us-east-1"); + when(kinesisLeaseCoordinationTableConfig.getAwsRegion()).thenReturn(Region.US_EAST_1); + when(kinesisLeaseConfigSupplier.getKinesisExtensionLeaseConfig()).thenReturn(Optional.ofNullable(kinesisLeaseConfig)); + when(awsAuthenticationConfig.getAwsRegion()).thenReturn(Region.US_EAST_1); + when(awsAuthenticationConfig.getAwsStsRoleArn()).thenReturn(UUID.randomUUID().toString()); + when(awsAuthenticationConfig.getAwsStsExternalId()).thenReturn(UUID.randomUUID().toString()); + final Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + AwsCredentialsProvider defaultCredentialsProvider = mock(AwsCredentialsProvider.class); + when(awsCredentialsSupplier.getProvider(AwsCredentialsOptions.defaultOptions())).thenReturn(defaultCredentialsProvider); + when(awsAuthenticationConfig.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides); + when(kinesisSourceConfig.getAwsAuthenticationConfig()).thenReturn(awsAuthenticationConfig); + when(pipelineDescription.getPipelineName()).thenReturn(PIPELINE_NAME); + } + + public KinesisSource createObjectUnderTest() { + return new KinesisSource(kinesisSourceConfig, pluginMetrics, pluginFactory, pipelineDescription, awsCredentialsSupplier, acknowledgementSetManager, kinesisLeaseConfigSupplier); + } + + @Test + public void testSourceWithoutAcknowledgements() { + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(false); + source = createObjectUnderTest(); + assertThat(source.areAcknowledgementsEnabled(), equalTo(false)); + } + + @Test + public void testSourceWithAcknowledgements() { + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(true); + source = createObjectUnderTest(); + assertThat(source.areAcknowledgementsEnabled(), equalTo(true)); + } + + @Test + public void testSourceStart() { + + source = createObjectUnderTest(); + + Buffer> buffer = mock(Buffer.class); + when(kinesisSourceConfig.getNumberOfRecordsToAccumulate()).thenReturn(100); + KinesisStreamConfig kinesisStreamConfig = mock(KinesisStreamConfig.class); + when(kinesisStreamConfig.getName()).thenReturn(streamId); + when(kinesisSourceConfig.getStreams()).thenReturn(List.of(kinesisStreamConfig)); + source.setKinesisService(kinesisService); + + source.start(buffer); + + verify(kinesisService, times(1)).start(any(Buffer.class)); + + } + + @Test + public void testSourceStartBufferNull() { + + source = createObjectUnderTest(); + + assertThrows(IllegalStateException.class, () -> source.start(null)); + + verify(kinesisService, times(0)).start(any(Buffer.class)); + + } + + @Test + public void testSourceStop() { + + source = createObjectUnderTest(); + + source.setKinesisService(kinesisService); + + source.stop(); + + verify(kinesisService, times(1)).shutDown(); + + } + +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/AwsAuthenticationConfigTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/AwsAuthenticationConfigTest.java new file mode 100644 index 0000000000..499711c4a9 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/AwsAuthenticationConfigTest.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.configuration; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import software.amazon.awssdk.regions.Region; + +import java.util.Collections; +import java.util.Map; +import java.util.UUID; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; + +public class AwsAuthenticationConfigTest { + private ObjectMapper objectMapper = new ObjectMapper(); + + @ParameterizedTest + @ValueSource(strings = {"us-east-1", "us-west-2", "eu-central-1"}) + void getAwsRegionReturnsRegion(final String regionString) { + final Region expectedRegionObject = Region.of(regionString); + final Map jsonMap = Map.of("region", regionString); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsRegion(), equalTo(expectedRegionObject)); + } + + @Test + void getAwsRegionReturnsNullWhenRegionIsNull() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsRegion(), nullValue()); + } + + @Test + void getAwsStsRoleArnReturnsValueFromDeserializedJSON() { + final String stsRoleArn = UUID.randomUUID().toString(); + final Map jsonMap = Map.of("sts_role_arn", stsRoleArn); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsStsRoleArn(), equalTo(stsRoleArn)); + } + + @Test + void getAwsStsRoleArnReturnsNullIfNotInJSON() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsStsRoleArn(), nullValue()); + } + + @Test + void getAwsStsExternalIdReturnsValueFromDeserializedJSON() { + final String stsExternalId = UUID.randomUUID().toString(); + final Map jsonMap = Map.of("sts_external_id", stsExternalId); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsStsExternalId(), equalTo(stsExternalId)); + } + + @Test + void getAwsStsExternalIdReturnsNullIfNotInJSON() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsStsExternalId(), nullValue()); + } + + @Test + void getAwsStsHeaderOverridesReturnsValueFromDeserializedJSON() { + final Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + final Map jsonMap = Map.of("sts_header_overrides", stsHeaderOverrides); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsStsHeaderOverrides(), equalTo(stsHeaderOverrides)); + } + + @Test + void getAwsStsHeaderOverridesReturnsNullIfNotInJSON() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsStsHeaderOverrides(), nullValue()); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/InitialPositionInStreamConfigTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/InitialPositionInStreamConfigTest.java new file mode 100644 index 0000000000..2e1b638342 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/InitialPositionInStreamConfigTest.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.configuration; + +import org.junit.jupiter.api.Test; +import software.amazon.kinesis.common.InitialPositionInStream; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class InitialPositionInStreamConfigTest { + + @Test + void testInitialPositionGetByNameLATEST() { + final InitialPositionInStreamConfig initialPositionInStreamConfig = InitialPositionInStreamConfig.fromPositionValue("latest"); + assertEquals(initialPositionInStreamConfig, InitialPositionInStreamConfig.LATEST); + assertEquals(initialPositionInStreamConfig.toString(), "latest"); + assertEquals(initialPositionInStreamConfig.getPosition(), "latest"); + assertEquals(initialPositionInStreamConfig.getPositionInStream(), InitialPositionInStream.LATEST); + } + + @Test + void testInitialPositionGetByNameEarliest() { + final InitialPositionInStreamConfig initialPositionInStreamConfig = InitialPositionInStreamConfig.fromPositionValue("earliest"); + assertEquals(initialPositionInStreamConfig, InitialPositionInStreamConfig.EARLIEST); + assertEquals(initialPositionInStreamConfig.toString(), "earliest"); + assertEquals(initialPositionInStreamConfig.getPosition(), "earliest"); + assertEquals(initialPositionInStreamConfig.getPositionInStream(), InitialPositionInStream.TRIM_HORIZON); + } + +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisSourceConfigTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisSourceConfigTest.java new file mode 100644 index 0000000000..5846fe4b04 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisSourceConfigTest.java @@ -0,0 +1,161 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.configuration; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.module.SimpleModule; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.opensearch.dataprepper.pipeline.parser.DataPrepperDurationDeserializer; +import software.amazon.awssdk.regions.Region; +import software.amazon.kinesis.common.InitialPositionInStream; + +import java.io.File; +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class KinesisSourceConfigTest { + private static final String PIPELINE_CONFIG_WITH_ACKS_ENABLED = "pipeline_with_acks_enabled.yaml"; + private static final String PIPELINE_CONFIG_WITH_POLLING_CONFIG_ENABLED = "pipeline_with_polling_config_enabled.yaml"; + private static final String PIPELINE_CONFIG_CHECKPOINT_ENABLED = "pipeline_with_checkpoint_enabled.yaml"; + private static final Duration MINIMAL_CHECKPOINT_INTERVAL = Duration.ofMillis(2 * 60 * 1000); // 2 minute + + KinesisSourceConfig kinesisSourceConfig; + + ObjectMapper objectMapper; + + @BeforeEach + void setUp(TestInfo testInfo) throws IOException { + String fileName = testInfo.getTags().stream().findFirst().orElse(""); + final File configurationFile = new File(getClass().getClassLoader().getResource(fileName).getFile()); + objectMapper = new ObjectMapper(new YAMLFactory()); + SimpleModule simpleModule = new SimpleModule(); + simpleModule.addDeserializer(Duration.class, new DataPrepperDurationDeserializer()); + objectMapper.registerModule(new JavaTimeModule()); + objectMapper.registerModule(simpleModule); + + final Map pipelineConfig = objectMapper.readValue(configurationFile, Map.class); + final Map sourceMap = (Map) pipelineConfig.get("source"); + final Map kinesisConfigMap = (Map) sourceMap.get("kinesis"); + String json = objectMapper.writeValueAsString(kinesisConfigMap); + final Reader reader = new StringReader(json); + kinesisSourceConfig = objectMapper.readValue(reader, KinesisSourceConfig.class); + + } + + @Test + @Tag(PIPELINE_CONFIG_WITH_ACKS_ENABLED) + void testSourceConfig() { + + assertThat(kinesisSourceConfig, notNullValue()); + assertEquals(KinesisSourceConfig.DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE, kinesisSourceConfig.getNumberOfRecordsToAccumulate()); + assertEquals(KinesisSourceConfig.DEFAULT_TIME_OUT_IN_MILLIS, kinesisSourceConfig.getBufferTimeout()); + assertTrue(kinesisSourceConfig.isAcknowledgments()); + assertEquals(KinesisSourceConfig.DEFAULT_SHARD_ACKNOWLEDGEMENT_TIMEOUT, kinesisSourceConfig.getShardAcknowledgmentTimeout()); + assertThat(kinesisSourceConfig.getAwsAuthenticationConfig(), notNullValue()); + assertEquals(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsRegion(), Region.US_EAST_1); + assertEquals(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsRoleArn(), "arn:aws:iam::123456789012:role/OSI-PipelineRole"); + assertNull(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsExternalId()); + assertNull(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsHeaderOverrides()); + + List streamConfigs = kinesisSourceConfig.getStreams(); + assertNotNull(kinesisSourceConfig.getCodec()); + assertEquals(kinesisSourceConfig.getConsumerStrategy(), ConsumerStrategy.ENHANCED_FAN_OUT); + assertNull(kinesisSourceConfig.getPollingConfig()); + + assertEquals(streamConfigs.size(), 3); + + for (KinesisStreamConfig kinesisStreamConfig: streamConfigs) { + assertTrue(kinesisStreamConfig.getName().contains("stream")); + assertEquals(kinesisStreamConfig.getInitialPosition(), InitialPositionInStream.LATEST); + assertEquals(kinesisStreamConfig.getCheckPointInterval(), MINIMAL_CHECKPOINT_INTERVAL); + } + } + + @Test + @Tag(PIPELINE_CONFIG_WITH_POLLING_CONFIG_ENABLED) + void testSourceConfigWithStreamCodec() { + + assertThat(kinesisSourceConfig, notNullValue()); + assertEquals(KinesisSourceConfig.DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE, kinesisSourceConfig.getNumberOfRecordsToAccumulate()); + assertEquals(KinesisSourceConfig.DEFAULT_TIME_OUT_IN_MILLIS, kinesisSourceConfig.getBufferTimeout()); + assertFalse(kinesisSourceConfig.isAcknowledgments()); + assertEquals(KinesisSourceConfig.DEFAULT_SHARD_ACKNOWLEDGEMENT_TIMEOUT, kinesisSourceConfig.getShardAcknowledgmentTimeout()); + assertThat(kinesisSourceConfig.getAwsAuthenticationConfig(), notNullValue()); + assertEquals(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsRegion(), Region.US_EAST_1); + assertEquals(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsRoleArn(), "arn:aws:iam::123456789012:role/OSI-PipelineRole"); + assertNull(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsExternalId()); + assertNull(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsHeaderOverrides()); + assertNotNull(kinesisSourceConfig.getCodec()); + List streamConfigs = kinesisSourceConfig.getStreams(); + assertEquals(kinesisSourceConfig.getConsumerStrategy(), ConsumerStrategy.POLLING); + assertNotNull(kinesisSourceConfig.getPollingConfig()); + assertEquals(kinesisSourceConfig.getPollingConfig().getMaxPollingRecords(), 10); + assertEquals(kinesisSourceConfig.getPollingConfig().getIdleTimeBetweenReads(), Duration.ofSeconds(10)); + + assertEquals(streamConfigs.size(), 3); + + for (KinesisStreamConfig kinesisStreamConfig: streamConfigs) { + assertTrue(kinesisStreamConfig.getName().contains("stream")); + assertEquals(kinesisStreamConfig.getInitialPosition(), InitialPositionInStream.LATEST); + assertEquals(kinesisStreamConfig.getCheckPointInterval(), MINIMAL_CHECKPOINT_INTERVAL); + } + } + + @Test + @Tag(PIPELINE_CONFIG_CHECKPOINT_ENABLED) + void testSourceConfigWithInitialPosition() { + + assertThat(kinesisSourceConfig, notNullValue()); + assertEquals(KinesisSourceConfig.DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE, kinesisSourceConfig.getNumberOfRecordsToAccumulate()); + assertEquals(KinesisSourceConfig.DEFAULT_TIME_OUT_IN_MILLIS, kinesisSourceConfig.getBufferTimeout()); + assertFalse(kinesisSourceConfig.isAcknowledgments()); + assertEquals(KinesisSourceConfig.DEFAULT_SHARD_ACKNOWLEDGEMENT_TIMEOUT, kinesisSourceConfig.getShardAcknowledgmentTimeout()); + assertThat(kinesisSourceConfig.getAwsAuthenticationConfig(), notNullValue()); + assertEquals(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsRegion(), Region.US_EAST_1); + assertEquals(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsRoleArn(), "arn:aws:iam::123456789012:role/OSI-PipelineRole"); + assertNull(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsExternalId()); + assertNull(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsHeaderOverrides()); + assertNotNull(kinesisSourceConfig.getCodec()); + List streamConfigs = kinesisSourceConfig.getStreams(); + assertEquals(kinesisSourceConfig.getConsumerStrategy(), ConsumerStrategy.ENHANCED_FAN_OUT); + + Map expectedCheckpointIntervals = new HashMap<>(); + expectedCheckpointIntervals.put("stream-1", Duration.ofSeconds(20)); + expectedCheckpointIntervals.put("stream-2", Duration.ofMinutes(15)); + expectedCheckpointIntervals.put("stream-3", Duration.ofHours(2)); + + assertEquals(streamConfigs.size(), 3); + + for (KinesisStreamConfig kinesisStreamConfig: streamConfigs) { + assertTrue(kinesisStreamConfig.getName().contains("stream")); + assertEquals(kinesisStreamConfig.getInitialPosition(), InitialPositionInStream.TRIM_HORIZON); + assertEquals(kinesisStreamConfig.getCheckPointInterval(), expectedCheckpointIntervals.get(kinesisStreamConfig.getName())); + } + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisStreamPollingConfigTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisStreamPollingConfigTest.java new file mode 100644 index 0000000000..02ac1960ed --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/configuration/KinesisStreamPollingConfigTest.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.configuration; + +import org.junit.jupiter.api.Test; + +import java.time.Duration; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class KinesisStreamPollingConfigTest { + private static final int DEFAULT_MAX_RECORDS = 10000; + private static final int IDLE_TIME_BETWEEN_READS_IN_MILLIS = 250; + + @Test + void testConfig() { + KinesisStreamPollingConfig kinesisStreamPollingConfig = new KinesisStreamPollingConfig(); + assertEquals(kinesisStreamPollingConfig.getMaxPollingRecords(), DEFAULT_MAX_RECORDS); + assertEquals(kinesisStreamPollingConfig.getIdleTimeBetweenReads(), Duration.ofMillis(IDLE_TIME_BETWEEN_READS_IN_MILLIS)); + } + +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/converter/KinesisRecordConverterTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/converter/KinesisRecordConverterTest.java new file mode 100644 index 0000000000..6b0646e993 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/converter/KinesisRecordConverterTest.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.converter; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.event.TestEventFactory; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.codec.json.NdjsonInputCodec; +import org.opensearch.dataprepper.plugins.codec.json.NdjsonInputConfig; +import software.amazon.kinesis.retrieval.KinesisClientRecord; + +import java.io.IOException; +import java.io.InputStream; +import java.io.StringWriter; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class KinesisRecordConverterTest { + + @Test + void setup() throws IOException { + InputCodec codec = mock(InputCodec.class); + KinesisRecordConverter kinesisRecordConverter = new KinesisRecordConverter(codec); + doNothing().when(codec).parse(any(InputStream.class), any(Consumer.class)); + + String sample_record_data = "sample record data"; + KinesisClientRecord kinesisClientRecord = KinesisClientRecord.builder() + .data(ByteBuffer.wrap(sample_record_data.getBytes())) + .build(); + kinesisRecordConverter.convert(List.of(kinesisClientRecord)); + verify(codec, times(1)).parse(any(InputStream.class), any(Consumer.class)); + } + + @Test + public void testRecordConverterWithNdJsonInputCodec() throws IOException { + + ObjectMapper objectMapper = new ObjectMapper(); + + int numRecords = 10; + final List> jsonObjects = IntStream.range(0, numRecords) + .mapToObj(i -> generateJson()) + .collect(Collectors.toList()); + + final StringWriter writer = new StringWriter(); + + for (final Map jsonObject : jsonObjects) { + writer.append(objectMapper.writeValueAsString(jsonObject)); + writer.append(System.lineSeparator()); + } + + KinesisRecordConverter kinesisRecordConverter = new KinesisRecordConverter( + new NdjsonInputCodec(new NdjsonInputConfig(), TestEventFactory.getTestEventFactory())); + + KinesisClientRecord kinesisClientRecord = KinesisClientRecord.builder() + .data(ByteBuffer.wrap(writer.toString().getBytes())) + .build(); + List> events = kinesisRecordConverter.convert(List.of(kinesisClientRecord)); + + assertEquals(events.size(), numRecords); + } + + private static Map generateJson() { + final Map jsonObject = new LinkedHashMap<>(); + for (int i = 0; i < 1; i++) { + jsonObject.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + } + jsonObject.put(UUID.randomUUID().toString(), Arrays.asList(UUID.randomUUID().toString(), UUID.randomUUID().toString(), UUID.randomUUID().toString())); + + return jsonObject; + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisCheckpointerRecordTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisCheckpointerRecordTest.java new file mode 100644 index 0000000000..a2cf8fecaf --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisCheckpointerRecordTest.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ +package org.opensearch.dataprepper.plugins.kinesis.source.processor; + +import org.junit.jupiter.api.Test; +import software.amazon.kinesis.processor.RecordProcessorCheckpointer; +import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.Mockito.mock; + +public class KinesisCheckpointerRecordTest { + private String shardId = "shardId-123"; + private String testConcurrencyToken = "testToken"; + + @Test + public void validateTwoRecords() { + + KinesisCheckpointerRecord kinesisCheckpointerRecord1 = KinesisCheckpointerRecord.builder() + .extendedSequenceNumber(ExtendedSequenceNumber.LATEST) + .readyToCheckpoint(false) + .build(); + KinesisCheckpointerRecord kinesisCheckpointerRecord2 = KinesisCheckpointerRecord.builder() + .extendedSequenceNumber(ExtendedSequenceNumber.LATEST) + .readyToCheckpoint(false) + .build(); + + assertEquals(kinesisCheckpointerRecord1.isReadyToCheckpoint(), kinesisCheckpointerRecord2.isReadyToCheckpoint()); + assertEquals(kinesisCheckpointerRecord1.getCheckpointer(), kinesisCheckpointerRecord2.getCheckpointer()); + assertEquals(kinesisCheckpointerRecord1.getExtendedSequenceNumber(), kinesisCheckpointerRecord2.getExtendedSequenceNumber()); + } + + @Test + public void validateTwoRecordsWithSetterMethods() { + RecordProcessorCheckpointer recordProcessorCheckpointer = mock(RecordProcessorCheckpointer.class); + KinesisCheckpointerRecord kinesisCheckpointerRecord1 = KinesisCheckpointerRecord.builder().build(); + kinesisCheckpointerRecord1.setCheckpointer(recordProcessorCheckpointer); + kinesisCheckpointerRecord1.setExtendedSequenceNumber(ExtendedSequenceNumber.LATEST); + kinesisCheckpointerRecord1.setReadyToCheckpoint(false); + + KinesisCheckpointerRecord kinesisCheckpointerRecord2 = KinesisCheckpointerRecord.builder().build(); + kinesisCheckpointerRecord2.setCheckpointer(recordProcessorCheckpointer); + kinesisCheckpointerRecord2.setExtendedSequenceNumber(ExtendedSequenceNumber.LATEST); + kinesisCheckpointerRecord2.setReadyToCheckpoint(false); + + assertEquals(kinesisCheckpointerRecord1.isReadyToCheckpoint(), kinesisCheckpointerRecord2.isReadyToCheckpoint()); + assertEquals(kinesisCheckpointerRecord1.getCheckpointer(), kinesisCheckpointerRecord2.getCheckpointer()); + assertEquals(kinesisCheckpointerRecord1.getExtendedSequenceNumber(), kinesisCheckpointerRecord2.getExtendedSequenceNumber()); + } + + @Test + public void testInvalidRecords() { + KinesisCheckpointerRecord kinesisCheckpointerRecord = KinesisCheckpointerRecord.builder().build(); + assertNotNull(kinesisCheckpointerRecord); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisCheckpointerTrackerTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisCheckpointerTrackerTest.java new file mode 100644 index 0000000000..fe0ab06877 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisCheckpointerTrackerTest.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.processor; + +import org.junit.jupiter.api.Test; +import software.amazon.kinesis.processor.RecordProcessorCheckpointer; +import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +public class KinesisCheckpointerTrackerTest { + + private Random random = new Random(); + + @Test + void testCheckPointerAddAndGet() { + KinesisCheckpointerTracker kinesisCheckpointerTracker = new KinesisCheckpointerTracker(); + + List extendedSequenceNumberList = new ArrayList<>(); + int numRecords = 10; + for (int i=0; i checkpointRecord = kinesisCheckpointerTracker.popLatestReadyToCheckpointRecord(); + assertTrue(checkpointRecord.isEmpty()); + assertEquals(kinesisCheckpointerTracker.size(), numRecords); + + int idx = random.nextInt(numRecords); + ExtendedSequenceNumber extendedSequenceNumber1 = extendedSequenceNumberList.get(idx); + kinesisCheckpointerTracker.markSequenceNumberForCheckpoint(extendedSequenceNumber1); + + Optional firstcheckpointer = kinesisCheckpointerTracker.popLatestReadyToCheckpointRecord(); + if (idx != 0) { + assertTrue(firstcheckpointer.isEmpty()); + assertEquals(kinesisCheckpointerTracker.size(), numRecords); + } else { + assertFalse(firstcheckpointer.isEmpty()); + assertEquals(kinesisCheckpointerTracker.size(), numRecords-1); + } + } + @Test + void testGetLastCheckpointerAndStoreIsEmpty() { + KinesisCheckpointerTracker kinesisCheckpointerTracker = new KinesisCheckpointerTracker(); + + List extendedSequenceNumberList = new ArrayList<>(); + int numRecords = 10; + for (int i=0; i checkpointer = kinesisCheckpointerTracker.popLatestReadyToCheckpointRecord(); + assertTrue(checkpointer.isPresent()); + assertEquals(0, kinesisCheckpointerTracker.size()); + } + + @Test + public void testMarkCheckpointerReadyForCheckpoint() { + + KinesisCheckpointerTracker kinesisCheckpointerTracker = new KinesisCheckpointerTracker(); + + ExtendedSequenceNumber extendedSequenceNumber = mock(ExtendedSequenceNumber.class); + assertThrows(IllegalArgumentException.class, () -> kinesisCheckpointerTracker.markSequenceNumberForCheckpoint(extendedSequenceNumber)); + + Optional checkpointer = kinesisCheckpointerTracker.popLatestReadyToCheckpointRecord(); + assertTrue(checkpointer.isEmpty()); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisRecordProcessorTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisRecordProcessorTest.java new file mode 100644 index 0000000000..ea002e27e9 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisRecordProcessorTest.java @@ -0,0 +1,517 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.processor; + +import io.micrometer.core.instrument.Counter; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.dataprepper.buffer.common.BufferAccumulator; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.EventMetadata; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisStreamConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.converter.KinesisRecordConverter; +import software.amazon.kinesis.common.StreamIdentifier; +import software.amazon.kinesis.exceptions.InvalidStateException; +import software.amazon.kinesis.exceptions.ShutdownException; +import software.amazon.kinesis.exceptions.ThrottlingException; +import software.amazon.kinesis.lifecycle.events.InitializationInput; +import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; +import software.amazon.kinesis.lifecycle.events.ShardEndedInput; +import software.amazon.kinesis.lifecycle.events.ShutdownRequestedInput; +import software.amazon.kinesis.processor.RecordProcessorCheckpointer; +import software.amazon.kinesis.retrieval.KinesisClientRecord; +import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber; + +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.kinesis.source.processor.KinesisRecordProcessor.ACKNOWLEDGEMENT_SET_FAILURES_METRIC_NAME; +import static org.opensearch.dataprepper.plugins.kinesis.source.processor.KinesisRecordProcessor.ACKNOWLEDGEMENT_SET_SUCCESS_METRIC_NAME; +import static org.opensearch.dataprepper.plugins.kinesis.source.processor.KinesisRecordProcessor.KINESIS_CHECKPOINT_FAILURES; +import static org.opensearch.dataprepper.plugins.kinesis.source.processor.KinesisRecordProcessor.KINESIS_RECORD_PROCESSED; +import static org.opensearch.dataprepper.plugins.kinesis.source.processor.KinesisRecordProcessor.KINESIS_RECORD_PROCESSING_ERRORS; +import static org.opensearch.dataprepper.plugins.kinesis.source.processor.KinesisRecordProcessor.KINESIS_STREAM_TAG_KEY; + +public class KinesisRecordProcessorTest { + private KinesisRecordProcessor kinesisRecordProcessor; + private static final String shardId = "123"; + private static final String streamId = "stream-1"; + private static final String codec_plugin_name = "json"; + private static final String sequence_number = "10001"; + private static final Long sub_sequence_number = 1L; + + private static final Duration CHECKPOINT_INTERVAL = Duration.ofMillis(1000); + private static final int NUMBER_OF_RECORDS_TO_ACCUMULATE = 10; + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private PluginFactory pluginFactory; + + @Mock + private KinesisSourceConfig kinesisSourceConfig; + + @Mock + private KinesisStreamConfig kinesisStreamConfig; + + @Mock + private InitializationInput initializationInput; + + @Mock + private ProcessRecordsInput processRecordsInput; + + @Mock + private RecordProcessorCheckpointer checkpointer; + + @Mock + StreamIdentifier streamIdentifier; + + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + + @Mock + private AcknowledgementSet acknowledgementSet; + + @Mock + private Counter recordProcessed; + + @Mock + private Counter recordProcessingErrors; + + @Mock + private Counter checkpointFailures; + + @Mock + private Counter acknowledgementSetSuccesses; + + @Mock + private Counter acknowledgementSetFailures; + + @Mock + private InputCodec codec; + + @Mock + private BufferAccumulator> bufferAccumulator; + + @Mock + private KinesisRecordConverter kinesisRecordConverter; + + @Mock + private KinesisCheckpointerTracker kinesisCheckpointerTracker; + + @BeforeEach + public void setup() { + MockitoAnnotations.initMocks(this); + pluginMetrics = mock(PluginMetrics.class); + pluginFactory = mock(PluginFactory.class); + acknowledgementSet = mock(AcknowledgementSet.class); + bufferAccumulator = mock(BufferAccumulator.class); + kinesisRecordConverter = mock(KinesisRecordConverter.class); + kinesisCheckpointerTracker = mock(KinesisCheckpointerTracker.class); + + when(initializationInput.shardId()).thenReturn(shardId); + when(streamIdentifier.streamName()).thenReturn(streamId); + when(kinesisStreamConfig.getName()).thenReturn(streamId); + PluginModel pluginModel = mock(PluginModel.class); + when(pluginModel.getPluginName()).thenReturn(codec_plugin_name); + when(pluginModel.getPluginSettings()).thenReturn(Collections.emptyMap()); + when(kinesisSourceConfig.getCodec()).thenReturn(pluginModel); + + codec = mock(InputCodec.class); + when(pluginFactory.loadPlugin(eq(InputCodec.class), any())).thenReturn(codec); + when(kinesisStreamConfig.getCheckPointInterval()).thenReturn(CHECKPOINT_INTERVAL); + when(kinesisSourceConfig.getNumberOfRecordsToAccumulate()).thenReturn(NUMBER_OF_RECORDS_TO_ACCUMULATE); + when(kinesisSourceConfig.getStreams()).thenReturn(List.of(kinesisStreamConfig)); + when(processRecordsInput.checkpointer()).thenReturn(checkpointer); + when(pluginMetrics.counterWithTags(ACKNOWLEDGEMENT_SET_SUCCESS_METRIC_NAME, KINESIS_STREAM_TAG_KEY, + streamIdentifier.streamName())).thenReturn(acknowledgementSetSuccesses); + when(pluginMetrics.counterWithTags(ACKNOWLEDGEMENT_SET_FAILURES_METRIC_NAME, KINESIS_STREAM_TAG_KEY, + streamIdentifier.streamName())).thenReturn(acknowledgementSetFailures); + + recordProcessed = mock(Counter.class); + when(pluginMetrics.counterWithTags(KINESIS_RECORD_PROCESSED, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName())).thenReturn(recordProcessed); + + recordProcessingErrors = mock(Counter.class); + when(pluginMetrics.counterWithTags(KINESIS_RECORD_PROCESSING_ERRORS, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName())).thenReturn(recordProcessingErrors); + } + + @Test + void testProcessRecordsWithoutAcknowledgementsWithCheckpointApplied() + throws Exception { + List kinesisClientRecords = createInputKinesisClientRecords(); + when(processRecordsInput.records()).thenReturn(kinesisClientRecords); + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(false); + when(kinesisStreamConfig.getCheckPointInterval()).thenReturn(Duration.ofMillis(0)); + when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); + + List> records = new ArrayList<>(); + Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); + Record record = new Record<>(event); + records.add(record); + when(kinesisRecordConverter.convert(eq(kinesisClientRecords))).thenReturn(records); + + kinesisRecordProcessor = new KinesisRecordProcessor(bufferAccumulator, kinesisSourceConfig, + acknowledgementSetManager, pluginMetrics, kinesisRecordConverter, kinesisCheckpointerTracker, streamIdentifier); + KinesisCheckpointerRecord kinesisCheckpointerRecord = mock(KinesisCheckpointerRecord.class); + ExtendedSequenceNumber extendedSequenceNumber = mock(ExtendedSequenceNumber.class); + when(extendedSequenceNumber.sequenceNumber()).thenReturn(sequence_number); + when(extendedSequenceNumber.subSequenceNumber()).thenReturn(sub_sequence_number); + when(kinesisCheckpointerRecord.getExtendedSequenceNumber()).thenReturn(extendedSequenceNumber); + when(kinesisCheckpointerRecord.getCheckpointer()).thenReturn(checkpointer); + when(kinesisCheckpointerTracker.popLatestReadyToCheckpointRecord()).thenReturn(Optional.of(kinesisCheckpointerRecord)); + kinesisRecordProcessor.initialize(initializationInput); + + kinesisRecordProcessor.processRecords(processRecordsInput); + + verify(checkpointer).checkpoint(eq(sequence_number), eq(sub_sequence_number)); + + final ArgumentCaptor> recordArgumentCaptor = ArgumentCaptor.forClass(Record.class); + + verify(bufferAccumulator).add(recordArgumentCaptor.capture()); + verify(bufferAccumulator).flush(); + + List> recordsCaptured = recordArgumentCaptor.getAllValues(); + assertEquals(recordsCaptured.size(), records.size()); + for (Record eventRecord: recordsCaptured) { + EventMetadata eventMetadata = eventRecord.getData().getMetadata(); + assertEquals(eventMetadata.getAttribute(MetadataKeyAttributes.KINESIS_STREAM_NAME_METADATA_ATTRIBUTE), streamIdentifier.streamName()); + } + + verify(acknowledgementSetManager, times(0)).create(any(), any(Duration.class)); + verify(recordProcessed, times(1)).increment(anyDouble()); + } + + @Test + public void testProcessRecordsWithoutAcknowledgementsEnabled() + throws Exception { + List kinesisClientRecords = createInputKinesisClientRecords(); + when(processRecordsInput.records()).thenReturn(kinesisClientRecords); + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(false); + when(kinesisStreamConfig.getCheckPointInterval()).thenReturn(Duration.ofMillis(0)); + when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); + + List> records = new ArrayList<>(); + Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); + Record record = new Record<>(event); + records.add(record); + when(kinesisRecordConverter.convert(eq(kinesisClientRecords))).thenReturn(records); + + kinesisRecordProcessor = new KinesisRecordProcessor(bufferAccumulator, kinesisSourceConfig, + acknowledgementSetManager, pluginMetrics, kinesisRecordConverter, kinesisCheckpointerTracker, streamIdentifier); + when(kinesisCheckpointerTracker.popLatestReadyToCheckpointRecord()).thenReturn(Optional.empty()); + kinesisRecordProcessor.initialize(initializationInput); + + kinesisRecordProcessor.processRecords(processRecordsInput); + + verifyNoInteractions(checkpointer); + + final ArgumentCaptor> recordArgumentCaptor = ArgumentCaptor.forClass(Record.class); + + verify(bufferAccumulator).add(recordArgumentCaptor.capture()); + verify(bufferAccumulator).flush(); + + List> recordsCaptured = recordArgumentCaptor.getAllValues(); + assertEquals(recordsCaptured.size(), records.size()); + for (Record eventRecord: recordsCaptured) { + EventMetadata eventMetadata = eventRecord.getData().getMetadata(); + assertEquals(eventMetadata.getAttribute(MetadataKeyAttributes.KINESIS_STREAM_NAME_METADATA_ATTRIBUTE), streamIdentifier.streamName()); + } + + verify(acknowledgementSetManager, times(0)).create(any(), any(Duration.class)); + verify(recordProcessed, times(1)).increment(anyDouble()); + } + + @Test + void testProcessRecordsWithAcknowledgementsEnabled() + throws Exception { + List kinesisClientRecords = createInputKinesisClientRecords(); + when(processRecordsInput.records()).thenReturn(kinesisClientRecords); + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(true); + when(kinesisStreamConfig.getCheckPointInterval()).thenReturn(Duration.ofMillis(0)); + AtomicReference numEventsAdded = new AtomicReference<>(0); + doAnswer(a -> { + numEventsAdded.getAndSet(numEventsAdded.get() + 1); + return null; + }).when(acknowledgementSet).add(any()); + + doAnswer(invocation -> { + Consumer consumer = invocation.getArgument(0); + consumer.accept(true); + return acknowledgementSet; + }).when(acknowledgementSetManager).create(any(Consumer.class), any(Duration.class)); + + List> records = new ArrayList<>(); + Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); + Record record = new Record<>(event); + records.add(record); + when(kinesisRecordConverter.convert(eq(kinesisClientRecords))).thenReturn(records); + + kinesisRecordProcessor = new KinesisRecordProcessor(bufferAccumulator, kinesisSourceConfig, + acknowledgementSetManager, pluginMetrics, kinesisRecordConverter, kinesisCheckpointerTracker, streamIdentifier); + KinesisCheckpointerRecord kinesisCheckpointerRecord = mock(KinesisCheckpointerRecord.class); + ExtendedSequenceNumber extendedSequenceNumber = mock(ExtendedSequenceNumber.class); + when(extendedSequenceNumber.sequenceNumber()).thenReturn(sequence_number); + when(extendedSequenceNumber.subSequenceNumber()).thenReturn(sub_sequence_number); + when(kinesisCheckpointerRecord.getExtendedSequenceNumber()).thenReturn(extendedSequenceNumber); + when(kinesisCheckpointerRecord.getCheckpointer()).thenReturn(checkpointer); + when(kinesisCheckpointerTracker.popLatestReadyToCheckpointRecord()).thenReturn(Optional.of(kinesisCheckpointerRecord)); + kinesisRecordProcessor.initialize(initializationInput); + + kinesisRecordProcessor.processRecords(processRecordsInput); + + final ArgumentCaptor> recordArgumentCaptor = ArgumentCaptor.forClass(Record.class); + + verify(bufferAccumulator).add(recordArgumentCaptor.capture()); + verify(bufferAccumulator).flush(); + + List> recordsCaptured = recordArgumentCaptor.getAllValues(); + assertEquals(recordsCaptured.size(), records.size()); + for (Record eventRecord: recordsCaptured) { + EventMetadata eventMetadata = eventRecord.getData().getMetadata(); + assertEquals(eventMetadata.getAttribute(MetadataKeyAttributes.KINESIS_STREAM_NAME_METADATA_ATTRIBUTE), streamIdentifier.streamName()); + } + verify(acknowledgementSetManager, times(1)).create(any(), any(Duration.class)); + verify(acknowledgementSetSuccesses, atLeastOnce()).increment(); + verify(recordProcessed, times(1)).increment(anyDouble()); + verifyNoInteractions(recordProcessingErrors); + } + + @Test + void testProcessRecordsWithNDJsonInputCodec() + throws Exception { + List kinesisClientRecords = createInputKinesisClientRecords(); + when(processRecordsInput.records()).thenReturn(kinesisClientRecords); + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(false); + when(kinesisStreamConfig.getCheckPointInterval()).thenReturn(Duration.ofMillis(0)); + + PluginModel pluginModel = mock(PluginModel.class); + when(pluginModel.getPluginName()).thenReturn("ndjson"); + when(pluginModel.getPluginSettings()).thenReturn(Collections.emptyMap()); + when(kinesisSourceConfig.getCodec()).thenReturn(pluginModel); + + InputCodec codec = mock(InputCodec.class); + when(pluginFactory.loadPlugin(eq(InputCodec.class), any())).thenReturn(codec); + + when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); + + List> records = new ArrayList<>(); + Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); + Record record = new Record<>(event); + records.add(record); + when(kinesisRecordConverter.convert(eq(kinesisClientRecords))).thenReturn(records); + + kinesisRecordProcessor = new KinesisRecordProcessor(bufferAccumulator, kinesisSourceConfig, + acknowledgementSetManager, pluginMetrics, kinesisRecordConverter, kinesisCheckpointerTracker, streamIdentifier); + KinesisCheckpointerRecord kinesisCheckpointerRecord = mock(KinesisCheckpointerRecord.class); + ExtendedSequenceNumber extendedSequenceNumber = mock(ExtendedSequenceNumber.class); + when(extendedSequenceNumber.sequenceNumber()).thenReturn(sequence_number); + when(extendedSequenceNumber.subSequenceNumber()).thenReturn(sub_sequence_number); + when(kinesisCheckpointerRecord.getCheckpointer()).thenReturn(checkpointer); + when(kinesisCheckpointerRecord.getExtendedSequenceNumber()).thenReturn(extendedSequenceNumber); + when(kinesisCheckpointerTracker.popLatestReadyToCheckpointRecord()).thenReturn(Optional.of(kinesisCheckpointerRecord)); + kinesisRecordProcessor.initialize(initializationInput); + + kinesisRecordProcessor.processRecords(processRecordsInput); + + verify(checkpointer).checkpoint(eq(sequence_number), eq(sub_sequence_number)); + final ArgumentCaptor> recordArgumentCaptor = ArgumentCaptor.forClass(Record.class); + + verify(bufferAccumulator).add(recordArgumentCaptor.capture()); + + List> recordsCaptured = recordArgumentCaptor.getAllValues(); + assertEquals(recordsCaptured.size(), records.size()); + for (Record eventRecord: recordsCaptured) { + EventMetadata eventMetadata = eventRecord.getData().getMetadata(); + assertEquals(eventMetadata.getAttribute(MetadataKeyAttributes.KINESIS_STREAM_NAME_METADATA_ATTRIBUTE), streamIdentifier.streamName()); + } + + verify(acknowledgementSetManager, times(0)).create(any(), any(Duration.class)); + verify(recordProcessed, times(1)).increment(anyDouble()); + } + + @Test + void testProcessRecordsNoThrowException() + throws Exception { + List kinesisClientRecords = createInputKinesisClientRecords(); + when(processRecordsInput.records()).thenReturn(kinesisClientRecords); + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(false); + + List> records = new ArrayList<>(); + Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); + Record record = new Record<>(event); + records.add(record); + when(kinesisRecordConverter.convert(eq(kinesisClientRecords))).thenReturn(records); + final Throwable exception = mock(RuntimeException.class); + doThrow(exception).when(bufferAccumulator).add(any(Record.class)); + + kinesisRecordProcessor = new KinesisRecordProcessor(bufferAccumulator, kinesisSourceConfig, + acknowledgementSetManager, pluginMetrics, kinesisRecordConverter, kinesisCheckpointerTracker, streamIdentifier); + kinesisRecordProcessor.initialize(initializationInput); + + assertDoesNotThrow(() -> kinesisRecordProcessor.processRecords(processRecordsInput)); + verify(recordProcessingErrors, times(1)).increment(); + verify(recordProcessed, times(0)).increment(anyDouble()); + } + + @Test + void testProcessRecordsBufferFlushNoThrowException() + throws Exception { + List kinesisClientRecords = createInputKinesisClientRecords(); + when(processRecordsInput.records()).thenReturn(kinesisClientRecords); + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(false); + + List> records = new ArrayList<>(); + Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); + Record record = new Record<>(event); + records.add(record); + when(kinesisRecordConverter.convert(eq(kinesisClientRecords))).thenReturn(records); + final Throwable exception = mock(RuntimeException.class); + doThrow(exception).when(bufferAccumulator).flush(); + + kinesisRecordProcessor = new KinesisRecordProcessor(bufferAccumulator, kinesisSourceConfig, + acknowledgementSetManager, pluginMetrics, kinesisRecordConverter, kinesisCheckpointerTracker, streamIdentifier); + kinesisRecordProcessor.initialize(initializationInput); + + assertDoesNotThrow(() -> kinesisRecordProcessor.processRecords(processRecordsInput)); + verify(recordProcessingErrors, times(1)).increment(); + verify(recordProcessed, times(0)).increment(anyDouble()); + + } + + @Test + void testShardEndedLatestCheckpoint() { + KinesisRecordProcessor mockKinesisRecordProcessor = new KinesisRecordProcessor(bufferAccumulator, kinesisSourceConfig, + acknowledgementSetManager, pluginMetrics, kinesisRecordConverter, kinesisCheckpointerTracker, streamIdentifier); + ShardEndedInput shardEndedInput = mock(ShardEndedInput.class); + when(shardEndedInput.checkpointer()).thenReturn(checkpointer); + + mockKinesisRecordProcessor.shardEnded(shardEndedInput); + + verify(shardEndedInput).checkpointer(); + } + + @ParameterizedTest + @ValueSource(classes = {ShutdownException.class, ThrottlingException.class, InvalidStateException.class}) + void testShardEndedCheckpointerThrowsNoThrowException(final Class exceptionType) throws Exception { + checkpointFailures = mock(Counter.class); + when(pluginMetrics.counterWithTags(KINESIS_CHECKPOINT_FAILURES, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName())).thenReturn(checkpointFailures); + + KinesisRecordProcessor mockKinesisRecordProcessor = new KinesisRecordProcessor(bufferAccumulator, kinesisSourceConfig, + acknowledgementSetManager, pluginMetrics, kinesisRecordConverter, kinesisCheckpointerTracker, streamIdentifier); + ShardEndedInput shardEndedInput = mock(ShardEndedInput.class); + when(shardEndedInput.checkpointer()).thenReturn(checkpointer); + doThrow(exceptionType).when(checkpointer).checkpoint(); + + assertDoesNotThrow(() -> mockKinesisRecordProcessor.shardEnded(shardEndedInput)); + + verify(checkpointer).checkpoint(); + verify(shardEndedInput, times(1)).checkpointer(); + verify(checkpointFailures, times(1)).increment(); + } + + @Test + void testShutdownRequestedWithLatestCheckpoint() { + checkpointFailures = mock(Counter.class); + when(pluginMetrics.counterWithTags(KINESIS_CHECKPOINT_FAILURES, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName())).thenReturn(checkpointFailures); + + KinesisRecordProcessor mockKinesisRecordProcessor = new KinesisRecordProcessor(bufferAccumulator, kinesisSourceConfig, + acknowledgementSetManager, pluginMetrics, kinesisRecordConverter, kinesisCheckpointerTracker, streamIdentifier); + ShutdownRequestedInput shutdownRequestedInput = mock(ShutdownRequestedInput.class); + when(shutdownRequestedInput.checkpointer()).thenReturn(checkpointer); + + mockKinesisRecordProcessor.shutdownRequested(shutdownRequestedInput); + + verify(shutdownRequestedInput).checkpointer(); + verify(checkpointFailures, times(0)).increment(); + } + + @ParameterizedTest + @ValueSource(classes = {ShutdownException.class, ThrottlingException.class, InvalidStateException.class}) + void testShutdownRequestedCheckpointerThrowsNoThrowException(final Class exceptionType) throws Exception { + checkpointFailures = mock(Counter.class); + when(pluginMetrics.counterWithTags(KINESIS_CHECKPOINT_FAILURES, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName())).thenReturn(checkpointFailures); + + KinesisRecordProcessor mockKinesisRecordProcessor = new KinesisRecordProcessor(bufferAccumulator, kinesisSourceConfig, + acknowledgementSetManager, pluginMetrics, kinesisRecordConverter, kinesisCheckpointerTracker, streamIdentifier); + doThrow(exceptionType).when(checkpointer).checkpoint(eq(sequence_number), eq(sub_sequence_number)); + + assertDoesNotThrow(() -> mockKinesisRecordProcessor.checkpoint(checkpointer, sequence_number, sub_sequence_number)); + + verify(checkpointer).checkpoint(eq(sequence_number), eq(sub_sequence_number)); + verify(checkpointFailures, times(1)).increment(); + } + + @ParameterizedTest + @ValueSource(classes = {ShutdownException.class, ThrottlingException.class, InvalidStateException.class}) + void testShutdownRequestedCheckpointerThrowsNoThrowExceptionRegularCheckpoint(final Class exceptionType) throws Exception { + checkpointFailures = mock(Counter.class); + when(pluginMetrics.counterWithTags(KINESIS_CHECKPOINT_FAILURES, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName())).thenReturn(checkpointFailures); + + KinesisRecordProcessor mockKinesisRecordProcessor = new KinesisRecordProcessor(bufferAccumulator, kinesisSourceConfig, + acknowledgementSetManager, pluginMetrics, kinesisRecordConverter, kinesisCheckpointerTracker, streamIdentifier); + ShutdownRequestedInput shutdownRequestedInput = mock(ShutdownRequestedInput.class); + when(shutdownRequestedInput.checkpointer()).thenReturn(checkpointer); + doThrow(exceptionType).when(checkpointer).checkpoint(); + + assertDoesNotThrow(() -> mockKinesisRecordProcessor.shutdownRequested(shutdownRequestedInput)); + + verify(checkpointer).checkpoint(); + verify(shutdownRequestedInput, times(1)).checkpointer(); + verify(checkpointFailures, times(1)).increment(); + } + + private List createInputKinesisClientRecords() { + List kinesisClientRecords = new ArrayList<>(); + for (int i = 0; i< KinesisRecordProcessorTest.NUMBER_OF_RECORDS_TO_ACCUMULATE; i++) { + Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); + KinesisClientRecord kinesisClientRecord = KinesisClientRecord.builder() + .data(ByteBuffer.wrap(event.toJsonString().getBytes())) + .sequenceNumber(Integer.toString(100 + i)).subSequenceNumber(i).build(); + kinesisClientRecords.add(kinesisClientRecord); + } + return kinesisClientRecords; + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisShardRecordProcessorFactoryTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisShardRecordProcessorFactoryTest.java new file mode 100644 index 0000000000..9f0a555253 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/kinesis/source/processor/KinesisShardRecordProcessorFactoryTest.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.kinesis.source.processor; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.kinesis.source.configuration.KinesisStreamConfig; +import software.amazon.kinesis.common.StreamIdentifier; + +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KinesisShardRecordProcessorFactoryTest { + private KinesisShardRecordProcessorFactory kinesisShardRecordProcessorFactory; + + private static final String streamId = "stream-1"; + private static final String codec_plugin_name = "json"; + + @Mock + private Buffer> buffer; + + @Mock + StreamIdentifier streamIdentifier; + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private PluginFactory pluginFactory; + + @Mock + private KinesisSourceConfig kinesisSourceConfig; + + @Mock + private KinesisStreamConfig kinesisStreamConfig; + + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + + @Mock + private InputCodec codec; + + @BeforeEach + void setup() { + MockitoAnnotations.initMocks(this); + + PluginModel pluginModel = mock(PluginModel.class); + when(pluginModel.getPluginName()).thenReturn(codec_plugin_name); + when(pluginModel.getPluginSettings()).thenReturn(Collections.emptyMap()); + when(kinesisSourceConfig.getCodec()).thenReturn(pluginModel); + when(kinesisSourceConfig.getNumberOfRecordsToAccumulate()).thenReturn(100); + + codec = mock(InputCodec.class); + when(pluginFactory.loadPlugin(eq(InputCodec.class), any())).thenReturn(codec); + + when(streamIdentifier.streamName()).thenReturn(streamId); + when(kinesisStreamConfig.getName()).thenReturn(streamId); + when(kinesisSourceConfig.getStreams()).thenReturn(List.of(kinesisStreamConfig)); + } + + @Test + void testKinesisRecordProcessFactoryReturnsKinesisRecordProcessor() { + kinesisShardRecordProcessorFactory = new KinesisShardRecordProcessorFactory(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, codec); + assertInstanceOf(KinesisRecordProcessor.class, kinesisShardRecordProcessorFactory.shardRecordProcessor(streamIdentifier)); + } + + @Test + void testKinesisRecordProcessFactoryDefaultUnsupported() { + kinesisShardRecordProcessorFactory = new KinesisShardRecordProcessorFactory(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, codec); + assertThrows(UnsupportedOperationException.class, () -> kinesisShardRecordProcessorFactory.shardRecordProcessor()); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/resources/pipeline_with_acks_enabled.yaml b/data-prepper-plugins/kinesis-source/src/test/resources/pipeline_with_acks_enabled.yaml new file mode 100644 index 0000000000..e5260372f5 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/resources/pipeline_with_acks_enabled.yaml @@ -0,0 +1,12 @@ +source: + kinesis: + streams: + - stream_name: "stream-1" + - stream_name: "stream-2" + - stream_name: "stream-3" + codec: + ndjson: + aws: + sts_role_arn: "arn:aws:iam::123456789012:role/OSI-PipelineRole" + region: "us-east-1" + acknowledgments: true \ No newline at end of file diff --git a/data-prepper-plugins/kinesis-source/src/test/resources/pipeline_with_checkpoint_enabled.yaml b/data-prepper-plugins/kinesis-source/src/test/resources/pipeline_with_checkpoint_enabled.yaml new file mode 100644 index 0000000000..c8b58725fd --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/resources/pipeline_with_checkpoint_enabled.yaml @@ -0,0 +1,17 @@ +source: + kinesis: + streams: + - stream_name: "stream-1" + initial_position: "EARLIEST" + checkpoint_interval: "20s" + - stream_name: "stream-2" + initial_position: "EARLIEST" + checkpoint_interval: "PT15M" + - stream_name: "stream-3" + initial_position: "EARLIEST" + checkpoint_interval: "PT2H" + codec: + ndjson: + aws: + sts_role_arn: "arn:aws:iam::123456789012:role/OSI-PipelineRole" + region: "us-east-1" \ No newline at end of file diff --git a/data-prepper-plugins/kinesis-source/src/test/resources/pipeline_with_polling_config_enabled.yaml b/data-prepper-plugins/kinesis-source/src/test/resources/pipeline_with_polling_config_enabled.yaml new file mode 100644 index 0000000000..4a3156ec2a --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/resources/pipeline_with_polling_config_enabled.yaml @@ -0,0 +1,15 @@ +source: + kinesis: + streams: + - stream_name: "stream-1" + - stream_name: "stream-2" + - stream_name: "stream-3" + codec: + ndjson: + aws: + sts_role_arn: "arn:aws:iam::123456789012:role/OSI-PipelineRole" + region: "us-east-1" + consumer_strategy: "polling" + polling: + max_polling_records: 10 + idle_time_between_reads: 10s \ No newline at end of file diff --git a/data-prepper-plugins/kinesis-source/src/test/resources/simple_pipeline_with_extensions.yaml b/data-prepper-plugins/kinesis-source/src/test/resources/simple_pipeline_with_extensions.yaml new file mode 100644 index 0000000000..4f964cae7f --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/resources/simple_pipeline_with_extensions.yaml @@ -0,0 +1,5 @@ +extensions: + kinesis: + lease_coordination: + table_name: "kinesis-pipeline-kcl" + region: "us-east-1" \ No newline at end of file diff --git a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/DocumentDBSource.java b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/DocumentDBSource.java index b6ff1fbdf1..6a8e82a8a9 100644 --- a/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/DocumentDBSource.java +++ b/data-prepper-plugins/mongodb/src/main/java/org/opensearch/dataprepper/plugins/mongo/documentdb/DocumentDBSource.java @@ -23,7 +23,7 @@ import java.util.function.Function; -@DataPrepperPlugin(name = "documentdb", pluginType = Source.class, pluginConfigurationType = MongoDBSourceConfig.class) +@DataPrepperPlugin(name = "documentdb", alternateNames = "mongodb", pluginType = Source.class, pluginConfigurationType = MongoDBSourceConfig.class) public class DocumentDBSource implements Source>, UsesEnhancedSourceCoordination { private static final Logger LOG = LoggerFactory.getLogger(DocumentDBSource.class); diff --git a/data-prepper-plugins/mongodb/src/main/resources/org/opensearch/dataprepper/transforms/rules/mongodb-rule.yaml b/data-prepper-plugins/mongodb/src/main/resources/org/opensearch/dataprepper/transforms/rules/mongodb-rule.yaml new file mode 100644 index 0000000000..33cc703072 --- /dev/null +++ b/data-prepper-plugins/mongodb/src/main/resources/org/opensearch/dataprepper/transforms/rules/mongodb-rule.yaml @@ -0,0 +1,4 @@ +plugin_name: "mongodb" +apply_when: + - "$..source.mongodb" + - "$..source.mongodb.s3_bucket" \ No newline at end of file diff --git a/data-prepper-plugins/mongodb/src/main/resources/org/opensearch/dataprepper/transforms/templates/mongodb-template.yaml b/data-prepper-plugins/mongodb/src/main/resources/org/opensearch/dataprepper/transforms/templates/mongodb-template.yaml new file mode 100644 index 0000000000..88208e631f --- /dev/null +++ b/data-prepper-plugins/mongodb/src/main/resources/org/opensearch/dataprepper/transforms/templates/mongodb-template.yaml @@ -0,0 +1,81 @@ +"<>": + workers: "<<$.<>.workers>>" + delay: "<<$.<>.delay>>" + buffer: "<<$.<>.buffer>>" + source: + mongodb: "<<$.<>.source.mongodb>>" + routes: + - initial_load: 'getMetadata("ingestion_type") == "EXPORT"' + - stream_load: 'getMetadata("ingestion_type") == "STREAM"' + sink: + - s3: + routes: + - initial_load + aws: + region: "<<$.<>.source.mongodb.s3_region>>" + sts_role_arn: "<<$.<>.source.mongodb.aws.sts_role_arn>>" + sts_external_id: "<<$.<>.source.mongodb.aws.sts_external_id>>" + sts_header_overrides: "<<$.<>.source.mongodb.aws.sts_header_overrides>>" + bucket: "<<$.<>.source.mongodb.s3_bucket>>" + threshold: + event_collect_timeout: "120s" + maximum_size: "2mb" + aggregate_threshold: + maximum_size: "128mb" + flush_capacity_ratio: 0 + object_key: + path_prefix: "${getMetadata(\"s3_partition_key\")}" + codec: + event_json: + default_bucket_owner: "<>.source.mongodb.aws.sts_role_arn>>" + - s3: + routes: + - stream_load + aws: + region: "<<$.<>.source.mongodb.s3_region>>" + sts_role_arn: "<<$.<>.source.mongodb.aws.sts_role_arn>>" + sts_external_id: "<<$.<>.source.mongodb.aws.sts_external_id>>" + sts_header_overrides: "<<$.<>.source.mongodb.aws.sts_header_overrides>>" + bucket: "<<$.<>.source.mongodb.s3_bucket>>" + threshold: + event_collect_timeout: "15s" + maximum_size: "1mb" + aggregate_threshold: + maximum_size: "128mb" + flush_capacity_ratio: 0 + object_key: + path_prefix: "${getMetadata(\"s3_partition_key\")}" + codec: + event_json: + default_bucket_owner: "<>.source.mongodb.aws.sts_role_arn>>" +"<>-s3": + workers: "<<$.<>.workers>>" + delay: "<<$.<>.delay>>" + buffer: "<<$.<>.buffer>>" + source: + s3: + codec: + event_json: + compression: "none" + aws: + region: "<<$.<>.source.mongodb.s3_region>>" + sts_role_arn: "<<$.<>.source.mongodb.aws.sts_role_arn>>" + sts_external_id: "<<$.<>.source.mongodb.aws.sts_external_id>>" + sts_header_overrides: "<<$.<>.source.mongodb.aws.sts_header_overrides>>" + acknowledgments: true + delete_s3_objects_on_read: true + disable_s3_metadata_in_event: true + scan: + folder_partitions: + depth: "<>.source.mongodb.s3_prefix>>" + max_objects_per_ownership: 50 + buckets: + - bucket: + name: "<<$.<>.source.mongodb.s3_bucket>>" + filter: + include_prefix: ["<>.source.mongodb.s3_prefix>>"] + scheduling: + interval: "60s" + processor: "<<$.<>.processor>>" + sink: "<<$.<>.sink>>" + routes: "<<$.<>.routes>>" # In placeholder, routes or route (defined as alias) will be transformed to route in json as route will be primarily picked in pipelineModel. \ No newline at end of file diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java index cd12711f27..fbe1aae36f 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java @@ -146,6 +146,7 @@ public void onEvent(com.github.shyiko.mysql.binlog.event.Event event) { public void stopClient() { try { binaryLogClient.disconnect(); + binaryLogClient.unregisterEventListener(this); binlogEventExecutorService.shutdownNow(); LOG.info("Binary log client disconnected."); } catch (Exception e) { diff --git a/data-prepper-plugins/rds-source/src/main/resources/org/opensearch/dataprepper/transforms/templates/rds-template.yaml b/data-prepper-plugins/rds-source/src/main/resources/org/opensearch/dataprepper/transforms/templates/rds-template.yaml index b439068cad..cbd0fec44e 100644 --- a/data-prepper-plugins/rds-source/src/main/resources/org/opensearch/dataprepper/transforms/templates/rds-template.yaml +++ b/data-prepper-plugins/rds-source/src/main/resources/org/opensearch/dataprepper/transforms/templates/rds-template.yaml @@ -48,6 +48,9 @@ codec: event_json: default_bucket_owner: "<>.source.rds.aws.sts_role_arn>>" + client: + max_connections: 100 + acquire_timeout: 20s "<>-s3": workers: "<<$.<>.workers>>" delay: "<<$.<>.delay>>" diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListenerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListenerTest.java index 27f3fa9037..1312607821 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListenerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListenerTest.java @@ -16,6 +16,7 @@ import org.junit.jupiter.params.provider.EnumSource; import org.mockito.Answers; import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; @@ -26,6 +27,7 @@ import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; +import java.io.IOException; import java.util.UUID; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -34,6 +36,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -103,6 +106,16 @@ void test_given_TableMap_event_then_calls_correct_handler() { verify(objectUnderTest).handleTableMapEvent(binlogEvent); } + @Test + void test_stopClient() throws IOException { + objectUnderTest.stopClient(); + + InOrder inOrder = inOrder(binaryLogClient, eventListnerExecutorService); + inOrder.verify(binaryLogClient).disconnect(); + inOrder.verify(binaryLogClient).unregisterEventListener(objectUnderTest); + inOrder.verify(eventListnerExecutorService).shutdownNow(); + } + @ParameterizedTest @EnumSource(names = {"WRITE_ROWS", "EXT_WRITE_ROWS"}) void test_given_WriteRows_event_then_calls_correct_handler(EventType eventType) { diff --git a/data-prepper-plugins/s3-sink/build.gradle b/data-prepper-plugins/s3-sink/build.gradle index 57198bf274..f102974ed4 100644 --- a/data-prepper-plugins/s3-sink/build.gradle +++ b/data-prepper-plugins/s3-sink/build.gradle @@ -14,6 +14,7 @@ dependencies { implementation 'joda-time:joda-time:2.12.7' implementation 'org.hibernate.validator:hibernate-validator:8.0.1.Final' implementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-csv' + implementation 'software.amazon.awssdk:netty-nio-client' implementation 'software.amazon.awssdk:s3' implementation 'software.amazon.awssdk:sts' implementation 'software.amazon.awssdk:securitylake:2.26.18' diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/ClientFactory.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/ClientFactory.java index 910f3966cc..f647057af4 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/ClientFactory.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/ClientFactory.java @@ -8,10 +8,14 @@ import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.plugins.sink.s3.configuration.AwsAuthenticationOptions; +import org.opensearch.dataprepper.plugins.sink.s3.configuration.ClientOptions; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; import software.amazon.awssdk.services.s3.S3Client; public final class ClientFactory { @@ -31,10 +35,21 @@ static S3AsyncClient createS3AsyncClient(final S3SinkConfig s3SinkConfig, final final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(s3SinkConfig.getAwsAuthenticationOptions()); final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions); - return S3AsyncClient.builder() + S3AsyncClientBuilder s3AsyncClientBuilder = S3AsyncClient.builder() .region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion()) .credentialsProvider(awsCredentialsProvider) - .overrideConfiguration(createOverrideConfiguration(s3SinkConfig)).build(); + .overrideConfiguration(createOverrideConfiguration(s3SinkConfig)); + + if (s3SinkConfig.getClientOptions() != null) { + final ClientOptions clientOptions = s3SinkConfig.getClientOptions(); + SdkAsyncHttpClient httpClient = NettyNioAsyncHttpClient.builder() + .connectionAcquisitionTimeout(clientOptions.getAcquireTimeout()) + .maxConcurrency(clientOptions.getMaxConnections()) + .build(); + s3AsyncClientBuilder.httpClient(httpClient); + } + + return s3AsyncClientBuilder.build(); } private static ClientOverrideConfiguration createOverrideConfiguration(final S3SinkConfig s3SinkConfig) { diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkConfig.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkConfig.java index 71e523e5f6..9e690d739a 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkConfig.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkConfig.java @@ -16,6 +16,7 @@ import org.opensearch.dataprepper.plugins.sink.s3.compression.CompressionOption; import org.opensearch.dataprepper.plugins.sink.s3.configuration.AggregateThresholdOptions; import org.opensearch.dataprepper.plugins.sink.s3.configuration.AwsAuthenticationOptions; +import org.opensearch.dataprepper.plugins.sink.s3.configuration.ClientOptions; import org.opensearch.dataprepper.plugins.sink.s3.configuration.ObjectKeyOptions; import org.opensearch.dataprepper.plugins.sink.s3.configuration.ThresholdOptions; @@ -95,6 +96,9 @@ private boolean isValidBucketConfig() { @AwsAccountId private String defaultBucketOwner; + @JsonProperty("client") + private ClientOptions clientOptions; + /** * Aws Authentication configuration Options. * @return aws authentication options. @@ -195,4 +199,8 @@ public Map getBucketOwners() { public String getDefaultBucketOwner() { return defaultBucketOwner; } + + public ClientOptions getClientOptions() { + return clientOptions; + } } diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/configuration/ClientOptions.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/configuration/ClientOptions.java new file mode 100644 index 0000000000..a5e83948a5 --- /dev/null +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/configuration/ClientOptions.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.s3.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; +import org.hibernate.validator.constraints.time.DurationMax; +import org.hibernate.validator.constraints.time.DurationMin; + +import java.time.Duration; + +public class ClientOptions { + private static final int DEFAULT_MAX_CONNECTIONS = 50; + private static final Duration DEFAULT_ACQUIRE_TIMEOUT = Duration.ofSeconds(10); + + @JsonProperty("max_connections") + @Min(1) + @Max(5000) + private int maxConnections = DEFAULT_MAX_CONNECTIONS; + + @JsonProperty("acquire_timeout") + @DurationMin(seconds = 1) + @DurationMax(seconds = 3600) + private Duration acquireTimeout = DEFAULT_ACQUIRE_TIMEOUT; + + public int getMaxConnections() { + return maxConnections; + } + + public Duration getAcquireTimeout() { + return acquireTimeout; + } +} diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/ClientFactoryTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/ClientFactoryTest.java index bf70dde593..947bc728e9 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/ClientFactoryTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/ClientFactoryTest.java @@ -17,12 +17,18 @@ import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.plugins.sink.s3.configuration.AwsAuthenticationOptions; +import org.opensearch.dataprepper.plugins.sink.s3.configuration.ClientOptions; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.S3ClientBuilder; +import java.time.Duration; import java.util.Map; import java.util.UUID; @@ -30,6 +36,7 @@ import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.MatcherAssert.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.verify; @@ -44,6 +51,8 @@ class ClientFactoryTest { @Mock private AwsAuthenticationOptions awsAuthenticationOptions; + @Mock + private ClientOptions clientOptions; @BeforeEach void setUp() { @@ -51,7 +60,7 @@ void setUp() { } @Test - void createS3Client_with_real_S3Client() { + void createS3AsyncClient_with_real_S3AsyncClient() { when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1); final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier); @@ -99,4 +108,66 @@ void createS3Client_provides_correct_inputs(final String regionString) { assertThat(actualCredentialsOptions.getStsExternalId(), equalTo(externalId)); assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides)); } + + @Test + void createS3AsyncClient_with_client_options_returns_expected_client() { + final Region region = Region.of("us-east-1"); + final String stsRoleArn = UUID.randomUUID().toString(); + final String externalId = UUID.randomUUID().toString(); + final Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region); + when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn); + when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(externalId); + when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides); + + final AwsCredentialsProvider expectedCredentialsProvider = mock(AwsCredentialsProvider.class); + when(awsCredentialsSupplier.getProvider(any())).thenReturn(expectedCredentialsProvider); + + final S3AsyncClientBuilder s3AsyncClientBuilder = mock(S3AsyncClientBuilder.class); + when(s3AsyncClientBuilder.region(region)).thenReturn(s3AsyncClientBuilder); + when(s3AsyncClientBuilder.credentialsProvider(any())).thenReturn(s3AsyncClientBuilder); + when(s3AsyncClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(s3AsyncClientBuilder); + + when(s3SinkConfig.getClientOptions()).thenReturn(clientOptions); + + final int maxConnections = 100; + final Duration acquireTimeout = Duration.ofSeconds(30); + when(clientOptions.getMaxConnections()).thenReturn(maxConnections); + when(clientOptions.getAcquireTimeout()).thenReturn(acquireTimeout); + + final NettyNioAsyncHttpClient.Builder httpClientBuilder = mock(NettyNioAsyncHttpClient.Builder.class); + final SdkAsyncHttpClient httpClient = mock(SdkAsyncHttpClient.class); + when(httpClientBuilder.connectionAcquisitionTimeout(any(Duration.class))).thenReturn(httpClientBuilder); + when(httpClientBuilder.maxConcurrency(anyInt())).thenReturn(httpClientBuilder); + when(httpClientBuilder.build()).thenReturn(httpClient); + + try(final MockedStatic s3AsyncClientMockedStatic = mockStatic(S3AsyncClient.class); + final MockedStatic httpClientMockedStatic = mockStatic(NettyNioAsyncHttpClient.class)) { + s3AsyncClientMockedStatic.when(S3AsyncClient::builder) + .thenReturn(s3AsyncClientBuilder); + httpClientMockedStatic.when(NettyNioAsyncHttpClient::builder) + .thenReturn(httpClientBuilder); + ClientFactory.createS3AsyncClient(s3SinkConfig, awsCredentialsSupplier); + } + + final ArgumentCaptor credentialsProviderArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsProvider.class); + verify(s3AsyncClientBuilder).credentialsProvider(credentialsProviderArgumentCaptor.capture()); + + final AwsCredentialsProvider actualCredentialsProvider = credentialsProviderArgumentCaptor.getValue(); + + assertThat(actualCredentialsProvider, equalTo(expectedCredentialsProvider)); + + final ArgumentCaptor optionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); + verify(awsCredentialsSupplier).getProvider(optionsArgumentCaptor.capture()); + + final AwsCredentialsOptions actualCredentialsOptions = optionsArgumentCaptor.getValue(); + assertThat(actualCredentialsOptions.getRegion(), equalTo(region)); + assertThat(actualCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn)); + assertThat(actualCredentialsOptions.getStsExternalId(), equalTo(externalId)); + assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides)); + + verify(httpClientBuilder).connectionAcquisitionTimeout(acquireTimeout); + verify(httpClientBuilder).maxConcurrency(maxConnections); + verify(s3AsyncClientBuilder).httpClient(httpClient); + } } \ No newline at end of file diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkConfigTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkConfigTest.java index d1660ebc63..bbd831b9cd 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkConfigTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkConfigTest.java @@ -76,4 +76,9 @@ void get_AWS_Auth_options_in_sinkconfig_exception() { void get_json_codec_test() { assertNull(new S3SinkConfig().getCodec()); } + + @Test + void get_client_option_test() { + assertNull(new S3SinkConfig().getClientOptions()); + } } \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index 97742e8576..4328fa9aac 100644 --- a/settings.gradle +++ b/settings.gradle @@ -183,4 +183,4 @@ include 'data-prepper-plugins:http-common' include 'data-prepper-plugins:aws-lambda' //include 'data-prepper-plugins:dummy-plugin' include 'data-prepper-plugin-schema' - +include 'data-prepper-plugins:kinesis-source' \ No newline at end of file