Skip to content

Commit

Permalink
Add support for dynamic bucket in S3 sink
Browse files Browse the repository at this point in the history
Signed-off-by: Taylor Gray <tylgry@amazon.com>
  • Loading branch information
graytaylor0 committed Apr 9, 2024
1 parent bcb1145 commit ed5956a
Show file tree
Hide file tree
Showing 37 changed files with 570 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,12 @@ public interface Event extends Serializable {
* of a Data Prepper expression
* @param format input format
* @param expressionEvaluator - The expression evaluator that will support formatting from Data Prepper expressions
* @param replacementForFailures - The String to use as a replacement for when keys in Events can't be found
* @return returns a string with no formatted parts, returns null if no value is found
* @throws RuntimeException if the input string is not properly formatted
* @since 2.1
*/
String formatString(String format, ExpressionEvaluator expressionEvaluator);
String formatString(final String format, final ExpressionEvaluator expressionEvaluator, final String replacementForFailures);

/**
* Returns event handle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ public String getAsJsonString(final String key) {
*/
@Override
public String formatString(final String format) {
return formatStringInternal(format, null);
return formatStringInternal(format, null, null);
}

/**
Expand All @@ -333,11 +333,11 @@ public String formatString(final String format) {
* @throws RuntimeException if the format is incorrect or the value is not a string
*/
@Override
public String formatString(final String format, final ExpressionEvaluator expressionEvaluator) {
return formatStringInternal(format, expressionEvaluator);
public String formatString(final String format, final ExpressionEvaluator expressionEvaluator, final String replacementForFailures) {
return formatStringInternal(format, expressionEvaluator, replacementForFailures);
}

private String formatStringInternal(final String format, final ExpressionEvaluator expressionEvaluator) {
private String formatStringInternal(final String format, final ExpressionEvaluator expressionEvaluator, final String replacementForFailures) {
int fromIndex = 0;
String result = "";
int position = 0;
Expand All @@ -361,7 +361,11 @@ private String formatStringInternal(final String format, final ExpressionEvaluat
if (expressionEvaluator != null && expressionEvaluator.isValidExpressionStatement(name)) {
val = expressionEvaluator.evaluate(name, this);
} else {
throw new EventKeyNotFoundException(String.format("The key %s could not be found in the Event when formatting", name));
if (replacementForFailures == null) {
throw new EventKeyNotFoundException(String.format("The key %s could not be found in the Event when formatting", name));
}

val = replacementForFailures;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ public void formatString_with_expression_evaluator_catches_exception_when_Event_
when(expressionEvaluator.evaluate(invalidKeyExpression, event)).thenReturn(invalidKeyExpressionResult);
when(expressionEvaluator.evaluate(expressionStatement, event)).thenReturn(expressionEvaluationResult);

assertThat(event.formatString(formatString, expressionEvaluator), is(equalTo(finalString)));
assertThat(event.formatString(formatString, expressionEvaluator, null), is(equalTo(finalString)));
}

@Test
Expand All @@ -630,7 +630,7 @@ public void testBuild_withFormatStringWithExpressionEvaluator() {
verify(expressionEvaluator, never()).evaluate(eq("foo"), any(Event.class));
when(expressionEvaluator.evaluate(expressionStatement, event)).thenReturn(expressionEvaluationResult);

assertThat(event.formatString(formatString, expressionEvaluator), is(equalTo(finalString)));
assertThat(event.formatString(formatString, expressionEvaluator, null), is(equalTo(finalString)));
}

@ParameterizedTest
Expand Down Expand Up @@ -662,6 +662,21 @@ public void testBuild_withFormatStringWithValueNotFound() {
assertThrows(EventKeyNotFoundException.class, () -> event.formatString("test-${boo}-string"));
}

@Test
public void testBuild_withFormatStringWithValueNotFound_and_replacement_failure() {

final String replacementForMissingKeys = "REPLACED";
final String jsonString = "{\"foo\": \"bar\", \"info\": {\"ids\": {\"id\":\"idx\"}}}";
final ExpressionEvaluator expressionEvaluator = mock(ExpressionEvaluator.class);
event = JacksonEvent.builder()
.withEventType(eventType)
.withData(jsonString)
.getThis()
.build();
final String result = event.formatString("test-${boo}-string", expressionEvaluator, replacementForMissingKeys);
assertThat(result, equalTo("test-" + replacementForMissingKeys + "-string"));
}

@Test
public void testBuild_withFormatStringWithInvalidFormat() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public void produceRawData(final byte[] bytes, final String key) throws Exceptio
public void produceRecords(final Record<Event> record) throws Exception {
bufferedEventHandles.add(record.getData().getEventHandle());
Event event = getEvent(record);
final String key = event.formatString(kafkaProducerConfig.getPartitionKey(), expressionEvaluator);
final String key = event.formatString(kafkaProducerConfig.getPartitionKey(), expressionEvaluator, null);
try {
if (Objects.equals(serdeFormat, MessageFormat.JSON.toString())) {
publishJsonMessage(record, key);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public Collection<Record<Event>> doExecute(final Collection<Record<Event>> recor
}

try {
final String key = (entry.getKey() == null) ? null : recordEvent.formatString(entry.getKey(), expressionEvaluator);
final String key = (entry.getKey() == null) ? null : recordEvent.formatString(entry.getKey(), expressionEvaluator, null);
final String metadataKey = entry.getMetadataKey();
Object value;
if (!Objects.isNull(entry.getValueExpression())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ public void doOutput(final Collection<Record<Event>> records) {
final SerializedJson document = getDocument(event);
String indexName = configuredIndexAlias;
try {
indexName = indexManager.getIndexName(event.formatString(indexName, expressionEvaluator));
indexName = indexManager.getIndexName(event.formatString(indexName, expressionEvaluator, null));
} catch (final Exception e) {
LOG.error("There was an exception when constructing the index name. Check the dlq if configured to see details about the affected Event: {}", e.getMessage());
dynamicIndexDroppedEvents.increment();
Expand All @@ -403,8 +403,8 @@ public void doOutput(final Collection<Record<Event>> records) {
String versionExpressionEvaluationResult = null;
if (versionExpression != null) {
try {
versionExpressionEvaluationResult = event.formatString(versionExpression, expressionEvaluator);
version = Long.valueOf(event.formatString(versionExpression, expressionEvaluator));
versionExpressionEvaluationResult = event.formatString(versionExpression, expressionEvaluator, null);
version = Long.valueOf(event.formatString(versionExpression, expressionEvaluator, null));
} catch (final NumberFormatException e) {
final String errorMessage = String.format(
"Unable to convert the result of evaluating document_version '%s' to Long for an Event. The evaluation result '%s' must be a valid Long type", versionExpression, versionExpressionEvaluationResult
Expand Down Expand Up @@ -433,7 +433,7 @@ public void doOutput(final Collection<Record<Event>> records) {
}
}
if (eventAction.contains("${")) {
eventAction = event.formatString(eventAction, expressionEvaluator);
eventAction = event.formatString(eventAction, expressionEvaluator, null);
}
if (OpenSearchBulkActions.fromOptionValue(eventAction) == null) {
LOG.error("Unknown action {}, skipping the event", eventAction);
Expand Down Expand Up @@ -485,7 +485,7 @@ SerializedJson getDocument(final Event event) {
docId = event.get(documentIdField, String.class);
} else if (Objects.nonNull(documentId)) {
try {
docId = event.formatString(documentId, expressionEvaluator);
docId = event.formatString(documentId, expressionEvaluator, null);
} catch (final ExpressionEvaluationException | EventKeyNotFoundException e) {
LOG.error("Unable to construct document_id with format {}, the document_id will be generated by OpenSearch", documentId, e);
}
Expand All @@ -496,7 +496,7 @@ SerializedJson getDocument(final Event event) {
routingValue = event.get(routingField, String.class);
} else if (routing != null) {
try {
routingValue = event.formatString(routing, expressionEvaluator);
routingValue = event.formatString(routing, expressionEvaluator, null);
} catch (final ExpressionEvaluationException | EventKeyNotFoundException e) {
LOG.error("Unable to construct routing with format {}, the routing will be generated by OpenSearch", routing, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public class S3SinkIT {

@Mock
private PluginSetting pluginSetting;
@Mock
@Mock(stubOnly = true)
private S3SinkConfig s3SinkConfig;
@Mock
private PluginFactory pluginFactory;
Expand Down Expand Up @@ -166,6 +166,8 @@ void setUp() {
.build();

when(expressionEvaluator.isValidFormatExpression(anyString())).thenReturn(true);

when(s3SinkConfig.getDefaultBucket()).thenReturn(null);
}

private S3Sink createObjectUnderTest() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.NoSuchBucketException;
import software.amazon.awssdk.services.s3.model.S3Exception;
import software.amazon.awssdk.services.s3.model.UploadPartRequest;
import software.amazon.awssdk.services.s3.model.UploadPartResponse;

Expand All @@ -27,6 +29,8 @@
public class S3OutputStream extends PositionOutputStream {
private static final Logger LOG = LoggerFactory.getLogger(S3OutputStream.class);

static final String ACCESS_DENIED = "Access Denied";

/**
* Default chunk size is 10MB
*/
Expand All @@ -42,6 +46,8 @@ public class S3OutputStream extends PositionOutputStream {
*/
private final String key;

private String targetBucket;

/**
* The temporary buffer used for storing the chunks
*/
Expand All @@ -65,21 +71,31 @@ public class S3OutputStream extends PositionOutputStream {
*/
private boolean open;

/**
* The default bucket to send to when upload fails with dynamic bucket
*/
private String defaultBucket;

/**
* Creates a new S3 OutputStream
*
* @param s3Client the AmazonS3 client
* @param bucketSupplier name of the bucket
* @param keySupplier path within the bucket
*/
public S3OutputStream(final S3Client s3Client, Supplier<String> bucketSupplier, Supplier<String> keySupplier) {
public S3OutputStream(final S3Client s3Client,
final Supplier<String> bucketSupplier,
final Supplier<String> keySupplier,
final String defaultBucket) {
this.s3Client = s3Client;
this.bucket = bucketSupplier.get();
this.targetBucket = bucketSupplier.get();
this.key = keySupplier.get();
buf = new byte[BUFFER_SIZE];
position = 0;
etags = new ArrayList<>();
open = true;
this.defaultBucket = defaultBucket;
}

@Override
Expand Down Expand Up @@ -157,7 +173,7 @@ public void close() {
.parts(completedParts)
.build();
CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder()
.bucket(bucket)
.bucket(targetBucket)
.key(key)
.uploadId(uploadId)
.multipartUpload(completedMultipartUpload)
Expand All @@ -184,21 +200,27 @@ private void flushBufferAndRewind() {

private void possiblyStartMultipartUpload() {
if (uploadId == null) {
CreateMultipartUploadRequest uploadRequest = CreateMultipartUploadRequest.builder()
.bucket(bucket)
.key(key)
.build();
CreateMultipartUploadResponse multipartUpload = s3Client.createMultipartUpload(uploadRequest);
uploadId = multipartUpload.uploadId();

LOG.debug("Created multipart upload {} bucket='{}',key='{}'.", uploadId, bucket, key);
try {
createMultipartUpload();
} catch (final S3Exception e) {
if (defaultBucket != null && (e instanceof NoSuchBucketException || e.getMessage().contains(ACCESS_DENIED))) {
targetBucket = defaultBucket;
LOG.warn("Bucket {} could not be accessed to create multi-part upload, attempting to create multi-part upload to default_bucket {}", bucket, defaultBucket);
createMultipartUpload();
} else {
throw e;
}
}

LOG.debug("Created multipart upload {} bucket='{}',key='{}'.", uploadId, targetBucket, key);
}
}

private void uploadPart() {
int partNumber = etags.size() + 1;
UploadPartRequest uploadRequest = UploadPartRequest.builder()
.bucket(bucket)
.bucket(targetBucket)
.key(key)
.uploadId(uploadId)
.partNumber(partNumber)
Expand All @@ -217,5 +239,14 @@ private void uploadPart() {
public long getPos() throws IOException {
return position + (long) etags.size() * (long) BUFFER_SIZE;
}

private void createMultipartUpload() {
CreateMultipartUploadRequest uploadRequest = CreateMultipartUploadRequest.builder()
.bucket(targetBucket)
.key(key)
.build();
CreateMultipartUploadResponse multipartUpload = s3Client.createMultipartUpload(uploadRequest);
uploadId = multipartUpload.uploadId();
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.plugins.sink.s3.accumulator.ObjectKey;


public class KeyGenerator {
private final S3SinkConfig s3SinkConfig;
private final ExtensionProvider extensionProvider;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ public S3Sink(final PluginSetting pluginSetting,
throw new InvalidPluginConfigurationException("name_pattern is not a valid format expression");
}

if (s3SinkConfig.getBucketName() != null &&
!expressionEvaluator.isValidFormatExpression(s3SinkConfig.getBucketName())) {
throw new InvalidPluginConfigurationException("bucket name is not a valid format expression");
}

S3OutputCodecContext s3OutputCodecContext = new S3OutputCodecContext(OutputCodecContext.fromSinkContext(sinkContext), compressionOption);

codec.validateAgainstCodecContext(s3OutputCodecContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ public class S3SinkConfig {
@Size(min = 3, max = 500, message = "bucket length should be at least 3 characters")
private String bucketName;

/**
* The default bucket to send to if using a dynamic bucket name and failures occur
* for any reason when sending to a dynamic bucket
*/
@JsonProperty("default_bucket")
@Size(min = 3, max = 500, message = "default_bucket length should be at least 3 characters")
private String defaultBucket;


@JsonProperty("object_key")
@Valid
private ObjectKeyOptions objectKeyOptions = new ObjectKeyOptions();
Expand Down Expand Up @@ -143,4 +152,6 @@ public int getMaxUploadRetries() {
public CompressionOption getCompression() {
return compression;
}

public String getDefaultBucket() { return defaultBucket; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
import java.util.function.Supplier;

public interface BufferFactory {
Buffer getBuffer(S3Client s3Client, Supplier<String> bucketSupplier, Supplier<String> keySupplier);
Buffer getBuffer(S3Client s3Client, Supplier<String> bucketSupplier, Supplier<String> keySupplier, String defaultBucket);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink.s3.accumulator;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.NoSuchBucketException;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Exception;

public class BufferUtilities {

private static final Logger LOG = LoggerFactory.getLogger(BufferUtilities.class);

static final String ACCESS_DENIED = "Access Denied";

static void putObjectOrSendToDefaultBucket(final S3Client s3Client,
final RequestBody requestBody,
final String objectKey,
final String targetBucket,
final String defaultBucket) {
try {
s3Client.putObject(
PutObjectRequest.builder().bucket(targetBucket).key(objectKey).build(),
requestBody);
} catch (final S3Exception e) {
if (defaultBucket != null && (e instanceof NoSuchBucketException || e.getMessage().contains(ACCESS_DENIED))) {
LOG.warn("Bucket {} could not be accessed, attempting to send to default_bucket {}", targetBucket, defaultBucket);
s3Client.putObject(
PutObjectRequest.builder().bucket(defaultBucket).key(objectKey).build(),
requestBody);
} else {
throw e;
}
}
}
}
Loading

0 comments on commit ed5956a

Please sign in to comment.