Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ private void handleError(Throwable e) {
}

private AsyncResponseTransformer<T, T> getDelegateTransformer(Long startAt) {
if (transformerCount.get() == 0) {
if (transformerCount.get() == 0 &&
initialConfig.fileWriteOption() != FileTransformerConfiguration.FileWriteOption.WRITE_TO_POSITION) {
// On the first request we need to maintain the same config so
// that the file is actually created on disk if it doesn't exist (for example, if CREATE_NEW or
// CREATE_OR_REPLACE_EXISTING is used)
// that the file is actually created on disk if it doesn't exist (for CREATE_NEW or CREATE_OR_REPLACE_EXISTING)
return AsyncResponseTransformer.toFile(path, initialConfig);
}
switch (initialConfig.fileWriteOption()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName;
import static software.amazon.awssdk.transfer.s3.SizeConstant.KB;
import static software.amazon.awssdk.transfer.s3.SizeConstant.MB;

import java.io.File;
Expand Down Expand Up @@ -82,6 +83,24 @@ void pauseAndResume_shouldResumeDownload() {
assertThat(path.toFile()).hasSameBinaryContentAs(sourceFile);
}

@Test
void pauseAndResume_beforeFirstPartCompletes_shouldResumeDownload() {
Path path = RandomTempFile.randomUncreatedFile().toPath();
DownloadFileRequest request = DownloadFileRequest.builder()
.getObjectRequest(b -> b.bucket(BUCKET).key(KEY))
.destination(path)
.build();
FileDownload download = tmJava.downloadFile(request);

// stop before we complete first part, so only wait for an amount of bytes much lower than 1 part, 1 KiB should do it
waitUntilAmountTransferred(download, KB);
ResumableFileDownload resumableFileDownload = download.pause();
FileDownload resumed = tmJava.resumeDownloadFile(resumableFileDownload);
resumed.completionFuture().join();
assertThat(path.toFile()).hasSameBinaryContentAs(sourceFile);
}


@Test
void pauseAndResume_whenAlreadyComplete_shouldHandleGracefully() {
Path path = RandomTempFile.randomUncreatedFile().toPath();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ private TransferProgressUpdater doDownloadFile(
TransferProgressUpdater progressUpdater = new TransferProgressUpdater(downloadRequest, null);
try {
progressUpdater.transferInitiated();
responseTransformer = isS3ClientMultipartEnabled()
responseTransformer = isS3ClientMultipartEnabled() && downloadRequest.getObjectRequest().range() == null
? progressUpdater.wrapForNonSerialFileDownload(
responseTransformer, downloadRequest.getObjectRequest())
: progressUpdater.wrapResponseTransformer(responseTransformer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,11 @@ private ResumableRequestConverter() {

if (hasRemainingParts(getObjectRequest)) {
log.debug(() -> "The paused download was performed with part GET, now resuming download of remaining parts");
Long positionToWriteFrom =
MultipartDownloadUtils.multipartDownloadResumeContext(originalDownloadRequest.getObjectRequest())
.map(MultipartDownloadResumeContext::bytesToLastCompletedParts)
.orElse(0L);
AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> responseTransformer =
AsyncResponseTransformer.toFile(originalDownloadRequest.destination(),
FileTransformerConfiguration.builder()
.fileWriteOption(WRITE_TO_POSITION)
.position(positionToWriteFrom)
.position(0L)
.failureBehavior(LEAVE)
.build());
return Pair.of(originalDownloadRequest, responseTransformer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@

package software.amazon.awssdk.services.s3.internal.multipart;

import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
Expand All @@ -30,6 +34,7 @@
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.utils.CompletableFutureUtils;
import software.amazon.awssdk.utils.ContentRangeParser;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Pair;

Expand Down Expand Up @@ -66,7 +71,7 @@ public class ParallelMultipartDownloaderSubscriber
* The total number of completed parts. A part is considered complete once the completable future associated with its request
* completes successfully.
*/
private final AtomicInteger completedParts = new AtomicInteger();
private final AtomicInteger completedParts;

/**
* The future returned to the user when calling
Expand All @@ -80,7 +85,7 @@ public class ParallelMultipartDownloaderSubscriber
* The {@link GetObjectResponse} to be returned in the completed future to the user. It corresponds to the response of first
* part GetObject
*/
private GetObjectResponse getObjectResponse;
private volatile GetObjectResponse getObjectResponse;

/**
* The subscription received from the publisher this subscriber subscribes to.
Expand Down Expand Up @@ -135,12 +140,17 @@ public class ParallelMultipartDownloaderSubscriber
private final AtomicInteger partNumber = new AtomicInteger(0);

/**
* Tracks if one of the parts requests future completed exceptionally. If this occurs, it means all retries were
* attempted for that part, but it still failed. This is a failure state, the error should be reported back to the user
* and any more request should be ignored.
* Tracks if one of the parts requests future completed exceptionally. If this occurs, it means all retries were attempted for
* that part, but it still failed. This is a failure state, the error should be reported back to the user and any more request
* should be ignored.
*/
private final AtomicBoolean isCompletedExceptionally = new AtomicBoolean(false);

/**
* When resuming a paused download, indicates which parts were already completed before pausing.
*/
private final Set<Integer> initialCompletedParts;

public ParallelMultipartDownloaderSubscriber(S3AsyncClient s3,
GetObjectRequest getObjectRequest,
CompletableFuture<GetObjectResponse> resultFuture,
Expand All @@ -149,6 +159,36 @@ public ParallelMultipartDownloaderSubscriber(S3AsyncClient s3,
this.getObjectRequest = getObjectRequest;
this.resultFuture = resultFuture;
this.maxInFlightParts = maxInFlightParts;
this.initialCompletedParts = initialCompletedParts(getObjectRequest);
this.completedParts = new AtomicInteger(initialCompletedParts.size());

if (resumingDownload()) {
int totalPartsFromInitialRequest = MultipartDownloadUtils.multipartDownloadResumeContext(getObjectRequest)
.map(MultipartDownloadResumeContext::totalParts)
.orElse(0);
if (totalPartsFromInitialRequest > 0) {
totalPartsFuture.complete(totalPartsFromInitialRequest);
}
getObjectResponse = MultipartDownloadUtils.multipartDownloadResumeContext(getObjectRequest)
.map(MultipartDownloadResumeContext::response)
.orElse(null);
}
}

private static Set<Integer> initialCompletedParts(GetObjectRequest getObjectRequest) {
return Collections.unmodifiableSet(
MultipartDownloadUtils.multipartDownloadResumeContext(getObjectRequest)
.map(MultipartDownloadResumeContext::completedParts)
.<Set<Integer>>map(HashSet::new)
.orElse(Collections.emptySet())
);
}

private boolean resumingDownload() {
Optional<Boolean> hasAlreadyCompletedParts =
MultipartDownloadUtils.multipartDownloadResumeContext(getObjectRequest)
.map(ctx -> !ctx.completedParts().isEmpty());
return hasAlreadyCompletedParts.orElse(false);
}

@Override
Expand Down Expand Up @@ -176,19 +216,18 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse
+ " - Total pending transformers: " + pendingTransformers.size()
+ " - Current in flight requests: " + inFlightRequests.keySet());

int currentPartNum = partNumber.incrementAndGet();
int currentPartNum = nextPart();

if (currentPartNum == 1) {
sendFirstRequest(asyncResponseTransformer);
} else {
pendingTransformers.offer(Pair.of(currentPartNum, asyncResponseTransformer));
totalPartsFuture.thenAccept(
totalParts -> processingRequests(asyncResponseTransformer, currentPartNum, totalParts));
}
}

private void processingRequests(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> asyncResponseTransformer,
int currentPartNum, Integer totalParts) {
int currentPartNum, int totalParts) {

if (currentPartNum > totalParts) {
// Do not process requests above total parts.
Expand All @@ -203,6 +242,7 @@ private void processingRequests(AsyncResponseTransformer<GetObjectResponse, GetO
return;
}

sendNextRequest(asyncResponseTransformer, currentPartNum, totalParts);
processPendingTransformers(totalParts);
}

Expand Down Expand Up @@ -233,11 +273,14 @@ private void sendNextRequest(AsyncResponseTransformer<GetObjectResponse, GetObje
inFlightRequests.remove(currentPartNumber);
inFlightRequestsNum.decrementAndGet();
completedParts.incrementAndGet();
MultipartDownloadUtils.multipartDownloadResumeContext(getObjectRequest)
.ifPresent(ctx -> ctx.addCompletedPart(currentPartNumber));

if (completedParts.get() >= totalParts) {
if (completedParts.get() > totalParts) {
resultFuture.completeExceptionally(new IllegalStateException("Total parts exceeded"));
} else {
updateResumeContextForCompletion(res);
resultFuture.complete(getObjectResponse);
}

Expand All @@ -254,6 +297,14 @@ private void sendNextRequest(AsyncResponseTransformer<GetObjectResponse, GetObje
});
}

private void updateResumeContextForCompletion(GetObjectResponse response) {
ContentRangeParser.totalBytes(response.contentRange())
.ifPresent(total -> MultipartDownloadUtils
.multipartDownloadResumeContext(getObjectRequest)
.ifPresent(ctx ->
ctx.addToBytesToLastCompletedParts(total)));
}

private void sendFirstRequest(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> asyncResponseTransformer) {
log.debug(() -> "Sending first request");
GetObjectRequest request = nextRequest(1);
Expand Down Expand Up @@ -282,6 +333,13 @@ private void sendFirstRequest(AsyncResponseTransformer<GetObjectResponse, GetObj
getObjectResponse = res;

processPendingTransformers(res.partsCount());
MultipartDownloadUtils.multipartDownloadResumeContext(getObjectRequest)
.ifPresent(ctx -> {
ctx.addCompletedPart(1);
ctx.response(res);
ctx.totalParts(res.partsCount());
});

synchronized (subscriptionLock) {
subscription.request(1);
}
Expand Down Expand Up @@ -312,7 +370,7 @@ private void setInitialPartCountAndEtag(GetObjectResponse response) {

private void handlePartError(Throwable e, int part) {
isCompletedExceptionally.set(true);
log.debug(() -> "Error on part " + part, e);
log.debug(() -> "Error on part " + part, e);
resultFuture.completeExceptionally(e);
inFlightRequests.values().forEach(future -> future.cancel(true));
}
Expand All @@ -334,9 +392,12 @@ private void processPendingTransformers(int totalParts) {

private void doProcessPendingTransformers(int totalParts) {
while (shouldProcessPendingTransformers()) {
Pair<Integer, AsyncResponseTransformer<GetObjectResponse, GetObjectResponse>> transformer =
pendingTransformers.poll();
sendNextRequest(transformer.right(), transformer.left(), totalParts);
Pair<Integer, AsyncResponseTransformer<GetObjectResponse, GetObjectResponse>> pair = pendingTransformers.poll();
Integer part = pair.left();
AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> transformer = pair.right();
if (part <= totalParts) {
sendNextRequest(transformer, part, totalParts);
}
}
}

Expand Down Expand Up @@ -372,4 +433,18 @@ private GetObjectRequest nextRequest(int nextPartToGet) {
});
}

private int nextPart() {
if (initialCompletedParts.isEmpty()) {
return partNumber.incrementAndGet();
}

synchronized (initialCompletedParts) {
int part = partNumber.incrementAndGet();
while (initialCompletedParts.contains(part)) {
part = partNumber.incrementAndGet();
}
return part;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ void tearDown() throws Exception {
}

@ParameterizedTest
@ValueSource(ints = {2, 3, 4, 5, 6, 7, 8, 9, 10, 49})
void happyPath_multipartDownload_partsLessThanMaxInFlight(int numParts) throws Exception {
@ValueSource(ints = {2, 3, 4, 5, 6, 7, 8, 9, 10, 49, // less than maxInFlightParts
50, // == maxInFlightParts
51, 100, 101 // more than maxInFlightParts
})
void happyPath_multipartDownload(int numParts) throws Exception {
int partSize = 1024;
byte[] expectedBody = utils.stubAllParts(testBucket, testKey, numParts, partSize);

Expand Down
Loading