Skip to content

Commit

Permalink
Added backoff and sqsclient auth
Browse files Browse the repository at this point in the history
Signed-off-by: Asif Sohail Mohammed <nsifmoh@amazon.com>
  • Loading branch information
asifsmohammed committed May 27, 2022
1 parent 4862ef3 commit 824eee3
Show file tree
Hide file tree
Showing 11 changed files with 395 additions and 59 deletions.
3 changes: 1 addition & 2 deletions data-prepper-plugins/s3-source/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ dependencies {
implementation project(':data-prepper-api')
implementation project(':data-prepper-plugins:blocking-buffer')
implementation 'com.fasterxml.jackson.core:jackson-databind:2.13.3'
implementation 'org.apache.commons:commons-lang3:3.12.0'
implementation 'software.amazon.awssdk:s3:2.17.191'
implementation 'software.amazon.awssdk:sts:2.17.191'
implementation 'software.amazon.awssdk:sqs:2.17.191'
implementation 'com.amazonaws:aws-java-sdk-s3'
implementation 'com.amazonaws:aws-java-sdk-s3:1.12.220'
testImplementation 'org.hamcrest:hamcrest:2.2'
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package com.amazon.dataprepper.plugins.source;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Random;

public class BackoffUtils {
private static final Logger LOG = LoggerFactory.getLogger(BackoffUtils.class);

private int numberOfRetries;
private long timeToWait;

private final Random random = new Random();

public BackoffUtils(final int numberOfRetries, final long timeToWait) {
this.numberOfRetries = numberOfRetries;
this.timeToWait = timeToWait;
}

public boolean shouldRetry() {
return numberOfRetries > 0;
}

public void errorOccurred() {
numberOfRetries -= 1;

if (shouldRetry()) {
waitUntilNextTry();
timeToWait += random.nextInt(1000);
}
}

private void waitUntilNextTry() {
try {
Thread.sleep(timeToWait);
} catch (InterruptedException e) {
LOG.error("Thread is interrupted.", e);
}
}

public void doNotRetry() {
numberOfRetries = 0;
}

public int getNumberOfTriesLeft() {
return numberOfRetries;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package com.amazon.dataprepper.plugins.source;

import com.amazon.dataprepper.plugins.source.configuration.AwsAuthenticationOptions;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.sts.StsClient;

public class S3ClientAuthentication {
private S3SourceConfig s3SourceConfig;
private AwsAuthenticationOptions awsAuthenticationOptions;

public S3ClientAuthentication(final S3SourceConfig s3SourceConfig) {
this.s3SourceConfig = s3SourceConfig;
this.awsAuthenticationOptions = s3SourceConfig.getAWSAuthentication();
}

public S3Client createS3Client(final StsClient stsClient) {

return software.amazon.awssdk.services.s3.S3Client.builder()
.region(Region.of(s3SourceConfig.getAWSAuthentication().getAwsRegion()))
.credentialsProvider(awsAuthenticationOptions.authenticateAwsConfiguration(stsClient))
.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,25 @@
import com.amazon.dataprepper.model.event.Event;
import com.amazon.dataprepper.model.record.Record;
import com.amazon.dataprepper.model.source.Source;
import org.apache.commons.lang3.NotImplementedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sts.StsClient;

import java.util.Objects;

@DataPrepperPlugin(name = "s3", pluginType = Source.class, pluginConfigurationType = S3SourceConfig.class)
public class S3Source implements Source<Record<Event>> {
private static final Logger LOG = LoggerFactory.getLogger(S3Source.class);

private final PluginMetrics pluginMetrics;
private final S3SourceConfig s3SourceConfig;
private AwsCredentialsProvider awsCredentialsProvider;
private S3Client s3Client;
private software.amazon.awssdk.services.s3.S3Client s3Client;
private SqsClient sqsClient;
private Thread sqsWorkerThread;

@DataPrepperPluginConstructor
public S3Source(PluginMetrics pluginMetrics, final S3SourceConfig s3SourceConfig) {
this.pluginMetrics = pluginMetrics;
this.s3SourceConfig = s3SourceConfig;

awsCredentialsProvider = createCredentialsProvider();
LOG.info("Creating S3 client.");
s3Client = createS3Client(awsCredentialsProvider);
sqsClient = SqsClient.create();

throw new NotImplementedException();
}

@Override
Expand All @@ -52,25 +39,16 @@ public void start(Buffer<Record<Event>> buffer) {
throw new IllegalStateException("Buffer provided is null");
}

for(int i = 0; i < s3SourceConfig.getSqsOptions().getThreadCount(); i++) {
Thread sqsWorkerThread = new Thread(new SqsWorker(sqsClient, s3Client, s3SourceConfig));
sqsWorkerThread.start();
}
LOG.info("Creating SQS and S3 client");
StsClient stsClient = StsClient.create();
this.s3Client = new S3ClientAuthentication(s3SourceConfig).createS3Client(stsClient);
this.sqsClient = new SqsClientAuthentication(s3SourceConfig).createSqsClient(stsClient);

sqsWorkerThread = new Thread(new SqsWorker(sqsClient, s3Client, s3SourceConfig));
}

@Override
public void stop() {

}

public AwsCredentialsProvider createCredentialsProvider() {
return Objects.requireNonNull(s3SourceConfig.getAWSAuthentication().authenticateAwsConfiguration(StsClient.create()));
}

private S3Client createS3Client(final AwsCredentialsProvider awsCredentialsProvider) {
return S3Client.builder()
.region(Region.of(s3SourceConfig.getAWSAuthentication().getAwsRegion()))
.credentialsProvider(awsCredentialsProvider)
.build();
Thread.currentThread().interrupt();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package com.amazon.dataprepper.plugins.source;

import com.amazon.dataprepper.plugins.source.configuration.AwsAuthenticationOptions;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sts.StsClient;

public class SqsClientAuthentication {
private S3SourceConfig s3SourceConfig;
private AwsAuthenticationOptions awsAuthenticationOptions;

public SqsClientAuthentication(final S3SourceConfig s3SourceConfig) {
this.s3SourceConfig = s3SourceConfig;
awsAuthenticationOptions = s3SourceConfig.getAWSAuthentication();
}

public SqsClient createSqsClient(final StsClient stsClient) {

return SqsClient.builder()
.credentialsProvider(awsAuthenticationOptions.authenticateAwsConfiguration(stsClient))
.region(Region.of(s3SourceConfig.getAWSAuthentication().getAwsRegion()))
.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,43 +23,64 @@
public class SqsWorker implements Runnable {
private static final Logger LOG = LoggerFactory.getLogger(SqsWorker.class);

private static final int NUMBER_OF_RETRIES = 5;
private static final long TIME_TO_WAIT = 5000;

private final S3SourceConfig s3SourceConfig;
private final SqsClient sqsClient;
private final S3Client s3Client;
private final ObjectMapper objectMapper;

public SqsWorker(SqsClient sqsClient, S3Client s3Client, S3SourceConfig s3SourceConfig) {
SqsOptions sqsOptions;
private BackoffUtils sqsBackoff;

public SqsWorker(final SqsClient sqsClient, final S3Client s3Client, final S3SourceConfig s3SourceConfig) {
this.s3SourceConfig = s3SourceConfig;
this.sqsClient = sqsClient;
this.s3Client = s3Client;
objectMapper = new ObjectMapper();
sqsOptions = s3SourceConfig.getSqsOptions();
}

@Override
public void run() {
SqsOptions sqsOptions = s3SourceConfig.getSqsOptions();

sqsBackoff = new BackoffUtils(NUMBER_OF_RETRIES, TIME_TO_WAIT);

do {
// read messages from sqs
List<Message> messages = new ArrayList<>();

try {
ReceiveMessageRequest receiveMessageRequest = ReceiveMessageRequest.builder()
.queueUrl(sqsOptions.getSqsUrl())
.maxNumberOfMessages(sqsOptions.getMaximumMessages())
.visibilityTimeout((int) sqsOptions.getVisibilityTimeout().getSeconds())
.waitTimeSeconds((int) sqsOptions.getWaitTime().getSeconds())
.build();

messages.addAll(sqsClient.receiveMessage(receiveMessageRequest).messages());
} catch (SqsException e) {
LOG.error(e.awsErrorDetails().errorMessage());
System.exit(1);
while (sqsBackoff.shouldRetry()) {
try {
ReceiveMessageRequest receiveMessageRequest = createReceiveMessageRequest();
messages.addAll(sqsClient.receiveMessage(receiveMessageRequest).messages());

sqsBackoff.doNotRetry();
} catch (SqsException e) {
sqsBackoff.errorOccurred();
if (sqsBackoff.getNumberOfTriesLeft() == 0)
LOG.error("Error reading from SQS: {}", e.awsErrorDetails().errorMessage());
}
}


// try {
// ReceiveMessageRequest receiveMessageRequest = ReceiveMessageRequest.builder()
// .queueUrl(sqsOptions.getSqsUrl())
// .maxNumberOfMessages(sqsOptions.getMaximumMessages())
// .visibilityTimeout((int) sqsOptions.getVisibilityTimeout().getSeconds())
// .waitTimeSeconds((int) sqsOptions.getWaitTime().getSeconds())
// .build();
//
// messages.addAll(sqsClient.receiveMessage(receiveMessageRequest).messages());
// } catch (SqsException e) {
// LOG.error("Error reading from SQS: {}", e.awsErrorDetails().errorMessage());
// }

// read each message as S3 event message
List<S3EventNotification.S3EventNotificationRecord> s3EventNotificationRecords = messages.stream()
.map(this::getS3EventMessages)
.map(this::convertS3EventMessages)
.collect(Collectors.toList());


Expand All @@ -81,14 +102,22 @@ public void run() {
try {
Thread.sleep(s3SourceConfig.getSqsOptions().getPollDelay().toMillis());
} catch (InterruptedException e) {
e.printStackTrace();
LOG.error("Thread is interrupted while polling.", e);
}
}
} while (true);
}

private S3EventNotification.S3EventNotificationRecord getS3EventMessages(Message message) {
return objectMapper.convertValue(message, S3EventNotification.S3EventNotificationRecord.class);
private ReceiveMessageRequest createReceiveMessageRequest() {
return ReceiveMessageRequest.builder()
.queueUrl(sqsOptions.getSqsUrl())
.maxNumberOfMessages(sqsOptions.getMaximumMessages())
.visibilityTimeout((int) sqsOptions.getVisibilityTimeout().getSeconds())
.waitTimeSeconds((int) sqsOptions.getWaitTime().getSeconds())
.build();
}

private S3EventNotification.S3EventNotificationRecord convertS3EventMessages(final Message message) {
return objectMapper.convertValue(message, S3EventNotification.S3EventNotificationRecord.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
import java.time.Duration;

public class SqsOptions {
private final int DEFAULT_MAXIMUM_MESSAGES = 10;
private final int DEFAULT_VISIBILITY_TIMEOUT_SECONDS = 30;
private final int DEFAULT_WAIT_TIME_SECONDS = 0;
private final int DEFAULT_POLL_DELAY_SECONDS = 0;
private final int DEFAULT_THREAD_COUNT = 1;
private static final int DEFAULT_MAXIMUM_MESSAGES = 10;
private static final Duration DEFAULT_VISIBILITY_TIMEOUT_SECONDS = Duration.ofSeconds(30);
private static final Duration DEFAULT_WAIT_TIME_SECONDS = Duration.ofSeconds(20);
private static final Duration DEFAULT_POLL_DELAY_SECONDS = Duration.ofSeconds(0);
private static final int DEFAULT_THREAD_COUNT = 1;

@JsonProperty("queue_url")
@NotBlank(message = "SQS URL cannot be null or empty")
Expand All @@ -29,14 +29,14 @@ public class SqsOptions {
@JsonProperty("visibility_timeout")
@Min(0)
@Max(43200)
private Duration visibilityTimeout = Duration.ofSeconds(DEFAULT_VISIBILITY_TIMEOUT_SECONDS);
private Duration visibilityTimeout = DEFAULT_VISIBILITY_TIMEOUT_SECONDS;

@JsonProperty("wait_time")
@Max(20)
private Duration waitTime = Duration.ofSeconds(DEFAULT_WAIT_TIME_SECONDS);
private Duration waitTime = DEFAULT_WAIT_TIME_SECONDS;

@JsonProperty("poll_delay")
private Duration pollDelay = Duration.ofSeconds(DEFAULT_POLL_DELAY_SECONDS);
private Duration pollDelay = DEFAULT_POLL_DELAY_SECONDS;

@JsonProperty("thread_count")
private int threadCount = DEFAULT_THREAD_COUNT;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package com.amazon.dataprepper.plugins.source;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

class BackoffUtilsTest {
private static final int NUMBER_OF_RETRIES = 5;
private static final int TIME_TO_WAIT = 50;

BackoffUtils backoffUtils;

@BeforeEach
void setUp() {
backoffUtils = new BackoffUtils(NUMBER_OF_RETRIES, TIME_TO_WAIT);
}


@Test
void shouldRetry_should_return_true_for_value_greater_than_zero() {
assertTrue(backoffUtils.shouldRetry());
}

@Test
void shouldRetry_should_return_false_for_zero() {
backoffUtils = new BackoffUtils(0, 1000);
assertFalse(backoffUtils.shouldRetry());
}

@Test
void errorOccurred_should_decrement_retires() {
backoffUtils.errorOccurred();
assertThat(backoffUtils.getNumberOfTriesLeft(), equalTo(NUMBER_OF_RETRIES - 1));
}

@Test
void doNotRetry_should_set_retries_to_zero() {
backoffUtils.doNotRetry();
assertThat(backoffUtils.getNumberOfTriesLeft(), equalTo(0));
}
}
Loading

0 comments on commit 824eee3

Please sign in to comment.