diff --git a/.changes/next-release/feature-S3TransferManager-a02ba8b.json b/.changes/next-release/feature-S3TransferManager-a02ba8b.json new file mode 100644 index 000000000000..f3bbe2a7ffed --- /dev/null +++ b/.changes/next-release/feature-S3TransferManager-a02ba8b.json @@ -0,0 +1,6 @@ +{ + "type": "feature", + "category": "S3 Transfer Manager", + "contributor": "", + "description": "This change enables multipart download for S3 Transfer Manager with the java-based Multipart S3 Async Client." +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/FileTransformerConfiguration.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/FileTransformerConfiguration.java index 902815f96c49..640c37224c9c 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/FileTransformerConfiguration.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/FileTransformerConfiguration.java @@ -23,6 +23,7 @@ import java.util.concurrent.ExecutorService; import software.amazon.awssdk.annotations.SdkPublicApi; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.utils.ToString; import software.amazon.awssdk.utils.Validate; import software.amazon.awssdk.utils.builder.CopyableBuilder; import software.amazon.awssdk.utils.builder.ToCopyableBuilder; @@ -41,11 +42,19 @@ public final class FileTransformerConfiguration implements ToCopyableBuilder executorService() { return Optional.ofNullable(executorService); } + /** + * Exclusively used with {@link FileWriteOption#WRITE_TO_POSITION}. Configures the position, where to start writing to the + * existing file. The location correspond to the first byte where new data will be written. For example, if {@code 128} is + * configured, bytes 0-127 of the existing file will remain untouched and data will be written starting at byte 128. If not + * specified, defaults to 0. + * + * @return The offset at which to start overwriting data in the file. + */ + public Long position() { + return position; + } + /** * Create a {@link Builder}, used to create a {@link FileTransformerConfiguration}. */ @@ -137,6 +158,9 @@ public boolean equals(Object o) { if (failureBehavior != that.failureBehavior) { return false; } + if (!Objects.equals(position, that.position)) { + return false; + } return Objects.equals(executorService, that.executorService); } @@ -145,6 +169,7 @@ public int hashCode() { int result = fileWriteOption != null ? fileWriteOption.hashCode() : 0; result = 31 * result + (failureBehavior != null ? failureBehavior.hashCode() : 0); result = 31 * result + (executorService != null ? executorService.hashCode() : 0); + result = 31 * result + (position != null ? position.hashCode() : 0); return result; } @@ -165,7 +190,15 @@ public enum FileWriteOption { /** * Create a new file if it doesn't exist, otherwise append to the existing file. */ - CREATE_OR_APPEND_TO_EXISTING + CREATE_OR_APPEND_TO_EXISTING, + + /** + * Write to an existing file at the specified position, defined by the {@link FileTransformerConfiguration#position()}. If + * the file does not exist, a {@link java.nio.file.NoSuchFileException} will be thrown. If + * {@link FileTransformerConfiguration#position()} is not configured, start overwriting data at the beginning of the file + * (byte 0). + */ + WRITE_TO_POSITION } /** @@ -209,12 +242,24 @@ public interface Builder extends CopyableBuilder { + + private final Long bufferSizeInBytes; + + private SplittingTransformerConfiguration(DefaultBuilder builder) { + this.bufferSizeInBytes = Validate.paramNotNull(builder.bufferSize, "bufferSize"); + } + + /** + * Create a {@link Builder}, used to create a {@link SplittingTransformerConfiguration}. + */ + public static Builder builder() { + return new DefaultBuilder(); + } + + /** + * @return the buffer size + */ + public Long bufferSizeInBytes() { + return bufferSizeInBytes; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + SplittingTransformerConfiguration that = (SplittingTransformerConfiguration) o; + + return Objects.equals(bufferSizeInBytes, that.bufferSizeInBytes); + } + + @Override + public int hashCode() { + return bufferSizeInBytes != null ? bufferSizeInBytes.hashCode() : 0; + } + + @Override + public String toString() { + return ToString.builder("SplittingTransformerConfiguration") + .add("bufferSizeInBytes", bufferSizeInBytes) + .build(); + } + + @Override + public Builder toBuilder() { + return new DefaultBuilder(this); + } + + public interface Builder extends CopyableBuilder { + + /** + * Configures the maximum amount of memory in bytes buffered by the {@link SplittingTransformer}. + * + * @param bufferSize the buffer size in bytes + * @return This object for method chaining. + */ + Builder bufferSizeInBytes(Long bufferSize); + } + + private static final class DefaultBuilder implements Builder { + private Long bufferSize; + + private DefaultBuilder(SplittingTransformerConfiguration configuration) { + this.bufferSize = configuration.bufferSizeInBytes; + } + + private DefaultBuilder() { + } + + @Override + public Builder bufferSizeInBytes(Long bufferSize) { + this.bufferSize = bufferSize; + return this; + } + + @Override + public SplittingTransformerConfiguration build() { + return new SplittingTransformerConfiguration(this); + } + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncResponseTransformer.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncResponseTransformer.java index 64565a62a204..6550497d52ed 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncResponseTransformer.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncResponseTransformer.java @@ -26,11 +26,16 @@ import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.SplittingTransformerConfiguration; import software.amazon.awssdk.core.internal.async.ByteArrayAsyncResponseTransformer; +import software.amazon.awssdk.core.internal.async.DefaultAsyncResponseTransformerSplitResult; import software.amazon.awssdk.core.internal.async.FileAsyncResponseTransformer; import software.amazon.awssdk.core.internal.async.InputStreamResponseTransformer; import software.amazon.awssdk.core.internal.async.PublisherAsyncResponseTransformer; +import software.amazon.awssdk.core.internal.async.SplittingTransformer; import software.amazon.awssdk.utils.Validate; +import software.amazon.awssdk.utils.builder.CopyableBuilder; +import software.amazon.awssdk.utils.builder.ToCopyableBuilder; /** * Callback interface to handle a streaming asynchronous response. @@ -38,8 +43,8 @@ *

Synchronization

*

* All operations, including those called on the {@link org.reactivestreams.Subscriber} of the stream are guaranteed to be - * synchronized externally; i.e. no two methods on this interface or on the {@link org.reactivestreams.Subscriber} will be - * invoked concurrently. It is not guaranteed that the methods will being invoked by the same thread. + * synchronized externally; i.e. no two methods on this interface or on the {@link org.reactivestreams.Subscriber} will be invoked + * concurrently. It is not guaranteed that the methods will being invoked by the same thread. *

*

Invocation Order

*

@@ -81,11 +86,10 @@ public interface AsyncResponseTransformer { /** * Initial call to enable any setup required before the response is handled. *

- * Note that this will be called for each request attempt, up to the number of retries allowed by the configured {@link - * software.amazon.awssdk.core.retry.RetryPolicy}. + * Note that this will be called for each request attempt, up to the number of retries allowed by the configured + * {@link software.amazon.awssdk.core.retry.RetryPolicy}. *

- * This method is guaranteed to be called before the request is executed, and before {@link #onResponse(Object)} is - * signaled. + * This method is guaranteed to be called before the request is executed, and before {@link #onResponse(Object)} is signaled. * * @return The future holding the transformed response. */ @@ -106,18 +110,58 @@ public interface AsyncResponseTransformer { void onStream(SdkPublisher publisher); /** - * Called when an error is encountered while making the request or receiving the response. - * Implementations should free up any resources in this method. This method may be called - * multiple times during the lifecycle of a request if automatic retries are enabled. + * Called when an error is encountered while making the request or receiving the response. Implementations should free up any + * resources in this method. This method may be called multiple times during the lifecycle of a request if automatic retries + * are enabled. * * @param error Error that occurred. */ void exceptionOccurred(Throwable error); /** - * Creates an {@link AsyncResponseTransformer} that writes all the content to the given file. In the event of an error, - * the SDK will attempt to delete the file (whatever has been written to it so far). If the file already exists, an - * exception will be thrown. + * Creates an {@link SplitResult} which contains an {@link SplittingTransformer} that splits the + * {@link AsyncResponseTransformer} into multiple ones, publishing them as a {@link SdkPublisher}. + * + * @param splitConfig configuration for the split transformer + * @return SplitAsyncResponseTransformer instance. + * @see SplittingTransformer + * @see SplitResult + */ + default SplitResult split(SplittingTransformerConfiguration splitConfig) { + Validate.notNull(splitConfig, "splitConfig must not be null"); + CompletableFuture future = new CompletableFuture<>(); + SdkPublisher> transformer = SplittingTransformer + .builder() + .upstreamResponseTransformer(this) + .maximumBufferSizeInBytes(splitConfig.bufferSizeInBytes()) + .resultFuture(future) + .build(); + return AsyncResponseTransformer.SplitResult.builder() + .publisher(transformer) + .resultFuture(future) + .build(); + } + + /** + * Creates an {@link SplitResult} which contains an {@link SplittingTransformer} that splits the + * {@link AsyncResponseTransformer} into multiple ones, publishing them as a {@link SdkPublisher}. + * + * @param splitConfig configuration for the split transformer + * @return SplitAsyncResponseTransformer instance. + * @see SplittingTransformer + * @see SplitResult + */ + default SplitResult split(Consumer splitConfig) { + SplittingTransformerConfiguration conf = SplittingTransformerConfiguration.builder() + .applyMutation(splitConfig) + .build(); + return split(conf); + } + + /** + * Creates an {@link AsyncResponseTransformer} that writes all the content to the given file. In the event of an error, the + * SDK will attempt to delete the file (whatever has been written to it so far). If the file already exists, an exception will + * be thrown. * * @param path Path to file to write to. * @param Pojo Response type. @@ -129,8 +173,8 @@ static AsyncResponseTransformer toFile(Path pa } /** - * Creates an {@link AsyncResponseTransformer} that writes all the content to the given file with the specified {@link - * FileTransformerConfiguration}. + * Creates an {@link AsyncResponseTransformer} that writes all the content to the given file with the specified + * {@link FileTransformerConfiguration}. * * @param path Path to file to write to. * @param config configuration for the transformer @@ -143,8 +187,8 @@ static AsyncResponseTransformer toFile(Path pa } /** - * This is a convenience method that creates an instance of the {@link FileTransformerConfiguration} builder, - * avoiding the need to create one manually via {@link FileTransformerConfiguration#builder()}. + * This is a convenience method that creates an instance of the {@link FileTransformerConfiguration} builder, avoiding the + * need to create one manually via {@link FileTransformerConfiguration#builder()}. * * @see #toFile(Path, FileTransformerConfiguration) */ @@ -155,9 +199,9 @@ static AsyncResponseTransformer toFile( } /** - * Creates an {@link AsyncResponseTransformer} that writes all the content to the given file. In the event of an error, - * the SDK will attempt to delete the file (whatever has been written to it so far). If the file already exists, an - * exception will be thrown. + * Creates an {@link AsyncResponseTransformer} that writes all the content to the given file. In the event of an error, the + * SDK will attempt to delete the file (whatever has been written to it so far). If the file already exists, an exception will + * be thrown. * * @param file File to write to. * @param Pojo Response type. @@ -168,8 +212,8 @@ static AsyncResponseTransformer toFile(File fi } /** - * Creates an {@link AsyncResponseTransformer} that writes all the content to the given file with the specified {@link - * FileTransformerConfiguration}. + * Creates an {@link AsyncResponseTransformer} that writes all the content to the given file with the specified + * {@link FileTransformerConfiguration}. * * @param file File to write to. * @param config configuration for the transformer @@ -182,8 +226,8 @@ static AsyncResponseTransformer toFile(File fi } /** - * This is a convenience method that creates an instance of the {@link FileTransformerConfiguration} builder, - * avoiding the need to create one manually via {@link FileTransformerConfiguration#builder()}. + * This is a convenience method that creates an instance of the {@link FileTransformerConfiguration} builder, avoiding the + * need to create one manually via {@link FileTransformerConfiguration#builder()}. * * @see #toFile(File, FileTransformerConfiguration) */ @@ -237,16 +281,14 @@ static AsyncResponseTransformer - * When this transformer is used with an async client, the {@link CompletableFuture} that the client returns will - * be completed once the {@link SdkResponse} is available and the response body begins streaming. This - * behavior differs from some other transformers, like {@link #toFile(Path)} and {@link #toBytes()}, which only - * have their {@link CompletableFuture} completed after the entire response body has finished streaming. + * When this transformer is used with an async client, the {@link CompletableFuture} that the client returns will be completed + * once the {@link SdkResponse} is available and the response body begins streaming. This behavior differs from some + * other transformers, like {@link #toFile(Path)} and {@link #toBytes()}, which only have their {@link CompletableFuture} + * completed after the entire response body has finished streaming. *

- * You are responsible for performing blocking reads from this input stream and closing the stream when you are - * finished. + * You are responsible for performing blocking reads from this input stream and closing the stream when you are finished. *

* Example usage: *

@@ -260,7 +302,71 @@ static  AsyncResponseTransformer
      */
     static 
-            AsyncResponseTransformer> toBlockingInputStream() {
+        AsyncResponseTransformer> toBlockingInputStream() {
         return new InputStreamResponseTransformer<>();
     }
+
+    /**
+     * Helper interface containing the result of {@link AsyncResponseTransformer#split(SplittingTransformerConfiguration)
+     * splitting} an AsyncResponseTransformer. This class holds both the publisher of the individual
+     * {@code AsyncResponseTransformer} and the {@code CompletableFuture } which will
+     * complete when the {@code AsyncResponseTransformer} that was split itself would complete.
+     *
+     * @param  ResponseT of the original AsyncResponseTransformer that was split.
+     * @param    ResultT of the original AsyncResponseTransformer that was split.
+     * @see AsyncResponseTransformer#split(SplittingTransformerConfiguration)
+     */
+    interface SplitResult
+        extends ToCopyableBuilder,
+        AsyncResponseTransformer.SplitResult> {
+
+        /**
+         * The individual {@link AsyncResponseTransformer} will be available through the publisher returned by this method.
+         *
+         * @return the publisher which publishes the individual {@link AsyncResponseTransformer}
+         */
+        SdkPublisher> publisher();
+
+        /**
+         * The future returned by this method will be completed when the future returned by calling the
+         * {@link AsyncResponseTransformer#prepare()} method on the AsyncResponseTransformer which was split completes.
+         *
+         * @return The future
+         */
+        CompletableFuture resultFuture();
+
+        static  Builder builder() {
+            return DefaultAsyncResponseTransformerSplitResult.builder();
+        }
+
+        interface Builder
+            extends CopyableBuilder,
+            AsyncResponseTransformer.SplitResult> {
+
+            /**
+             * @return the publisher which was configured on this Builder instance.
+             */
+            SdkPublisher> publisher();
+
+            /**
+             * Sets the publisher publishing the individual {@link AsyncResponseTransformer}
+             * @param publisher the publisher
+             * @return an instance of this Builder
+             */
+            Builder publisher(SdkPublisher> publisher);
+
+            /**
+             * @return The future which was configured an this Builder instance.
+             */
+            CompletableFuture resultFuture();
+
+            /**
+             * Sets the future that will be completed when the future returned by calling the
+             * {@link AsyncResponseTransformer#prepare()} method on the AsyncResponseTransformer which was split completes.
+             * @param future the future
+             * @return an instance of this Builder
+             */
+            Builder resultFuture(CompletableFuture future);
+        }
+    }
 }
diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/DefaultAsyncResponseTransformerSplitResult.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/DefaultAsyncResponseTransformerSplitResult.java
new file mode 100644
index 000000000000..ed64b1d8eae4
--- /dev/null
+++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/DefaultAsyncResponseTransformerSplitResult.java
@@ -0,0 +1,105 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ *  http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package software.amazon.awssdk.core.internal.async;
+
+import java.util.concurrent.CompletableFuture;
+import software.amazon.awssdk.annotations.SdkInternalApi;
+import software.amazon.awssdk.core.async.AsyncResponseTransformer;
+import software.amazon.awssdk.core.async.SdkPublisher;
+import software.amazon.awssdk.utils.Validate;
+
+@SdkInternalApi
+public final class DefaultAsyncResponseTransformerSplitResult
+    implements AsyncResponseTransformer.SplitResult {
+
+    private final SdkPublisher> publisher;
+    private final CompletableFuture future;
+
+    private DefaultAsyncResponseTransformerSplitResult(Builder builder) {
+        this.publisher = Validate.paramNotNull(
+            builder.publisher(), "asyncResponseTransformerPublisher");
+        this.future = Validate.paramNotNull(
+            builder.resultFuture(), "future");
+    }
+
+    /**
+     * The individual {@link AsyncResponseTransformer} will be available through the publisher returned by this method.
+     * @return the publisher which publishes the individual {@link AsyncResponseTransformer}
+     */
+    public SdkPublisher> publisher() {
+        return this.publisher;
+    }
+
+    /**
+     * The future returned by this method will be completed when the future returned by calling the
+     * {@link AsyncResponseTransformer#prepare()} method on the AsyncResponseTransformer which was split completes.
+     * @return The future
+     */
+    public CompletableFuture resultFuture() {
+        return this.future;
+    }
+
+    @Override
+    public AsyncResponseTransformer.SplitResult.Builder toBuilder() {
+        return new DefaultBuilder<>(this);
+    }
+
+    public static  DefaultBuilder builder() {
+        return new DefaultBuilder<>();
+    }
+
+    public static class DefaultBuilder
+        implements AsyncResponseTransformer.SplitResult.Builder {
+        private SdkPublisher> publisher;
+        private CompletableFuture future;
+
+        DefaultBuilder() {
+        }
+
+        DefaultBuilder(DefaultAsyncResponseTransformerSplitResult split) {
+            this.publisher = split.publisher;
+            this.future = split.future;
+        }
+
+        @Override
+        public SdkPublisher> publisher() {
+            return this.publisher;
+        }
+
+        @Override
+        public AsyncResponseTransformer.SplitResult.Builder publisher(
+            SdkPublisher> publisher) {
+            this.publisher = publisher;
+            return this;
+        }
+
+        @Override
+        public CompletableFuture resultFuture() {
+            return this.future;
+        }
+
+        @Override
+        public AsyncResponseTransformer.SplitResult.Builder resultFuture(CompletableFuture future) {
+            this.future = future;
+            return this;
+        }
+
+        @Override
+        public AsyncResponseTransformer.SplitResult build() {
+            return new DefaultAsyncResponseTransformerSplitResult<>(this);
+        }
+    }
+}
diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformer.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformer.java
index 27901ff61fd9..9d0bdf560af2 100644
--- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformer.java
+++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformer.java
@@ -16,6 +16,7 @@
 package software.amazon.awssdk.core.internal.async;
 
 import static software.amazon.awssdk.core.FileTransformerConfiguration.FileWriteOption.CREATE_OR_APPEND_TO_EXISTING;
+import static software.amazon.awssdk.core.FileTransformerConfiguration.FileWriteOption.WRITE_TO_POSITION;
 import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely;
 import static software.amazon.awssdk.utils.FunctionalUtils.runAndLogError;
 
@@ -44,6 +45,7 @@
 import software.amazon.awssdk.core.async.SdkPublisher;
 import software.amazon.awssdk.core.exception.SdkClientException;
 import software.amazon.awssdk.utils.Logger;
+import software.amazon.awssdk.utils.Validate;
 
 /**
  * {@link AsyncResponseTransformer} that writes the data to the specified file.
@@ -61,19 +63,21 @@ public final class FileAsyncResponseTransformer implements AsyncRespo
     private final FileTransformerConfiguration configuration;
 
     public FileAsyncResponseTransformer(Path path) {
-        this.path = path;
-        this.configuration = FileTransformerConfiguration.defaultCreateNew();
-        this.position = 0L;
+        this(path, FileTransformerConfiguration.defaultCreateNew(), 0L);
     }
 
     public FileAsyncResponseTransformer(Path path, FileTransformerConfiguration fileConfiguration) {
+        this(path, fileConfiguration, determineFilePositionToWrite(path, fileConfiguration));
+    }
+
+    private FileAsyncResponseTransformer(Path path, FileTransformerConfiguration fileTransformerConfiguration, long position) {
         this.path = path;
-        this.configuration = fileConfiguration;
-        this.position = determineFilePositionToWrite(path);
+        this.configuration = fileTransformerConfiguration;
+        this.position = position;
     }
 
-    private long determineFilePositionToWrite(Path path) {
-        if (configuration.fileWriteOption() == CREATE_OR_APPEND_TO_EXISTING) {
+    private static long determineFilePositionToWrite(Path path, FileTransformerConfiguration fileConfiguration) {
+        if (fileConfiguration.fileWriteOption() == CREATE_OR_APPEND_TO_EXISTING) {
             try {
                 return Files.size(path);
             } catch (NoSuchFileException e) {
@@ -82,6 +86,9 @@ private long determineFilePositionToWrite(Path path) {
                 throw SdkClientException.create("Cannot determine the current file size " + path, exception);
             }
         }
+        if (fileConfiguration.fileWriteOption() == WRITE_TO_POSITION) {
+            return Validate.getOrDefault(fileConfiguration.position(), () -> 0L);
+        }
         return  0L;
     }
 
@@ -98,6 +105,9 @@ private AsynchronousFileChannel createChannel(Path path) throws IOException {
             case CREATE_NEW:
                 Collections.addAll(options, StandardOpenOption.WRITE, StandardOpenOption.CREATE_NEW);
                 break;
+            case WRITE_TO_POSITION:
+                Collections.addAll(options, StandardOpenOption.WRITE);
+                break;
             default:
                 throw new IllegalArgumentException("Unsupported file write option: " + configuration.fileWriteOption());
         }
@@ -151,7 +161,12 @@ public void exceptionOccurred(Throwable throwable) {
                                () -> Files.deleteIfExists(path));
             }
         }
-        cf.completeExceptionally(throwable);
+        if (cf != null) {
+            cf.completeExceptionally(throwable);
+        } else {
+            log.warn(() -> "An exception occurred before the call to prepare() was able to instantiate the CompletableFuture."
+                           + "The future cannot be completed exceptionally because it is null");
+        }
     }
 
     /**
@@ -234,11 +249,14 @@ public void onError(Throwable t) {
 
         @Override
         public void onComplete() {
+            log.trace(() -> "onComplete");
             // if write in progress, tell write to close on finish.
             synchronized (this) {
                 if (writeInProgress) {
+                    log.trace(() -> "writeInProgress = true, not closing");
                     closeOnLastWrite = true;
                 } else {
+                    log.trace(() -> "writeInProgress = false, closing");
                     close();
                 }
             }
@@ -249,6 +267,7 @@ private void close() {
                 if (fileChannel != null) {
                     invokeSafely(fileChannel::close);
                 }
+                log.trace(() -> "Completing File async transformer future future");
                 future.complete(null);
             } catch (RuntimeException exception) {
                 future.completeExceptionally(exception);
diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingTransformer.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingTransformer.java
new file mode 100644
index 000000000000..2c76bbc1d88f
--- /dev/null
+++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingTransformer.java
@@ -0,0 +1,443 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ *  http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package software.amazon.awssdk.core.internal.async;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicLong;
+import org.reactivestreams.Subscriber;
+import org.reactivestreams.Subscription;
+import software.amazon.awssdk.annotations.SdkInternalApi;
+import software.amazon.awssdk.core.SplittingTransformerConfiguration;
+import software.amazon.awssdk.core.async.AsyncResponseTransformer;
+import software.amazon.awssdk.core.async.SdkPublisher;
+import software.amazon.awssdk.utils.CompletableFutureUtils;
+import software.amazon.awssdk.utils.Logger;
+import software.amazon.awssdk.utils.Validate;
+import software.amazon.awssdk.utils.async.DelegatingBufferingSubscriber;
+import software.amazon.awssdk.utils.async.SimplePublisher;
+
+/**
+ * Split a {@link AsyncResponseTransformer} into multiple ones, publishing them as a {@link SdkPublisher}. Created using the
+ * {@link AsyncResponseTransformer#split(SplittingTransformerConfiguration) split} method. The upstream
+ * {@link AsyncResponseTransformer} that is split will receive data from the individual transformers.
+ * 

+ * This publisher also buffers an amount of data before sending it to the upstream transformer, as specified by the + * maximumBufferSize. ByteBuffers will be published once the buffer has been reached, or when the subscription to this publisher + * is cancelled. + *

+ * Cancelling the subscription to this publisher signals that no more data needs to be sent to the upstream transformer. This + * publisher will then send all data currently buffered to the upstream transformer and complete the downstream subscriber. + */ +@SdkInternalApi +public class SplittingTransformer implements SdkPublisher> { + + private static final Logger log = Logger.loggerFor(SplittingTransformer.class); + + /** + * The AsyncResponseTransformer on which the {@link AsyncResponseTransformer#split(SplittingTransformerConfiguration) split} + * method was called. + */ + private final AsyncResponseTransformer upstreamResponseTransformer; + + /** + * Set to true once {@code .prepare()} is called on the upstreamResponseTransformer + */ + private final AtomicBoolean preparedCalled = new AtomicBoolean(false); + + /** + * Set to true once {@code .onResponse()} is called on the upstreamResponseTransformer + */ + private final AtomicBoolean onResponseCalled = new AtomicBoolean(false); + + /** + * Set to true once {@code .onStream()} is called on the upstreamResponseTransformer + */ + private final AtomicBoolean onStreamCalled = new AtomicBoolean(false); + + /** + * Set to true once {@code .cancel()} is called in the subscription of the downstream subscriber, or if the + * {@code resultFuture} is cancelled. + */ + private final AtomicBoolean isCancelled = new AtomicBoolean(false); + + /** + * Future to track the status of the upstreamResponseTransformer. Will be completed when the future returned by calling + * {@code prepare()} on the upstreamResponseTransformer itself completes. + */ + private final CompletableFuture resultFuture; + + /** + * The buffer size used to buffer the content received from the downstream subscriber + */ + private final long maximumBufferInBytes; + + /** + * This publisher is used to send the bytes received from the downstream subscriber's transformers to a + * {@link DelegatingBufferingSubscriber} that will buffer a number of bytes up to {@code maximumBufferSize}. + */ + private final SimplePublisher publisherToUpstream = new SimplePublisher<>(); + + /** + * The downstream subscriber that is subscribed to this publisher. + */ + private Subscriber> downstreamSubscriber; + + /** + * The amount requested by the downstream subscriber that is still left to fulfill. Updated. when the + * {@link Subscription#request(long) request} method is called on the downstream subscriber's subscription. Corresponds to the + * number of {@code AsyncResponseTransformer} that will be published to the downstream subscriber. + */ + private final AtomicLong outstandingDemand = new AtomicLong(0); + + /** + * This flag stops the current thread from publishing transformers while another thread is already publishing. + */ + private final AtomicBoolean emitting = new AtomicBoolean(false); + + private final Object cancelLock = new Object(); + + private SplittingTransformer(AsyncResponseTransformer upstreamResponseTransformer, + Long maximumBufferSizeInBytes, + CompletableFuture resultFuture) { + this.upstreamResponseTransformer = Validate.paramNotNull( + upstreamResponseTransformer, "upstreamResponseTransformer"); + this.resultFuture = Validate.paramNotNull( + resultFuture, "resultFuture"); + Validate.notNull(maximumBufferSizeInBytes, "maximumBufferSizeInBytes"); + this.maximumBufferInBytes = Validate.isPositive( + maximumBufferSizeInBytes, "maximumBufferSizeInBytes"); + + this.resultFuture.whenComplete((r, e) -> { + if (e == null) { + return; + } + if (isCancelled.compareAndSet(false, true)) { + handleFutureCancel(e); + } + }); + } + + /** + * @param downstreamSubscriber the {@link Subscriber} to the individual AsyncResponseTransformer + */ + @Override + public void subscribe(Subscriber> downstreamSubscriber) { + if (downstreamSubscriber == null) { + throw new NullPointerException("downstreamSubscriber must not be null"); + } + this.downstreamSubscriber = downstreamSubscriber; + downstreamSubscriber.onSubscribe(new DownstreamSubscription()); + } + + /** + * The subscription implementation for the subscriber to this SplittingTransformer. + */ + private final class DownstreamSubscription implements Subscription { + + @Override + public void request(long n) { + if (n <= 0) { + downstreamSubscriber.onError(new IllegalArgumentException("Amount requested must be positive")); + return; + } + long newDemand = outstandingDemand.updateAndGet(current -> { + if (Long.MAX_VALUE - current < n) { + return Long.MAX_VALUE; + } + return current + n; + }); + log.trace(() -> String.format("new outstanding demand: %s", newDemand)); + emit(); + } + + @Override + public void cancel() { + log.trace(() -> String.format("received cancel signal. Current cancel state is 'isCancelled=%s'", isCancelled.get())); + if (isCancelled.compareAndSet(false, true)) { + handleSubscriptionCancel(); + } + } + } + + private void emit() { + do { + if (!emitting.compareAndSet(false, true)) { + return; + } + try { + if (doEmit()) { + return; + } + } finally { + emitting.compareAndSet(true, false); + } + } while (outstandingDemand.get() > 0); + } + + private boolean doEmit() { + long demand = outstandingDemand.get(); + + while (demand > 0) { + if (isCancelled.get()) { + return true; + } + if (outstandingDemand.get() > 0) { + demand = outstandingDemand.decrementAndGet(); + downstreamSubscriber.onNext(new IndividualTransformer()); + } + } + return false; + } + + /** + * Handle the {@code .cancel()} signal received from the downstream subscription. Data that is being sent to the upstream + * transformer need to finish processing before we complete. One typical use case for this is completing the multipart + * download, the subscriber having reached the final part will signal that it doesn't need more parts by calling + * {@code .cancel()} on the subscription. + */ + private void handleSubscriptionCancel() { + synchronized (cancelLock) { + if (downstreamSubscriber == null) { + log.trace(() -> "downstreamSubscriber already null, skipping downstreamSubscriber.onComplete()"); + return; + } + if (!onStreamCalled.get()) { + // we never subscribe publisherToUpstream to the upstream, it would not complete + downstreamSubscriber = null; + return; + } + publisherToUpstream.complete().whenComplete((v, t) -> { + if (downstreamSubscriber == null) { + return; + } + if (t != null) { + downstreamSubscriber.onError(t); + } else { + log.trace(() -> "calling downstreamSubscriber.onComplete()"); + downstreamSubscriber.onComplete(); + } + downstreamSubscriber = null; + }); + } + } + + /** + * Handle when the {@link SplittingTransformer#resultFuture} is cancelled or completed exceptionally from the outside. Data + * need to stop being sent to the upstream transformer immediately. One typical use case for this is transfer manager needing + * to pause download by calling {@code .cancel(true)} on the future. + * + * @param e The exception the future was complete exceptionally with. + */ + private void handleFutureCancel(Throwable e) { + synchronized (cancelLock) { + publisherToUpstream.error(e); + if (downstreamSubscriber != null) { + downstreamSubscriber.onError(e); + downstreamSubscriber = null; + } + } + } + + /** + * The AsyncResponseTransformer for each of the individual requests that is sent back to the downstreamSubscriber when + * requested. A future is created per request that is completed when onComplete is called on the subscriber for that request + * body publisher. + */ + private class IndividualTransformer implements AsyncResponseTransformer { + private ResponseT response; + private CompletableFuture individualFuture; + + @Override + public CompletableFuture prepare() { + this.individualFuture = new CompletableFuture<>(); + if (preparedCalled.compareAndSet(false, true)) { + if (isCancelled.get()) { + return individualFuture; + } + log.trace(() -> "calling prepare on the upstream transformer"); + CompletableFuture upstreamFuture = upstreamResponseTransformer.prepare(); + if (!resultFuture.isDone()) { + CompletableFutureUtils.forwardResultTo(upstreamFuture, resultFuture); + } + } + resultFuture.whenComplete((r, e) -> { + if (e == null) { + return; + } + individualFuture.completeExceptionally(e); + }); + individualFuture.whenComplete((r, e) -> { + if (isCancelled.get()) { + handleSubscriptionCancel(); + } + }); + return this.individualFuture; + } + + @Override + public void onResponse(ResponseT response) { + if (onResponseCalled.compareAndSet(false, true)) { + log.trace(() -> "calling onResponse on the upstream transformer"); + upstreamResponseTransformer.onResponse(response); + } + this.response = response; + } + + @Override + public void onStream(SdkPublisher publisher) { + if (downstreamSubscriber == null) { + return; + } + synchronized (cancelLock) { + if (onStreamCalled.compareAndSet(false, true)) { + log.trace(() -> "calling onStream on the upstream transformer"); + upstreamResponseTransformer.onStream(upstreamSubscriber -> publisherToUpstream.subscribe( + DelegatingBufferingSubscriber.builder() + .maximumBufferInBytes(maximumBufferInBytes) + .delegate(upstreamSubscriber) + .build()) + ); + } + } + publisher.subscribe(new IndividualPartSubscriber<>(this.individualFuture, response)); + } + + @Override + public void exceptionOccurred(Throwable error) { + publisherToUpstream.error(error); + log.trace(() -> "calling exceptionOccurred on the upstream transformer"); + upstreamResponseTransformer.exceptionOccurred(error); + } + } + + /** + * the Subscriber for each of the individual request's ByteBuffer publisher + */ + class IndividualPartSubscriber implements Subscriber { + + private final CompletableFuture future; + private final T response; + private Subscription subscription; + + IndividualPartSubscriber(CompletableFuture future, T response) { + this.future = future; + this.response = response; + } + + @Override + public void onSubscribe(Subscription s) { + if (this.subscription != null) { + s.cancel(); + return; + } + this.subscription = s; + s.request(1); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + if (byteBuffer == null) { + throw new NullPointerException("onNext must not be called with null byteBuffer"); + } + publisherToUpstream.send(byteBuffer).whenComplete((r, t) -> { + if (t != null) { + handleError(t); + return; + } + if (!isCancelled.get()) { + subscription.request(1); + } + }); + } + + @Override + public void onError(Throwable t) { + handleError(t); + } + + @Override + public void onComplete() { + future.complete(response); + } + + private void handleError(Throwable t) { + publisherToUpstream.error(t); + future.completeExceptionally(t); + } + } + + public static Builder builder() { + return new Builder<>(); + } + + public static final class Builder { + + private Long maximumBufferSize; + private CompletableFuture returnFuture; + private AsyncResponseTransformer upstreamResponseTransformer; + + private Builder() { + } + + /** + * The {@link AsyncResponseTransformer} that will receive the data from each of the individually published + * {@link IndividualTransformer}, usually intended to be the one on which the + * {@link AsyncResponseTransformer#split(SplittingTransformerConfiguration)})} method was called. + * + * @param upstreamResponseTransformer the {@code AsyncResponseTransformer} that was split. + * @return an instance of this builder + */ + public Builder upstreamResponseTransformer( + AsyncResponseTransformer upstreamResponseTransformer) { + this.upstreamResponseTransformer = upstreamResponseTransformer; + return this; + } + + /** + * The amount of data in byte this publisher will buffer into memory before sending it to the upstream transformer. The + * data will be sent if chunk of {@code maximumBufferSize} to the upstream transformer unless the subscription is + * cancelled while less amount is buffered, in which case a chunk with a size less than {@code maximumBufferSize} will be + * sent. + * + * @param maximumBufferSize the amount of data buffered and the size of the chunk of data + * @return an instance of this builder + */ + public Builder maximumBufferSizeInBytes(Long maximumBufferSize) { + this.maximumBufferSize = maximumBufferSize; + return this; + } + + /** + * The future that will be completed when the future which is returned by the call to + * {@link AsyncResponseTransformer#prepare()} completes. + * + * @param returnFuture the future to complete. + * @return an instance of this builder + */ + public Builder resultFuture(CompletableFuture returnFuture) { + this.returnFuture = returnFuture; + return this; + } + + public SplittingTransformer build() { + return new SplittingTransformer<>(this.upstreamResponseTransformer, + this.maximumBufferSize, + this.returnFuture); + } + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/FileTransformerConfigurationTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/FileTransformerConfigurationTest.java index 9eb426ee3ac6..a0e533a9f1a1 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/FileTransformerConfigurationTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/FileTransformerConfigurationTest.java @@ -16,14 +16,33 @@ package software.amazon.awssdk.core; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static software.amazon.awssdk.core.FileTransformerConfiguration.FailureBehavior.DELETE; import static software.amazon.awssdk.core.FileTransformerConfiguration.FileWriteOption.CREATE_NEW; import nl.jqno.equalsverifier.EqualsVerifier; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; class FileTransformerConfigurationTest { + @ParameterizedTest + @EnumSource( + value = FileTransformerConfiguration.FileWriteOption.class, + names = {"CREATE_NEW", "CREATE_OR_REPLACE_EXISTING", "CREATE_OR_APPEND_TO_EXISTING"}) + void position_whenUsedWithNotWriteToPosition_shouldThrowIllegalArgumentException( + FileTransformerConfiguration.FileWriteOption fileWriteOption) { + FileTransformerConfiguration.Builder builder = FileTransformerConfiguration.builder() + .position(123L) + .failureBehavior(DELETE) + .fileWriteOption(fileWriteOption); + assertThatThrownBy(builder::build) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining(fileWriteOption.name()); + } + @Test void equalsHashcode() { EqualsVerifier.forClass(FileTransformerConfiguration.class) diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/SplittingTransformerConfigurationTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/SplittingTransformerConfigurationTest.java new file mode 100644 index 000000000000..68df2a61b8fc --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/SplittingTransformerConfigurationTest.java @@ -0,0 +1,43 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import nl.jqno.equalsverifier.EqualsVerifier; +import org.junit.jupiter.api.Test; + +public class SplittingTransformerConfigurationTest { + + @Test + void equalsHashcode() { + EqualsVerifier.forClass(SplittingTransformerConfiguration.class) + .withNonnullFields("bufferSizeInBytes") + .verify(); + + } + + @Test + void toBuilder() { + SplittingTransformerConfiguration configuration = + SplittingTransformerConfiguration.builder() + .bufferSizeInBytes(4444L) + .build(); + + SplittingTransformerConfiguration another = configuration.toBuilder().build(); + assertThat(configuration).isEqualTo(another); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerTest.java index bbb6891e15d8..1f0973849b32 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerTest.java @@ -28,9 +28,9 @@ import java.nio.file.FileAlreadyExistsException; import java.nio.file.FileSystem; import java.nio.file.Files; +import java.nio.file.NoSuchFileException; import java.nio.file.Path; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.concurrent.Callable; @@ -50,6 +50,7 @@ import org.reactivestreams.Subscription; import software.amazon.awssdk.core.FileTransformerConfiguration; import software.amazon.awssdk.core.FileTransformerConfiguration.FileWriteOption; +import software.amazon.awssdk.core.FileTransformerConfiguration.FailureBehavior; import software.amazon.awssdk.core.async.SdkPublisher; /** @@ -186,8 +187,11 @@ void createOrAppendExisting_fileExists_shouldAppend() throws Exception { @MethodSource("configurations") void exceptionOccurred_deleteFileBehavior(FileTransformerConfiguration configuration) throws Exception { Path testPath = testFs.getPath("test_file.txt"); - FileAsyncResponseTransformer transformer = new FileAsyncResponseTransformer<>(testPath, - configuration); + if (configuration.fileWriteOption() == FileWriteOption.WRITE_TO_POSITION) { + // file must exist for WRITE_TO_POSITION + Files.write(testPath, "foobar".getBytes(StandardCharsets.UTF_8)); + } + FileAsyncResponseTransformer transformer = new FileAsyncResponseTransformer<>(testPath, configuration); stubException(RandomStringUtils.random(200), transformer); if (configuration.failureBehavior() == LEAVE) { assertThat(testPath).exists(); @@ -197,28 +201,19 @@ void exceptionOccurred_deleteFileBehavior(FileTransformerConfiguration configura } private static List configurations() { - return Arrays.asList( - FileTransformerConfiguration.defaultCreateNew(), - FileTransformerConfiguration.defaultCreateOrAppend(), - FileTransformerConfiguration.defaultCreateOrReplaceExisting(), - FileTransformerConfiguration.builder() - .fileWriteOption(FileWriteOption.CREATE_NEW) - .failureBehavior(LEAVE).build(), - FileTransformerConfiguration.builder() - .fileWriteOption(FileWriteOption.CREATE_NEW) - .failureBehavior(DELETE).build(), - FileTransformerConfiguration.builder() - .fileWriteOption(FileWriteOption.CREATE_OR_APPEND_TO_EXISTING) - .failureBehavior(DELETE).build(), - FileTransformerConfiguration.builder() - .fileWriteOption(FileWriteOption.CREATE_OR_APPEND_TO_EXISTING) - .failureBehavior(LEAVE).build(), - FileTransformerConfiguration.builder() - .fileWriteOption(FileWriteOption.CREATE_OR_REPLACE_EXISTING) - .failureBehavior(DELETE).build(), - FileTransformerConfiguration.builder() - .fileWriteOption(FileWriteOption.CREATE_OR_REPLACE_EXISTING) - .failureBehavior(LEAVE).build()); + List conf = new ArrayList<>(); + conf.add(FileTransformerConfiguration.defaultCreateNew()); + conf.add(FileTransformerConfiguration.defaultCreateOrAppend()); + conf.add(FileTransformerConfiguration.defaultCreateOrReplaceExisting()); + for (FailureBehavior failureBehavior : FailureBehavior.values()) { + for (FileWriteOption fileWriteOption : FileWriteOption.values()) { + conf.add(FileTransformerConfiguration.builder() + .fileWriteOption(fileWriteOption) + .failureBehavior(failureBehavior) + .build()); + } + } + return conf; } @Test @@ -247,6 +242,70 @@ void explicitExecutor_shouldUseExecutor() throws Exception { } } + @Test + void writeToPosition_fileExists_shouldAppendFromPosition() throws Exception { + int totalSize = 100; + long prefixSize = 80L; + int newContentLength = 20; + + Path testPath = testFs.getPath("test_file.txt"); + String contentBeforeRewrite = RandomStringUtils.randomAlphanumeric(totalSize); + byte[] existingBytes = contentBeforeRewrite.getBytes(StandardCharsets.UTF_8); + Files.write(testPath, existingBytes); + String newContent = RandomStringUtils.randomAlphanumeric(newContentLength); + FileAsyncResponseTransformer transformer = new FileAsyncResponseTransformer<>( + testPath, + FileTransformerConfiguration.builder() + .position(prefixSize) + .failureBehavior(DELETE) + .fileWriteOption(FileWriteOption.WRITE_TO_POSITION) + .build()); + + stubSuccessfulStreaming(newContent, transformer); + + String expectedContent = contentBeforeRewrite.substring(0, 80) + newContent; + assertThat(testPath).hasContent(expectedContent); + } + + @Test + void writeToPosition_fileDoesNotExists_shouldThrowException() throws Exception { + Path path = testFs.getPath("this/file/does/not/exists"); + FileAsyncResponseTransformer transformer = new FileAsyncResponseTransformer<>( + path, + FileTransformerConfiguration.builder() + .position(0L) + .failureBehavior(DELETE) + .fileWriteOption(FileWriteOption.WRITE_TO_POSITION) + .build()); + CompletableFuture future = transformer.prepare(); + transformer.onResponse("foobar"); + assertThatThrownBy(() -> { + transformer.onStream(testPublisher("foo-bar-content")); + future.get(10, TimeUnit.SECONDS); + }).hasRootCauseInstanceOf(NoSuchFileException.class); + } + + @Test + void writeToPosition_fileExists_positionNotDefined_shouldRewriteFromStart() throws Exception { + int totalSize = 100; + Path testPath = testFs.getPath("test_file.txt"); + String contentBeforeRewrite = RandomStringUtils.randomAlphanumeric(totalSize); + byte[] existingBytes = contentBeforeRewrite.getBytes(StandardCharsets.UTF_8); + Files.write(testPath, existingBytes); + String newContent = RandomStringUtils.randomAlphanumeric(totalSize); + FileAsyncResponseTransformer transformer = new FileAsyncResponseTransformer<>( + testPath, + FileTransformerConfiguration.builder() + .failureBehavior(DELETE) + .fileWriteOption(FileWriteOption.WRITE_TO_POSITION) + .build()); + + stubSuccessfulStreaming(newContent, transformer); + + assertThat(testPath).hasContent(newContent); + + } + @Test void onStreamFailed_shouldCompleteFutureExceptionally() { Path testPath = testFs.getPath("test_file.txt"); @@ -275,9 +334,9 @@ private static void stubException(String newContent, FileAsyncResponseTransforme transformer.onStream(SdkPublisher.adapt(Flowable.just(content, content))); transformer.exceptionOccurred(runtimeException); - assertThatThrownBy(() -> future.get(10, TimeUnit.SECONDS)) - .hasRootCause(runtimeException); - assertThat(future.isCompletedExceptionally()).isTrue(); + assertThat(future).failsWithin(1, TimeUnit.SECONDS) + .withThrowableOfType(ExecutionException.class) + .withCause(runtimeException); } private static SdkPublisher testPublisher(String content) { diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/IndividualPartSubscriberTckTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/IndividualPartSubscriberTckTest.java new file mode 100644 index 000000000000..a72a3ab7aa1f --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/IndividualPartSubscriberTckTest.java @@ -0,0 +1,90 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + + +import java.nio.ByteBuffer; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.reactivestreams.tck.SubscriberWhiteboxVerification; +import org.reactivestreams.tck.TestEnvironment; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.utils.async.SimplePublisher; + +public class IndividualPartSubscriberTckTest extends SubscriberWhiteboxVerification { + + private static final byte[] DATA = {0, 1, 2, 3, 4, 5, 6, 7}; + + protected IndividualPartSubscriberTckTest() { + super(new TestEnvironment()); + } + + @Override + public Subscriber createSubscriber(WhiteboxSubscriberProbe probe) { + CompletableFuture future = new CompletableFuture<>(); + SimplePublisher publisher = new SimplePublisher<>(); + SplittingTransformer> transformer = + SplittingTransformer.>builder() + .upstreamResponseTransformer(AsyncResponseTransformer.toBytes()) + .maximumBufferSizeInBytes(32L) + .resultFuture(new CompletableFuture<>()) + .build(); + return transformer.new IndividualPartSubscriber(future, ByteBuffer.wrap(new byte[0])) { + @Override + public void onSubscribe(Subscription s) { + super.onSubscribe(s); + probe.registerOnSubscribe(new SubscriberPuppet() { + @Override + public void triggerRequest(long l) { + s.request(l); + } + + @Override + public void signalCancel() { + s.cancel(); + } + }); + } + + @Override + public void onNext(ByteBuffer bb) { + super.onNext(bb); + probe.registerOnNext(bb); + } + + @Override + public void onError(Throwable t) { + super.onError(t); + probe.registerOnError(t); + } + + @Override + public void onComplete() { + super.onComplete(); + probe.registerOnComplete(); + } + + }; + } + + @Override + public ByteBuffer createElement(int element) { + return ByteBuffer.wrap(DATA); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingTransformerTckTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingTransformerTckTest.java new file mode 100644 index 000000000000..7826fc2e80fc --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingTransformerTckTest.java @@ -0,0 +1,54 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import java.util.concurrent.CompletableFuture; +import org.reactivestreams.Publisher; +import org.reactivestreams.tck.PublisherVerification; +import org.reactivestreams.tck.TestEnvironment; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.SdkPublisher; + +public class SplittingTransformerTckTest extends PublisherVerification> { + + public SplittingTransformerTckTest() { + super(new TestEnvironment()); + } + + @Override + public Publisher> createPublisher(long l) { + CompletableFuture> future = new CompletableFuture<>(); + AsyncResponseTransformer> upstreamTransformer = AsyncResponseTransformer.toBytes(); + SplittingTransformer> transformer = + SplittingTransformer.>builder() + .upstreamResponseTransformer(upstreamTransformer) + .maximumBufferSizeInBytes(64 * 1024L) + .resultFuture(future) + .build(); + return SdkPublisher.adapt(transformer).limit(Math.toIntExact(l)); + } + + @Override + public Publisher> createFailedPublisher() { + return null; + } + + @Override + public long maxElementsFromPublisher() { + return Long.MAX_VALUE; + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingTransformerTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingTransformerTest.java new file mode 100644 index 000000000000..d0f1e75b68ca --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingTransformerTest.java @@ -0,0 +1,429 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.fail; + +import java.nio.ByteBuffer; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.utils.Logger; + +class SplittingTransformerTest { + private static final Logger log = Logger.loggerFor(SplittingTransformerTest.class); + + @Test + void whenSubscriberCancelSubscription_AllDataSentToTransformer() { + UpstreamTestTransformer upstreamTestTransformer = new UpstreamTestTransformer(); + CompletableFuture future = new CompletableFuture<>(); + SplittingTransformer split = + SplittingTransformer.builder() + .upstreamResponseTransformer(upstreamTestTransformer) + .maximumBufferSizeInBytes(1024 * 1024 * 32L) + .resultFuture(future) + .build(); + split.subscribe(new CancelAfterNTestSubscriber( + 4, n -> AsyncRequestBody.fromString(String.format("This is the body of %d.", n)))); + future.join(); + String expected = "This is the body of 0.This is the body of 1.This is the body of 2.This is the body of 3."; + assertThat(upstreamTestTransformer.contentAsString()).isEqualTo(expected); + } + + @Test + void whenSubscriberFailsAttempt_UpstreamTransformerCompletesExceptionally() { + UpstreamTestTransformer upstreamTestTransformer = new UpstreamTestTransformer(); + CompletableFuture future = new CompletableFuture<>(); + SplittingTransformer split = + SplittingTransformer.builder() + .upstreamResponseTransformer(upstreamTestTransformer) + .maximumBufferSizeInBytes(1024 * 1024 * 32L) + .resultFuture(future) + .build(); + split.subscribe(new FailAfterNTestSubscriber(2)); + assertThatThrownBy(future::join).hasMessageContaining("TEST ERROR 2"); + } + + @Test + void whenDataExceedsBufferSize_UpstreamShouldReceiveAllData() { + Long evenBufferSize = 16 * 1024L; + + // We send 9 split body of 7kb with a buffer size of 16kb. This is to test when uneven body size is used compared to + // the buffer size, this test use a body size which does not evenly divides with the buffer size. + int unevenBodyLength = 7 * 1024; + int splitAmount = 9; + UpstreamTestTransformer upstreamTestTransformer = new UpstreamTestTransformer(); + CompletableFuture future = new CompletableFuture<>(); + SplittingTransformer split = + SplittingTransformer.builder() + .upstreamResponseTransformer(upstreamTestTransformer) + .maximumBufferSizeInBytes(evenBufferSize) + .resultFuture(future) + .build(); + split.subscribe(new CancelAfterNTestSubscriber( + splitAmount, + n -> { + String content = + IntStream.range(0, unevenBodyLength).mapToObj(i -> String.valueOf(n)).collect(Collectors.joining()); + return AsyncRequestBody.fromString(content); + })); + future.join(); + StringBuilder expected = new StringBuilder(); + for (int i = 0; i < splitAmount; i++) { + int value = i; + expected.append(IntStream.range(0, unevenBodyLength).mapToObj(j -> String.valueOf(value)).collect(Collectors.joining())); + } + assertThat(upstreamTestTransformer.contentAsString()).hasSize(unevenBodyLength * splitAmount); + assertThat(upstreamTestTransformer.contentAsString()).isEqualTo(expected.toString()); + } + + @Test + void whenRequestingMany_allDemandGetsFulfilled() { + UpstreamTestTransformer upstreamTestTransformer = new UpstreamTestTransformer(); + CompletableFuture future = new CompletableFuture<>(); + SplittingTransformer split = + SplittingTransformer.builder() + .upstreamResponseTransformer(upstreamTestTransformer) + .maximumBufferSizeInBytes(1024 * 1024 * 32L) + .resultFuture(future) + .build(); + split.subscribe(new RequestingTestSubscriber(4)); + + future.join(); + String expected = "This is the body of 1.This is the body of 2.This is the body of 3.This is the body of 4."; + assertThat(upstreamTestTransformer.contentAsString()).isEqualTo(expected); + } + + @Test + void negativeBufferSize_shouldThrowIllegalArgument() { + assertThatThrownBy(() -> SplittingTransformer.builder() + .maximumBufferSizeInBytes(-1L) + .upstreamResponseTransformer(new UpstreamTestTransformer()) + .resultFuture(new CompletableFuture<>()) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("maximumBufferSizeInBytes"); + } + + @Test + void nullBufferSize_shouldThrowNullPointerException() { + assertThatThrownBy(() -> SplittingTransformer.builder() + .maximumBufferSizeInBytes(null) + .upstreamResponseTransformer(new UpstreamTestTransformer()) + .resultFuture(new CompletableFuture<>()) + .build()) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("maximumBufferSizeInBytes"); + } + + @Test + void nullUpstreamTransformer_shouldThrowNullPointerException() { + assertThatThrownBy(() -> SplittingTransformer.builder() + .maximumBufferSizeInBytes(1024L) + .upstreamResponseTransformer(null) + .resultFuture(new CompletableFuture<>()) + .build()) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("upstreamResponseTransformer"); + } + + @Test + void nullFuture_shouldThrowNullPointerException() { + assertThatThrownBy(() -> SplittingTransformer.builder() + .maximumBufferSizeInBytes(1024L) + .upstreamResponseTransformer(new UpstreamTestTransformer()) + .resultFuture(null) + .build()) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("resultFuture"); + } + + @Test + void resultFutureCancelled_shouldSignalErrorToSubscriberAndCancelTransformerFuture() { + CompletableFuture future = new CompletableFuture<>(); + UpstreamTestTransformer transformer = new UpstreamTestTransformer(); + SplittingTransformer split = + SplittingTransformer.builder() + .upstreamResponseTransformer(transformer) + .maximumBufferSizeInBytes(1024L) + .resultFuture(future) + .build(); + + ErrorCapturingSubscriber subscriber = new ErrorCapturingSubscriber(); + split.subscribe(subscriber); + + future.cancel(true); + + assertThat(subscriber.error).isNotNull(); + assertThat(subscriber.error).isInstanceOf(CancellationException.class); + + CompletableFuture transformerFuture = transformer.future; + assertThat(transformerFuture).isCancelled(); + } + + private static class ErrorCapturingSubscriber + implements Subscriber> { + + private Subscription subscription; + private Throwable error; + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + s.request(1); + } + + @Override + public void onNext(AsyncResponseTransformer transformer) { + transformer.prepare(); + transformer.onResponse(new TestResultObject("test")); + transformer.onStream(AsyncRequestBody.fromString("test")); + } + + @Override + public void onError(Throwable t) { + this.error = t; + } + + @Override + public void onComplete() { + /* do nothing, test only */ + } + } + + private static class CancelAfterNTestSubscriber + implements Subscriber> { + + private final int n; + private Subscription subscription; + private int total = 0; + private final Function bodySupplier; + + CancelAfterNTestSubscriber(int n, Function bodySupplier) { + this.n = n; + this.bodySupplier = bodySupplier; + } + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + s.request(1); + } + + @Override + public void onNext(AsyncResponseTransformer transformer) { + // simulate what is done during a service call + if (total >= n) { + subscription.cancel(); + return; + } + CompletableFuture future = transformer.prepare(); + future.whenComplete((r, e) -> { + if (e != null) { + fail(e); + } + }); + transformer.onResponse(new TestResultObject("container msg: " + total)); + transformer.onStream(bodySupplier.apply(total)); + total++; + subscription.request(1); + } + + @Override + public void onError(Throwable t) { + fail("Unexpected onError", t); + } + + @Override + public void onComplete() { + // do nothing, test only + } + } + + private static class FailAfterNTestSubscriber + implements Subscriber> { + + private final int n; + private Subscription subscription; + private int total = 0; + + FailAfterNTestSubscriber(int n) { + this.n = n; + } + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + s.request(1); + } + + @Override + public void onNext(AsyncResponseTransformer transformer) { + if (total > n) { + fail("Did not expect more than 2 request to be made"); + } + + transformer.prepare(); + if (total == n) { + transformer.exceptionOccurred(new RuntimeException("TEST ERROR " + total)); + return; + } + + transformer.onResponse(new TestResultObject("container msg: " + total)); + transformer.onStream(AsyncRequestBody.fromString(String.format("This is the body of %d.", total))); + total++; + subscription.request(1); + } + + @Override + public void onError(Throwable t) { + // do nothing, test only + } + + @Override + public void onComplete() { + // do nothing, test only + } + } + + private static class RequestingTestSubscriber + implements Subscriber> { + + private final int totalToRequest; + private Subscription subscription; + private int received = 0; + + RequestingTestSubscriber(int totalToRequest) { + this.totalToRequest = totalToRequest; + } + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + s.request(totalToRequest); + } + + @Override + public void onNext(AsyncResponseTransformer transformer) { + received++; + transformer.prepare(); + transformer.onResponse(new TestResultObject("container msg: " + received)); + transformer.onStream(AsyncRequestBody.fromString(String.format("This is the body of %d.", received))); + if (received >= totalToRequest) { + subscription.cancel(); + } + } + + @Override + public void onError(Throwable t) { + fail("unexpected onError", t); + } + + @Override + public void onComplete() { + // do nothing, test only + } + } + + + private static class UpstreamTestTransformer implements AsyncResponseTransformer { + + private final CompletableFuture future; + private final StringBuilder content = new StringBuilder(); + + UpstreamTestTransformer() { + this.future = new CompletableFuture<>(); + } + + @Override + public CompletableFuture prepare() { + log.info(() -> "[UpstreamTestTransformer] prepare"); + return this.future; + } + + @Override + public void onResponse(TestResultObject response) { + log.info(() -> String.format("[UpstreamTestTransformer] onResponse: %s", response.toString())); + } + + @Override + public void onStream(SdkPublisher publisher) { + log.info(() -> "[UpstreamTestTransformer] onStream"); + publisher.subscribe(new Subscriber() { + private Subscription subscription; + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + ByteBuffer dup = byteBuffer.duplicate(); + byte[] dest = new byte[dup.capacity()]; + dup.position(0); + dup.get(dest); + String str = new String(dest); + content.append(str); + } + + @Override + public void onError(Throwable t) { + future.completeExceptionally(t); + } + + @Override + public void onComplete() { + future.complete(new Object()); + } + }); + } + + @Override + public void exceptionOccurred(Throwable error) { + future.completeExceptionally(error); + } + + public String contentAsString() { + return content.toString(); + } + } + + private static class TestResultObject { + + private final String msg; + + TestResultObject(String msg) { + this.msg = msg; + } + + @Override + public String toString() { + return "TestResultObject{'" + msg + "'}"; + } + } +} diff --git a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3IntegrationTestBase.java b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3IntegrationTestBase.java index 26f963177190..1b4f7f105a38 100644 --- a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3IntegrationTestBase.java +++ b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3IntegrationTestBase.java @@ -15,8 +15,10 @@ package software.amazon.awssdk.transfer.s3; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.provider.Arguments; import software.amazon.awssdk.crt.CrtResource; import software.amazon.awssdk.crt.Log; import software.amazon.awssdk.regions.Region; @@ -68,19 +70,17 @@ public static void setUpForAllIntegTests() throws Exception { Log.initLoggingToStdout(Log.LogLevel.Warn); System.setProperty("aws.crt.debugnative", "true"); s3 = s3ClientBuilder().build(); - s3Async = s3AsyncClientBuilder() - .multipartEnabled(true) - .build(); + s3Async = s3AsyncClientBuilder().build(); s3CrtAsync = S3CrtAsyncClient.builder() .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN) .region(DEFAULT_REGION) .build(); tmCrt = S3TransferManager.builder() - .s3Client(s3CrtAsync) - .build(); - tmJava = S3TransferManager.builder() - .s3Client(s3Async) + .s3Client(s3CrtAsync) .build(); + tmJava = S3TransferManager.builder() + .s3Client(s3Async) + .build(); } @@ -101,6 +101,7 @@ protected static S3ClientBuilder s3ClientBuilder() { protected static S3AsyncClientBuilder s3AsyncClientBuilder() { return S3AsyncClient.builder() + .multipartEnabled(true) .region(DEFAULT_REGION) .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN); } @@ -174,4 +175,10 @@ protected static void deleteBucketAndAllContents(String bucketName) { s3.deleteBucket(DeleteBucketRequest.builder().bucket(bucketName).build()); } + static Stream transferManagers() { + return Stream.of( + Arguments.of(tmCrt), + Arguments.of(tmJava)); + } + } diff --git a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerCopyIntegrationTest.java b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerCopyIntegrationTest.java index 9ac9be30e316..63597a302902 100644 --- a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerCopyIntegrationTest.java +++ b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerCopyIntegrationTest.java @@ -57,15 +57,12 @@ enum TmType{ JAVA, CRT } - private static Stream transferManagers() { - return Stream.of( - Arguments.of(TmType.JAVA), - Arguments.of(TmType.CRT) - ); + private static Stream transferManagerTypes() { + return Stream.of(Arguments.of(TmType.JAVA), Arguments.of(TmType.CRT)); } @ParameterizedTest - @MethodSource("transferManagers") + @MethodSource("transferManagerTypes") void copy_copiedObject_hasSameContent(TmType tmType) throws Exception { CaptureTransferListener transferListener = new CaptureTransferListener(); byte[] originalContent = randomBytes(OBJ_SIZE); @@ -75,7 +72,7 @@ void copy_copiedObject_hasSameContent(TmType tmType) throws Exception { } @ParameterizedTest - @MethodSource("transferManagers") + @MethodSource("transferManagerTypes") void copy_specialCharacters_hasSameContent(TmType tmType) throws Exception { CaptureTransferListener transferListener = new CaptureTransferListener(); byte[] originalContent = randomBytes(OBJ_SIZE); diff --git a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerDownloadDirectoryIntegrationTest.java b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerDownloadDirectoryIntegrationTest.java index 14eaf72cbe92..e362af640209 100644 --- a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerDownloadDirectoryIntegrationTest.java +++ b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerDownloadDirectoryIntegrationTest.java @@ -33,9 +33,9 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.opentest4j.AssertionFailedError; import software.amazon.awssdk.testutils.FileUtils; import software.amazon.awssdk.transfer.s3.model.CompletedDirectoryDownload; @@ -97,6 +97,7 @@ public static void teardown() { } closeQuietly(tmCrt, log.logger()); + closeQuietly(tmJava, log.logger()); } /** @@ -116,21 +117,23 @@ public static void teardown() { * } * */ - @Test - public void downloadDirectory() throws Exception { - DirectoryDownload downloadDirectory = tmCrt.downloadDirectory(u -> u.destination(directory) - .bucket(TEST_BUCKET)); + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("transferManagers") + public void downloadDirectory(S3TransferManager tm) throws Exception { + DirectoryDownload downloadDirectory = tm.downloadDirectory(u -> u.destination(directory) + .bucket(TEST_BUCKET)); CompletedDirectoryDownload completedDirectoryDownload = downloadDirectory.completionFuture().get(5, TimeUnit.SECONDS); assertThat(completedDirectoryDownload.failedTransfers()).isEmpty(); assertTwoDirectoriesHaveSameStructure(sourceDirectory, directory); } - @ParameterizedTest - @ValueSource(strings = {"notes/2021", "notes/2021/"}) - void downloadDirectory_withPrefix(String prefix) throws Exception { - DirectoryDownload downloadDirectory = tmCrt.downloadDirectory(u -> u.destination(directory) - .listObjectsV2RequestTransformer(r -> r.prefix(prefix)) - .bucket(TEST_BUCKET)); + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("prefixTestArguments") + void downloadDirectory_withPrefix(S3TransferManager tm, String prefix) throws Exception { + DirectoryDownload downloadDirectory = + tm.downloadDirectory(u -> u.destination(directory) + .listObjectsV2RequestTransformer(r -> r.prefix(prefix)) + .bucket(TEST_BUCKET)); CompletedDirectoryDownload completedDirectoryDownload = downloadDirectory.completionFuture().get(5, TimeUnit.SECONDS); assertThat(completedDirectoryDownload.failedTransfers()).isEmpty(); @@ -152,12 +155,14 @@ void downloadDirectory_withPrefix(String prefix) throws Exception { * } * */ - @Test - void downloadDirectory_containsObjectWithPrefixInTheKey_shouldResolveCorrectly() throws Exception { + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("transferManagers") + void downloadDirectory_containsObjectWithPrefixInTheKey_shouldResolveCorrectly(S3TransferManager tm) + throws Exception { String prefix = "notes"; - DirectoryDownload downloadDirectory = tmCrt.downloadDirectory(u -> u.destination(directory) - .listObjectsV2RequestTransformer(r -> r.prefix(prefix)) - .bucket(TEST_BUCKET)); + DirectoryDownload downloadDirectory = tm.downloadDirectory(u -> u.destination(directory) + .listObjectsV2RequestTransformer(r -> r.prefix(prefix)) + .bucket(TEST_BUCKET)); CompletedDirectoryDownload completedDirectoryDownload = downloadDirectory.completionFuture().get(5, TimeUnit.SECONDS); assertThat(completedDirectoryDownload.failedTransfers()).isEmpty(); @@ -182,14 +187,15 @@ void downloadDirectory_containsObjectWithPrefixInTheKey_shouldResolveCorrectly() * } * */ - @Test - public void downloadDirectory_withPrefixAndDelimiter() throws Exception { + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("transferManagers") + public void downloadDirectory_withPrefixAndDelimiter(S3TransferManager tm) throws Exception { String prefix = "notes-2021"; DirectoryDownload downloadDirectory = - tmCrt.downloadDirectory(u -> u.destination(directory) - .listObjectsV2RequestTransformer(r -> r.delimiter(CUSTOM_DELIMITER) + tm.downloadDirectory(u -> u.destination(directory) + .listObjectsV2RequestTransformer(r -> r.delimiter(CUSTOM_DELIMITER) .prefix(prefix)) - .bucket(TEST_BUCKET_CUSTOM_DELIMITER)); + .bucket(TEST_BUCKET_CUSTOM_DELIMITER)); CompletedDirectoryDownload completedDirectoryDownload = downloadDirectory.completionFuture().get(5, TimeUnit.SECONDS); assertThat(completedDirectoryDownload.failedTransfers()).isEmpty(); assertTwoDirectoriesHaveSameStructure(sourceDirectory.resolve("notes").resolve("2021"), directory); @@ -206,9 +212,10 @@ public void downloadDirectory_withPrefixAndDelimiter() throws Exception { * } * */ - @Test - public void downloadDirectory_withFilter() throws Exception { - DirectoryDownload downloadDirectory = tmCrt.downloadDirectory(u -> u + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("transferManagers") + public void downloadDirectory_withFilter(S3TransferManager tm) throws Exception { + DirectoryDownload downloadDirectory = tm.downloadDirectory(u -> u .destination(directory) .bucket(TEST_BUCKET) .filter(s3Object -> s3Object.key().startsWith("notes/2021/2"))); @@ -296,4 +303,14 @@ private static Path createLocalTestDirectory() throws IOException { RandomStringUtils.random(100).getBytes(StandardCharsets.UTF_8)); return directory; } + + private static Stream prefixTestArguments() { + String[] prefixes = {"notes/2021", "notes/2021/"}; + return Stream.of( + Arguments.of(tmCrt, prefixes[0]), + Arguments.of(tmCrt, prefixes[1]), + Arguments.of(tmJava, prefixes[0]), + Arguments.of(tmJava, prefixes[1]) + ); + } } diff --git a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerDownloadIntegrationTest.java b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerDownloadIntegrationTest.java index 0aa4d5484131..396aa62f00c4 100644 --- a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerDownloadIntegrationTest.java +++ b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerDownloadIntegrationTest.java @@ -29,7 +29,8 @@ import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.ResponsePublisher; @@ -66,11 +67,12 @@ public static void cleanup() { deleteBucketAndAllContents(BUCKET); } - @Test - void download_toFile() throws Exception { + @ParameterizedTest + @MethodSource("transferManagers") + void download_toFile(S3TransferManager tm) throws Exception { Path path = RandomTempFile.randomUncreatedFile().toPath(); FileDownload download = - tmCrt.downloadFile(DownloadFileRequest.builder() + tm.downloadFile(DownloadFileRequest.builder() .getObjectRequest(b -> b.bucket(BUCKET).key(KEY)) .destination(path) .addTransferListener(LoggingTransferListener.create()) @@ -80,44 +82,47 @@ void download_toFile() throws Exception { assertThat(completedFileDownload.response().responseMetadata().requestId()).isNotNull(); } - @Test - void download_toFile_shouldReplaceExisting() throws IOException { + @ParameterizedTest + @MethodSource("transferManagers") + void download_toFile_shouldReplaceExisting(S3TransferManager tm) throws IOException { Path path = RandomTempFile.randomUncreatedFile().toPath(); Files.write(path, RandomStringUtils.random(1024).getBytes(StandardCharsets.UTF_8)); assertThat(path).exists(); FileDownload download = - tmCrt.downloadFile(DownloadFileRequest.builder() - .getObjectRequest(b -> b.bucket(BUCKET).key(KEY)) - .destination(path) - .addTransferListener(LoggingTransferListener.create()) - .build()); + tm.downloadFile(DownloadFileRequest.builder() + .getObjectRequest(b -> b.bucket(BUCKET).key(KEY)) + .destination(path) + .addTransferListener(LoggingTransferListener.create()) + .build()); CompletedFileDownload completedFileDownload = download.completionFuture().join(); assertThat(Md5Utils.md5AsBase64(path.toFile())).isEqualTo(Md5Utils.md5AsBase64(file)); assertThat(completedFileDownload.response().responseMetadata().requestId()).isNotNull(); } - @Test - void download_toBytes() throws Exception { + @ParameterizedTest + @MethodSource("transferManagers") + void download_toBytes(S3TransferManager tm) throws Exception { Download> download = - tmCrt.download(DownloadRequest.builder() - .getObjectRequest(b -> b.bucket(BUCKET).key(KEY)) - .responseTransformer(AsyncResponseTransformer.toBytes()) - .addTransferListener(LoggingTransferListener.create()) - .build()); + tm.download(DownloadRequest.builder() + .getObjectRequest(b -> b.bucket(BUCKET).key(KEY)) + .responseTransformer(AsyncResponseTransformer.toBytes()) + .addTransferListener(LoggingTransferListener.create()) + .build()); CompletedDownload> completedDownload = download.completionFuture().join(); ResponseBytes result = completedDownload.result(); assertThat(Md5Utils.md5AsBase64(result.asByteArray())).isEqualTo(Md5Utils.md5AsBase64(file)); assertThat(result.response().responseMetadata().requestId()).isNotNull(); } - @Test - void download_toPublisher() throws Exception { + @ParameterizedTest + @MethodSource("transferManagers") + void download_toPublisher(S3TransferManager tm) throws Exception { Download> download = - tmCrt.download(DownloadRequest.builder() - .getObjectRequest(b -> b.bucket(BUCKET).key(KEY)) - .responseTransformer(AsyncResponseTransformer.toPublisher()) - .addTransferListener(LoggingTransferListener.create()) - .build()); + tm.download(DownloadRequest.builder() + .getObjectRequest(b -> b.bucket(BUCKET).key(KEY)) + .responseTransformer(AsyncResponseTransformer.toPublisher()) + .addTransferListener(LoggingTransferListener.create()) + .build()); CompletedDownload> completedDownload = download.completionFuture().join(); ResponsePublisher responsePublisher = completedDownload.result(); ByteBuffer buf = ByteBuffer.allocate(Math.toIntExact(responsePublisher.response().contentLength())); diff --git a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerDownloadPauseResumeIntegrationTest.java b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerDownloadPauseResumeIntegrationTest.java index ec39139a7a50..cee4049ed04d 100644 --- a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerDownloadPauseResumeIntegrationTest.java +++ b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerDownloadPauseResumeIntegrationTest.java @@ -16,6 +16,7 @@ package software.amazon.awssdk.transfer.s3; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName; import static software.amazon.awssdk.transfer.s3.SizeConstant.MB; @@ -28,7 +29,8 @@ import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.core.SdkResponse; import software.amazon.awssdk.core.retry.backoff.FixedDelayBackoffStrategy; import software.amazon.awssdk.core.sync.RequestBody; @@ -68,8 +70,9 @@ public static void cleanup() { sourceFile.delete(); } - @Test - void pauseAndResume_ObjectNotChanged_shouldResumeDownload() { + @ParameterizedTest + @MethodSource("transferManagers") + void pauseAndResume_ObjectNotChanged_shouldResumeDownload(S3TransferManager tm) { Path path = RandomTempFile.randomUncreatedFile().toPath(); TestDownloadListener testDownloadListener = new TestDownloadListener(); DownloadFileRequest request = DownloadFileRequest.builder() @@ -77,13 +80,13 @@ void pauseAndResume_ObjectNotChanged_shouldResumeDownload() { .destination(path) .addTransferListener(testDownloadListener) .build(); - FileDownload download = tmCrt.downloadFile(request); + FileDownload download = tm.downloadFile(request); waitUntilFirstByteBufferDelivered(download); ResumableFileDownload resumableFileDownload = download.pause(); long bytesTransferred = resumableFileDownload.bytesTransferred(); log.debug(() -> "Paused: " + resumableFileDownload); - assertThat(resumableFileDownload.downloadFileRequest()).isEqualTo(request); + assertEqualsBySdkFields(resumableFileDownload.downloadFileRequest(), request); assertThat(testDownloadListener.getObjectResponse).isNotNull(); assertThat(resumableFileDownload.s3ObjectLastModified()).hasValue(testDownloadListener.getObjectResponse.lastModified()); assertThat(bytesTransferred).isEqualTo(path.toFile().length()); @@ -93,18 +96,34 @@ void pauseAndResume_ObjectNotChanged_shouldResumeDownload() { assertThat(download.completionFuture()).isCancelled(); log.debug(() -> "Resuming download "); - verifyFileDownload(path, resumableFileDownload, OBJ_SIZE - bytesTransferred); + verifyFileDownload(path, resumableFileDownload, OBJ_SIZE - bytesTransferred, tm); } - @Test - void pauseAndResume_objectChanged_shouldStartFromBeginning() { + private void assertEqualsBySdkFields(DownloadFileRequest actual, DownloadFileRequest expected) { + // Transfer manager adds an execution attribute to the GetObjectRequest, so both objects are different. + // Need to assert equality by sdk fields, which does not check execution attributes. + assertThat(actual.destination()) + .withFailMessage("ResumableFileDownload destination not equal to the original DownloadFileRequest") + .isEqualTo(expected.destination()); + assertThat(actual.transferListeners()) + .withFailMessage("ResumableFileDownload transferListeners not equal to the original DownloadFileRequest") + .isEqualTo(expected.transferListeners()); + assertTrue(actual.getObjectRequest().equalsBySdkFields(expected.getObjectRequest()), + () -> String.format("ResumableFileDownload GetObjectRequest not equal to the original DownloadFileRequest. " + + "expected: %s. Actual:" + + " %s", actual.getObjectRequest(), expected.getObjectRequest())); + } + + @ParameterizedTest + @MethodSource("transferManagers") + void pauseAndResume_objectChanged_shouldStartFromBeginning(S3TransferManager tm) { try { Path path = RandomTempFile.randomUncreatedFile().toPath(); DownloadFileRequest request = DownloadFileRequest.builder() .getObjectRequest(b -> b.bucket(BUCKET).key(KEY)) .destination(path) .build(); - FileDownload download = tmCrt.downloadFile(request); + FileDownload download = tm.downloadFile(request); waitUntilFirstByteBufferDelivered(download); ResumableFileDownload resumableFileDownload = download.pause(); @@ -118,7 +137,7 @@ void pauseAndResume_objectChanged_shouldStartFromBeginning() { .build(), RequestBody.fromString(newObject)); log.debug(() -> "Resuming download "); - FileDownload resumedFileDownload = tmCrt.resumeDownloadFile(resumableFileDownload); + FileDownload resumedFileDownload = tm.resumeDownloadFile(resumableFileDownload); resumedFileDownload.progress().snapshot(); resumedFileDownload.completionFuture().join(); assertThat(path.toFile()).hasContent(newObject); @@ -131,24 +150,26 @@ void pauseAndResume_objectChanged_shouldStartFromBeginning() { } } - @Test - void pauseAndResume_fileChanged_shouldStartFromBeginning() throws Exception { + @ParameterizedTest + @MethodSource("transferManagers") + void pauseAndResume_fileChanged_shouldStartFromBeginning(S3TransferManager tm) throws Exception { Path path = RandomTempFile.randomUncreatedFile().toPath(); DownloadFileRequest request = DownloadFileRequest.builder() .getObjectRequest(b -> b.bucket(BUCKET).key(KEY)) .destination(path) .build(); - FileDownload download = tmCrt.downloadFile(request); + FileDownload download = tm.downloadFile(request); waitUntilFirstByteBufferDelivered(download); ResumableFileDownload resumableFileDownload = download.pause(); Files.write(path, "helloworld".getBytes(StandardCharsets.UTF_8)); - verifyFileDownload(path, resumableFileDownload, OBJ_SIZE); + verifyFileDownload(path, resumableFileDownload, OBJ_SIZE, tm); } - private static void verifyFileDownload(Path path, ResumableFileDownload resumableFileDownload, long expectedBytesTransferred) { - FileDownload resumedFileDownload = tmCrt.resumeDownloadFile(resumableFileDownload); + private static void verifyFileDownload(Path path, ResumableFileDownload resumableFileDownload, + long expectedBytesTransferred, S3TransferManager tm) { + FileDownload resumedFileDownload = tm.resumeDownloadFile(resumableFileDownload); resumedFileDownload.completionFuture().join(); assertThat(resumedFileDownload.progress().snapshot().totalBytes()).hasValue(expectedBytesTransferred); assertThat(path.toFile()).hasSameBinaryContentAs(sourceFile); diff --git a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerMultipartDownloadPauseResumeIntegrationTest.java b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerMultipartDownloadPauseResumeIntegrationTest.java new file mode 100644 index 000000000000..dcbace143b5f --- /dev/null +++ b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerMultipartDownloadPauseResumeIntegrationTest.java @@ -0,0 +1,121 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.transfer.s3; + +import static org.assertj.core.api.Assertions.assertThat; +import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName; +import static software.amazon.awssdk.transfer.s3.SizeConstant.MB; + +import java.io.File; +import java.nio.file.Path; +import java.time.Duration; +import java.util.List; +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.core.LogEvent; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.retry.backoff.FixedDelayBackoffStrategy; +import software.amazon.awssdk.core.waiters.Waiter; +import software.amazon.awssdk.core.waiters.WaiterAcceptor; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.testutils.LogCaptor; +import software.amazon.awssdk.testutils.RandomTempFile; +import software.amazon.awssdk.transfer.s3.model.DownloadFileRequest; +import software.amazon.awssdk.transfer.s3.model.FileDownload; +import software.amazon.awssdk.transfer.s3.model.ResumableFileDownload; +import software.amazon.awssdk.transfer.s3.progress.TransferProgressSnapshot; + +public class S3TransferManagerMultipartDownloadPauseResumeIntegrationTest extends S3IntegrationTestBase { + private static final String BUCKET = temporaryBucketName(S3TransferManagerMultipartDownloadPauseResumeIntegrationTest.class); + private static final String KEY = "key"; + + private static final long OBJ_SIZE = 32 * MB; // 32mib for 4 parts of 8 mib + private static File sourceFile; + + @BeforeAll + public static void setup() throws Exception { + createBucket(BUCKET); + sourceFile = new RandomTempFile(OBJ_SIZE); + + // use async client for multipart upload (with default part size) + s3Async.putObject(PutObjectRequest.builder() + .bucket(BUCKET) + .key(KEY) + .build(), sourceFile.toPath()) + .join(); + } + + @AfterAll + public static void cleanup() { + deleteBucketAndAllContents(BUCKET); + sourceFile.delete(); + } + + @Test + void pauseAndResume_shouldResumeDownload() { + Path path = RandomTempFile.randomUncreatedFile().toPath(); + DownloadFileRequest request = DownloadFileRequest.builder() + .getObjectRequest(b -> b.bucket(BUCKET).key(KEY)) + .destination(path) + .build(); + FileDownload download = tmJava.downloadFile(request); + + // wait until we receive enough byte to stop somewhere between part 2 and 3, 18 Mib should do it + waitUntilAmountTransferred(download, 18 * MB); + ResumableFileDownload resumableFileDownload = download.pause(); + FileDownload resumed = tmJava.resumeDownloadFile(resumableFileDownload); + resumed.completionFuture().join(); + assertThat(path.toFile()).hasSameBinaryContentAs(sourceFile); + } + + @Test + void pauseAndResume_whenAlreadyComplete_shouldHandleGracefully() { + Path path = RandomTempFile.randomUncreatedFile().toPath(); + DownloadFileRequest request = DownloadFileRequest.builder() + .getObjectRequest(b -> b.bucket(BUCKET).key(KEY)) + .destination(path) + .build(); + FileDownload download = tmJava.downloadFile(request); + download.completionFuture().join(); + ResumableFileDownload resume = download.pause(); + try (LogCaptor logCaptor = LogCaptor.create(Level.DEBUG)) { + FileDownload resumedDownload = tmJava.resumeDownloadFile(resume); + assertThat(resumedDownload.completionFuture()).isCompleted(); + assertThat(path.toFile()).hasSameBinaryContentAs(sourceFile); + + List logEvents = logCaptor.loggedEvents(); + assertThat(logEvents).noneMatch( + event -> event.getMessage().getFormattedMessage().contains("Sending downloadFileRequest")); + LogEvent firstLog = logEvents.get(0); + assertThat(firstLog.getMessage().getFormattedMessage()) + .contains("The multipart download associated to the provided ResumableFileDownload is already completed, " + + "nothing to resume"); + } + } + + private void waitUntilAmountTransferred(FileDownload download, long amountTransferred) { + Waiter waiter = + Waiter.builder(TransferProgressSnapshot.class) + .addAcceptor(WaiterAcceptor.successOnResponseAcceptor(r -> r.transferredBytes() > amountTransferred)) + .addAcceptor(WaiterAcceptor.retryOnResponseAcceptor(r -> true)) + .overrideConfiguration(o -> o.waitTimeout(Duration.ofMinutes(5)) + .maxAttempts(Integer.MAX_VALUE) + .backoffStrategy(FixedDelayBackoffStrategy.create(Duration.ofMillis(100)))) + .build(); + waiter.run(() -> download.progress().snapshot()); + } +} diff --git a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadDirectoryIntegrationTest.java b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadDirectoryIntegrationTest.java index cfd8155853b9..f0e9134c247d 100644 --- a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadDirectoryIntegrationTest.java +++ b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadDirectoryIntegrationTest.java @@ -26,19 +26,18 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.Arrays; -import java.util.Collection; import java.util.List; import java.util.UUID; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.commons.codec.binary.Hex; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.services.s3.model.NoSuchBucketException; @@ -77,49 +76,52 @@ public static void teardown() { } } - @Test - void uploadDirectory_filesSentCorrectly() { + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("transferManagers") + void uploadDirectory_filesSentCorrectly(S3TransferManager tm) { String prefix = "yolo"; - DirectoryUpload uploadDirectory = tmCrt.uploadDirectory(u -> u.source(directory) - .bucket(TEST_BUCKET) - .s3Prefix(prefix)); + DirectoryUpload uploadDirectory = tm.uploadDirectory(u -> u.source(directory) + .bucket(TEST_BUCKET) + .s3Prefix(prefix)); CompletedDirectoryUpload completedDirectoryUpload = uploadDirectory.completionFuture().join(); assertThat(completedDirectoryUpload.failedTransfers()).isEmpty(); List keys = s3.listObjectsV2Paginator(b -> b.bucket(TEST_BUCKET).prefix(prefix)).contents().stream().map(S3Object::key) - .collect(Collectors.toList()); + .collect(Collectors.toList()); assertThat(keys).containsOnly(prefix + "/bar.txt", prefix + "/foo/1.txt", prefix + "/foo/2.txt"); keys.forEach(k -> verifyContent(k, k.substring(prefix.length() + 1) + randomString)); } - @Test - void uploadDirectory_nonExistsBucket_shouldAddFailedRequest() { + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("transferManagers") + void uploadDirectory_nonExistsBucket_shouldAddFailedRequest(S3TransferManager tm) { String prefix = "yolo"; - DirectoryUpload uploadDirectory = tmCrt.uploadDirectory(u -> u.source(directory) - .bucket("nonExistingTestBucket" + UUID.randomUUID()) - .s3Prefix(prefix)); + DirectoryUpload uploadDirectory = tm.uploadDirectory(u -> u.source(directory) + .bucket("nonExistingTestBucket" + UUID.randomUUID()) + .s3Prefix(prefix)); CompletedDirectoryUpload completedDirectoryUpload = uploadDirectory.completionFuture().join(); assertThat(completedDirectoryUpload.failedTransfers()).hasSize(3).allSatisfy(f -> - assertThat(f.exception()).isInstanceOf(NoSuchBucketException.class)); + assertThat(f.exception()).isInstanceOf(NoSuchBucketException.class)); } - @Test - void uploadDirectory_withDelimiter_filesSentCorrectly() { + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("transferManagers") + void uploadDirectory_withDelimiter_filesSentCorrectly(S3TransferManager tm) { String prefix = "hello"; String delimiter = "0"; - DirectoryUpload uploadDirectory = tmCrt.uploadDirectory(u -> u.source(directory) - .bucket(TEST_BUCKET) - .s3Delimiter(delimiter) - .s3Prefix(prefix)); + DirectoryUpload uploadDirectory = tm.uploadDirectory(u -> u.source(directory) + .bucket(TEST_BUCKET) + .s3Delimiter(delimiter) + .s3Prefix(prefix)); CompletedDirectoryUpload completedDirectoryUpload = uploadDirectory.completionFuture().join(); assertThat(completedDirectoryUpload.failedTransfers()).isEmpty(); List keys = s3.listObjectsV2Paginator(b -> b.bucket(TEST_BUCKET).prefix(prefix)).contents().stream().map(S3Object::key) - .collect(Collectors.toList()); + .collect(Collectors.toList()); assertThat(keys).containsOnly(prefix + "0bar.txt", prefix + "0foo01.txt", prefix + "0foo02.txt"); keys.forEach(k -> { @@ -128,18 +130,19 @@ void uploadDirectory_withDelimiter_filesSentCorrectly() { }); } - @Test - void uploadDirectory_withRequestTransformer_usesRequestTransformer() throws Exception { + @ParameterizedTest(autoCloseArguments = false) + @MethodSource("transferManagers") + void uploadDirectory_withRequestTransformer_usesRequestTransformer(S3TransferManager tm) throws Exception { String prefix = "requestTransformerTest"; Path newSourceForEachUpload = Paths.get(directory.toString(), "bar.txt"); CompletedDirectoryUpload result = - tmCrt.uploadDirectory(r -> r.source(directory) - .bucket(TEST_BUCKET) - .s3Prefix(prefix) - .uploadFileRequestTransformer(f -> f.source(newSourceForEachUpload))) - .completionFuture() - .get(10, TimeUnit.SECONDS); + tm.uploadDirectory(r -> r.source(directory) + .bucket(TEST_BUCKET) + .s3Prefix(prefix) + .uploadFileRequestTransformer(f -> f.source(newSourceForEachUpload))) + .completionFuture() + .get(10, TimeUnit.SECONDS); assertThat(result.failedTransfers()).isEmpty(); s3.listObjectsV2Paginator(b -> b.bucket(TEST_BUCKET).prefix(prefix)).contents().forEach(object -> { @@ -147,8 +150,8 @@ void uploadDirectory_withRequestTransformer_usesRequestTransformer() throws Exce }); } - public static Collection prefix() { - return Arrays.asList( + public static Stream prefix() { + return Stream.of( /* ASCII, 1-byte UTF-8 */ "E", /* ASCII, 2-byte UTF-8 */ @@ -159,16 +162,17 @@ public static Collection prefix() { "स", /* Non-ASCII, 4-byte UTF-8 */ "\uD808\uDC8C" - ); + ).flatMap(prefix -> Stream.of(Arguments.of(prefix, tmCrt), Arguments.of(prefix, tmJava))); } /** - * Tests the behavior of traversing local directories with special Unicode characters in their path name. These characters have - * known to be problematic when using Java's old File API or with Windows (which uses UTF-16 for file-name encoding). + * Tests the behavior of traversing local directories with special Unicode characters in their path name. These characters + * have known to be problematic when using Java's old File API or with Windows (which uses UTF-16 for file-name encoding). */ - @ParameterizedTest + @ParameterizedTest(autoCloseArguments = false) @MethodSource("prefix") - void uploadDirectory_fileNameWithUnicode_traversedCorrectly(String directoryPrefix) throws IOException { + void uploadDirectory_fileNameWithUnicode_traversedCorrectly(String directoryPrefix, S3TransferManager tm) + throws IOException { assumeTrue(Charset.defaultCharset().equals(StandardCharsets.UTF_8), "Ignoring the test if the test directory can't be " + "created"); Path testDirectory = null; @@ -179,8 +183,8 @@ void uploadDirectory_fileNameWithUnicode_traversedCorrectly(String directoryPref testDirectory = createLocalTestDirectory(directoryPrefix); Path finalTestDirectory = testDirectory; - DirectoryUpload uploadDirectory = tmCrt.uploadDirectory(u -> u.source(finalTestDirectory) - .bucket(TEST_BUCKET)); + DirectoryUpload uploadDirectory = tm.uploadDirectory(u -> u.source(finalTestDirectory) + .bucket(TEST_BUCKET)); CompletedDirectoryUpload completedDirectoryUpload = uploadDirectory.completionFuture().join(); assertThat(completedDirectoryUpload.failedTransfers()).isEmpty(); @@ -224,7 +228,7 @@ private Path createLocalTestDirectory(String directoryPrefix) throws IOException private static void verifyContent(String key, String expectedContent) { String actualContent = s3.getObject(r -> r.bucket(TEST_BUCKET).key(key), - ResponseTransformer.toBytes()).asUtf8String(); + ResponseTransformer.toBytes()).asUtf8String(); assertThat(actualContent).isEqualTo(expectedContent); } diff --git a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadIntegrationTest.java b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadIntegrationTest.java index 4598e388af39..1b720950be41 100644 --- a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadIntegrationTest.java +++ b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadIntegrationTest.java @@ -24,12 +24,10 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.CancellationException; -import java.util.stream.Stream; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.async.AsyncRequestBody; @@ -66,39 +64,34 @@ public static void teardown() throws IOException { deleteBucketAndAllContents(TEST_BUCKET); } - private static Stream transferManagers() { - return Stream.of( - Arguments.of(tmCrt), - Arguments.of(tmJava)); - } @ParameterizedTest @MethodSource("transferManagers") - void upload_file_SentCorrectly(S3TransferManager transferManager) throws IOException { + void upload_file_SentCorrectly(S3TransferManager tm) throws IOException { Map metadata = new HashMap<>(); - CaptureTransferListener transferListener = new CaptureTransferListener(); - metadata.put("x-amz-meta-foobar", "FOO BAR"); + CaptureTransferListener transferListener = new CaptureTransferListener(); + metadata.put("x-amz-meta-foobar", "FOO BAR"); FileUpload fileUpload = - transferManager.uploadFile(u -> u.putObjectRequest(p -> p.bucket(TEST_BUCKET).key(TEST_KEY).metadata(metadata).checksumAlgorithm(ChecksumAlgorithm.CRC32)) - .source(testFile.toPath()) - .addTransferListener(LoggingTransferListener.create()) - .addTransferListener(transferListener) - .build()); + tm.uploadFile(u -> u.putObjectRequest(p -> p.bucket(TEST_BUCKET).key(TEST_KEY).metadata(metadata).checksumAlgorithm(ChecksumAlgorithm.CRC32)) + .source(testFile.toPath()) + .addTransferListener(LoggingTransferListener.create()) + .addTransferListener(transferListener) + .build()); CompletedFileUpload completedFileUpload = fileUpload.completionFuture().join(); assertThat(completedFileUpload.response().responseMetadata().requestId()).isNotNull(); assertThat(completedFileUpload.response().sdkHttpResponse()).isNotNull(); ResponseInputStream obj = s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), - ResponseTransformer.toInputStream()); + ResponseTransformer.toInputStream()); assertThat(ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath()))) - .isEqualTo(ChecksumUtils.computeCheckSum(obj)); + .isEqualTo(ChecksumUtils.computeCheckSum(obj)); assertThat(obj.response().responseMetadata().requestId()).isNotNull(); assertThat(obj.response().metadata()).containsEntry("foobar", "FOO BAR"); assertThat(fileUpload.progress().snapshot().sdkResponse()).isPresent(); assertListenerForSuccessfulTransferComplete(transferListener); - } + } private static void assertListenerForSuccessfulTransferComplete(CaptureTransferListener transferListener) { assertThat(transferListener.isTransferInitiated()).isTrue(); @@ -111,17 +104,17 @@ private static void assertListenerForSuccessfulTransferComplete(CaptureTransferL @ParameterizedTest @MethodSource("transferManagers") - void upload_asyncRequestBodyFromString_SentCorrectly(S3TransferManager transferManager) throws IOException { + void upload_asyncRequestBodyFromString_SentCorrectly(S3TransferManager tm) throws IOException { String content = RandomStringUtils.randomAscii(OBJ_SIZE); CaptureTransferListener transferListener = new CaptureTransferListener(); Upload upload = - transferManager.upload(UploadRequest.builder() - .putObjectRequest(b -> b.bucket(TEST_BUCKET).key(TEST_KEY)) - .requestBody(AsyncRequestBody.fromString(content)) - .addTransferListener(LoggingTransferListener.create()) - .addTransferListener(transferListener) - .build()); + tm.upload(UploadRequest.builder() + .putObjectRequest(b -> b.bucket(TEST_BUCKET).key(TEST_KEY)) + .requestBody(AsyncRequestBody.fromString(content)) + .addTransferListener(LoggingTransferListener.create()) + .addTransferListener(transferListener) + .build()); CompletedUpload completedUpload = upload.completionFuture().join(); assertThat(completedUpload.response().responseMetadata().requestId()).isNotNull(); @@ -140,16 +133,16 @@ void upload_asyncRequestBodyFromString_SentCorrectly(S3TransferManager transferM @ParameterizedTest @MethodSource("transferManagers") - void upload_asyncRequestBodyFromFile_SentCorrectly(S3TransferManager transferManager) throws IOException { + void upload_asyncRequestBodyFromFile_SentCorrectly(S3TransferManager tm) throws IOException { CaptureTransferListener transferListener = new CaptureTransferListener(); Upload upload = - transferManager.upload(UploadRequest.builder() - .putObjectRequest(b -> b.bucket(TEST_BUCKET).key(TEST_KEY)) - .requestBody(FileAsyncRequestBody.builder().chunkSizeInBytes(1024).path(testFile.toPath()).build()) - .addTransferListener(LoggingTransferListener.create()) - .addTransferListener(transferListener) - .build()); + tm.upload(UploadRequest.builder() + .putObjectRequest(b -> b.bucket(TEST_BUCKET).key(TEST_KEY)) + .requestBody(FileAsyncRequestBody.builder().chunkSizeInBytes(1024).path(testFile.toPath()).build()) + .addTransferListener(LoggingTransferListener.create()) + .addTransferListener(transferListener) + .build()); CompletedUpload completedUpload = upload.completionFuture().join(); assertThat(completedUpload.response().responseMetadata().requestId()).isNotNull(); @@ -169,16 +162,19 @@ void upload_asyncRequestBodyFromFile_SentCorrectly(S3TransferManager transferMan @ParameterizedTest @MethodSource("transferManagers") - void upload_file_Interupted_CancelsTheListener(S3TransferManager transferManager) throws IOException, InterruptedException { + void upload_file_Interupted_CancelsTheListener(S3TransferManager tm) { Map metadata = new HashMap<>(); CaptureTransferListener transferListener = new CaptureTransferListener(); metadata.put("x-amz-meta-foobar", "FOO BAR"); FileUpload fileUpload = - transferManager.uploadFile(u -> u.putObjectRequest(p -> p.bucket(TEST_BUCKET).key(TEST_KEY).metadata(metadata).checksumAlgorithm(ChecksumAlgorithm.CRC32)) - .source(testFile.toPath()) - .addTransferListener(LoggingTransferListener.create()) - .addTransferListener(transferListener) - .build()); + tm.uploadFile(u -> u.putObjectRequest(p -> p.bucket(TEST_BUCKET) + .key(TEST_KEY) + .metadata(metadata) + .checksumAlgorithm(ChecksumAlgorithm.CRC32)) + .source(testFile.toPath()) + .addTransferListener(LoggingTransferListener.create()) + .addTransferListener(transferListener) + .build()); fileUpload.completionFuture().cancel(true); assertThat(transferListener.isTransferInitiated()).isTrue(); diff --git a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadPauseResumeIntegrationTest.java b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadPauseResumeIntegrationTest.java index 0e995048e1ae..1a49e5c618a4 100644 --- a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadPauseResumeIntegrationTest.java +++ b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadPauseResumeIntegrationTest.java @@ -35,6 +35,7 @@ import software.amazon.awssdk.core.waiters.AsyncWaiter; import software.amazon.awssdk.core.waiters.Waiter; import software.amazon.awssdk.core.waiters.WaiterAcceptor; +import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.ListMultipartUploadsResponse; import software.amazon.awssdk.services.s3.model.ListPartsResponse; import software.amazon.awssdk.services.s3.model.NoSuchUploadException; @@ -62,6 +63,7 @@ public static void setup() throws Exception { largeFile = new RandomTempFile(LARGE_OBJ_SIZE); smallFile = new RandomTempFile(SMALL_OBJ_SIZE); executorService = Executors.newScheduledThreadPool(3); + } @AfterAll @@ -72,7 +74,7 @@ public static void cleanup() { executorService.shutdown(); } - private static Stream transferManagers() { + private static Stream transferManagersArguments() { return Stream.of( Arguments.of(tmJava, tmJava), Arguments.of(tmCrt, tmCrt), @@ -82,7 +84,7 @@ private static Stream transferManagers() { } @ParameterizedTest - @MethodSource("transferManagers") + @MethodSource("transferManagersArguments") void pause_singlePart_shouldResume(S3TransferManager uploadTm, S3TransferManager resumeTm) { UploadFileRequest request = UploadFileRequest.builder() .putObjectRequest(b -> b.bucket(BUCKET).key(KEY)) @@ -100,7 +102,7 @@ void pause_singlePart_shouldResume(S3TransferManager uploadTm, S3TransferManager } @ParameterizedTest - @MethodSource("transferManagers") + @MethodSource("transferManagersArguments") void pause_fileNotChanged_shouldResume(S3TransferManager uploadTm, S3TransferManager resumeTm) throws Exception { UploadFileRequest request = UploadFileRequest.builder() .putObjectRequest(b -> b.bucket(BUCKET).key(KEY)) @@ -124,7 +126,7 @@ void pause_fileNotChanged_shouldResume(S3TransferManager uploadTm, S3TransferMan } @ParameterizedTest - @MethodSource("transferManagers") + @MethodSource("transferManagersArguments") void pauseImmediately_resume_shouldStartFromBeginning(S3TransferManager uploadTm, S3TransferManager resumeTm) { UploadFileRequest request = UploadFileRequest.builder() .putObjectRequest(b -> b.bucket(BUCKET).key(KEY)) @@ -142,7 +144,7 @@ void pauseImmediately_resume_shouldStartFromBeginning(S3TransferManager uploadTm } @ParameterizedTest - @MethodSource("transferManagers") + @MethodSource("transferManagersArguments") void pause_fileChanged_resumeShouldStartFromBeginning(S3TransferManager uploadTm, S3TransferManager resumeTm) throws Exception { UploadFileRequest request = UploadFileRequest.builder() .putObjectRequest(b -> b.bucket(BUCKET).key(KEY)) @@ -191,13 +193,14 @@ private void verifyMultipartUploadIdNotExist(ResumableFileUpload resumableFileUp } private static void waitUntilMultipartUploadExists() { - Waiter waiter = Waiter.builder(ListMultipartUploadsResponse.class) - .addAcceptor(WaiterAcceptor.successOnResponseAcceptor(ListMultipartUploadsResponse::hasUploads)) - .addAcceptor(WaiterAcceptor.retryOnResponseAcceptor(r -> true)) - .overrideConfiguration(o -> o.waitTimeout(Duration.ofMinutes(1)) - .maxAttempts(10) - .backoffStrategy(FixedDelayBackoffStrategy.create(Duration.ofMillis(100)))) - .build(); + Waiter waiter = + Waiter.builder(ListMultipartUploadsResponse.class) + .addAcceptor(WaiterAcceptor.successOnResponseAcceptor(ListMultipartUploadsResponse::hasUploads)) + .addAcceptor(WaiterAcceptor.retryOnResponseAcceptor(r -> true)) + .overrideConfiguration(o -> o.waitTimeout(Duration.ofMinutes(1)) + .maxAttempts(10) + .backoffStrategy(FixedDelayBackoffStrategy.create(Duration.ofMillis(100)))) + .build(); waiter.run(() -> s3.listMultipartUploads(l -> l.bucket(BUCKET))); } diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/S3TransferManager.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/S3TransferManager.java index 362bf1419e2d..5bf9b55ed657 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/S3TransferManager.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/S3TransferManager.java @@ -75,9 +75,24 @@ * * S3TransferManager transferManager = * S3TransferManager.builder() - * .s3AsyncClient(s3AsyncClient) + * .s3Client(s3AsyncClient) * .build(); * } + * + * Create an S3 Transfer Manager with S3 Multipart Async Client. The S3 Multipart Async Client is an alternative to the CRT + * client that offers the same multipart upload/download feature. + * {@snippet : + * S3AsyncClient s3AsyncClient = s3AsyncClient.builder() + * .multipartEnabled(true) + * .multipartConfiguration(conf -> conf.apiCallBufferSizeInBytes(32 * MB)) + * .build(); + * + * S3TransferManager transferManager = + * S3TransferManager.builder() + * .s3Client(s3AsyncClient) + * .build(); + * } + * *

Common Usage Patterns

* Upload a file to S3 * {@snippet : @@ -157,6 +172,11 @@ * // Wait for the transfer to complete * CompletedCopy completedCopy = copy.completionFuture().join(); * } + * The automatic parallel transfer feature (multipart upload/download) is available + * through the AWS-CRT based S3 client {@code S3AsyncClient.crtBuilder().build)} + * and Java-based S3 multipart client {@code S3AsyncClient.builder().multipartEnabled(true).build()}. + * If no client is configured, AWS-CRT based S3 client will be used if AWS CRT is in the classpath, + * otherwise, Java-based S3 multipart client will be used. */ @SdkPublicApi @ThreadSafe diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java index ee8353f46c3c..0e6e47ce8d90 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java @@ -15,12 +15,15 @@ package software.amazon.awssdk.transfer.s3.internal; +import static software.amazon.awssdk.services.s3.internal.multipart.MultipartDownloadUtils.multipartDownloadResumeContext; import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.JAVA_PROGRESS_LISTENER; +import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.MULTIPART_DOWNLOAD_RESUME_CONTEXT; import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.PAUSE_OBSERVABLE; import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.RESUME_TOKEN; import static software.amazon.awssdk.transfer.s3.SizeConstant.MB; import static software.amazon.awssdk.transfer.s3.internal.utils.ResumableRequestConverter.toDownloadFileRequestAndTransformer; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.function.Consumer; @@ -34,8 +37,8 @@ import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody; -import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.internal.multipart.MultipartDownloadResumeContext; import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient; import software.amazon.awssdk.services.s3.internal.resource.S3AccessPointResource; import software.amazon.awssdk.services.s3.internal.resource.S3ArnConverter; @@ -57,6 +60,8 @@ import software.amazon.awssdk.transfer.s3.internal.model.DefaultFileDownload; import software.amazon.awssdk.transfer.s3.internal.model.DefaultFileUpload; import software.amazon.awssdk.transfer.s3.internal.model.DefaultUpload; +import software.amazon.awssdk.transfer.s3.internal.progress.DefaultTransferProgress; +import software.amazon.awssdk.transfer.s3.internal.progress.DefaultTransferProgressSnapshot; import software.amazon.awssdk.transfer.s3.internal.progress.ResumeTransferProgress; import software.amazon.awssdk.transfer.s3.internal.progress.TransferProgressUpdater; import software.amazon.awssdk.transfer.s3.model.CompletedCopy; @@ -198,7 +203,6 @@ public FileUpload uploadFile(UploadFileRequest uploadFileRequest) { pauseObservable = null; } - try { assertNotUnsupportedArn(putObjectRequest.bucket(), "upload"); @@ -291,16 +295,31 @@ private CopyObjectRequest attachSdkAttribute(CopyObjectRequest copyObjectRequest Consumer builderMutation) { AwsRequestOverrideConfiguration modifiedRequestOverrideConfig = copyObjectRequest.overrideConfiguration() - .map(o -> o.toBuilder().applyMutation(builderMutation).build()) - .orElseGet(() -> AwsRequestOverrideConfiguration.builder() - .applyMutation(builderMutation) - .build()); + .map(o -> o.toBuilder().applyMutation(builderMutation).build()) + .orElseGet(() -> AwsRequestOverrideConfiguration.builder() + .applyMutation(builderMutation) + .build()); return copyObjectRequest.toBuilder() - .overrideConfiguration(modifiedRequestOverrideConfig) - .build(); + .overrideConfiguration(modifiedRequestOverrideConfig) + .build(); + } + + private GetObjectRequest attachSdkAttribute(GetObjectRequest request, + Consumer builderMutation) { + AwsRequestOverrideConfiguration modifiedRequestOverrideConfig = + request.overrideConfiguration() + .map(o -> o.toBuilder().applyMutation(builderMutation).build()) + .orElseGet(() -> AwsRequestOverrideConfiguration.builder() + .applyMutation(builderMutation) + .build()); + + return request.toBuilder() + .overrideConfiguration(modifiedRequestOverrideConfig) + .build(); } + @Override public final DirectoryUpload uploadDirectory(UploadDirectoryRequest uploadDirectoryRequest) { Validate.paramNotNull(uploadDirectoryRequest, "uploadDirectoryRequest"); @@ -325,13 +344,16 @@ public final Download download(DownloadRequest downl TransferProgressUpdater progressUpdater = new TransferProgressUpdater(downloadRequest, null); progressUpdater.transferInitiated(); - responseTransformer = progressUpdater.wrapResponseTransformer(responseTransformer); + responseTransformer = isS3ClientMultipartEnabled() + ? progressUpdater.wrapResponseTransformerForMultipartDownload( + responseTransformer, downloadRequest.getObjectRequest()) + : progressUpdater.wrapResponseTransformer(responseTransformer); progressUpdater.registerCompletion(returnFuture); try { assertNotUnsupportedArn(downloadRequest.getObjectRequest().bucket(), "download"); - CompletableFuture future = doGetObject(downloadRequest.getObjectRequest(), responseTransformer); + CompletableFuture future = s3AsyncClient.getObject(downloadRequest.getObjectRequest(), responseTransformer); // Forward download cancellation to future CompletableFutureUtils.forwardExceptionTo(returnFuture, future); @@ -351,14 +373,21 @@ public final Download download(DownloadRequest downl public final FileDownload downloadFile(DownloadFileRequest downloadRequest) { Validate.paramNotNull(downloadRequest, "downloadFileRequest"); + GetObjectRequest getObjectRequestWithAttributes = attachSdkAttribute( + downloadRequest.getObjectRequest(), + b -> b.putExecutionAttribute(MULTIPART_DOWNLOAD_RESUME_CONTEXT, new MultipartDownloadResumeContext())); + DownloadFileRequest downloadFileRequestWithAttributes = + downloadRequest.copy(downloadFileRequest -> downloadFileRequest.getObjectRequest(getObjectRequestWithAttributes)); + AsyncResponseTransformer responseTransformer = - AsyncResponseTransformer.toFile(downloadRequest.destination(), + AsyncResponseTransformer.toFile(downloadFileRequestWithAttributes.destination(), FileTransformerConfiguration.defaultCreateOrReplaceExisting()); CompletableFuture returnFuture = new CompletableFuture<>(); - TransferProgressUpdater progressUpdater = doDownloadFile(downloadRequest, responseTransformer, returnFuture); + TransferProgressUpdater progressUpdater = doDownloadFile( + downloadFileRequestWithAttributes, responseTransformer, returnFuture); - return new DefaultFileDownload(returnFuture, progressUpdater.progress(), () -> downloadRequest, null); + return new DefaultFileDownload(returnFuture, progressUpdater.progress(), () -> downloadFileRequestWithAttributes, null); } private TransferProgressUpdater doDownloadFile( @@ -368,12 +397,16 @@ private TransferProgressUpdater doDownloadFile( TransferProgressUpdater progressUpdater = new TransferProgressUpdater(downloadRequest, null); try { progressUpdater.transferInitiated(); - responseTransformer = progressUpdater.wrapResponseTransformer(responseTransformer); + responseTransformer = isS3ClientMultipartEnabled() + ? progressUpdater.wrapResponseTransformerForMultipartDownload( + responseTransformer, downloadRequest.getObjectRequest()) + : progressUpdater.wrapResponseTransformer(responseTransformer); progressUpdater.registerCompletion(returnFuture); assertNotUnsupportedArn(downloadRequest.getObjectRequest().bucket(), "download"); - CompletableFuture future = doGetObject(downloadRequest.getObjectRequest(), responseTransformer); + CompletableFuture future = s3AsyncClient.getObject( + downloadRequest.getObjectRequest(), responseTransformer); // Forward download cancellation to future CompletableFutureUtils.forwardExceptionTo(returnFuture, future); @@ -391,6 +424,16 @@ private TransferProgressUpdater doDownloadFile( @Override public final FileDownload resumeDownloadFile(ResumableFileDownload resumableFileDownload) { Validate.paramNotNull(resumableFileDownload, "resumableFileDownload"); + + // check if the multipart-download was already completed and handle it gracefully. + Optional optCtx = + multipartDownloadResumeContext(resumableFileDownload.downloadFileRequest().getObjectRequest()); + if (optCtx.map(MultipartDownloadResumeContext::isComplete).orElse(false)) { + log.debug(() -> "The multipart download associated to the provided ResumableFileDownload is already completed, " + + "nothing to resume"); + return completedDownload(resumableFileDownload, optCtx.get()); + } + CompletableFuture returnFuture = new CompletableFuture<>(); DownloadFileRequest originalDownloadRequest = resumableFileDownload.downloadFileRequest(); GetObjectRequest getObjectRequest = originalDownloadRequest.getObjectRequest(); @@ -427,6 +470,20 @@ public final FileDownload resumeDownloadFile(ResumableFileDownload resumableFile resumableFileDownload); } + private FileDownload completedDownload(ResumableFileDownload resumableFileDownload, MultipartDownloadResumeContext ctx) { + CompletedFileDownload completedFileDownload = CompletedFileDownload.builder().response(ctx.response()).build(); + DefaultTransferProgressSnapshot completedProgressSnapshot = + DefaultTransferProgressSnapshot.builder() + .sdkResponse(ctx.response()) + .totalBytes(ctx.bytesToLastCompletedParts()) + .transferredBytes(resumableFileDownload.bytesTransferred()) + .build(); + return new DefaultFileDownload(CompletableFuture.completedFuture(completedFileDownload), + new DefaultTransferProgress(completedProgressSnapshot), + resumableFileDownload::downloadFileRequest, + resumableFileDownload); + } + private DownloadFileRequest newOrOriginalRequestForPause(CompletableFuture newDownloadFuture, DownloadFileRequest originalDownloadRequest) { try { @@ -552,14 +609,4 @@ private static boolean isMrapArn(Arn arn) { return !s3EndpointResource.region().isPresent(); } - - // TODO remove once MultipartS3AsyncClient is complete - private CompletableFuture doGetObject( - GetObjectRequest getObjectRequest, AsyncResponseTransformer asyncResponseTransformer) { - S3AsyncClient clientToUse = s3AsyncClient; - if (s3AsyncClient instanceof MultipartS3AsyncClient) { - clientToUse = (S3AsyncClient) ((DelegatingS3AsyncClient) s3AsyncClient).delegate(); - } - return clientToUse.getObject(getObjectRequest, asyncResponseTransformer); - } } diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerFactory.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerFactory.java index 7b4041804a29..429d534e074e 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerFactory.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerFactory.java @@ -21,7 +21,6 @@ import software.amazon.awssdk.core.internal.util.ClassLoaderHelper; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.internal.crt.S3CrtAsyncClient; -import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient; import software.amazon.awssdk.transfer.s3.S3TransferManager; import software.amazon.awssdk.utils.Logger; @@ -53,17 +52,15 @@ public static S3TransferManager createTransferManager(DefaultBuilder tmBuilder) return new CrtS3TransferManager(transferConfiguration, s3AsyncClient, isDefaultS3AsyncClient); } - if (s3AsyncClient.getClass().getName().equals("software.amazon.awssdk.services.s3.DefaultS3AsyncClient")) { - log.warn(() -> "The provided DefaultS3AsyncClient is not an instance of S3CrtAsyncClient, and thus multipart" - + " upload/download feature is not enabled and resumable file upload is not supported. To benefit " - + "from maximum throughput, consider using S3AsyncClient.crtBuilder().build() instead."); - } else if (s3AsyncClient instanceof MultipartS3AsyncClient) { - log.warn(() -> "The provided S3AsyncClient is an instance of MultipartS3AsyncClient, and thus multipart" - + " download feature is not enabled. To benefit from all features, " - + "consider using S3AsyncClient.crtBuilder().build() instead."); - } else { - log.debug(() -> "The provided S3AsyncClient is not an instance of S3CrtAsyncClient, and thus multipart" - + " upload/download feature may not be enabled and resumable file upload may not be supported."); + if (!s3AsyncClient.getClass().getName().equals("software.amazon.awssdk.services.s3.internal.multipart" + + ".MultipartS3AsyncClient")) { + log.debug(() -> "The provided S3AsyncClient is neither " + + "an AWS CRT-based S3 async client (S3AsyncClient.crtBuilder().build()) or " + + "a Java-based S3 async client (S3AsyncClient.builder().multipartEnabled(true).build()), " + + "and thus multipart upload/download feature may not be enabled and resumable file upload may not " + + "be supported. To benefit from maximum throughput, consider using " + + "S3AsyncClient.crtBuilder().build() or " + + "S3AsyncClient.builder().multipartEnabled(true).build() instead"); } return new GenericS3TransferManager(transferConfiguration, s3AsyncClient, isDefaultS3AsyncClient); diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/model/DefaultFileDownload.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/model/DefaultFileDownload.java index 22b3ee09e337..3932762c3089 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/model/DefaultFileDownload.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/model/DefaultFileDownload.java @@ -17,9 +17,11 @@ import java.io.File; import java.time.Instant; +import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.function.Supplier; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.services.s3.internal.multipart.MultipartDownloadUtils; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.transfer.s3.model.CompletedFileDownload; import software.amazon.awssdk.transfer.s3.model.DownloadFileRequest; @@ -81,12 +83,14 @@ private ResumableFileDownload doPause() { File destination = request.destination().toFile(); long length = destination.length(); Instant fileLastModified = Instant.ofEpochMilli(destination.lastModified()); + List completedParts = MultipartDownloadUtils.completedParts(request.getObjectRequest()); return ResumableFileDownload.builder() .downloadFileRequest(request) .s3ObjectLastModified(s3objectLastModified) .fileLastModified(fileLastModified) .bytesTransferred(length) .totalSizeInBytes(totalSizeInBytes) + .completedParts(completedParts) .build(); } diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParser.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParser.java new file mode 100644 index 000000000000..03e67c402a56 --- /dev/null +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParser.java @@ -0,0 +1,75 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.transfer.s3.internal.progress; + +import java.util.OptionalLong; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.StringUtils; + +/** + * Parse a Content-Range header value into a total byte count. The expected format is the following:

+ * {@code Content-Range: -\/}
+ * {@code Content-Range: -\/*}
{@code Content-Range: *\/}

+ *

+ * The only supported {@code } is the {@code bytes} value. + */ +@SdkInternalApi +public final class ContentRangeParser { + + private static final Logger log = Logger.loggerFor(ContentRangeParser.class); + + private ContentRangeParser() { + } + + /** + * Parse the Content-Range to extract the total number of byte from the content. Only supports the {@code bytes} unit, any + * other unit will result in an empty OptionalLong. If the total length in unknown, which is represented by a {@code *} symbol + * in the header value, an empty OptionalLong will be returned. + * + * @param contentRange the value of the Content-Range header to be parsed. + * @return The total number of bytes in the content range or an empty optional if the contentRange is null, empty or if the + * total length is not a valid long. + */ + public static OptionalLong totalBytes(String contentRange) { + if (StringUtils.isEmpty(contentRange)) { + return OptionalLong.empty(); + } + + String trimmed = contentRange.trim(); + if (!trimmed.startsWith("bytes")) { + return OptionalLong.empty(); + } + + int lastSlash = trimmed.lastIndexOf('/'); + if (lastSlash == -1) { + return OptionalLong.empty(); + } + + String totalBytes = trimmed.substring(lastSlash + 1); + if ("*".equals(totalBytes)) { + return OptionalLong.empty(); + } + + try { + long value = Long.parseLong(totalBytes); + return value > 0 ? OptionalLong.of(value) : OptionalLong.empty(); + } catch (NumberFormatException e) { + log.warn(() -> "failed to parse content range", e); + return OptionalLong.empty(); + } + } +} diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/TransferProgressUpdater.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/TransferProgressUpdater.java index 67970031c9aa..2cab45039d97 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/TransferProgressUpdater.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/TransferProgressUpdater.java @@ -28,6 +28,7 @@ import software.amazon.awssdk.core.async.listener.AsyncResponseTransformerListener; import software.amazon.awssdk.core.async.listener.PublisherListener; import software.amazon.awssdk.crt.s3.S3MetaRequestProgress; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.transfer.s3.model.CompletedObjectTransfer; import software.amazon.awssdk.transfer.s3.model.TransferObjectRequest; @@ -164,41 +165,42 @@ public void subscriberOnComplete() { }; } - public AsyncResponseTransformer wrapResponseTransformer( - AsyncResponseTransformer responseTransformer) { + public AsyncResponseTransformer wrapResponseTransformerForMultipartDownload( + AsyncResponseTransformer responseTransformer, GetObjectRequest request) { return AsyncResponseTransformerListener.wrap( responseTransformer, - new AsyncResponseTransformerListener() { + new BaseAsyncResponseTransformerListener() { @Override public void transformerOnResponse(GetObjectResponse response) { - if (response.contentLength() != null) { + // if the GetObjectRequest is a range-get, the Content-Length headers of the response needs to be used + // to update progress since the Content-Range would incorrectly upgrade progress with the whole object + // size. + if (request.range() != null) { + if (response.contentLength() != null) { progress.updateAndGet(b -> b.totalBytes(response.contentLength()).sdkResponse(response)); + } + } else { + // if the GetObjectRequest is not a range-get, it might be a part-get. In that case, we need to parse + // the Content-Range header to get the correct totalByte amount. + ContentRangeParser + .totalBytes(response.contentRange()) + .ifPresent(totalBytes -> progress.updateAndGet(b -> b.totalBytes(totalBytes).sdkResponse(response))); } } + } + ); + } + public AsyncResponseTransformer wrapResponseTransformer( + AsyncResponseTransformer responseTransformer) { + return AsyncResponseTransformerListener.wrap( + responseTransformer, + new BaseAsyncResponseTransformerListener() { @Override - public void transformerExceptionOccurred(Throwable t) { - transferFailed(t); - } - - @Override - public void publisherSubscribe(Subscriber subscriber) { - resetBytesTransferred(); - } - - @Override - public void subscriberOnNext(ByteBuffer byteBuffer) { - incrementBytesTransferred(byteBuffer.limit()); - } - - @Override - public void subscriberOnError(Throwable t) { - transferFailed(t); - } - - @Override - public void subscriberOnComplete() { - endOfStreamFuture.complete(null); + public void transformerOnResponse(GetObjectResponse response) { + if (response.contentLength() != null) { + progress.updateAndGet(b -> b.totalBytes(response.contentLength()).sdkResponse(response)); + } } }); } @@ -250,4 +252,39 @@ private void transferFailed(Throwable t) { .exception(t) .build()); } + + private class BaseAsyncResponseTransformerListener implements AsyncResponseTransformerListener { + @Override + public void transformerOnResponse(GetObjectResponse response) { + if (response.contentLength() != null) { + progress.updateAndGet(b -> b.totalBytes(response.contentLength()).sdkResponse(response)); + } + } + + @Override + public void transformerExceptionOccurred(Throwable t) { + transferFailed(t); + } + + @Override + public void publisherSubscribe(Subscriber subscriber) { + resetBytesTransferred(); + } + + @Override + public void subscriberOnNext(ByteBuffer byteBuffer) { + incrementBytesTransferred(byteBuffer.limit()); + } + + @Override + public void subscriberOnError(Throwable t) { + transferFailed(t); + } + + @Override + public void subscriberOnComplete() { + endOfStreamFuture.complete(null); + } + + } } diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/serialization/ResumableFileDownloadSerializer.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/serialization/ResumableFileDownloadSerializer.java index 0b05234a34e1..3a60f7edfbbb 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/serialization/ResumableFileDownloadSerializer.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/serialization/ResumableFileDownloadSerializer.java @@ -63,6 +63,7 @@ public static byte[] toJson(ResumableFileDownload download) { "s3ObjectLastModified"); } marshallDownloadFileRequest(download.downloadFileRequest(), jsonGenerator); + TransferManagerJsonMarshaller.LIST.marshall(download.completedParts(), jsonGenerator, "completedParts"); jsonGenerator.writeEndObject(); return jsonGenerator.getBytes(); @@ -138,7 +139,9 @@ private static ResumableFileDownload fromNodes(Map downloadNod builder.s3ObjectLastModified(instantUnmarshaller.unmarshall(downloadNodes.get("s3ObjectLastModified"))); } builder.downloadFileRequest(parseDownloadFileRequest(downloadNodes.get("downloadFileRequest"))); - + if (downloadNodes.get("completedParts") != null) { + builder.completedParts(TransferManagerJsonUnmarshaller.LIST_INT.unmarshall(downloadNodes.get("completedParts"))); + } return builder.build(); } diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/serialization/TransferManagerJsonUnmarshaller.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/serialization/TransferManagerJsonUnmarshaller.java index 1316444e2394..da61e48cf6c3 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/serialization/TransferManagerJsonUnmarshaller.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/serialization/TransferManagerJsonUnmarshaller.java @@ -20,8 +20,10 @@ import java.math.BigDecimal; import java.time.Instant; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.function.Function; +import java.util.stream.Collectors; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.SdkField; @@ -99,6 +101,21 @@ public Map unmarshall(String content, SdkField field) { } }; + TransferManagerJsonUnmarshaller> LIST_INT = new TransferManagerJsonUnmarshaller>() { + @Override + public List unmarshall(JsonNode jsonContent, SdkField field) { + if (jsonContent == null) { + return null; + } + return jsonContent.asArray().stream().map(INTEGER::unmarshall).collect(Collectors.toList()); + } + + @Override + public List unmarshall(String content, SdkField field) { + return unmarshall(JsonNode.parser().parse(content), field); + } + }; + default T unmarshall(JsonNode jsonContent, SdkField field) { return jsonContent != null && !jsonContent.isNull() ? unmarshall(jsonContent.text(), field) : null; } diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/utils/ResumableRequestConverter.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/utils/ResumableRequestConverter.java index ac08219add19..4236ece0be23 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/utils/ResumableRequestConverter.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/utils/ResumableRequestConverter.java @@ -15,12 +15,17 @@ package software.amazon.awssdk.transfer.s3.internal.utils; +import static software.amazon.awssdk.core.FileTransformerConfiguration.FailureBehavior.LEAVE; +import static software.amazon.awssdk.core.FileTransformerConfiguration.FileWriteOption.WRITE_TO_POSITION; import static software.amazon.awssdk.transfer.s3.internal.utils.FileUtils.fileNotModified; import java.time.Instant; +import java.util.Optional; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.FileTransformerConfiguration; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.services.s3.internal.multipart.MultipartDownloadResumeContext; +import software.amazon.awssdk.services.s3.internal.multipart.MultipartDownloadUtils; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.HeadObjectResponse; @@ -39,39 +44,76 @@ private ResumableRequestConverter() { /** * Converts a {@link ResumableFileDownload} to {@link DownloadFileRequest} and {@link AsyncResponseTransformer} pair. + *

+ * If before resuming the download the file on disk was modified, or the s3 object was modified, we need to restart the + * download from the beginning. + *

+ * If the original requests has some individual parts downloaded, we need to make a multipart GET for the remaining parts. + *

+ * Else, we need to make a ranged GET for the remaining bytes. */ public static Pair> - toDownloadFileRequestAndTransformer(ResumableFileDownload resumableFileDownload, - HeadObjectResponse headObjectResponse, - DownloadFileRequest originalDownloadRequest) { + toDownloadFileRequestAndTransformer(ResumableFileDownload resumableFileDownload, + HeadObjectResponse headObjectResponse, + DownloadFileRequest originalDownloadRequest) { GetObjectRequest getObjectRequest = originalDownloadRequest.getObjectRequest(); DownloadFileRequest newDownloadFileRequest; - boolean shouldAppend; Instant lastModified = resumableFileDownload.s3ObjectLastModified().orElse(null); - boolean s3ObjectNotModified = headObjectResponse.lastModified().equals(lastModified); - - boolean fileNotModified = fileNotModified(resumableFileDownload.bytesTransferred(), - resumableFileDownload.fileLastModified(), resumableFileDownload.downloadFileRequest().destination()); - - if (fileNotModified && s3ObjectNotModified) { - newDownloadFileRequest = resumedDownloadFileRequest(resumableFileDownload, - originalDownloadRequest, - getObjectRequest, - headObjectResponse); - shouldAppend = true; - } else { - logIfNeeded(originalDownloadRequest, getObjectRequest, fileNotModified, s3ObjectNotModified); - shouldAppend = false; - newDownloadFileRequest = newDownloadFileRequest(originalDownloadRequest, getObjectRequest, - headObjectResponse); + boolean s3ObjectModified = !headObjectResponse.lastModified().equals(lastModified); + + boolean fileModified = !fileNotModified(resumableFileDownload.bytesTransferred(), + resumableFileDownload.fileLastModified(), + resumableFileDownload.downloadFileRequest().destination()); + + if (fileModified || s3ObjectModified) { + // modification detected: new download request for the whole object from the beginning + logIfNeeded(originalDownloadRequest, getObjectRequest, fileModified, s3ObjectModified); + newDownloadFileRequest = newDownloadFileRequest(originalDownloadRequest, getObjectRequest, headObjectResponse); + + AsyncResponseTransformer responseTransformer = + fileAsyncResponseTransformer(newDownloadFileRequest, false); + return Pair.of(newDownloadFileRequest, responseTransformer); + } + + if (hasRemainingParts(getObjectRequest)) { + // multipart GET for the remaining parts + Long positionToWriteFrom = + MultipartDownloadUtils.multipartDownloadResumeContext(originalDownloadRequest.getObjectRequest()) + .map(MultipartDownloadResumeContext::bytesToLastCompletedParts) + .orElse(0L); + AsyncResponseTransformer responseTransformer = + AsyncResponseTransformer.toFile(originalDownloadRequest.destination(), + FileTransformerConfiguration.builder() + .fileWriteOption(WRITE_TO_POSITION) + .position(positionToWriteFrom) + .failureBehavior(LEAVE) + .build()); + return Pair.of(originalDownloadRequest, responseTransformer); } + // ranged GET for the remaining bytes. + newDownloadFileRequest = resumedDownloadFileRequest(resumableFileDownload, + originalDownloadRequest, + getObjectRequest, + headObjectResponse); AsyncResponseTransformer responseTransformer = - fileAsyncResponseTransformer(newDownloadFileRequest, shouldAppend); + fileAsyncResponseTransformer(newDownloadFileRequest, true); return Pair.of(newDownloadFileRequest, responseTransformer); } + private static boolean hasRemainingParts(GetObjectRequest getObjectRequest) { + Optional optCtx = MultipartDownloadUtils.multipartDownloadResumeContext(getObjectRequest); + if (!optCtx.isPresent()) { + return false; + } + MultipartDownloadResumeContext ctx = optCtx.get(); + if (ctx.response() != null && ctx.response().partsCount() == null) { + return false; + } + return !ctx.completedParts().isEmpty(); + } + private static AsyncResponseTransformer fileAsyncResponseTransformer( DownloadFileRequest newDownloadFileRequest, boolean shouldAppend) { @@ -85,10 +127,10 @@ private static AsyncResponseTransformer fi private static void logIfNeeded(DownloadFileRequest downloadRequest, GetObjectRequest getObjectRequest, - boolean fileNotModified, - boolean s3ObjectNotModified) { + boolean fileModified, + boolean s3ObjectModified) { if (log.logger().isDebugEnabled()) { - if (!s3ObjectNotModified) { + if (s3ObjectModified) { log.debug(() -> String.format("The requested object in bucket (%s) with key (%s) " + "has been modified on Amazon S3 since the last " + "pause. The SDK will download the S3 object from " @@ -96,7 +138,7 @@ private static void logIfNeeded(DownloadFileRequest downloadRequest, getObjectRequest.bucket(), getObjectRequest.key())); } - if (!fileNotModified) { + if (fileModified) { log.debug(() -> String.format("The file (%s) has been modified since " + "the last pause. " + "The SDK will download the requested object in bucket" diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/model/ResumableFileDownload.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/model/ResumableFileDownload.java index 83b6213bd948..bfc61092f4a0 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/model/ResumableFileDownload.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/model/ResumableFileDownload.java @@ -23,6 +23,9 @@ import java.nio.file.Files; import java.nio.file.Path; import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.OptionalLong; @@ -61,6 +64,7 @@ public final class ResumableFileDownload implements ResumableTransfer, private final Instant s3ObjectLastModified; private final Long totalSizeInBytes; private final Instant fileLastModified; + private final List completedParts; private ResumableFileDownload(DefaultBuilder builder) { this.downloadFileRequest = Validate.paramNotNull(builder.downloadFileRequest, "downloadFileRequest"); @@ -69,6 +73,8 @@ private ResumableFileDownload(DefaultBuilder builder) { this.s3ObjectLastModified = builder.s3ObjectLastModified; this.totalSizeInBytes = Validate.isPositiveOrNull(builder.totalSizeInBytes, "totalSizeInBytes"); this.fileLastModified = builder.fileLastModified; + List compledPartsList = Validate.getOrDefault(builder.completedParts, Collections::emptyList); + this.completedParts = Collections.unmodifiableList(new ArrayList<>(compledPartsList)); } @Override @@ -94,6 +100,9 @@ public boolean equals(Object o) { if (!Objects.equals(fileLastModified, that.fileLastModified)) { return false; } + if (!Objects.equals(completedParts, that.completedParts)) { + return false; + } return Objects.equals(totalSizeInBytes, that.totalSizeInBytes); } @@ -104,6 +113,7 @@ public int hashCode() { result = 31 * result + (s3ObjectLastModified != null ? s3ObjectLastModified.hashCode() : 0); result = 31 * result + (fileLastModified != null ? fileLastModified.hashCode() : 0); result = 31 * result + (totalSizeInBytes != null ? totalSizeInBytes.hashCode() : 0); + result = 31 * result + (completedParts != null ? completedParts.hashCode() : 0); return result; } @@ -149,6 +159,15 @@ public OptionalLong totalSizeInBytes() { return totalSizeInBytes == null ? OptionalLong.empty() : OptionalLong.of(totalSizeInBytes); } + /** + * The lists of parts that were successfully completed and saved to the file. Non-empty only for multipart downloads. + * + * @return part numbers of a multipart download that were completed saved to file. + */ + public List completedParts() { + return completedParts; + } + @Override public String toString() { return ToString.builder("ResumableFileDownload") @@ -157,6 +176,7 @@ public String toString() { .add("s3ObjectLastModified", s3ObjectLastModified) .add("totalSizeInBytes", totalSizeInBytes) .add("downloadFileRequest", downloadFileRequest) + .add("completedParts", completedParts) .build(); } @@ -318,6 +338,14 @@ default ResumableFileDownload.Builder downloadFileRequest(Consumer completedParts); } private static final class DefaultBuilder implements Builder { @@ -327,6 +355,7 @@ private static final class DefaultBuilder implements Builder { private Instant s3ObjectLastModified; private Long totalSizeInBytes; private Instant fileLastModified; + private List completedParts; private DefaultBuilder() { } @@ -337,6 +366,7 @@ private DefaultBuilder(ResumableFileDownload persistableFileDownload) { this.totalSizeInBytes = persistableFileDownload.totalSizeInBytes; this.fileLastModified = persistableFileDownload.fileLastModified; this.s3ObjectLastModified = persistableFileDownload.s3ObjectLastModified; + this.completedParts = persistableFileDownload.completedParts; } @Override @@ -369,6 +399,12 @@ public Builder fileLastModified(Instant fileLastModified) { return this; } + @Override + public Builder completedParts(List completedParts) { + this.completedParts = Collections.unmodifiableList(completedParts); + return this; + } + @Override public ResumableFileDownload build() { return new ResumableFileDownload(this); diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/MultipartDownloadJavaBasedTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/MultipartDownloadJavaBasedTest.java deleted file mode 100644 index 1b5c1063239f..000000000000 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/MultipartDownloadJavaBasedTest.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.transfer.s3.internal; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import java.nio.file.Paths; -import java.util.concurrent.CompletableFuture; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.core.async.AsyncResponseTransformer; -import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient; -import software.amazon.awssdk.services.s3.model.GetObjectRequest; -import software.amazon.awssdk.services.s3.model.GetObjectResponse; -import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration; -import software.amazon.awssdk.transfer.s3.S3TransferManager; -import software.amazon.awssdk.transfer.s3.model.CompletedFileDownload; - -class MultipartDownloadJavaBasedTest { - private S3AsyncClient mockDelegate; - private MultipartS3AsyncClient s3Multi; - private S3TransferManager tm; - private UploadDirectoryHelper uploadDirectoryHelper; - private DownloadDirectoryHelper downloadDirectoryHelper; - private TransferManagerConfiguration configuration; - - @BeforeEach - public void methodSetup() { - mockDelegate = mock(S3AsyncClient.class); - s3Multi = MultipartS3AsyncClient.create(mockDelegate, MultipartConfiguration.builder().build()); - uploadDirectoryHelper = mock(UploadDirectoryHelper.class); - configuration = mock(TransferManagerConfiguration.class); - downloadDirectoryHelper = mock(DownloadDirectoryHelper.class); - tm = new GenericS3TransferManager(s3Multi, uploadDirectoryHelper, configuration, downloadDirectoryHelper); - } - - @Test - void usingMultipartDownload_shouldNotThrowException() { - GetObjectResponse response = GetObjectResponse.builder().build(); - when(mockDelegate.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) - .thenReturn(CompletableFuture.completedFuture(response)); - - CompletedFileDownload completedFileDownload = tm.downloadFile(d -> d.getObjectRequest(g -> g.bucket("bucket") - .key("key")) - .destination(Paths.get("."))) - .completionFuture() - .join(); - assertThat(completedFileDownload.response()).isEqualTo(response); - } -} diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/MultipartDownloadResumeContextTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/MultipartDownloadResumeContextTest.java new file mode 100644 index 000000000000..12659b2e6f06 --- /dev/null +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/MultipartDownloadResumeContextTest.java @@ -0,0 +1,95 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.transfer.s3.internal; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.MULTIPART_DOWNLOAD_RESUME_CONTEXT; + +import java.nio.file.Paths; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.function.Predicate; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.transfer.s3.S3TransferManager; +import software.amazon.awssdk.transfer.s3.model.DownloadFileRequest; +import software.amazon.awssdk.transfer.s3.model.FileDownload; +import software.amazon.awssdk.transfer.s3.model.ResumableFileDownload; + +public class MultipartDownloadResumeContextTest { + + S3AsyncClient s3; + S3TransferManager tm; + + @BeforeEach + void init() { + this.s3 = mock(S3AsyncClient.class); + this.tm = new GenericS3TransferManager(s3, + mock(UploadDirectoryHelper.class), + mock(TransferManagerConfiguration.class), + mock(DownloadDirectoryHelper.class)); + } + + @Test + void pauseAndResume_shouldKeepMultipartContext() { + CompletableFuture future = new CompletableFuture<>(); + when(s3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(future); + when(s3.headObject(any(Consumer.class))) + .thenReturn(new CompletableFuture<>()); + + GetObjectRequest req = GetObjectRequest.builder().key("key").bucket("bucket").build(); + + FileDownload dl = tm.downloadFile( + DownloadFileRequest.builder() + .destination(Paths.get("some", "path")) + .getObjectRequest(req) + .build()); + ResumableFileDownload resume = dl.pause(); + + assertThat(resume.downloadFileRequest().getObjectRequest()) + .matches(hasMultipartContextAttribute(), "[1] hasMultipartContextAttribute"); + + FileDownload dl2 = tm.resumeDownloadFile(resume); + ResumableFileDownload resume2 = dl2.pause(); + + assertThat(resume2.downloadFileRequest().getObjectRequest()) + .matches(hasMultipartContextAttribute(), "[2] hasMultipartContextAttribute"); + } + + private Predicate hasMultipartContextAttribute() { + return getObjectRequest -> { + if (!getObjectRequest.overrideConfiguration().isPresent()) { + return false; + } + + return getObjectRequest.overrideConfiguration() + .get() + .executionAttributes() + .getAttribute(MULTIPART_DOWNLOAD_RESUME_CONTEXT) + != null; + }; + } + +} diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerListenerTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerListenerTest.java index a9b7529cd9c1..9f8fcc839207 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerListenerTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerListenerTest.java @@ -17,6 +17,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -57,6 +58,7 @@ import software.amazon.awssdk.transfer.s3.model.FileDownload; import software.amazon.awssdk.transfer.s3.model.FileUpload; import software.amazon.awssdk.transfer.s3.S3TransferManager; +import software.amazon.awssdk.transfer.s3.model.TransferObjectRequest; import software.amazon.awssdk.transfer.s3.model.Upload; import software.amazon.awssdk.transfer.s3.model.UploadFileRequest; import software.amazon.awssdk.transfer.s3.model.UploadRequest; @@ -186,7 +188,7 @@ public void downloadFile_success_shouldInvokeListener() throws Exception { ArgumentCaptor.forClass(TransferListener.Context.TransferInitiated.class); verify(listener, timeout(1000).times(1)).transferInitiated(captor1.capture()); TransferListener.Context.TransferInitiated ctx1 = captor1.getValue(); - assertThat(ctx1.request()).isSameAs(downloadRequest); + assertDownloadRequest((DownloadFileRequest) ctx1.request(), downloadRequest); // transferSize is not known until we receive GetObjectResponse header assertThat(ctx1.progressSnapshot().totalBytes()).isNotPresent(); assertThat(ctx1.progressSnapshot().transferredBytes()).isZero(); @@ -195,7 +197,7 @@ public void downloadFile_success_shouldInvokeListener() throws Exception { ArgumentCaptor.forClass(TransferListener.Context.BytesTransferred.class); verify(listener, timeout(1000).times(1)).bytesTransferred(captor2.capture()); TransferListener.Context.BytesTransferred ctx2 = captor2.getValue(); - assertThat(ctx2.request()).isSameAs(downloadRequest); + assertDownloadRequest((DownloadFileRequest) ctx2.request(), downloadRequest); // transferSize should now be known assertThat(ctx2.progressSnapshot().totalBytes()).hasValue(contentLength); assertThat(ctx2.progressSnapshot().transferredBytes()).isPositive(); @@ -204,7 +206,7 @@ public void downloadFile_success_shouldInvokeListener() throws Exception { ArgumentCaptor.forClass(TransferListener.Context.TransferComplete.class); verify(listener, timeout(1000).times(1)).transferComplete(captor3.capture()); TransferListener.Context.TransferComplete ctx3 = captor3.getValue(); - assertThat(ctx3.request()).isSameAs(downloadRequest); + assertDownloadRequest((DownloadFileRequest) ctx3.request(), downloadRequest); assertThat(ctx3.progressSnapshot().totalBytes()).hasValue(contentLength); assertThat(ctx3.progressSnapshot().transferredBytes()).isEqualTo(contentLength); assertThat(ctx3.completedTransfer()).isSameAs(download.completionFuture().get()); @@ -213,6 +215,12 @@ public void downloadFile_success_shouldInvokeListener() throws Exception { verifyNoMoreInteractions(listener); } + private void assertDownloadRequest(DownloadFileRequest actual, DownloadFileRequest expected) { + assertThat(actual.destination()).isEqualTo(expected.destination()); + assertThat(actual.transferListeners()).isEqualTo(expected.transferListeners()); + assertTrue(actual.getObjectRequest().equalsBySdkFields(expected.getObjectRequest())); + } + @Test public void download_success_shouldInvokeListener() throws Exception { TransferListener listener = mock(TransferListener.class); diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerTest.java index 50e05abbafaf..6dd09db8acfa 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerTest.java @@ -25,9 +25,11 @@ import java.nio.file.Paths; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; @@ -217,6 +219,25 @@ void download_cancel_shouldForwardCancellation() { assertThat(s3CrtFuture).isCancelled(); } + @Test + @Timeout(value = 5, unit = TimeUnit.SECONDS) + void download_futureReturnsNull_doesNotHang() { + AsyncResponseTransformer mockTr = mock(AsyncResponseTransformer.class); + CompletableFuture returnMockFuture = new CompletableFuture<>(); + when(mockS3Crt.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(returnMockFuture); + DownloadRequest downloadRequest = + DownloadRequest.builder() + .getObjectRequest(g -> g.bucket("bucket").key("key")) + .responseTransformer(mockTr).build(); + + CompletableFuture> future = tm.download(downloadRequest).completionFuture(); + returnMockFuture.complete(null); + assertThatThrownBy(future::join) + .hasCauseInstanceOf(NullPointerException.class) + .hasMessageContaining("result must not be null"); + } + @Test void objectLambdaArnBucketProvided_shouldThrowException() { String objectLambdaArn = "arn:xxx:s3-object-lambda"; diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerLoggingTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerLoggingTest.java index 4f2be4d00063..ead70bc22a5f 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerLoggingTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerLoggingTest.java @@ -17,21 +17,28 @@ import static org.assertj.core.api.Assertions.assertThat; -import java.util.HashSet; import java.util.List; -import java.util.Set; +import java.util.function.Predicate; import org.apache.logging.log4j.Level; import org.apache.logging.log4j.core.LogEvent; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.testutils.LogCaptor; import software.amazon.awssdk.transfer.s3.S3TransferManager; +import software.amazon.awssdk.utils.internal.SystemSettingUtilsTestBackdoor; class TransferManagerLoggingTest { + private static final String EXPECTED_DEBUG_MESSAGE = "The provided S3AsyncClient is neither " + + "an AWS CRT-based S3 async client (S3AsyncClient.crtBuilder().build()) or " + + "a Java-based S3 async client (S3AsyncClient.builder().multipartEnabled(true).build()), " + + "and thus multipart upload/download feature may not be enabled and resumable file upload may not " + + "be supported. To benefit from maximum throughput, consider using " + + "S3AsyncClient.crtBuilder().build() or " + + "S3AsyncClient.builder().multipartEnabled(true).build() instead"; + @Test void transferManager_withCrtClient_shouldNotLogWarnMessages() { @@ -42,26 +49,87 @@ void transferManager_withCrtClient_shouldNotLogWarnMessages() { LogCaptor logCaptor = LogCaptor.create(Level.WARN); S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build()) { List events = logCaptor.loggedEvents(); - assertThat(events).isEmpty(); + assertThat(events) + .withFailMessage("A message from S3TransferManager was logged at DEBUG level when none was expected") + .noneMatch(loggedFromS3TransferManager()); } } @Test - void transferManager_withJavaClient_shouldLogWarnMessage() { + void transferManager_withJavaClientMultiPartNotSet_shouldLogDebugMessage() { + try (S3AsyncClient s3Crt = S3AsyncClient.builder() + .region(Region.US_WEST_2) + .credentialsProvider(() -> AwsBasicCredentials.create("foo", "bar")) + .build(); + LogCaptor logCaptor = LogCaptor.create(Level.DEBUG); + S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build()) { + List events = logCaptor.loggedEvents(); + assertLogged(events, Level.DEBUG, EXPECTED_DEBUG_MESSAGE); + } + } + + @Test + void transferManager_withJavaClientMultiPartEqualsFalse_shouldLogDebugMessage() { try (S3AsyncClient s3Crt = S3AsyncClient.builder() .region(Region.US_WEST_2) .credentialsProvider(() -> AwsBasicCredentials.create("foo", "bar")) + .multipartEnabled(false) .build(); - LogCaptor logCaptor = LogCaptor.create(Level.WARN); + LogCaptor logCaptor = LogCaptor.create(Level.DEBUG); + S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build()) { + List events = logCaptor.loggedEvents(); + assertLogged(events, Level.DEBUG, EXPECTED_DEBUG_MESSAGE); + } + } + + @Test + void transferManager_withDefaultClient_shouldNotLogDebugMessage() { + + SystemSettingUtilsTestBackdoor.addEnvironmentVariableOverride("AWS_REGION", "us-east-1"); + try (LogCaptor logCaptor = LogCaptor.create(Level.DEBUG); + S3TransferManager tm = S3TransferManager.builder().build()) { + List events = logCaptor.loggedEvents(); + assertThat(events) + .withFailMessage("A message from S3TransferManager was logged at DEBUG level when none was expected") + .noneMatch(loggedFromS3TransferManager()); + } + SystemSettingUtilsTestBackdoor.clearEnvironmentVariableOverrides(); + } + + @Test + void transferManager_withMultiPartEnabledJavaClient_shouldNotLogDebugMessage() { + + try (S3AsyncClient s3Crt = S3AsyncClient.builder() + .region(Region.US_WEST_2) + .credentialsProvider(() -> AwsBasicCredentials.create("foo", "bar")) + .multipartEnabled(true) + .build(); + LogCaptor logCaptor = LogCaptor.create(Level.DEBUG); S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build()) { List events = logCaptor.loggedEvents(); - assertLogged(events, Level.WARN, "The provided DefaultS3AsyncClient is not an instance of S3CrtAsyncClient, and " - + "thus multipart upload/download feature is not enabled and resumable file upload" - + " is " - + "not supported. To benefit from maximum throughput, consider using " - + "S3AsyncClient.crtBuilder().build() instead."); + assertThat(events) + .withFailMessage("A message from S3TransferManager was logged at DEBUG level when none was expected") + .noneMatch(loggedFromS3TransferManager()); + } + } + + @Test + void transferManager_withMultiPartEnabledAndCrossRegionEnabledJavaClient_shouldNotLogDebugMessage() { + + try (S3AsyncClient s3Crt = S3AsyncClient.builder() + .region(Region.US_WEST_2) + .credentialsProvider(() -> AwsBasicCredentials.create("foo", "bar")) + .multipartEnabled(true) + .crossRegionAccessEnabled(true) + .build(); + LogCaptor logCaptor = LogCaptor.create(Level.DEBUG); + S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build()) { + List events = logCaptor.loggedEvents(); + assertThat(events) + .withFailMessage("A message from S3TransferManager was logged at DEBUG level when none was expected") + .noneMatch(loggedFromS3TransferManager()); } } @@ -72,4 +140,9 @@ private static void assertLogged(List events, org.apache.logging.log4j assertThat(msg).isEqualTo(message); assertThat(event.getLevel()).isEqualTo(level); } + + private static Predicate loggedFromS3TransferManager() { + String tmLoggerName = "software.amazon.awssdk.transfer.s3.S3TransferManager"; + return logEvent -> tmLoggerName.equals(logEvent.getLoggerName()); + } } diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferProgressUpdaterTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferProgressUpdaterTest.java index da4df5d13564..8c4ed3f4f725 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferProgressUpdaterTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferProgressUpdaterTest.java @@ -15,6 +15,7 @@ package software.amazon.awssdk.transfer.s3.internal; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -29,6 +30,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Stream; import org.apache.commons.lang3.RandomStringUtils; @@ -37,12 +39,17 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; import org.reactivestreams.Subscriber; import software.amazon.awssdk.core.SdkResponse; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.http.async.SimpleSubscriber; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.testutils.RandomTempFile; import software.amazon.awssdk.transfer.s3.CaptureTransferListener; @@ -51,6 +58,8 @@ import software.amazon.awssdk.transfer.s3.model.TransferObjectRequest; import software.amazon.awssdk.transfer.s3.progress.LoggingTransferListener; import software.amazon.awssdk.transfer.s3.progress.TransferListener; +import software.amazon.awssdk.transfer.s3.progress.TransferProgressSnapshot; +import software.amazon.awssdk.utils.async.SimplePublisher; class TransferProgressUpdaterTest { private static final long OBJ_SIZE = 16 * MB; @@ -151,6 +160,99 @@ void transferFailedWhenSubscriptionErrors() throws Exception { Mockito.verify(mockListener, never()).transferComplete(ArgumentMatchers.any(TransferListener.Context.TransferComplete.class)); } + @ParameterizedTest + @ValueSource(longs = {8, 16, 31, 32, 33, 1024, Long.MAX_VALUE}) + void transferProgressUpdater_useContentRangeForTotalBytes(long contentLength) { + TransferObjectRequest unusedMockTransferRequest = Mockito.mock(TransferObjectRequest.class); + TransferProgressUpdater transferProgressUpdater = new TransferProgressUpdater(unusedMockTransferRequest, null); + AsyncResponseTransformer transformer = + transferProgressUpdater.wrapResponseTransformerForMultipartDownload( + new AsyncResponseTransformer() { + @Override + public CompletableFuture prepare() { + return new CompletableFuture<>(); + } + + @Override + public void onResponse(GetObjectResponse response) { + // noop, test only + } + + @Override + public void onStream(SdkPublisher publisher) { + publisher.subscribe(b -> { /* do nothing, test only */ }); + } + + @Override + public void exceptionOccurred(Throwable error) { + // noop, test only + } + }, GetObjectRequest.builder().build()); + transformer.prepare(); + transformer.onResponse(GetObjectResponse.builder() + .contentRange("bytes 0-127/" + contentLength) + .build()); + TransferProgressSnapshot snapshot = transferProgressUpdater.progress().snapshot(); + assertThat(snapshot.totalBytes()).isPresent(); + assertThat(snapshot.totalBytes().getAsLong()).isEqualTo(contentLength); + + // simulate sending bytes + SimplePublisher publisher = new SimplePublisher<>(); + transformer.onStream(SdkPublisher.adapt(publisher)); + assertThat(transferProgressUpdater.progress().snapshot().transferredBytes()).isEqualTo(0L); + + publisher.send(ByteBuffer.wrap(new byte[] {0, 1, 2, 3, 4, 5, 6, 7})).join(); + assertThat(transferProgressUpdater.progress().snapshot().totalBytes().getAsLong()).isEqualTo(contentLength); + assertThat(transferProgressUpdater.progress().snapshot().transferredBytes()).isEqualTo(8L); + } + + @ParameterizedTest + @ValueSource(longs = {8, 16, 31, 32, 33, 1024, Long.MAX_VALUE}) + void transferProgressUpdater_useContentLengthWhenRangeGet(long contentLength) { + TransferObjectRequest unusedMockTransferRequest = Mockito.mock(TransferObjectRequest.class); + TransferProgressUpdater transferProgressUpdater = new TransferProgressUpdater(unusedMockTransferRequest, null); + AsyncResponseTransformer transformer = + transferProgressUpdater.wrapResponseTransformerForMultipartDownload( + new AsyncResponseTransformer() { + @Override + public CompletableFuture prepare() { + return new CompletableFuture<>(); + } + + @Override + public void onResponse(GetObjectResponse response) { + // noop, test only + } + + @Override + public void onStream(SdkPublisher publisher) { + publisher.subscribe(b -> { /* do nothing, test only */ }); + } + + @Override + public void exceptionOccurred(Throwable error) { + // noop, test only + } + }, GetObjectRequest.builder().range("bytes=0-" + contentLength).build()); + transformer.prepare(); + transformer.onResponse(GetObjectResponse.builder() + .contentLength(contentLength) + .build()); + TransferProgressSnapshot snapshot = transferProgressUpdater.progress().snapshot(); + assertThat(snapshot.totalBytes()).isPresent(); + assertThat(snapshot.totalBytes().getAsLong()).isEqualTo(contentLength); + + // simulate sending bytes + SimplePublisher publisher = new SimplePublisher<>(); + transformer.onStream(SdkPublisher.adapt(publisher)); + assertThat(transferProgressUpdater.progress().snapshot().transferredBytes()).isEqualTo(0L); + + publisher.send(ByteBuffer.wrap(new byte[] {0, 1, 2, 3, 4, 5, 6, 7})).join(); + assertThat(transferProgressUpdater.progress().snapshot().totalBytes().getAsLong()).isEqualTo(contentLength); + assertThat(transferProgressUpdater.progress().snapshot().transferredBytes()).isEqualTo(8L); + + } + private static class ExceptionThrowingByteArrayInputStream extends ByteArrayInputStream { private final int exceptionPosition; diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParserTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParserTest.java new file mode 100644 index 000000000000..6dc7a7fc3ce8 --- /dev/null +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParserTest.java @@ -0,0 +1,54 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.transfer.s3.internal.progress; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.OptionalLong; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class ContentRangeParserTest { + + private ContentRangeParser parser; + + @ParameterizedTest + @MethodSource("argumentProvider") + void testContentRangeParser(String contentRange, OptionalLong expected) { + assertThat(ContentRangeParser.totalBytes(contentRange)).isEqualTo(expected); + } + + static Stream argumentProvider() { + return Stream.of( + Arguments.of(null, OptionalLong.empty()), + Arguments.of("", OptionalLong.empty()), + Arguments.of("bytes 0-0/1", OptionalLong.of(1)), + Arguments.of("bytes 1-2/3", OptionalLong.of(3)), + Arguments.of("bytes 0-23456/890890890", OptionalLong.of(890890890)), + Arguments.of("bytes 1023-81204/890890890", OptionalLong.of(890890890)), + Arguments.of("bytes 1023-81204/999999999999999999999999999999", OptionalLong.empty()), + Arguments.of("bytes 1023-81204/-1234", OptionalLong.empty()), + Arguments.of("bytes 1023-81204/not-a-number", OptionalLong.empty()), + Arguments.of("bytes 1-2/*", OptionalLong.empty()), + Arguments.of("mib 1-2/3", OptionalLong.empty()), + Arguments.of("mib/bla 1-2/3", OptionalLong.empty()), + Arguments.of("bla bla bla", OptionalLong.empty())); + } + +} \ No newline at end of file diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/serialization/ResumableFileDownloadSerializerTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/serialization/ResumableFileDownloadSerializerTest.java index d86e8cece0d7..f9792a816bcf 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/serialization/ResumableFileDownloadSerializerTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/serialization/ResumableFileDownloadSerializerTest.java @@ -67,7 +67,7 @@ class ResumableFileDownloadSerializerTest { @ParameterizedTest @MethodSource("downloadObjects") - void serializeDeserialize_ShouldWorkForAllDownloads(ResumableFileDownload download) { + void serializeDeserialize_ShouldWorkForAllDownloads(ResumableFileDownload download) { byte[] serializedDownload = ResumableFileDownloadSerializer.toJson(download); ResumableFileDownload deserializedDownload = ResumableFileDownloadSerializer.fromJson(serializedDownload); @@ -75,7 +75,7 @@ void serializeDeserialize_ShouldWorkForAllDownloads(ResumableFileDownload downlo } @Test - void serializeDeserialize_fromStoredString_ShouldWork() { + void serializeDeserialize_fromStoredString_ShouldWork() { ResumableFileDownload download = ResumableFileDownload.builder() .downloadFileRequest(d -> d.destination(Paths.get("test/request")) @@ -95,12 +95,12 @@ void serializeDeserialize_fromStoredString_ShouldWork() { } @Test - void serializeDeserialize_DoesNotPersistConfiguration() { + void serializeDeserialize_DoesNotPersistConfiguration() { ResumableFileDownload download = ResumableFileDownload.builder() .downloadFileRequest(d -> d.destination(PATH) .getObjectRequest(GET_OBJECT_REQUESTS.get("STANDARD")) - .addTransferListener(LoggingTransferListener.create())) + .addTransferListener(LoggingTransferListener.create())) .bytesTransferred(1000L) .build(); @@ -113,7 +113,7 @@ void serializeDeserialize_DoesNotPersistConfiguration() { } @Test - void serializeDeserialize_DoesNotPersistRequestOverrideConfiguration() { + void serializeDeserialize_DoesNotPersistRequestOverrideConfiguration() { GetObjectRequest requestWithOverride = GetObjectRequest.builder() .bucket("BUCKET") @@ -139,6 +139,28 @@ void serializeDeserialize_DoesNotPersistRequestOverrideConfiguration() { assertThat(deserializedDownload).isEqualTo(download.copy(d -> d.downloadFileRequest(fileRequestCopy))); } + @Test + void serializeDeserialize_withCompletedParts_persistCompletedParts() { + ResumableFileDownload download = + ResumableFileDownload.builder() + .downloadFileRequest(d -> d.destination(Paths.get("test/request")) + .getObjectRequest(GET_OBJECT_REQUESTS.get("ALL_TYPES"))) + .bytesTransferred(1000L) + .fileLastModified(parseIso8601Date("2022-03-08T10:15:30Z")) + .totalSizeInBytes(5000L) + .s3ObjectLastModified(parseIso8601Date("2022-03-10T08:21:00Z")) + .completedParts(Arrays.asList(1, 2, 3)) + .build(); + byte[] serializedDownload = ResumableFileDownloadSerializer.toJson(download); + assertThat(new String(serializedDownload, StandardCharsets.UTF_8)) + .isEqualTo(SERIALIZED_DOWNLOAD_OBJECT_WITH_COMPLETED_PARTS); + + ResumableFileDownload deserializedDownload = ResumableFileDownloadSerializer.fromJson( + SERIALIZED_DOWNLOAD_OBJECT_WITH_COMPLETED_PARTS.getBytes(StandardCharsets.UTF_8)); + assertThat(deserializedDownload).isEqualTo(download); + + } + public static Collection downloadObjects() { return Stream.of(differentDownloadSettings(), differentGetObjects()) @@ -159,7 +181,8 @@ private static List differentDownloadSettings() { resumableFileDownload(1000L, null, null, null, request), resumableFileDownload(1000L, null, DATE1, null, request), resumableFileDownload(1000L, 2000L, DATE1, DATE2, request), - resumableFileDownload(Long.MAX_VALUE, Long.MAX_VALUE, DATE1, DATE2, request) + resumableFileDownload(Long.MAX_VALUE, Long.MAX_VALUE, DATE1, DATE2, request), + resumableFileDownload(1000L, 2000L, DATE1, DATE2, request, Arrays.asList(1, 2, 3)) ); } @@ -182,6 +205,16 @@ private static ResumableFileDownload resumableFileDownload(Long bytesTransferred } return builder.build(); } + private static ResumableFileDownload resumableFileDownload(Long bytesTransferred, + Long totalSizeInBytes, + Instant fileLastModified, + Instant s3ObjectLastModified, + DownloadFileRequest request, + List completedParts) { + ResumableFileDownload dl = resumableFileDownload( + bytesTransferred, totalSizeInBytes, fileLastModified, s3ObjectLastModified, request); + return dl.copy(b -> b.completedParts(completedParts)); + } private static DownloadFileRequest downloadRequest(Path path, GetObjectRequest request) { return DownloadFileRequest.builder() @@ -196,5 +229,16 @@ private static DownloadFileRequest downloadRequest(Path path, GetObjectRequest r + "\"getObjectRequest\":{\"Bucket\":\"BUCKET\"," + "\"If-Modified-Since\":1577880630.000,\"Key\":\"KEY\"," + "\"x-amz-request-payer\":\"requester\",\"partNumber\":1," - + "\"x-amz-checksum-mode\":\"ENABLED\"}}}"; + + "\"x-amz-checksum-mode\":\"ENABLED\"}},\"completedParts\":[]}"; + + private static final String SERIALIZED_DOWNLOAD_OBJECT_WITH_COMPLETED_PARTS = + "{\"bytesTransferred\":1000," + + "\"fileLastModified\":1646734530.000," + + "\"totalSizeInBytes\":5000,\"s3ObjectLastModified\":1646900460" + + ".000,\"downloadFileRequest\":{\"destination\":\"test/request\"," + + "\"getObjectRequest\":{\"Bucket\":\"BUCKET\"," + + "\"If-Modified-Since\":1577880630.000,\"Key\":\"KEY\"," + + "\"x-amz-request-payer\":\"requester\",\"partNumber\":1," + + "\"x-amz-checksum-mode\":\"ENABLED\"}},\"completedParts\":[1,2,3]}"; + } diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/serialization/TransferManagerJsonUnmarshallerTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/serialization/TransferManagerJsonUnmarshallerTest.java index f49c4a649df7..8c2fc95d02db 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/serialization/TransferManagerJsonUnmarshallerTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/serialization/TransferManagerJsonUnmarshallerTest.java @@ -19,12 +19,14 @@ import java.math.BigDecimal; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.internal.ArrayJsonNode; import software.amazon.awssdk.protocols.jsoncore.internal.NullJsonNode; import software.amazon.awssdk.protocols.jsoncore.internal.NumberJsonNode; import software.amazon.awssdk.protocols.jsoncore.internal.StringJsonNode; @@ -57,7 +59,12 @@ private static Stream unmarshallingValues() { Arguments.of(new StringJsonNode(BinaryUtils.toBase64(SdkBytes.fromString("100", StandardCharsets.UTF_8) .asByteArray())), SdkBytes.fromString("100", StandardCharsets.UTF_8), - TransferManagerJsonUnmarshaller.SDK_BYTES) + TransferManagerJsonUnmarshaller.SDK_BYTES), + Arguments.of(new ArrayJsonNode(Arrays.asList(new NumberJsonNode("1"), + new NumberJsonNode("2"), + new NumberJsonNode("3"))), + Arrays.asList(1, 2, 3), + TransferManagerJsonUnmarshaller.LIST_INT) ); } diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/model/ResumableFileDownloadTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/model/ResumableFileDownloadTest.java index c818cd550a71..2d23ea181430 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/model/ResumableFileDownloadTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/model/ResumableFileDownloadTest.java @@ -26,13 +26,13 @@ import java.nio.file.Files; import java.nio.file.Path; import java.time.Instant; +import java.util.Arrays; import nl.jqno.equalsverifier.EqualsVerifier; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.testutils.RandomTempFile; -import software.amazon.awssdk.transfer.s3.model.ResumableFileDownload; class ResumableFileDownloadTest { @@ -130,6 +130,7 @@ private static ResumableFileDownload resumableFileDownload() { .fileLastModified(DATE1) .s3ObjectLastModified(DATE2) .totalSizeInBytes(2000L) + .completedParts(Arrays.asList(1, 2, 5)) .build(); } } diff --git a/services-custom/s3-transfer-manager/src/test/resources/log4j2.properties b/services-custom/s3-transfer-manager/src/test/resources/log4j2.properties index 827f0c09a093..85978ec46781 100644 --- a/services-custom/s3-transfer-manager/src/test/resources/log4j2.properties +++ b/services-custom/s3-transfer-manager/src/test/resources/log4j2.properties @@ -23,6 +23,12 @@ appender.console.layout.pattern = %d{HH:mm:ss.SSS} [%t] %-5level %logger{36} - % rootLogger.level = info rootLogger.appenderRef.stdout.ref = ConsoleAppender +logger.split.name = software.amazon.awssdk.core.internal.async +logger.split.level = trace + +logger.multi.name = software.amazon.awssdk.services.s3.internal.multipart +logger.multi.level = trace + # Uncomment below to enable more specific logging # #logger.sdk.name = software.amazon.awssdk diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java new file mode 100644 index 000000000000..9dd72fa5dfab --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java @@ -0,0 +1,78 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SplittingTransformerConfiguration; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.ChecksumMode; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.utils.Logger; + +@SdkInternalApi +public class DownloadObjectHelper { + private static final Logger log = Logger.loggerFor(DownloadObjectHelper.class); + + private final S3AsyncClient s3AsyncClient; + private final long bufferSizeInBytes; + + public DownloadObjectHelper(S3AsyncClient s3AsyncClient, long bufferSizeInBytes) { + this.s3AsyncClient = s3AsyncClient; + this.bufferSizeInBytes = bufferSizeInBytes; + } + + public CompletableFuture downloadObject( + GetObjectRequest getObjectRequest, AsyncResponseTransformer asyncResponseTransformer) { + if (getObjectRequest.range() != null || getObjectRequest.partNumber() != null) { + logSinglePartMessage(getObjectRequest); + return s3AsyncClient.getObject(getObjectRequest, asyncResponseTransformer); + } + GetObjectRequest requestToPerform = getObjectRequest.toBuilder().checksumMode(ChecksumMode.ENABLED).build(); + AsyncResponseTransformer.SplitResult split = + asyncResponseTransformer.split(SplittingTransformerConfiguration.builder() + .bufferSizeInBytes(bufferSizeInBytes) + .build()); + MultipartDownloaderSubscriber subscriber = subscriber(requestToPerform); + split.publisher().subscribe(subscriber); + return split.resultFuture(); + } + + private MultipartDownloaderSubscriber subscriber(GetObjectRequest getObjectRequest) { + Optional multipartDownloadContext = + MultipartDownloadUtils.multipartDownloadResumeContext(getObjectRequest); + return multipartDownloadContext + .map(ctx -> new MultipartDownloaderSubscriber(s3AsyncClient, getObjectRequest, ctx.highestSequentialCompletedPart())) + .orElseGet(() -> new MultipartDownloaderSubscriber(s3AsyncClient, getObjectRequest)); + } + + private void logSinglePartMessage(GetObjectRequest getObjectRequest) { + log.debug(() -> { + String reason = ""; + if (getObjectRequest.range() != null) { + reason = " because getObjectRequest range is included in the request." + + " range = " + getObjectRequest.range(); + } else if (getObjectRequest.partNumber() != null) { + reason = " because getObjectRequest part number is included in the request." + + " part number = " + getObjectRequest.partNumber(); + } + return "Using single part download" + reason; + }); + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadResumeContext.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadResumeContext.java new file mode 100644 index 000000000000..0d525796e364 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadResumeContext.java @@ -0,0 +1,143 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.SortedSet; +import java.util.TreeSet; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.utils.ToString; +import software.amazon.awssdk.utils.Validate; + +/** + * This class keep tracks of the state of a multipart download across multipart GET requests. + */ +@SdkInternalApi +public class MultipartDownloadResumeContext { + + /** + * Keeps track of complete parts in a list sorted in ascending order + */ + private final SortedSet completedParts; + + /** + * Keep track of the byte index to the last byte of the last completed part + */ + private Long bytesToLastCompletedParts; + + /** + * The total number of parts of the multipart download. + */ + private Integer totalParts; + + /** + * The GetObjectResponse to return to the user. + */ + private GetObjectResponse response; + + public MultipartDownloadResumeContext() { + this(new TreeSet<>(), 0L); + } + + public MultipartDownloadResumeContext(Collection completedParts, Long bytesToLastCompletedParts) { + this.completedParts = new TreeSet<>(Validate.notNull( + completedParts, "completedParts must not be null")); + this.bytesToLastCompletedParts = Validate.notNull( + bytesToLastCompletedParts, "bytesToLastCompletedParts must not be null"); + } + + public List completedParts() { + return Arrays.asList(completedParts.toArray(new Integer[0])); + } + + public Long bytesToLastCompletedParts() { + return bytesToLastCompletedParts; + } + + public void addCompletedPart(int partNumber) { + completedParts.add(partNumber); + } + + public void addToBytesToLastCompletedParts(long bytes) { + bytesToLastCompletedParts += bytes; + } + + public void totalParts(int totalParts) { + this.totalParts = totalParts; + } + + public Integer totalParts() { + return totalParts; + } + + public GetObjectResponse response() { + return this.response; + } + + public void response(GetObjectResponse response) { + this.response = response; + } + + /** + * @return the highest sequentially completed part, 0 means no parts completed. Used for non-sequential operation when parts + * may have been completed in a non-sequential order. For example, if parts [1, 2, 3, 6, 7, 10] were completed, this + * method will return 3. + * + */ + public int highestSequentialCompletedPart() { + if (completedParts.isEmpty() || completedParts.first() != 1) { + return 0; + } + if (completedParts.size() == 1) { + return 1; + } + + // for sequential operation, make sure we don't skip any non-completed part by returning the + // highest sequentially completed part + int previous = completedParts.first(); + for (Integer i : completedParts) { + if (i - previous > 1) { + return previous; + } + previous = i; + } + return completedParts.last(); + } + + /** + * Check if the multipart download is complete or not by checking if the total amount of downloaded parts is equal to the + * total amount of parts. + * + * @return true if all parts were downloaded, false if not. + */ + public boolean isComplete() { + if (totalParts == null) { + return false; + } + return completedParts.size() == totalParts; + } + + @Override + public String toString() { + return ToString.builder("MultipartDownloadContext") + .add("completedParts", completedParts) + .add("bytesToLastCompletedParts", bytesToLastCompletedParts) + .build(); + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadUtils.java new file mode 100644 index 000000000000..807b6a8bbbc0 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadUtils.java @@ -0,0 +1,61 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.MULTIPART_DOWNLOAD_RESUME_CONTEXT; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; + +@SdkInternalApi +public final class MultipartDownloadUtils { + + private MultipartDownloadUtils() { + } + + /** + * This method checks the + * {@link software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute#MULTIPART_DOWNLOAD_RESUME_CONTEXT} + * execution attributes for a context object and returns the complete parts associated with it, or an empty list of no + * context is found. + * + * @param request + * @return The list of completed parts for a GetObjectRequest, or an empty list if none were found. + */ + public static List completedParts(GetObjectRequest request) { + return multipartDownloadResumeContext(request) + .map(MultipartDownloadResumeContext::completedParts) + .orElseGet(Collections::emptyList); + } + + /** + * This method checks the + * {@link software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute#MULTIPART_DOWNLOAD_RESUME_CONTEXT} + * execution attributes for a context object and returns it if it finds one. Otherwise, returns an empty Optional. + * + * @param request the request to look for execution attributes + * @return the MultipartDownloadResumeContext if one is found, otherwise an empty Optional. + */ + public static Optional multipartDownloadResumeContext(GetObjectRequest request) { + return request + .overrideConfiguration() + .flatMap(conf -> Optional.ofNullable(conf.executionAttributes().getAttribute(MULTIPART_DOWNLOAD_RESUME_CONTEXT))); + } + +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java new file mode 100644 index 000000000000..d369d0caff02 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java @@ -0,0 +1,189 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.utils.Logger; + +/** + * A subscriber implementation that will download all individual parts for a multipart get-object request. It receives the + * individual {@link AsyncResponseTransformer} which will be used to perform the individual part requests. This is a 'one-shot' + * class, it should NOT be reused for more than one multipart download + */ +@SdkInternalApi +public class MultipartDownloaderSubscriber implements Subscriber> { + private static final Logger log = Logger.loggerFor(MultipartDownloaderSubscriber.class); + + /** + * The s3 client used to make the individual part requests + */ + private final S3AsyncClient s3; + + /** + * The GetObjectRequest that was provided when calling s3.getObject(...). It is copied for each individual request, and the + * copy has the partNumber field updated as more parts are downloaded. + */ + private final GetObjectRequest getObjectRequest; + + /** + * This value indicates the total number of parts of the object to get. If null, it means we don't know the total amount of + * parts, either because we haven't received a response from s3 yet to set it, or the object to get is not multipart. + */ + private volatile Integer totalParts; + + /** + * The total number of completed parts. A part is considered complete once the completable future associated with its request + * completes successfully. + */ + private final AtomicInteger completedParts; + + /** + * The subscription received from the publisher this subscriber subscribes to. + */ + private Subscription subscription; + + /** + * This future will be completed once this subscriber reaches a terminal state, failed or successfully, and will be completed + * accordingly. + */ + private final CompletableFuture future = new CompletableFuture<>(); + + /** + * The etag of the object being downloaded. + */ + private volatile String eTag; + + /** + * The Subscription lock + */ + private final Object lock = new Object(); + + public MultipartDownloaderSubscriber(S3AsyncClient s3, GetObjectRequest getObjectRequest) { + this(s3, getObjectRequest, 0); + } + + public MultipartDownloaderSubscriber(S3AsyncClient s3, GetObjectRequest getObjectRequest, int completedParts) { + this.s3 = s3; + this.getObjectRequest = getObjectRequest; + this.completedParts = new AtomicInteger(completedParts); + } + + @Override + public void onSubscribe(Subscription s) { + if (this.subscription != null) { + s.cancel(); + return; + } + this.subscription = s; + this.subscription.request(1); + } + + @Override + public void onNext(AsyncResponseTransformer asyncResponseTransformer) { + if (asyncResponseTransformer == null) { + subscription.cancel(); + throw new NullPointerException("onNext must not be called with null asyncResponseTransformer"); + } + + int nextPartToGet = completedParts.get() + 1; + + synchronized (lock) { + if (totalParts != null && nextPartToGet > totalParts) { + log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts)); + subscription.cancel(); + return; + } + } + + GetObjectRequest actualRequest = nextRequest(nextPartToGet); + log.debug(() -> "Sending GetObjectRequest for next part with partNumber=" + nextPartToGet); + CompletableFuture getObjectFuture = s3.getObject(actualRequest, asyncResponseTransformer); + getObjectFuture.whenComplete((response, error) -> { + if (error != null) { + log.debug(() -> "Error encountered during GetObjectRequest with partNumber=" + nextPartToGet); + onError(error); + return; + } + requestMoreIfNeeded(response); + }); + } + + private void requestMoreIfNeeded(GetObjectResponse response) { + int totalComplete = completedParts.incrementAndGet(); + MultipartDownloadUtils.multipartDownloadResumeContext(getObjectRequest) + .ifPresent(ctx -> { + ctx.addCompletedPart(totalComplete); + ctx.addToBytesToLastCompletedParts(response.contentLength()); + if (ctx.response() == null) { + ctx.response(response); + } + }); + log.debug(() -> String.format("Completed part %d", totalComplete)); + + if (eTag == null) { + this.eTag = response.eTag(); + log.debug(() -> String.format("Multipart object ETag: %s", this.eTag)); + } + + Integer partCount = response.partsCount(); + if (partCount != null && totalParts == null) { + log.debug(() -> String.format("Total amount of parts of the object to download: %d", partCount)); + MultipartDownloadUtils.multipartDownloadResumeContext(getObjectRequest) + .ifPresent(ctx -> ctx.totalParts(partCount)); + totalParts = partCount; + } + + synchronized (lock) { + if (totalParts != null && totalParts > 1 && totalComplete < totalParts) { + subscription.request(1); + } else { + log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts)); + subscription.cancel(); + } + } + } + + @Override + public void onError(Throwable t) { + future.completeExceptionally(t); + } + + @Override + public void onComplete() { + future.complete(null); + } + + public CompletableFuture future() { + return this.future; + } + + private GetObjectRequest nextRequest(int nextPartToGet) { + return getObjectRequest.copy(req -> { + req.partNumber(nextPartToGet); + if (eTag != null) { + req.ifMatch(eTag); + } + }); + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java index 27738e4d3d9b..266eb081cf24 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java @@ -55,6 +55,7 @@ public final class MultipartS3AsyncClient extends DelegatingS3AsyncClient { private final UploadObjectHelper mpuHelper; private final CopyObjectHelper copyObjectHelper; + private final DownloadObjectHelper downloadObjectHelper; private MultipartS3AsyncClient(S3AsyncClient delegate, MultipartConfiguration multipartConfiguration) { super(delegate); @@ -63,8 +64,10 @@ private MultipartS3AsyncClient(S3AsyncClient delegate, MultipartConfiguration mu MultipartConfigurationResolver resolver = new MultipartConfigurationResolver(validConfiguration); long minPartSizeInBytes = resolver.minimalPartSizeInBytes(); long threshold = resolver.thresholdInBytes(); + long apiCallBufferSize = resolver.apiCallBufferSize(); mpuHelper = new UploadObjectHelper(delegate, resolver); copyObjectHelper = new CopyObjectHelper(delegate, minPartSizeInBytes, threshold); + downloadObjectHelper = new DownloadObjectHelper(delegate, apiCallBufferSize); } @Override @@ -111,10 +114,7 @@ public CompletableFuture copyObject(CopyObjectRequest copyOb @Override public CompletableFuture getObject( GetObjectRequest getObjectRequest, AsyncResponseTransformer asyncResponseTransformer) { - // TODO uncomment once implemented - // getObjectRequest = getObjectRequest.toBuilder().checksumMode(ChecksumMode.ENABLED).build(); - throw new UnsupportedOperationException( - "Multipart download is not yet supported. Instead use the CRT based S3 client for multipart download."); + return downloadObjectHelper.downloadObject(getObjectRequest, asyncResponseTransformer); } @Override diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java index be2500703e15..9b074d6244f3 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java @@ -15,26 +15,28 @@ package software.amazon.awssdk.services.s3.multipart; -import java.util.function.Consumer; import software.amazon.awssdk.annotations.SdkPublicApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; import software.amazon.awssdk.services.s3.model.CopyObjectRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.utils.builder.CopyableBuilder; import software.amazon.awssdk.utils.builder.ToCopyableBuilder; /** * Class that hold configuration properties related to multipart operation for a {@link S3AsyncClient}. Passing this class to the * {@link S3AsyncClientBuilder#multipartConfiguration(MultipartConfiguration)} will enable automatic conversion of - * {@link S3AsyncClient#putObject(Consumer, AsyncRequestBody)}, {@link S3AsyncClient#copyObject(CopyObjectRequest)} to their - * respective multipart operation. + * {@link S3AsyncClient#getObject(GetObjectRequest, AsyncResponseTransformer)}, + * {@link S3AsyncClient#putObject(PutObjectRequest, AsyncRequestBody)} and + * {@link S3AsyncClient#copyObject(CopyObjectRequest)} to their respective multipart operation. *

- * Note: The multipart operation for {@link S3AsyncClient#getObject(GetObjectRequest, AsyncResponseTransformer)} is - * temporarily disabled and will result in throwing a {@link UnsupportedOperationException} if called when configured for - * multipart operation. + * Note that multipart download fetch individual part of the object using {@link GetObjectRequest#partNumber() part number}, this + * means it will only download multiple parts if the + * object itself was uploaded as a {@link S3AsyncClient#createMultipartUpload(CreateMultipartUploadRequest) multipart object} */ @SdkPublicApi public final class MultipartConfiguration implements ToCopyableBuilder { diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/S3MultipartExecutionAttribute.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/S3MultipartExecutionAttribute.java index 7bfacd9d387e..2aad60b253d2 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/S3MultipartExecutionAttribute.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/S3MultipartExecutionAttribute.java @@ -19,6 +19,7 @@ import software.amazon.awssdk.core.async.listener.PublisherListener; import software.amazon.awssdk.core.interceptor.ExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.services.s3.internal.multipart.MultipartDownloadResumeContext; @SdkProtectedApi public final class S3MultipartExecutionAttribute extends SdkExecutionAttribute { @@ -26,4 +27,6 @@ public final class S3MultipartExecutionAttribute extends SdkExecutionAttribute { public static final ExecutionAttribute PAUSE_OBSERVABLE = new ExecutionAttribute<>("PauseObservable"); public static final ExecutionAttribute> JAVA_PROGRESS_LISTENER = new ExecutionAttribute<>("JavaProgressListener"); + public static final ExecutionAttribute MULTIPART_DOWNLOAD_RESUME_CONTEXT = + new ExecutionAttribute<>("MultipartDownloadResumeContext"); } diff --git a/services/s3/src/main/resources/codegen-resources/customization.config b/services/s3/src/main/resources/codegen-resources/customization.config index 49aac349beda..015a928975ce 100644 --- a/services/s3/src/main/resources/codegen-resources/customization.config +++ b/services/s3/src/main/resources/codegen-resources/customization.config @@ -315,7 +315,7 @@ "multipartCustomization": { "multipartConfigurationClass": "software.amazon.awssdk.services.s3.multipart.MultipartConfiguration", "multipartConfigMethodDoc": "Configuration for multipart operation of this client.", - "multipartEnableMethodDoc": "Enables automatic conversion of PUT and COPY methods to their equivalent multipart operation. CRC32 checksum will be enabled for PUT, unless the checksum is specified or checksum validation is disabled.", + "multipartEnableMethodDoc": "Enables automatic conversion of GET, PUT and COPY methods to their equivalent multipart operation. CRC32 checksum will be enabled for PUT, unless the checksum is specified or checksum validation is disabled.", "contextParamEnabledKey": "S3AsyncClientDecorator.MULTIPART_ENABLED_KEY", "contextParamConfigKey": "S3AsyncClientDecorator.MULTIPART_CONFIGURATION_KEY" }, diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java index a4816c1b568b..ee98bbe9e4c8 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java @@ -143,4 +143,4 @@ private void verifyResumeToken(S3ResumeToken s3ResumeToken, int numExistingParts assertThat(s3ResumeToken.totalNumParts()).isEqualTo(TOTAL_NUM_PARTS); assertThat(s3ResumeToken.numPartsCompleted()).isEqualTo(numExistingParts); } -} \ No newline at end of file +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadTestUtil.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadTestUtil.java new file mode 100644 index 000000000000..708972b6b0d7 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadTestUtil.java @@ -0,0 +1,98 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.urlMatching; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; + +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import software.amazon.awssdk.services.s3.utils.AsyncResponseTransformerTestSupplier; + +public class MultipartDownloadTestUtil { + + private static final String RETRY_SCENARIO = "retry"; + private static final String SUCCESS_STATE = "success"; + private static final String FAILED_STATE = "failed"; + + private String testBucket; + private String testKey; + private String eTag; + private Random random = new Random(); + + public MultipartDownloadTestUtil(String testBucket, String testKey, String eTag) { + this.testBucket = testBucket; + this.testKey = testKey; + this.eTag = eTag; + } + + public static List> transformersSuppliers() { + return Arrays.asList( + new AsyncResponseTransformerTestSupplier.ByteTestArtSupplier(), + new AsyncResponseTransformerTestSupplier.InputStreamArtSupplier(), + new AsyncResponseTransformerTestSupplier.PublisherArtSupplier(), + new AsyncResponseTransformerTestSupplier.FileArtSupplier() + ); + } + + public byte[] stubAllParts(String testBucket, String testKey, int amountOfPartToTest, int partSize) { + byte[] expectedBody = new byte[amountOfPartToTest * partSize]; + for (int i = 0; i < amountOfPartToTest; i++) { + byte[] individualBody = stubForPart(testBucket, testKey, i + 1, amountOfPartToTest, partSize); + System.arraycopy(individualBody, 0, expectedBody, i * partSize, individualBody.length); + } + return expectedBody; + } + + public byte[] stubForPart(String testBucket, String testKey,int part, int totalPart, int partSize) { + byte[] body = new byte[partSize]; + random.nextBytes(body); + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=%d", testBucket, testKey, part))).willReturn( + aResponse() + .withHeader("x-amz-mp-parts-count", totalPart + "") + .withHeader("ETag", eTag) + .withBody(body))); + return body; + } + + public void verifyCorrectAmountOfRequestsMade(int amountOfPartToTest) { + String urlTemplate = ".*partNumber=%d.*"; + for (int i = 1; i <= amountOfPartToTest; i++) { + verify(getRequestedFor(urlMatching(String.format(urlTemplate, i)))); + } + verify(0, getRequestedFor(urlMatching(String.format(urlTemplate, amountOfPartToTest + 1)))); + } + + public byte[] stubForPartSuccess(int part, int totalPart, int partSize) { + byte[] body = new byte[partSize]; + random.nextBytes(body); + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=%d", testBucket, testKey, part))) + .inScenario(RETRY_SCENARIO) + .whenScenarioStateIs(SUCCESS_STATE) + .willReturn( + aResponse() + .withHeader("x-amz-mp-parts-count", totalPart + "") + .withHeader("ETag", eTag) + .withBody(body))); + return body; + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadUtilsTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadUtilsTest.java new file mode 100644 index 000000000000..5fce3aff2144 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloadUtilsTest.java @@ -0,0 +1,63 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.assertj.core.api.Assertions.assertThat; +import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.MULTIPART_DOWNLOAD_RESUME_CONTEXT; + +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; + +class MultipartDownloadUtilsTest { + + @Test + void noContext_completedPartShouldBeEmpty() { + GetObjectRequest req = GetObjectRequest.builder().build(); + assertThat(MultipartDownloadUtils.completedParts(req)).isEmpty(); + } + + @Test + void noContext_contextShouldBeEmpty() { + GetObjectRequest req = GetObjectRequest.builder().build(); + assertThat(MultipartDownloadUtils.multipartDownloadResumeContext(req)).isEmpty(); + } + + @Test + void contextWithParts_completedPartsShouldReturnListOfParts() { + MultipartDownloadResumeContext ctx = new MultipartDownloadResumeContext(); + ctx.addCompletedPart(1); + ctx.addCompletedPart(2); + ctx.addCompletedPart(3); + GetObjectRequest req = GetObjectRequest + .builder() + .overrideConfiguration(conf -> conf.putExecutionAttribute(MULTIPART_DOWNLOAD_RESUME_CONTEXT, ctx)) + .build(); + + assertThat(MultipartDownloadUtils.completedParts(req)).containsExactly(1, 2, 3); + } + + @Test + void contextWithParts_contextShouldBePresent() { + MultipartDownloadResumeContext ctx = new MultipartDownloadResumeContext(); + GetObjectRequest req = GetObjectRequest + .builder() + .overrideConfiguration(conf -> conf.putExecutionAttribute(MULTIPART_DOWNLOAD_RESUME_CONTEXT, ctx)) + .build(); + + assertThat(MultipartDownloadUtils.multipartDownloadResumeContext(req)).isPresent(); + } + +} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberTckTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberTckTest.java new file mode 100644 index 000000000000..7fab37f7d71a --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberTckTest.java @@ -0,0 +1,118 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.nio.ByteBuffer; +import java.util.concurrent.CompletableFuture; +import org.mockito.Mockito; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.reactivestreams.tck.SubscriberWhiteboxVerification; +import org.reactivestreams.tck.TestEnvironment; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; + +public class MultipartDownloaderSubscriberTckTest + extends SubscriberWhiteboxVerification> { + private S3AsyncClient s3mock; + + public MultipartDownloaderSubscriberTckTest() { + super(new TestEnvironment()); + this.s3mock = Mockito.mock(S3AsyncClient.class); + } + + @Override + public Subscriber> + createSubscriber(WhiteboxSubscriberProbe> probe) { + CompletableFuture responseFuture = + CompletableFuture.completedFuture(GetObjectResponse.builder().partsCount(4).build()); + when(s3mock.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))).thenReturn(responseFuture); + return new MultipartDownloaderSubscriber(s3mock, GetObjectRequest.builder() + .bucket("test-bucket-unused") + .key("test-key-unused") + .build()) { + @Override + public void onError(Throwable throwable) { + super.onError(throwable); + probe.registerOnError(throwable); + } + + @Override + public void onSubscribe(Subscription subscription) { + super.onSubscribe(subscription); + probe.registerOnSubscribe(new SubscriberPuppet() { + @Override + public void triggerRequest(long elements) { + subscription.request(elements); + } + + @Override + public void signalCancel() { + subscription.cancel(); + } + }); + } + + @Override + public void onNext(AsyncResponseTransformer item) { + super.onNext(item); + probe.registerOnNext(item); + } + + @Override + public void onComplete() { + super.onComplete(); + probe.registerOnComplete(); + } + }; + } + + @Override + public AsyncResponseTransformer createElement(int element) { + return new TestAsyncResponseTransformer(); + } + + private static class TestAsyncResponseTransformer implements AsyncResponseTransformer { + private CompletableFuture future; + + @Override + public CompletableFuture prepare() { + this.future = new CompletableFuture<>(); + return this.future; + } + + @Override + public void onResponse(GetObjectResponse response) { + this.future.complete(response); + } + + @Override + public void onStream(SdkPublisher publisher) { + // do nothing, test + } + + @Override + public void exceptionOccurred(Throwable error) { + future.completeExceptionally(error); + } + } +} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberWiremockTest.java new file mode 100644 index 000000000000..1c6eb666a9c2 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberWiremockTest.java @@ -0,0 +1,182 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import static software.amazon.awssdk.services.s3.internal.multipart.MultipartDownloadTestUtil.transformersSuppliers; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.net.URI; +import java.util.Arrays; +import java.util.List; +import java.util.UUID; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.SplittingTransformerConfiguration; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3Configuration; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.utils.AsyncResponseTransformerTestSupplier; +import software.amazon.awssdk.utils.Pair; + +@WireMockTest +class MultipartDownloaderSubscriberWiremockTest { + + private final String testBucket = "test-bucket"; + private final String testKey = "test-key"; + + private S3AsyncClient s3AsyncClient; + private MultipartDownloadTestUtil util; + + @BeforeEach + public void init(WireMockRuntimeInfo wiremock) { + s3AsyncClient = S3AsyncClient.builder() + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create("key", "secret"))) + .region(Region.US_WEST_2) + .endpointOverride(URI.create("http://localhost:" + wiremock.getHttpPort())) + .serviceConfiguration(S3Configuration.builder() + .pathStyleAccessEnabled(true) + .build()) + .build(); + util = new MultipartDownloadTestUtil(testBucket, testKey, UUID.randomUUID().toString()); + } + + @ParameterizedTest + @MethodSource("argumentsProvider") + void happyPath_shouldReceiveAllBodyPartInCorrectOrder(AsyncResponseTransformerTestSupplier supplier, + int amountOfPartToTest, + int partSize) { + byte[] expectedBody = util.stubAllParts(testBucket, testKey, amountOfPartToTest, partSize); + AsyncResponseTransformer transformer = supplier.transformer(); + AsyncResponseTransformer.SplitResult split = transformer.split( + SplittingTransformerConfiguration.builder() + .bufferSizeInBytes(1024 * 32L) + .build()); + Subscriber> subscriber = new MultipartDownloaderSubscriber( + s3AsyncClient, + GetObjectRequest.builder() + .bucket(testBucket) + .key(testKey) + .build()); + + split.publisher().subscribe(subscriber); + T response = split.resultFuture().join(); + + byte[] body = supplier.body(response); + assertArrayEquals(expectedBody, body); + util.verifyCorrectAmountOfRequestsMade(amountOfPartToTest); + } + + @ParameterizedTest + @MethodSource("argumentsProvider") + void errorOnFirstRequest_shouldCompleteExceptionally(AsyncResponseTransformerTestSupplier supplier, + int amountOfPartToTest, + int partSize) { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey))).willReturn( + aResponse() + .withStatus(400) + .withBody("400test error message"))); + AsyncResponseTransformer transformer = supplier.transformer(); + AsyncResponseTransformer.SplitResult split = transformer.split( + SplittingTransformerConfiguration.builder() + .bufferSizeInBytes(1024 * 32L) + .build()); + Subscriber> subscriber = new MultipartDownloaderSubscriber( + s3AsyncClient, + GetObjectRequest.builder() + .bucket(testBucket) + .key(testKey) + .build()); + + split.publisher().subscribe(subscriber); + assertThatThrownBy(() -> split.resultFuture().join()) + .hasMessageContaining("test error message"); + } + + @ParameterizedTest + @MethodSource("argumentsProvider") + void errorOnThirdRequest_shouldCompleteExceptionallyOnlyPartsGreaterThanTwo( + AsyncResponseTransformerTestSupplier supplier, + int amountOfPartToTest, + int partSize) { + util.stubForPart(testBucket, testKey, 1, 3, partSize); + util.stubForPart(testBucket, testKey, 2, 3, partSize); + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=3", testBucket, testKey))).willReturn( + aResponse() + .withStatus(400) + .withBody("400test error message"))); + AsyncResponseTransformer transformer = supplier.transformer(); + AsyncResponseTransformer.SplitResult split = transformer.split( + SplittingTransformerConfiguration.builder() + .bufferSizeInBytes(1024 * 32L) + .build()); + Subscriber> subscriber = new MultipartDownloaderSubscriber( + s3AsyncClient, + GetObjectRequest.builder() + .bucket(testBucket) + .key(testKey) + .build()); + + if (partSize > 1) { + split.publisher().subscribe(subscriber); + assertThatThrownBy(() -> { + T res = split.resultFuture().join(); + supplier.body(res); + }).hasMessageContaining("test error message"); + } else { + T res = split.resultFuture().join(); + assertNotNull(supplier.body(res)); + } + } + + private static Stream argumentsProvider() { + // amount of part, individual part size + List> partSizes = Arrays.asList( + Pair.of(4, 16), + Pair.of(1, 1024), + Pair.of(31, 1243), + Pair.of(16, 16 * 1024), + Pair.of(1, 1024 * 1024), + Pair.of(4, 1024 * 1024), + Pair.of(1, 4 * 1024 * 1024), + Pair.of(4, 6 * 1024 * 1024), + Pair.of(7, 5 * 3752) + ); + + Stream.Builder sb = Stream.builder(); + transformersSuppliers().forEach(tr -> partSizes.forEach(p -> sb.accept(arguments(tr, p.left(), p.right())))); + return sb.build(); + } + +} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClientTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClientTest.java new file mode 100644 index 000000000000..9d80be02bb71 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClientTest.java @@ -0,0 +1,68 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.SplittingTransformerConfiguration; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration; + +class MultipartS3AsyncClientTest { + + @Test + void byteRangeManuallySpecified_shouldBypassMultipart() { + S3AsyncClient mockDelegate = mock(S3AsyncClient.class); + AsyncResponseTransformer> mockTransformer = + mock(AsyncResponseTransformer.class); + GetObjectRequest req = GetObjectRequest.builder() + .bucket("test-bucket") + .key("test-key") + .range("Range: bytes 0-499/1234") + .build(); + S3AsyncClient s3AsyncClient = MultipartS3AsyncClient.create(mockDelegate, MultipartConfiguration.builder().build()); + s3AsyncClient.getObject(req, mockTransformer); + verify(mockTransformer, never()).split(any(SplittingTransformerConfiguration.class)); + verify(mockDelegate, times(1)).getObject(any(GetObjectRequest.class), eq(mockTransformer)); + } + + @Test + void partManuallySpecified_shouldBypassMultipart() { + S3AsyncClient mockDelegate = mock(S3AsyncClient.class); + AsyncResponseTransformer> mockTransformer = + mock(AsyncResponseTransformer.class); + GetObjectRequest req = GetObjectRequest.builder() + .bucket("test-bucket") + .key("test-key") + .partNumber(1) + .build(); + S3AsyncClient s3AsyncClient = MultipartS3AsyncClient.create(mockDelegate, MultipartConfiguration.builder().build()); + s3AsyncClient.getObject(req, mockTransformer); + verify(mockTransformer, never()).split(any(SplittingTransformerConfiguration.class)); + verify(mockDelegate, times(1)).getObject(any(GetObjectRequest.class), eq(mockTransformer)); + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/multipart/MultipartDownloadResumeContextTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/multipart/MultipartDownloadResumeContextTest.java new file mode 100644 index 000000000000..585c39f8f4a1 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/multipart/MultipartDownloadResumeContextTest.java @@ -0,0 +1,59 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.multipart; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.services.s3.internal.multipart.MultipartDownloadResumeContext; + +class MultipartDownloadResumeContextTest { + + @ParameterizedTest + @MethodSource("source") + void highest(List completedParts, int expectedNextNonCompleted) { + MultipartDownloadResumeContext context = new MultipartDownloadResumeContext(); + completedParts.forEach(context::addCompletedPart); + assertThat(context.highestSequentialCompletedPart()).isEqualTo(expectedNextNonCompleted); + } + + private static Stream source() { + return Stream.of( + Arguments.of(Arrays.asList(), 0), + Arguments.of(Arrays.asList(0), 0), + Arguments.of(Arrays.asList(1), 1), + Arguments.of(Arrays.asList(1, 2), 2), + Arguments.of(Arrays.asList(1, 2, 3), 3), + Arguments.of(Arrays.asList(1, 2, 3, 4), 4), + Arguments.of(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 10), + Arguments.of(Arrays.asList(1, 3, 4, 5), 1), + Arguments.of(Arrays.asList(1, 2, 4, 5), 2), + Arguments.of(Arrays.asList(1, 2, 3, 5), 3), + Arguments.of(Arrays.asList(1, 3, 5), 1), + Arguments.of(Arrays.asList(1, 4, 5), 1), + Arguments.of(Arrays.asList(1, 5), 1), + Arguments.of(Arrays.asList(1, 2, 3, 4, 6, 8, 9), 4), + Arguments.of(Arrays.asList(2, 4, 6), 0), + Arguments.of(Arrays.asList(2, 3, 5), 0) + ); + } + +} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/utils/AsyncResponseTransformerTestSupplier.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/utils/AsyncResponseTransformerTestSupplier.java new file mode 100644 index 000000000000..c87e40ad07f3 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/utils/AsyncResponseTransformerTestSupplier.java @@ -0,0 +1,215 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.utils; + +import static org.junit.jupiter.api.Assertions.fail; + +import com.google.common.jimfs.Jimfs; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.file.FileSystem; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.ResponsePublisher; +import software.amazon.awssdk.core.internal.async.FileAsyncResponseTransformer; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.utils.IoUtils; + +/** + * Contains the {@link AsyncResponseTransformer} to be used in a test as well as logic on how to retrieve the body content of the + * request for that specific transformer. + * + * @param the type returned of the future associated with the {@link AsyncResponseTransformer} + */ +public interface AsyncResponseTransformerTestSupplier { + + class ByteTestArtSupplier implements AsyncResponseTransformerTestSupplier> { + + @Override + public byte[] body(ResponseBytes response) { + return response.asByteArray(); + } + + @Override + public AsyncResponseTransformer> transformer() { + return AsyncResponseTransformer.toBytes(); + } + + @Override + public String toString() { + return "AsyncResponseTransformer.toBytes"; + } + } + + class InputStreamArtSupplier implements AsyncResponseTransformerTestSupplier> { + + @Override + public byte[] body(ResponseInputStream response) { + try { + return IoUtils.toByteArray(response); + } catch (IOException ioe) { + throw new UncheckedIOException(ioe); + } + } + + @Override + public AsyncResponseTransformer> transformer() { + return AsyncResponseTransformer.toBlockingInputStream(); + } + + @Override + public String toString() { + return "AsyncResponseTransformer.toBlockingInputStream"; + } + } + + class FileArtSupplier implements AsyncResponseTransformerTestSupplier { + + private Path path; + + @Override + public byte[] body(GetObjectResponse response) { + try { + return Files.readAllBytes(path); + } catch (IOException ioe) { + fail("unexpected IOE during test", ioe); + return new byte[0]; + } + } + + @Override + public AsyncResponseTransformer transformer() { + FileSystem jimfs = Jimfs.newFileSystem(); + String filePath = "/tmp-file-" + UUID.randomUUID(); + this.path = jimfs.getPath(filePath); + return AsyncResponseTransformer.toFile(this.path); + } + + @Override + public String toString() { + return "AsyncResponseTransformer.toFile"; + } + + @Override + public boolean requiresJimfs() { + return true; + } + } + + class PublisherArtSupplier implements AsyncResponseTransformerTestSupplier> { + + @Override + public byte[] body(ResponsePublisher response) { + List buffer = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(1); + AtomicReference error = new AtomicReference<>(); + response.subscribe(new Subscriber() { + Subscription s; + + @Override + public void onSubscribe(Subscription s) { + this.s = s; + s.request(1); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + while (byteBuffer.remaining() > 0) { + buffer.add(byteBuffer.get()); + } + s.request(1); + } + + @Override + public void onError(Throwable t) { + error.set(t); + latch.countDown(); + } + + @Override + public void onComplete() { + latch.countDown(); + } + }); + try { + latch.await(); + } catch (InterruptedException e) { + fail("Unexpected thread interruption during test", e); + } + if (error.get() != null) { + throw new RuntimeException(error.get()); + } + return unbox(buffer.toArray(new Byte[0])); + } + + private byte[] unbox(Byte[] arr) { + byte[] bb = new byte[arr.length]; + int i = 0; + for (Byte b : arr) { + bb[i] = b; + i++; + } + return bb; + } + + @Override + public AsyncResponseTransformer> transformer() { + return AsyncResponseTransformer.toPublisher(); + } + + @Override + public String toString() { + return "AsyncResponseTransformer.toPublisher"; + } + } + + /** + * Call this method to retrieve the AsyncResponseTransformer required to perform the test + * + * @return + */ + AsyncResponseTransformer transformer(); + + /** + * Implementation of this method whould retreive the whole body of the request made using the AsyncResponseTransformer as a + * byte array. + * + * @param response the response the {@link AsyncResponseTransformerTestSupplier#transformer} + * @return + */ + byte[] body(T response); + + /** + * Sonce {@link FileAsyncResponseTransformer} works with file, some test might need to initialize an in-memory + * {@link FileSystem} with jimfs. + * + * @return true if the test using this class requires setup with jimfs + */ + default boolean requiresJimfs() { + return false; + } +} diff --git a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/BaseJavaS3ClientBenchmark.java b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/BaseJavaS3ClientBenchmark.java index 8ba2ec156612..6e6becef158a 100644 --- a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/BaseJavaS3ClientBenchmark.java +++ b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/BaseJavaS3ClientBenchmark.java @@ -23,9 +23,6 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; -import software.amazon.awssdk.http.async.SdkAsyncHttpClient; -import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient; -import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.utils.Logger; @@ -59,33 +56,10 @@ protected BaseJavaS3ClientBenchmark(TransferManagerBenchmarkConfig config) { .multipartConfiguration(c -> c.minimumPartSizeInBytes(partSizeInMb * MB) .thresholdInBytes(partSizeInMb * 2 * MB) .apiCallBufferSizeInBytes(readBufferInMb * MB)) - .httpClientBuilder(httpClient(config)) + .httpClientBuilder(TransferManagerBenchmark.httpClient(config)) .build(); } - private SdkAsyncHttpClient.Builder httpClient(TransferManagerBenchmarkConfig config) { - if (config.forceCrtHttpClient()) { - logger.info(() -> "Using CRT HTTP client"); - AwsCrtAsyncHttpClient.Builder builder = AwsCrtAsyncHttpClient.builder(); - if (config.readBufferSizeInMb() != null) { - builder.readBufferSizeInBytes(config.readBufferSizeInMb() * MB); - } - if (config.maxConcurrency() != null) { - builder.maxConcurrency(config.maxConcurrency()); - } - return builder; - } - NettyNioAsyncHttpClient.Builder builder = NettyNioAsyncHttpClient.builder(); - if (config.connectionAcquisitionTimeoutInSec() != null) { - Duration connAcqTimeout = Duration.ofSeconds(config.connectionAcquisitionTimeoutInSec()); - builder.connectionAcquisitionTimeout(connAcqTimeout); - } - if (config.maxConcurrency() != null) { - builder.maxConcurrency(config.maxConcurrency()); - } - return builder; - } - protected abstract void sendOneRequest(List latencies) throws Exception; protected abstract long contentLength() throws Exception; diff --git a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/BaseTransferManagerBenchmark.java b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/BaseTransferManagerBenchmark.java index d40680bca028..8e11065b53de 100644 --- a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/BaseTransferManagerBenchmark.java +++ b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/BaseTransferManagerBenchmark.java @@ -60,19 +60,8 @@ public abstract class BaseTransferManagerBenchmark implements TransferManagerBen BaseTransferManagerBenchmark(TransferManagerBenchmarkConfig config) { logger.info(() -> "Benchmark config: " + config); - Long partSizeInMb = config.partSizeInMb() == null ? null : config.partSizeInMb() * MB; - Long readBufferSizeInMb = config.readBufferSizeInMb() == null ? null : config.readBufferSizeInMb() * MB; - S3CrtAsyncClientBuilder builder = S3CrtAsyncClient.builder() - .targetThroughputInGbps(config.targetThroughput()) - .minimumPartSizeInBytes(partSizeInMb) - .initialReadBufferSizeInBytes(readBufferSizeInMb) - .targetThroughputInGbps(config.targetThroughput() == null ? - Double.valueOf(100.0) : - config.targetThroughput()); - if (config.maxConcurrency() != null) { - builder.maxConcurrency(config.maxConcurrency()); - } - s3 = builder.build(); + + s3 = createS3AsyncClient(config); s3Sync = S3Client.builder().build(); transferManager = S3TransferManager.builder() .s3Client(s3) @@ -171,6 +160,38 @@ private void warmUpDownloadBatch() { CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); } + private S3AsyncClient createS3AsyncClient(TransferManagerBenchmarkConfig config) { + Long partSizeInMb = config.partSizeInMb() == null ? null : config.partSizeInMb() * MB; + Long readBufferSizeInMb = config.readBufferSizeInMb() == null ? null : config.readBufferSizeInMb() * MB; + switch (config.s3Client()) { + case CRT: { + logger.info(() -> "Using CRT S3 Async client"); + S3CrtAsyncClientBuilder builder = S3CrtAsyncClient.builder() + .targetThroughputInGbps(config.targetThroughput()) + .minimumPartSizeInBytes(partSizeInMb) + .initialReadBufferSizeInBytes(readBufferSizeInMb) + .targetThroughputInGbps(config.targetThroughput() == null ? + Double.valueOf(100.0) : + config.targetThroughput()); + if (config.maxConcurrency() != null) { + builder.maxConcurrency(config.maxConcurrency()); + } + return builder.build(); + } + case JAVA: { + logger.info(() -> "Using Java-based S3 Async client"); + return S3AsyncClient.builder() + .multipartEnabled(true) + .multipartConfiguration(c -> c.minimumPartSizeInBytes(partSizeInMb) + .apiCallBufferSizeInBytes(readBufferSizeInMb)) + .httpClientBuilder(TransferManagerBenchmark.httpClient(config)) + .build(); + } + default: + throw new IllegalArgumentException("base s3 client must be crt or java"); + } + } + private void warmUpUploadBatch() { List> futures = new ArrayList<>(); for (int i = 0; i < 20; i++) { diff --git a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/BenchmarkRunner.java b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/BenchmarkRunner.java index d83fc87026ae..ceb398c2e465 100644 --- a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/BenchmarkRunner.java +++ b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/BenchmarkRunner.java @@ -41,6 +41,8 @@ public final class BenchmarkRunner { private static final String READ_BUFFER_IN_MB = "readBufferInMB"; private static final String VERSION = "version"; + private static final String S3_CLIENT = "s3Client"; + private static final String PREFIX = "prefix"; private static final String TIMEOUT = "timeoutInMin"; @@ -90,6 +92,10 @@ public static void main(String... args) throws org.apache.commons.cli.ParseExcep options.addOption(null, READ_BUFFER_IN_MB, true, "Read buffer size in MB"); options.addOption(null, VERSION, true, "The major version of the transfer manager to run test: " + "v1 | v2 | crt | java, default: v2"); + options.addOption(null, S3_CLIENT, true, "For v2 transfer manager benchmarks, which base s3 client " + + "should be used: " + + "crt | java, default: crt"); + options.addOption(null, PREFIX, true, "S3 Prefix used in downloadDirectory and uploadDirectory"); options.addOption(null, CONTENT_LENGTH, true, "Content length to upload from memory. Used only in the " @@ -140,7 +146,11 @@ public static void main(String... args) throws org.apache.commons.cli.ParseExcep benchmark = new JavaS3ClientCopyBenchmark(config); break; } - throw new UnsupportedOperationException("Java based s3 client benchmark only support upload and copy"); + if (operation == TransferManagerOperation.DOWNLOAD) { + benchmark = new JavaS3ClientDownloadBenchmark(config); + break; + } + throw new UnsupportedOperationException(); default: throw new UnsupportedOperationException(); } @@ -151,6 +161,11 @@ private static TransferManagerBenchmarkConfig parseConfig(CommandLine cmd) { TransferManagerOperation operation = TransferManagerOperation.valueOf(cmd.getOptionValue(OPERATION) .toUpperCase(Locale.ENGLISH)); + TransferManagerBaseS3Client s3Client = cmd.getOptionValue(S3_CLIENT) == null + ? TransferManagerBaseS3Client.CRT + : TransferManagerBaseS3Client.valueOf(cmd.getOptionValue(S3_CLIENT) + .toUpperCase(Locale.ENGLISH)); + String filePath = cmd.getOptionValue(FILE); String bucket = cmd.getOptionValue(BUCKET); String key = cmd.getOptionValue(KEY); @@ -205,6 +220,7 @@ private static TransferManagerBenchmarkConfig parseConfig(CommandLine cmd) { .connectionAcquisitionTimeoutInSec(connAcqTimeoutInSec) .forceCrtHttpClient(forceCrtHttpClient) .maxConcurrency(maxConcurrency) + .s3Client(s3Client) .build(); } @@ -216,6 +232,11 @@ public enum TransferManagerOperation { UPLOAD_DIRECTORY } + public enum TransferManagerBaseS3Client { + CRT, + JAVA + } + private enum SdkVersion { V1, V2, diff --git a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/JavaS3ClientDownloadBenchmark.java b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/JavaS3ClientDownloadBenchmark.java new file mode 100644 index 000000000000..484bdd890be6 --- /dev/null +++ b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/JavaS3ClientDownloadBenchmark.java @@ -0,0 +1,62 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.s3benchmarks; + +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import software.amazon.awssdk.core.FileTransformerConfiguration; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.utils.Logger; + +public class JavaS3ClientDownloadBenchmark extends BaseJavaS3ClientBenchmark { + private static final Logger log = Logger.loggerFor(JavaS3ClientDownloadBenchmark.class); + private final String filePath; + + public JavaS3ClientDownloadBenchmark(TransferManagerBenchmarkConfig config) { + super(config); + this.filePath = config.filePath(); + } + + @Override + protected void sendOneRequest(List latencies) throws Exception { + Double latency; + if (filePath == null) { + log.info(() -> "Starting download to memory"); + latency = runWithTime(s3AsyncClient.getObject( + req -> req.key(key).bucket(bucket), new NoOpResponseTransformer<>() + )::join).latency(); + } else { + log.info(() -> "Starting download to file"); + Path path = Paths.get(filePath); + FileTransformerConfiguration conf = FileTransformerConfiguration + .builder() + .failureBehavior(FileTransformerConfiguration.FailureBehavior.LEAVE) + .fileWriteOption(FileTransformerConfiguration.FileWriteOption.CREATE_OR_REPLACE_EXISTING) + .build(); + + latency = runWithTime(s3AsyncClient.getObject( + req -> req.key(key).bucket(bucket), AsyncResponseTransformer.toFile(path, conf) + )::join).latency(); + } + latencies.add(latency); + } + + @Override + protected long contentLength() throws Exception { + return s3Client.headObject(b -> b.bucket(bucket).key(key)).contentLength(); + } +} diff --git a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/NoOpResponseTransformer.java b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/NoOpResponseTransformer.java index b6648405f254..6c2a460dbe7a 100644 --- a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/NoOpResponseTransformer.java +++ b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/NoOpResponseTransformer.java @@ -25,11 +25,11 @@ /** * A no-op {@link AsyncResponseTransformer} */ -public class NoOpResponseTransformer implements AsyncResponseTransformer { - private CompletableFuture future; +public class NoOpResponseTransformer implements AsyncResponseTransformer { + private CompletableFuture future; @Override - public CompletableFuture prepare() { + public CompletableFuture prepare() { future = new CompletableFuture<>(); return future; } @@ -50,10 +50,10 @@ public void exceptionOccurred(Throwable error) { } static class NoOpSubscriber implements Subscriber { - private final CompletableFuture future; + private final CompletableFuture future; private Subscription subscription; - NoOpSubscriber(CompletableFuture future) { + NoOpSubscriber(CompletableFuture future) { this.future = future; } @@ -75,7 +75,7 @@ public void onError(Throwable throwable) { @Override public void onComplete() { - future.complete(null); + future.complete(new Object()); } } diff --git a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/TransferManagerBenchmark.java b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/TransferManagerBenchmark.java index c182934f4e3f..099de3c182a2 100644 --- a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/TransferManagerBenchmark.java +++ b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/TransferManagerBenchmark.java @@ -15,19 +15,53 @@ package software.amazon.awssdk.s3benchmarks; +import static software.amazon.awssdk.transfer.s3.SizeConstant.MB; + +import java.time.Duration; import java.util.function.Supplier; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.utils.Logger; /** * Factory to create the benchmark */ @FunctionalInterface public interface TransferManagerBenchmark { + Logger logger = Logger.loggerFor(TransferManagerBenchmark.class); /** * The benchmark method to run */ void run(); + static > SdkAsyncHttpClient.Builder httpClient( + TransferManagerBenchmarkConfig config) { + if (config.forceCrtHttpClient()) { + logger.info(() -> "Using CRT HTTP client"); + AwsCrtAsyncHttpClient.Builder builder = AwsCrtAsyncHttpClient.builder(); + if (config.readBufferSizeInMb() != null) { + builder.readBufferSizeInBytes(config.readBufferSizeInMb() * MB); + } + if (config.maxConcurrency() != null) { + builder.maxConcurrency(config.maxConcurrency()); + } + return (T) builder; + } + NettyNioAsyncHttpClient.Builder builder = NettyNioAsyncHttpClient.builder(); + if (config.connectionAcquisitionTimeoutInSec() != null) { + Duration connAcqTimeout = Duration.ofSeconds(config.connectionAcquisitionTimeoutInSec()); + builder.connectionAcquisitionTimeout(connAcqTimeout); + } + if (config.maxConcurrency() != null) { + builder.maxConcurrency(config.maxConcurrency()); + } + return (T) builder; + } + + + static TransferManagerBenchmark v2Download(TransferManagerBenchmarkConfig config) { return new TransferManagerDownloadBenchmark(config); } diff --git a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/TransferManagerBenchmarkConfig.java b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/TransferManagerBenchmarkConfig.java index a3750f472498..f7c99a946be9 100644 --- a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/TransferManagerBenchmarkConfig.java +++ b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/TransferManagerBenchmarkConfig.java @@ -33,6 +33,7 @@ public final class TransferManagerBenchmarkConfig { private final Long connectionAcquisitionTimeoutInSec; private final Boolean forceCrtHttpClient; private final Integer maxConcurrency; + private final BenchmarkRunner.TransferManagerBaseS3Client s3Client; private final Long readBufferSizeInMb; private final BenchmarkRunner.TransferManagerOperation operation; @@ -55,6 +56,7 @@ private TransferManagerBenchmarkConfig(Builder builder) { this.connectionAcquisitionTimeoutInSec = builder.connectionAcquisitionTimeoutInSec; this.forceCrtHttpClient = builder.forceCrtHttpClient; this.maxConcurrency = builder.maxConcurrency; + this.s3Client = builder.s3Client; } public String filePath() { @@ -121,6 +123,10 @@ public Integer maxConcurrency() { return this.maxConcurrency; } + public BenchmarkRunner.TransferManagerBaseS3Client s3Client() { + return this.s3Client; + } + public static Builder builder() { return new Builder(); } @@ -160,6 +166,7 @@ static final class Builder { private Long connectionAcquisitionTimeoutInSec; private Boolean forceCrtHttpClient; private Integer maxConcurrency; + private BenchmarkRunner.TransferManagerBaseS3Client s3Client; private Integer iteration; private BenchmarkRunner.TransferManagerOperation operation; @@ -247,6 +254,11 @@ public Builder maxConcurrency(Integer maxConcurrency) { return this; } + public Builder s3Client(BenchmarkRunner.TransferManagerBaseS3Client s3Client) { + this.s3Client = s3Client; + return this; + } + public TransferManagerBenchmarkConfig build() { return new TransferManagerBenchmarkConfig(this); } diff --git a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/TransferManagerDownloadBenchmark.java b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/TransferManagerDownloadBenchmark.java index 485e5631b24b..925c399db2a8 100644 --- a/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/TransferManagerDownloadBenchmark.java +++ b/test/s3-benchmarks/src/main/java/software/amazon/awssdk/s3benchmarks/TransferManagerDownloadBenchmark.java @@ -98,7 +98,7 @@ private void downloadOnceToFile(List latencies) throws Exception { private void downloadOnceToMemory(List latencies) throws Exception { long start = System.currentTimeMillis(); - AsyncResponseTransformer responseTransformer = new NoOpResponseTransformer<>(); + AsyncResponseTransformer responseTransformer = new NoOpResponseTransformer<>(); transferManager.download(DownloadRequest.builder() .getObjectRequest(req -> req.bucket(bucket).key(key)) .responseTransformer(responseTransformer) diff --git a/test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/s3/S3AsyncBaseStabilityTest.java b/test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/s3/S3AsyncBaseStabilityTest.java index 20f27963bdad..3dc409b2cf9f 100644 --- a/test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/s3/S3AsyncBaseStabilityTest.java +++ b/test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/s3/S3AsyncBaseStabilityTest.java @@ -16,6 +16,7 @@ package software.amazon.awssdk.stability.tests.s3; import static org.assertj.core.api.Assertions.assertThat; +import static software.amazon.awssdk.stability.tests.utils.StabilityTestRunner.ALLOWED_MAX_PEAK_THREAD_COUNT; import java.io.File; import java.io.IOException; @@ -61,10 +62,18 @@ public abstract class S3AsyncBaseStabilityTest extends AwsTestBase { .build(); } + private final int allowedPeakThreads; + public S3AsyncBaseStabilityTest(S3AsyncClient testClient) { + this(testClient, ALLOWED_MAX_PEAK_THREAD_COUNT); + } + + public S3AsyncBaseStabilityTest(S3AsyncClient testClient, int maxThreadCount) { + this.allowedPeakThreads = maxThreadCount; this.testClient = testClient; } + @RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class) public void largeObject_put_get_usingFile() { String md5Upload = uploadLargeObjectFromFile(); @@ -87,7 +96,7 @@ protected String computeKeyName(int i) { protected void doGetBucketAcl_lowTpsLongInterval() { IntFunction> future = i -> testClient.getBucketAcl(b -> b.bucket(getTestBucketName())); String className = this.getClass().getSimpleName(); - StabilityTestRunner.newRunner() + StabilityTestRunner.newRunner(allowedPeakThreads) .testName(className + ".getBucketAcl_lowTpsLongInterval") .futureFactory(future) .requestCountPerRun(10) @@ -99,7 +108,7 @@ protected void doGetBucketAcl_lowTpsLongInterval() { protected String downloadLargeObjectToFile() { File randomTempFile = RandomTempFile.randomUncreatedFile(); - StabilityTestRunner.newRunner() + StabilityTestRunner.newRunner(allowedPeakThreads) .testName("S3AsyncStabilityTest.downloadLargeObjectToFile") .futures(testClient.getObject(b -> b.bucket(getTestBucketName()).key(LARGE_KEY_NAME), AsyncResponseTransformer.toFile(randomTempFile))) @@ -120,7 +129,7 @@ protected String uploadLargeObjectFromFile() { try { file = new RandomTempFile((long) 2e+9); String md5 = Md5Utils.md5AsBase64(file); - StabilityTestRunner.newRunner() + StabilityTestRunner.newRunner(allowedPeakThreads) .testName("S3AsyncStabilityTest.uploadLargeObjectFromFile") .futures(testClient.putObject(b -> b.bucket(getTestBucketName()).key(LARGE_KEY_NAME), AsyncRequestBody.fromFile(file))) @@ -144,7 +153,7 @@ protected void putObject() { AsyncRequestBody.fromBytes(bytes)); }; - StabilityTestRunner.newRunner() + StabilityTestRunner.newRunner(allowedPeakThreads) .testName("S3AsyncStabilityTest.putObject") .futureFactory(future) .requestCountPerRun(CONCURRENCY) @@ -160,7 +169,7 @@ protected void getObject() { return testClient.getObject(b -> b.bucket(getTestBucketName()).key(keyName), AsyncResponseTransformer.toFile(path)); }; - StabilityTestRunner.newRunner() + StabilityTestRunner.newRunner(allowedPeakThreads) .testName("S3AsyncStabilityTest.getObject") .futureFactory(future) .requestCountPerRun(CONCURRENCY) @@ -183,7 +192,7 @@ protected static void deleteBucketAndAllContents(S3AsyncClient client, String bu client.deleteBucket(DeleteBucketRequest.builder().bucket(bucketName).build()).join(); } catch (Exception e) { - log.error(() -> "Failed to delete bucket: " +bucketName); + log.error(() -> "Failed to delete bucket: " + bucketName, e); } } diff --git a/test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/s3/S3MultipartJavaBasedStabilityTest.java b/test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/s3/S3MultipartJavaBasedStabilityTest.java new file mode 100644 index 000000000000..6bda03809700 --- /dev/null +++ b/test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/s3/S3MultipartJavaBasedStabilityTest.java @@ -0,0 +1,63 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.stability.tests.s3; + +import java.time.Duration; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.services.s3.S3AsyncClient; + +public class S3MultipartJavaBasedStabilityTest extends S3AsyncBaseStabilityTest { + private static final String BUCKET_NAME = String.format("s3multipartjavabasedstabilitytest%d", System.currentTimeMillis()); + private static final S3AsyncClient multipartJavaBasedClient; + + static { + multipartJavaBasedClient = S3AsyncClient.builder() + .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN) + .httpClientBuilder(NettyNioAsyncHttpClient.builder() + .maxConcurrency(CONCURRENCY)) + .multipartEnabled(true) + .overrideConfiguration(b -> b.apiCallTimeout(Duration.ofMinutes(5)) + // Retry at test level + .retryPolicy(RetryPolicy.none())) + .build(); + } + + public S3MultipartJavaBasedStabilityTest() { + // S3 multipart client uses more threads because for large file uploads, it reads from different positions of the files + // at the same time, which will trigger more Java I/O threads to spin up + super(multipartJavaBasedClient, 250); + } + + @BeforeAll + public static void setup() { + multipartJavaBasedClient.createBucket(b -> b.bucket(BUCKET_NAME)).join(); + multipartJavaBasedClient.waiter().waitUntilBucketExists(b -> b.bucket(BUCKET_NAME)).join(); + } + + @AfterAll + public static void cleanup() { + deleteBucketAndAllContents(multipartJavaBasedClient, BUCKET_NAME); + multipartJavaBasedClient.close(); + } + + @Override + protected String getTestBucketName() { + return BUCKET_NAME; + } +} diff --git a/test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/utils/StabilityTestRunner.java b/test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/utils/StabilityTestRunner.java index eceda3abc1a5..3009a83d95ed 100644 --- a/test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/utils/StabilityTestRunner.java +++ b/test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/utils/StabilityTestRunner.java @@ -69,6 +69,7 @@ */ public class StabilityTestRunner { + public static final int ALLOWED_MAX_PEAK_THREAD_COUNT = 90; private static final Logger log = Logger.loggerFor(StabilityTestRunner.class); private static final double ALLOWED_FAILURE_RATIO = 0.05; private static final int TESTS_TIMEOUT_IN_MINUTES = 60; @@ -76,7 +77,7 @@ public class StabilityTestRunner { // because of the internal thread pool used in AsynchronousFileChannel // Also, synchronous clients have their own thread pools so this measurement needs to be mutable // so that the async and synchronous paths can both use this runner. - private int allowedPeakThreadCount = 90; + private int allowedPeakThreadCount = ALLOWED_MAX_PEAK_THREAD_COUNT;; private ThreadMXBean threadMXBean; private IntFunction> futureFactory; @@ -343,12 +344,12 @@ private void processResult(TestResult testResult) { } if (testResult.peakThreadCount() > allowedPeakThreadCount) { - String errorMessage = String.format("The number of peak thread exceeds the allowed peakThread threshold %s", - allowedPeakThreadCount); + String errorMessage = String.format("The number of peak thread %s exceeds the allowed peakThread threshold %s", + testResult.peakThreadCount(), allowedPeakThreadCount); threadDump(testResult.testName()); - throw new AssertionError(errorMessage); + log.warn(() -> errorMessage); } } diff --git a/utils/src/main/java/software/amazon/awssdk/utils/CompletableFutureUtils.java b/utils/src/main/java/software/amazon/awssdk/utils/CompletableFutureUtils.java index 334e76d96bed..a30527f7071f 100644 --- a/utils/src/main/java/software/amazon/awssdk/utils/CompletableFutureUtils.java +++ b/utils/src/main/java/software/amazon/awssdk/utils/CompletableFutureUtils.java @@ -66,7 +66,7 @@ public static CompletionException errorAsCompletionException(Throwable t) { /** * Forward the {@code Throwable} from {@code src} to {@code dst}. - + * * @param src The source of the {@code Throwable}. * @param dst The destination where the {@code Throwable} will be forwarded to. * @@ -147,8 +147,9 @@ public static CompletableFuture forwardResultTo(CompletableFuture src, } /** - * Completes the {@code dst} future based on the result of the {@code src} future, synchronously, - * after applying the provided transformation {@link Function} if successful. + * Completes the {@code dst} future based on the result of the {@code src} future, synchronously, after applying the provided + * transformation {@link Function} if successful. If the function threw an exception, the destination + * future will be completed exceptionally with that exception. * * @param src The source {@link CompletableFuture} * @param dst The destination where the {@code Throwable} or transformed result will be forwarded to. @@ -160,9 +161,16 @@ public static CompletableFuture forwardTransformedResu src.whenComplete((r, e) -> { if (e != null) { dst.completeExceptionally(e); - } else { - dst.complete(function.apply(r)); + return; + } + DestT result = null; + try { + result = function.apply(r); + } catch (Throwable functionException) { + dst.completeExceptionally(functionException); + return; } + dst.complete(result); }); return src; diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/BaseSubscriberAdapter.java b/utils/src/main/java/software/amazon/awssdk/utils/async/BaseSubscriberAdapter.java new file mode 100644 index 000000000000..ea8acfb9d594 --- /dev/null +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/BaseSubscriberAdapter.java @@ -0,0 +1,304 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.utils.Logger; + +/** + * Base of subscribers that can adapt one type to another. This subscriber will receive onNext signal with the {@code U} type, + * but will need to {@link BaseSubscriberAdapter#fulfillDownstreamDemand() fulfill the downstream demand} of the delegate + * subscriber with instance of the {@code T} type. + * + * @param the type that the delegate subscriber demands. + * @param the type sent by the publisher this subscriber is subscribed to. + */ +@SdkProtectedApi +public abstract class BaseSubscriberAdapter extends DelegatingSubscriber { + private static final Logger log = Logger.loggerFor(BaseSubscriberAdapter.class); + + /** + * The amount of unfulfilled demand open against the upstream subscriber. + */ + protected final AtomicLong upstreamDemand = new AtomicLong(0); + + /** + * The amount of unfulfilled demand the downstream subscriber has opened against us. + */ + protected final AtomicLong downstreamDemand = new AtomicLong(0); + + /** + * A flag that is used to ensure that only one thread is handling updates to the state of this subscriber at a time. This + * allows us to ensure that the downstream onNext, onComplete and onError are only ever invoked serially. + */ + protected final AtomicBoolean handlingStateUpdate = new AtomicBoolean(false); + + /** + * Whether the upstream subscriber has called onError on us. If this is null, we haven't gotten an onError. If it's non-null + * this will be the exception that the upstream passed to our onError. After we get an onError, we'll call onError on the + * downstream subscriber as soon as possible. + */ + protected final AtomicReference onErrorFromUpstream = new AtomicReference<>(null); + + /** + * Whether we have called onComplete or onNext on the downstream subscriber. + */ + protected volatile boolean terminalCallMadeDownstream = false; + + /** + * Whether the upstream subscriber has called onComplete on us. After this happens, we'll drain any outstanding items in the + * allItems queue and then call onComplete on the downstream subscriber. + */ + protected volatile boolean onCompleteCalledByUpstream = false; + + /** + * The subscription to the upstream subscriber. + */ + protected Subscription upstreamSubscription; + + protected BaseSubscriberAdapter(Subscriber subscriber) { + super(subscriber); + } + + /** + * This method is called inside the onNext signal. Implementation should do what is required to store the data + * before fulfilling the demand from the downstream subscriber. + * + * @param item the value with which onNext was called. + */ + abstract void doWithItem(T item); + + /** + * This method is called when demand from the downstream subscriber needs to be fulfilled. Called in a loop + * until {@code downstreamDemand} is no longer needed. Implementations are responsible for decrementing the {@code + * downstreamDemand} accordingly as demand gets fulfilled. + */ + protected abstract void fulfillDownstreamDemand(); + + @Override + public void onSubscribe(Subscription subscription) { + if (upstreamSubscription != null) { + log.warn(() -> "Received duplicate subscription, cancelling the duplicate.", new IllegalStateException()); + subscription.cancel(); + return; + } + + upstreamSubscription = subscription; + subscriber.onSubscribe(new Subscription() { + @Override + public void request(long l) { + addDownstreamDemand(l); + handleStateUpdate(); + } + + @Override + public void cancel() { + subscription.cancel(); + } + }); + } + + @Override + public void onNext(T item) { + try { + doWithItem(item); + } catch (RuntimeException e) { + upstreamSubscription.cancel(); + onError(e); + throw e; + } + + upstreamDemand.decrementAndGet(); + handleStateUpdate(); + } + + @Override + public void onError(Throwable throwable) { + onErrorFromUpstream.compareAndSet(null, throwable); + handleStateUpdate(); + } + + @Override + public void onComplete() { + onCompleteCalledByUpstream = true; + handleStateUpdate(); + } + + /** + * Increment the downstream demand by the provided value, accounting for overflow. + */ + private void addDownstreamDemand(long l) { + if (l > 0) { + downstreamDemand.getAndUpdate(current -> { + long newValue = current + l; + return newValue >= 0 ? newValue : Long.MAX_VALUE; + }); + } else { + log.error(() -> "Demand " + l + " must not be negative."); + upstreamSubscription.cancel(); + onError(new IllegalArgumentException("Demand must not be negative")); + } + } + + /** + * This is invoked after each downstream request or upstream onNext, onError or onComplete. + */ + protected void handleStateUpdate() { + do { + // Anything that happens after this if statement and before we set handlingStateUpdate to false is guaranteed to only + // happen on one thread. For that reason, we should only invoke onNext, onComplete or onError within that block. + if (!handlingStateUpdate.compareAndSet(false, true)) { + return; + } + + try { + // If we've already called onComplete or onError, don't do anything. + if (terminalCallMadeDownstream) { + return; + } + + // Call onNext, onComplete and onError as needed based on the current subscriber state. + handleOnNextState(); + handleUpstreamDemandState(); + handleOnCompleteState(); + handleOnErrorState(); + } catch (Error e) { + throw e; + } catch (Throwable e) { + log.error(() -> "Unexpected exception encountered that violates the reactive streams specification. Attempting " + + "to terminate gracefully.", e); + upstreamSubscription.cancel(); + onError(e); + } finally { + handlingStateUpdate.set(false); + } + + // It's possible we had an important state change between when we decided to release the state update flag, and we + // actually released it. If that seems to have happened, try to handle that state change on this thread, because + // another thread is not guaranteed to come around and do so. + } while (onNextNeeded() || upstreamDemandNeeded() || onCompleteNeeded() || onErrorNeeded()); + } + + /** + * Fulfill downstream demand by flushing + */ + private void handleOnNextState() { + while (onNextNeeded() && !onErrorNeeded()) { + fulfillDownstreamDemand(); + } + } + + /** + * Returns true if we need to call onNext downstream. If this is executed outside the handling-state-update condition, the + * result is subject to change. + */ + private boolean onNextNeeded() { + return downstreamDemand.get() > 0 && additionalOnNextNeededCheck(); + } + + + /** + * Can be overridden by subclasses to provide additional checks before calling onNext on downstream. + */ + boolean additionalOnNextNeededCheck() { + return true; + } + + /** + * Request more upstream demand if it's needed. + */ + private void handleUpstreamDemandState() { + if (upstreamDemandNeeded()) { + ensureUpstreamDemandExists(); + } + } + + /** + * Returns true if we need to increase our upstream demand. + */ + private boolean upstreamDemandNeeded() { + return upstreamDemand.get() <= 0 && downstreamDemand.get() > 0 && additionalUpstreamDemandNeededCheck(); + } + + /** + * Can be overridden by subclasses to provide additional checks to see if we need to increase our upstream demand. + */ + boolean additionalUpstreamDemandNeededCheck() { + return true; + } + + /** + * If there are zero pending items in the queue and the upstream has called onComplete, then tell the downstream we're done. + */ + private void handleOnCompleteState() { + if (onCompleteNeeded()) { + terminalCallMadeDownstream = true; + subscriber.onComplete(); + } + } + + /** + * Returns true if we need to call onComplete downstream. If this is executed outside the handling-state-update condition, the + * result is subject to change. + */ + private boolean onCompleteNeeded() { + return onCompleteCalledByUpstream && !terminalCallMadeDownstream && additionalOnCompleteNeededCheck(); + } + + /** + * Can be overridden by subclasses to provide additional checks before calling onComplete on downstream + */ + boolean additionalOnCompleteNeededCheck() { + return true; + } + + /** + * If the upstream has called onError, then tell the downstream we're done, no matter what state the queue is in. + */ + private void handleOnErrorState() { + if (onErrorNeeded()) { + terminalCallMadeDownstream = true; + subscriber.onError(onErrorFromUpstream.get()); + } + } + + /** + * Returns true if we need to call onError downstream. If this is executed outside the handling-state-update condition, the + * result is subject to change. + */ + private boolean onErrorNeeded() { + return onErrorFromUpstream.get() != null && !terminalCallMadeDownstream; + } + + /** + * Ensure that we have at least 1 demand upstream, so that we can get more items. + */ + private void ensureUpstreamDemandExists() { + if (this.upstreamDemand.get() < 0) { + log.error(() -> "Upstream delivered more data than requested. Resetting state to prevent a frozen stream.", + new IllegalStateException()); + upstreamDemand.set(1); + upstreamSubscription.request(1); + } else if (this.upstreamDemand.compareAndSet(0, 1)) { + upstreamSubscription.request(1); + } + } +} diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/DelegatingBufferingSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/DelegatingBufferingSubscriber.java new file mode 100644 index 000000000000..9e29737d444b --- /dev/null +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/DelegatingBufferingSubscriber.java @@ -0,0 +1,124 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import static software.amazon.awssdk.utils.async.StoringSubscriber.EventType.ON_COMPLETE; +import static software.amazon.awssdk.utils.async.StoringSubscriber.EventType.ON_NEXT; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; + +@SdkProtectedApi +public class DelegatingBufferingSubscriber extends BaseSubscriberAdapter { + private static final Logger log = Logger.loggerFor(DelegatingBufferingSubscriber.class); + + /** + * The maximum amount of bytes allowed to be stored in the StoringSubscriber + */ + private final long maximumBufferInBytes; + + /** + * Current amount of bytes buffered in the StoringSubscriber + */ + private final AtomicLong currentlyBuffered = new AtomicLong(0); + + /** + * Stores the bytes received from the upstream publisher, awaiting sending them to the delegate once the buffer size is + * reached. + */ + private final StoringSubscriber storage = new StoringSubscriber<>(Integer.MAX_VALUE); + + protected DelegatingBufferingSubscriber(Long maximumBufferInBytes, Subscriber delegate) { + super(Validate.notNull(delegate, "delegate must not be null")); + this.maximumBufferInBytes = Validate.notNull(maximumBufferInBytes, "maximumBufferInBytes must not be null"); + } + + @Override + public void onSubscribe(Subscription subscription) { + storage.onSubscribe(new DemandIgnoringSubscription(subscription)); + super.onSubscribe(subscription); + } + + @Override + void doWithItem(ByteBuffer buffer) { + storage.onNext(buffer.duplicate()); + currentlyBuffered.addAndGet(buffer.remaining()); + } + + @Override + protected void fulfillDownstreamDemand() { + storage.poll() + .filter(event -> event.type() == ON_NEXT) + .ifPresent(byteBufferEvent -> { + currentlyBuffered.addAndGet(-byteBufferEvent.value().remaining()); + downstreamDemand.decrementAndGet(); + log.trace(() -> "demand: " + downstreamDemand.get()); + subscriber.onNext(byteBufferEvent.value()); + }); + } + + /** + * Returns true if we need to call onNext downstream. + */ + @Override + boolean additionalOnNextNeededCheck() { + return storage.peek().map(event -> event.type() == ON_NEXT).orElse(false); + } + + /** + * Returns true if we need to call onComplete downstream. + */ + @Override + boolean additionalOnCompleteNeededCheck() { + return storage.peek().map(event -> event.type() == ON_COMPLETE).orElse(true); + } + + /** + * Returns true if we need to increase our upstream demand. + */ + @Override + boolean additionalUpstreamDemandNeededCheck() { + return currentlyBuffered.get() < maximumBufferInBytes; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + private Long maximumBufferInBytes; + private Subscriber delegate; + + public Builder maximumBufferInBytes(Long maximumBufferInBytes) { + this.maximumBufferInBytes = maximumBufferInBytes; + return this; + } + + public Builder delegate(Subscriber delegate) { + this.delegate = delegate; + return this; + } + + public DelegatingBufferingSubscriber build() { + return new DelegatingBufferingSubscriber(maximumBufferInBytes, delegate); + } + } +} diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/FlatteningSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/FlatteningSubscriber.java index 7ff7b830eed0..5b9fdd7a300d 100644 --- a/utils/src/main/java/software/amazon/awssdk/utils/async/FlatteningSubscriber.java +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/FlatteningSubscriber.java @@ -16,257 +16,64 @@ package software.amazon.awssdk.utils.async; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkProtectedApi; -import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.Validate; @SdkProtectedApi -public class FlatteningSubscriber extends DelegatingSubscriber, U> { - private static final Logger log = Logger.loggerFor(FlatteningSubscriber.class); - - /** - * The amount of unfulfilled demand open against the upstream subscriber. - */ - private final AtomicLong upstreamDemand = new AtomicLong(0); - - /** - * The amount of unfulfilled demand the downstream subscriber has opened against us. - */ - private final AtomicLong downstreamDemand = new AtomicLong(0); - - /** - * A flag that is used to ensure that only one thread is handling updates to the state of this subscriber at a time. This - * allows us to ensure that the downstream onNext, onComplete and onError are only ever invoked serially. - */ - private final AtomicBoolean handlingStateUpdate = new AtomicBoolean(false); +public class FlatteningSubscriber extends BaseSubscriberAdapter, U> { /** * Items given to us by the upstream subscriber that we will use to fulfill demand of the downstream subscriber. */ private final LinkedBlockingQueue allItems = new LinkedBlockingQueue<>(); - /** - * Whether the upstream subscriber has called onError on us. If this is null, we haven't gotten an onError. If it's non-null - * this will be the exception that the upstream passed to our onError. After we get an onError, we'll call onError on the - * downstream subscriber as soon as possible. - */ - private final AtomicReference onErrorFromUpstream = new AtomicReference<>(null); - - /** - * Whether we have called onComplete or onNext on the downstream subscriber. - */ - private volatile boolean terminalCallMadeDownstream = false; - - /** - * Whether the upstream subscriber has called onComplete on us. After this happens, we'll drain any outstanding items in the - * allItems queue and then call onComplete on the downstream subscriber. - */ - private volatile boolean onCompleteCalledByUpstream = false; - - /** - * The subscription to the upstream subscriber. - */ - private Subscription upstreamSubscription; - public FlatteningSubscriber(Subscriber subscriber) { super(subscriber); } @Override - public void onSubscribe(Subscription subscription) { - if (upstreamSubscription != null) { - log.warn(() -> "Received duplicate subscription, cancelling the duplicate.", new IllegalStateException()); - subscription.cancel(); - return; - } - - upstreamSubscription = subscription; - subscriber.onSubscribe(new Subscription() { - @Override - public void request(long l) { - addDownstreamDemand(l); - handleStateUpdate(); - } - - @Override - public void cancel() { - subscription.cancel(); - } + void doWithItem(Iterable nextItems) { + nextItems.forEach(item -> { + Validate.notNull(nextItems, "Collections flattened by the flattening subscriber must not contain null."); + allItems.add(item); }); } @Override - public void onNext(Iterable nextItems) { - try { - nextItems.forEach(item -> { - Validate.notNull(nextItems, "Collections flattened by the flattening subscriber must not contain null."); - allItems.add(item); - }); - } catch (RuntimeException e) { - upstreamSubscription.cancel(); - onError(e); - throw e; - } - - upstreamDemand.decrementAndGet(); - handleStateUpdate(); - } - - @Override - public void onError(Throwable throwable) { - onErrorFromUpstream.compareAndSet(null, throwable); - handleStateUpdate(); - } - - @Override - public void onComplete() { - onCompleteCalledByUpstream = true; - handleStateUpdate(); - } - - /** - * Increment the downstream demand by the provided value, accounting for overflow. - */ - private void addDownstreamDemand(long l) { - - if (l > 0) { - downstreamDemand.getAndUpdate(current -> { - long newValue = current + l; - return newValue >= 0 ? newValue : Long.MAX_VALUE; - }); - } else { - log.error(() -> "Demand " + l + " must not be negative."); - upstreamSubscription.cancel(); - onError(new IllegalArgumentException("Demand must not be negative")); - } - } - - /** - * This is invoked after each downstream request or upstream onNext, onError or onComplete. - */ - private void handleStateUpdate() { - do { - // Anything that happens after this if statement and before we set handlingStateUpdate to false is guaranteed to only - // happen on one thread. For that reason, we should only invoke onNext, onComplete or onError within that block. - if (!handlingStateUpdate.compareAndSet(false, true)) { - return; - } - - try { - // If we've already called onComplete or onError, don't do anything. - if (terminalCallMadeDownstream) { - return; - } - - // Call onNext, onComplete and onError as needed based on the current subscriber state. - handleOnNextState(); - handleUpstreamDemandState(); - handleOnCompleteState(); - handleOnErrorState(); - } catch (Error e) { - throw e; - } catch (Throwable e) { - log.error(() -> "Unexpected exception encountered that violates the reactive streams specification. Attempting " - + "to terminate gracefully.", e); - upstreamSubscription.cancel(); - onError(e); - } finally { - handlingStateUpdate.set(false); - } - - // It's possible we had an important state change between when we decided to release the state update flag, and we - // actually released it. If that seems to have happened, try to handle that state change on this thread, because - // another thread is not guaranteed to come around and do so. - } while (onNextNeeded() || upstreamDemandNeeded() || onCompleteNeeded() || onErrorNeeded()); - } - - /** - * Fulfill downstream demand by pulling items out of the item queue and sending them downstream. - */ - private void handleOnNextState() { - while (onNextNeeded() && !onErrorNeeded()) { - downstreamDemand.decrementAndGet(); - subscriber.onNext(allItems.poll()); - } + protected void fulfillDownstreamDemand() { + downstreamDemand.decrementAndGet(); + subscriber.onNext(allItems.poll()); } /** * Returns true if we need to call onNext downstream. If this is executed outside the handling-state-update condition, the * result is subject to change. */ - private boolean onNextNeeded() { - return !allItems.isEmpty() && downstreamDemand.get() > 0; - } - - /** - * Request more upstream demand if it's needed. - */ - private void handleUpstreamDemandState() { - if (upstreamDemandNeeded()) { - ensureUpstreamDemandExists(); - } + @Override + boolean additionalOnNextNeededCheck() { + return !allItems.isEmpty(); } /** * Returns true if we need to increase our upstream demand. */ - private boolean upstreamDemandNeeded() { - return upstreamDemand.get() <= 0 && downstreamDemand.get() > 0 && allItems.isEmpty(); - } - - /** - * If there are zero pending items in the queue and the upstream has called onComplete, then tell the downstream - * we're done. - */ - private void handleOnCompleteState() { - if (onCompleteNeeded()) { - terminalCallMadeDownstream = true; - subscriber.onComplete(); - } + @Override + boolean additionalUpstreamDemandNeededCheck() { + return allItems.isEmpty(); } /** * Returns true if we need to call onNext downstream. If this is executed outside the handling-state-update condition, the * result is subject to change. */ - private boolean onCompleteNeeded() { - return onCompleteCalledByUpstream && allItems.isEmpty() && !terminalCallMadeDownstream; - } - - /** - * If the upstream has called onError, then tell the downstream we're done, no matter what state the queue is in. - */ - private void handleOnErrorState() { - if (onErrorNeeded()) { - terminalCallMadeDownstream = true; - subscriber.onError(onErrorFromUpstream.get()); - } - } - - /** - * Returns true if we need to call onError downstream. If this is executed outside the handling-state-update condition, the - * result is subject to change. - */ - private boolean onErrorNeeded() { - return onErrorFromUpstream.get() != null && !terminalCallMadeDownstream; + @Override + boolean additionalOnCompleteNeededCheck() { + return allItems.isEmpty(); } - /** - * Ensure that we have at least 1 demand upstream, so that we can get more items. - */ - private void ensureUpstreamDemandExists() { - if (this.upstreamDemand.get() < 0) { - log.error(() -> "Upstream delivered more data than requested. Resetting state to prevent a frozen stream.", - new IllegalStateException()); - upstreamDemand.set(1); - upstreamSubscription.request(1); - } else if (this.upstreamDemand.compareAndSet(0, 1)) { - upstreamSubscription.request(1); - } + @Override + public void onNext(Iterable item) { + super.onNext(item); } } diff --git a/utils/src/test/java/software/amazon/awssdk/utils/CompletableFutureUtilsTest.java b/utils/src/test/java/software/amazon/awssdk/utils/CompletableFutureUtilsTest.java index 2b6fe49e2cc1..f8312c0cc506 100644 --- a/utils/src/test/java/software/amazon/awssdk/utils/CompletableFutureUtilsTest.java +++ b/utils/src/test/java/software/amazon/awssdk/utils/CompletableFutureUtilsTest.java @@ -108,6 +108,18 @@ public void forwardTransformedResultTo_srcCompletesExceptionally_shouldCompleteD assertThatThrownBy(dst::join).hasCause(exception); } + @Test(timeout = 1000) + public void forwardTransformedResultTo_functionThrowsException_shouldCompleteExceptionally() { + CompletableFuture src = new CompletableFuture<>(); + CompletableFuture dst = new CompletableFuture<>(); + + CompletableFutureUtils.forwardTransformedResultTo(src, dst, x -> { throw new RuntimeException("foobar"); }); + src.complete(0); + assertThatThrownBy(dst::join) + .hasMessageContaining("foobar") + .hasCauseInstanceOf(RuntimeException.class); + } + @Test(timeout = 1000) public void anyFail_shouldCompleteWhenAnyFutureFails() { RuntimeException exception = new RuntimeException("blah"); @@ -206,4 +218,6 @@ public void joinLikeSync_canceled_throwsCancellationException() { .hasNoSuppressedExceptions() .hasNoCause() .isInstanceOf(CancellationException.class); - }} + } + +} diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/DelegatingBufferingSubscriberTckTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/DelegatingBufferingSubscriberTckTest.java new file mode 100644 index 000000000000..cdb2ec4a9a41 --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/DelegatingBufferingSubscriberTckTest.java @@ -0,0 +1,102 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.nio.ByteBuffer; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.reactivestreams.tck.SubscriberWhiteboxVerification; +import org.reactivestreams.tck.TestEnvironment; + +public class DelegatingBufferingSubscriberTckTest extends SubscriberWhiteboxVerification { + + private static final byte[] DATA = {0, 1, 2, 3, 4, 5, 6, 7}; + + protected DelegatingBufferingSubscriberTckTest() { + super(new TestEnvironment()); + } + + @Override + public Subscriber createSubscriber(WhiteboxSubscriberProbe probe) { + Subscriber delegate = new NoOpSubscriber(); + return new DelegatingBufferingSubscriber(1024L, delegate) { + @Override + public void onSubscribe(Subscription s) { + super.onSubscribe(s); + probe.registerOnSubscribe(new SubscriberPuppet() { + @Override + public void triggerRequest(long l) { + s.request(l); + } + + @Override + public void signalCancel() { + s.cancel(); + } + }); + } + + @Override + public void onNext(ByteBuffer bb) { + super.onNext(bb); + probe.registerOnNext(bb); + } + + @Override + public void onError(Throwable t) { + super.onError(t); + probe.registerOnError(t); + } + + @Override + public void onComplete() { + super.onComplete(); + probe.registerOnComplete(); + } + }; + } + + @Override + public ByteBuffer createElement(int element) { + return ByteBuffer.wrap(DATA); + } + + static class NoOpSubscriber implements Subscriber { + private Subscription subscription; + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + // do nothing, test only + } + + @Override + public void onComplete() { + // do nothing, test only + } + } + +} \ No newline at end of file diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/DelegatingBufferingSubscriberTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/DelegatingBufferingSubscriberTest.java new file mode 100644 index 000000000000..219c970a6cbf --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/DelegatingBufferingSubscriberTest.java @@ -0,0 +1,280 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +class DelegatingBufferingSubscriberTest { + + @Test + void givenMultipleBufferTotalToBufferSize_ExpectSubscriberGetThemAll() { + TestSubscriber testSubscriber = new TestSubscriber(32); + Subscriber subscriber = DelegatingBufferingSubscriber.builder() + .maximumBufferInBytes(32L) + .delegate(testSubscriber) + .build(); + SimplePublisher publisher = new SimplePublisher<>(); + publisher.subscribe(subscriber); + + testSubscriber.assertNothingReceived(); + for (int i = 0; i < 3; i++) { + ByteBuffer buff = byteArrayWithValue((byte) i, 8); + publisher.send(buff); + } + + ByteBuffer buff = byteArrayWithValue((byte) 3, 8); + publisher.send(buff); + assertThat(testSubscriber.onNextCallAmount).isEqualTo(4); + assertThat(testSubscriber.totalReceived).isEqualTo(32); + + publisher.complete(); + assertThat(testSubscriber.onNextCallAmount).isEqualTo(4); + assertThat(testSubscriber.totalReceived).isEqualTo(32); + + testSubscriber.assertAllReceivedInChunk(8); + assertThat(testSubscriber.onCompleteCalled).isTrue(); + } + + @Test + void givenMultipleBufferLessThenBufferSize_ExpectSubscriberGetThemAll() { + TestSubscriber testSubscriber = new TestSubscriber(32); + Subscriber subscriber = DelegatingBufferingSubscriber.builder() + .maximumBufferInBytes(64L) + .delegate(testSubscriber) + .build(); + SimplePublisher publisher = new SimplePublisher<>(); + publisher.subscribe(subscriber); + + testSubscriber.assertNothingReceived(); + for (int i = 0; i < 4; i++) { + ByteBuffer buff = byteArrayWithValue((byte) i, 8); + publisher.send(buff); + } + + publisher.complete(); + testSubscriber.assertBytesReceived(4, 32); + testSubscriber.assertAllReceivedInChunk(8); + } + + @Test + void exceedsBufferInMultipleChunk_BytesReceivedInMultipleBatches() { + TestSubscriber testSubscriber = new TestSubscriber(64); + Subscriber subscriber = DelegatingBufferingSubscriber.builder() + .maximumBufferInBytes(32L) + .delegate(testSubscriber) + .build(); + SimplePublisher publisher = new SimplePublisher<>(); + publisher.subscribe(subscriber); + + testSubscriber.assertNothingReceived(); + for (int i = 0; i < 3; i++) { + ByteBuffer buff = byteArrayWithValue((byte) i, 8); + publisher.send(buff); + } + + for (int i = 3; i < 8; i++) { + ByteBuffer buff = byteArrayWithValue((byte) i, 8); + publisher.send(buff); + } + testSubscriber.assertBytesReceived(8, 64); + + publisher.complete(); + + // make sure nothing more is received + testSubscriber.assertBytesReceived(8, 64); + testSubscriber.assertAllReceivedInChunk(8); + } + + @Test + void whenDataExceedsBufferSingle_ExpectAllBytesReceived() { + TestSubscriber testSubscriber = new TestSubscriber(256); + Subscriber subscriber = DelegatingBufferingSubscriber.builder() + .maximumBufferInBytes(32L) + .delegate(testSubscriber) + .build(); + SimplePublisher publisher = new SimplePublisher<>(); + publisher.subscribe(subscriber); + + publisher.send(byteArrayWithValue((byte) 0, 256)); + testSubscriber.assertBytesReceived(1, 256); + + publisher.complete(); + testSubscriber.assertBytesReceived(1, 256); + + testSubscriber.assertAllReceivedInChunk(256); + } + + @Test + void multipleBuffer_unevenSizes() { + TestSubscriber testSubscriber = new TestSubscriber(59); + Subscriber subscriber = DelegatingBufferingSubscriber.builder() + .maximumBufferInBytes(32L) + .delegate(testSubscriber) + .build(); + SimplePublisher publisher = new SimplePublisher<>(); + publisher.subscribe(subscriber); + + testSubscriber.assertNothingReceived(); + publisher.send(byteArrayWithValue((byte) 0, 9)); + + publisher.send(byteArrayWithValue((byte) 1, 20)); + + publisher.send(byteArrayWithValue((byte) 2, 30)); + testSubscriber.assertBytesReceived(3, 59); + + publisher.complete(); + testSubscriber.assertBytesReceived(3, 59); + + ByteBuffer received = testSubscriber.received; + received.position(0); + for (int i = 0; i < 9; i++) { + assertThat(received.get()).isEqualTo((byte) 0); + } + for (int i = 0; i < 20; i++) { + assertThat(received.get()).isEqualTo((byte) 1); + } + for (int i = 0; i < 30; i++) { + assertThat(received.get()).isEqualTo((byte) 2); + } + } + + @Test + void stochastic_ExpectAllBytesReceived() { + AtomicInteger i = new AtomicInteger(0); + int totalSendToMake = 16; + TestSubscriber testSubscriber = new TestSubscriber(512 * 1024); + Subscriber subscriber = DelegatingBufferingSubscriber.builder() + .maximumBufferInBytes(32 * 1024L) + .delegate(testSubscriber) + .build(); + SimplePublisher publisher = new SimplePublisher<>(); + publisher.subscribe(subscriber); + ExecutorService executor = Executors.newFixedThreadPool(8); + CountDownLatch latch = new CountDownLatch(totalSendToMake); + for (int j = 0; j < totalSendToMake; j++) { + executor.submit(() -> { + ByteBuffer buffer = byteArrayWithValue((byte) i.incrementAndGet(), 32 * 1024); + publisher.send(buffer).whenComplete((res, err) -> { + if (err != null) { + fail("unexpected error sending data"); + } + latch.countDown(); + }); + }); + } + try { + latch.await(); + } catch (InterruptedException e) { + fail("Test interrupted while waiting for all submitted task to finish"); + } + publisher.complete(); + testSubscriber.assertBytesReceived(totalSendToMake, 512 * 1024); + } + + @Test + void publisherError_ExpectSubscriberOnErrorToBeCalled() { + TestSubscriber testSubscriber = new TestSubscriber(32); + Subscriber subscriber = DelegatingBufferingSubscriber.builder() + .maximumBufferInBytes(32L) + .delegate(testSubscriber) + .build(); + SimplePublisher publisher = new SimplePublisher<>(); + publisher.subscribe(subscriber); + + for (int i = 0; i < 4; i++) { + ByteBuffer buff = byteArrayWithValue((byte) i, 8); + publisher.send(buff); + } + publisher.error(new RuntimeException("test exception")); + + publisher.complete(); + testSubscriber.assertBytesReceived(4, 32); + assertThat(testSubscriber.onErrorCalled).isTrue(); + } + + private class TestSubscriber implements Subscriber { + int onNextCallAmount = 0; + int totalReceived = 0; + int totalSizeExpected; + boolean onErrorCalled = false; + boolean onCompleteCalled = false; + ByteBuffer received; + + public TestSubscriber(int totalSizeExpected) { + this.totalSizeExpected = totalSizeExpected; + this.received = ByteBuffer.allocate(totalSizeExpected); + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + System.out.println("received in delegate " + byteBuffer.remaining()); + onNextCallAmount++; + totalReceived += byteBuffer.remaining(); + received.put(byteBuffer); + } + + @Override + public void onError(Throwable t) { + onErrorCalled = true; + } + + @Override + public void onComplete() { + onCompleteCalled = true; + } + + void assertNothingReceived() { + assertThat(onNextCallAmount).isZero(); + assertThat(totalReceived).isZero(); + } + + void assertBytesReceived(int timesOnNextWasCalled, int totalBytesReceived) { + assertThat(onNextCallAmount).isEqualTo(timesOnNextWasCalled); + assertThat(totalReceived).isEqualTo(totalBytesReceived); + } + + void assertAllReceivedInChunk(int chunkSize) { + received.position(0); + for (int i = 0; i < totalReceived / chunkSize; i++) { + for (int j = 0; j < chunkSize; j++) { + assertThat(received.get(i * chunkSize + j)).isEqualTo((byte) i); + } + } + } + } + + private static ByteBuffer byteArrayWithValue(byte value, int size) { + byte[] arr = new byte[size]; + Arrays.fill(arr, value); + return ByteBuffer.wrap(arr); + } +}