Skip to content

Commit

Permalink
Abort method for S3OutputStream for ability to cancel the upload midw…
Browse files Browse the repository at this point in the history
…ay through (#1221)

---------
Co-authored-by: Andrei <andrei@riskfront.ai>
  • Loading branch information
zhemaituk authored Sep 19, 2024
1 parent b5df124 commit 2f78149
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ public void flush() throws IOException {
localOutputStream.flush();
}

@Override
public void abort() throws IOException {
if (closed) {
throw new IllegalStateException("Stream is already closed. Too late to abort.");
}

localOutputStream.close();
closed = true;
deleteTempFile();
}

@Override
public void close() throws IOException {
if (closed) {
Expand All @@ -145,19 +156,22 @@ public void close() throws IOException {
}
}
this.upload(builder.build());
boolean result = file.delete();

if (!result) {
getLogger().warn(String.format("Temporary file %s could not be deleted", file.getPath()));
}
deleteTempFile();
}
catch (Exception se) {
getLogger().error(
String.format("Failed to upload %s. Temporary file @%s", location.getObject(), file.getPath()));
getLogger().error("Failed to upload {}. Temporary file @{}", location.getObject(), file.getPath());
throw new UploadFailedException(file.getPath(), se);
}
}

private void deleteTempFile() {
boolean result = file.delete();

if (!result) {
getLogger().warn("Temporary file {} could not be deleted", file.getPath());
}
}

protected abstract void upload(PutObjectRequest putObjectRequest);

protected Logger getLogger() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ public void write(int b) {
}
}

@Override
public void abort() {
synchronized (this.monitor) {
if (isClosed()) {
throw new IllegalStateException("Stream is already closed. Too late to abort.");
}
if (isMultiPartUpload()) {
abortMultiPartUpload(multipartUploadResponse);
}
outputStream = null;
}
}

@Override
public void close() {
synchronized (this.monitor) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.awspring.cloud.s3;

import java.io.IOException;
import java.io.OutputStream;

/**
Expand All @@ -25,4 +26,9 @@
*/
public abstract class S3OutputStream extends OutputStream {

/**
* Cancels the upload and cleans up temporal resources (temp files, partial multipart upload).
*/
public void abort() throws IOException {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -70,4 +71,16 @@ void throwsExceptionWhenUploadFails() throws IOException {
}
}

@Test
void abortsWhenExplicitlyInvoked() throws IOException {
S3Client s3Client = mock(S3Client.class);

try (DiskBufferingS3OutputStream diskBufferingS3OutputStream = new DiskBufferingS3OutputStream(
new Location("bucket", "key"), s3Client, null)) {
diskBufferingS3OutputStream.write("hello".getBytes(StandardCharsets.UTF_8));
diskBufferingS3OutputStream.abort();
}

verify(s3Client, never()).putObject(any(PutObjectRequest.class), any(RequestBody.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,74 @@ void abortsWhenCompletingMultipartUploadFails() throws IOException {
assertThat(requestCaptor.getValue().uploadId()).isEqualTo("uploadId");
}
}

@Test
void abortsWhenExplicitlyInvoked() throws IOException {
when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class)))
.thenReturn(CreateMultipartUploadResponse.builder().uploadId("uploadId").build());

when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class)))
.thenReturn(UploadPartResponse.builder().build());

when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class)))
.thenThrow(SdkException.builder().build());

final byte[] content = new byte[DEFAULT_BUFFER_CAPACITY_IN_BYTES + 1];

try (InMemoryBufferingS3OutputStream outputStream = new InMemoryBufferingS3OutputStream(
new Location("bucket", "key", null), s3Client, null, null, DEFAULT_BUFFER_CAPACITY)) {
new Random().nextBytes(content);
outputStream.write(content);
outputStream.abort();
}
final ArgumentCaptor<AbortMultipartUploadRequest> requestCaptor = ArgumentCaptor
.forClass(AbortMultipartUploadRequest.class);

verify(s3Client, times(1)).abortMultipartUpload(requestCaptor.capture());
assertThat(requestCaptor.getValue().bucket()).isEqualTo("bucket");
assertThat(requestCaptor.getValue().key()).isEqualTo("key");
assertThat(requestCaptor.getValue().uploadId()).isEqualTo("uploadId");
}

@Test
void abortsWhenInvokedBeforeWriting() {
try (InMemoryBufferingS3OutputStream outputStream = new InMemoryBufferingS3OutputStream(
new Location("bucket", "key", null), s3Client, null, null, DEFAULT_BUFFER_CAPACITY)) {
outputStream.abort();
}

verify(s3Client, never()).createMultipartUpload(any(CreateMultipartUploadRequest.class));
verify(s3Client, never()).abortMultipartUpload(any(AbortMultipartUploadRequest.class));
}

@Test
void failsWhenAbortingAfterClosing() {
InMemoryBufferingS3OutputStream outputStream = null;
try {
outputStream = new InMemoryBufferingS3OutputStream(new Location("bucket", "key", null), s3Client, null,
null, DEFAULT_BUFFER_CAPACITY);
}
finally {
assertThat(outputStream).isNotNull();
outputStream.close();
try {
outputStream.abort();
fail("IllegalStateException should be thrown.");
}
catch (IllegalStateException e) {
final ArgumentCaptor<PutObjectRequest> requestCaptor = ArgumentCaptor.forClass(PutObjectRequest.class);
final ArgumentCaptor<RequestBody> bodyCaptor = ArgumentCaptor.forClass(RequestBody.class);

verify(s3Client, times(1)).putObject(requestCaptor.capture(), bodyCaptor.capture());

assertThat(requestCaptor.getValue().bucket()).isEqualTo("bucket");
assertThat(requestCaptor.getValue().key()).isEqualTo("key");
assertThat(requestCaptor.getValue().contentLength()).isEqualTo(0);
assertThat(requestCaptor.getValue().contentMD5()).isNotNull();

verify(s3Client, never()).createMultipartUpload(any(CreateMultipartUploadRequest.class));
verify(s3Client, never()).abortMultipartUpload(any(AbortMultipartUploadRequest.class));
}
}
}
}

0 comments on commit 2f78149

Please sign in to comment.