diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/http/ExecutionContext.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/http/ExecutionContext.java index 01cc3e0c5166..141ba473986f 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/http/ExecutionContext.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/http/ExecutionContext.java @@ -15,13 +15,11 @@ package software.amazon.awssdk.core.http; -import java.util.Optional; import software.amazon.awssdk.annotations.NotThreadSafe; import software.amazon.awssdk.annotations.SdkProtectedApi; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptorChain; import software.amazon.awssdk.core.interceptor.InterceptorContext; -import software.amazon.awssdk.core.internal.progress.listener.ProgressUpdater; import software.amazon.awssdk.core.signer.Signer; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.utils.builder.CopyableBuilder; @@ -38,7 +36,6 @@ public final class ExecutionContext implements ToCopyableBuilder progressUpdater() { - return progressUpdater != null ? Optional.of(progressUpdater) : Optional.empty(); - } - @Override public Builder toBuilder() { return new Builder(this); @@ -95,7 +87,6 @@ public static class Builder implements CopyableBuilder CompletableFuture execute( .then(() -> new HttpChecksumStage(ClientType.ASYNC)) .then(MakeRequestImmutableStage::new) .then(RequestPipelineBuilder - .first(AsyncSigningStage::new) + .first(BeforeExecutionProgressReportingStage::new) + .then(AsyncSigningStage::new) .then(AsyncBeforeTransmissionExecutionInterceptorsStage::new) .then(d -> new MakeAsyncHttpRequestStage<>(responseHandler, d)) .wrappedWith(AsyncApiCallAttemptMetricCollectionStage::new) .wrappedWith((deps, wrapped) -> new AsyncRetryableStage2<>(responseHandler, deps, wrapped)) .then(async(() -> new UnwrapResponseContainer<>())) + .then(async(() -> new AfterExecutionProgressReportingStage<>())) .then(async(() -> new AfterExecutionInterceptorsStage<>())) .wrappedWith(AsyncExecutionFailureExceptionReportingStage::new) .wrappedWith(AsyncApiCallTimeoutTrackingStage::new) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java index f9d954962864..6653d72b9f55 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java @@ -27,6 +27,7 @@ import software.amazon.awssdk.core.http.HttpResponseHandler; import software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder; import software.amazon.awssdk.core.internal.http.pipeline.stages.AfterExecutionInterceptorsStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.AfterExecutionProgressReportingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AfterTransmissionExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallAttemptMetricCollectionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallAttemptTimeoutTrackingStage; @@ -34,6 +35,7 @@ import software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallTimeoutTrackingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ApplyTransactionIdStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ApplyUserAgentStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeExecutionProgressReportingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeTransmissionExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeUnmarshallingExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.CompressRequestStage; @@ -190,7 +192,8 @@ public OutputT execute(HttpResponseHandler> response .then(MakeRequestImmutableStage::new) // End of mutating request .then(RequestPipelineBuilder - .first(SigningStage::new) + .first(BeforeExecutionProgressReportingStage::new) + .then(SigningStage::new) .then(BeforeTransmissionExecutionInterceptorsStage::new) .then(MakeHttpRequestStage::new) .then(AfterTransmissionExecutionInterceptorsStage::new) @@ -204,6 +207,7 @@ public OutputT execute(HttpResponseHandler> response .wrappedWith(ApiCallTimeoutTrackingStage::new)::build) .wrappedWith((deps, wrapped) -> new ApiCallMetricCollectionStage<>(wrapped)) .then(() -> new UnwrapResponseContainer<>()) + .then(() -> new AfterExecutionProgressReportingStage<>()) .then(() -> new AfterExecutionInterceptorsStage<>()) .wrappedWith(ExecutionFailureExceptionReportingStage::new) .build(httpClientDependencies) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/RequestExecutionContext.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/RequestExecutionContext.java index 73d9f2ea94f4..af2aaa60b376 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/RequestExecutionContext.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/RequestExecutionContext.java @@ -26,6 +26,7 @@ import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.internal.http.pipeline.RequestPipeline; import software.amazon.awssdk.core.internal.http.timers.TimeoutTracker; +import software.amazon.awssdk.core.internal.progress.listener.ProgressUpdater; import software.amazon.awssdk.core.signer.Signer; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.utils.Validate; @@ -44,6 +45,7 @@ public final class RequestExecutionContext { private TimeoutTracker apiCallTimeoutTracker; private TimeoutTracker apiCallAttemptTimeoutTracker; private MetricCollector attemptMetricCollector; + private ProgressUpdater progressUpdater; private RequestExecutionContext(Builder builder) { this.requestProvider = builder.requestProvider; @@ -127,6 +129,14 @@ public void attemptMetricCollector(MetricCollector metricCollector) { this.attemptMetricCollector = metricCollector; } + public ProgressUpdater progressUpdater() { + return progressUpdater; + } + + public void progressUpdater(ProgressUpdater progressUpdater) { + this.progressUpdater = progressUpdater; + } + /** * Sets the request body provider. * Used for transforming the original body provider to sign events for diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AfterExecutionProgressReportingStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AfterExecutionProgressReportingStage.java new file mode 100644 index 000000000000..5950b5891173 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AfterExecutionProgressReportingStage.java @@ -0,0 +1,36 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.http.pipeline.RequestPipeline; +import software.amazon.awssdk.core.internal.util.ProgressListenerUtils; + +@SdkInternalApi +public class AfterExecutionProgressReportingStage implements RequestPipeline { + @Override + public OutputT execute(OutputT input, RequestExecutionContext context) throws Exception { + if (input instanceof SdkResponse) { + ProgressListenerUtils.updateProgressListenersWithSuccessResponse((SdkResponse) input, context.progressUpdater()); + } + + return input; + } + + +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncExecutionFailureExceptionReportingStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncExecutionFailureExceptionReportingStage.java index 4a0cbf3b6dba..8ff68499b5f5 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncExecutionFailureExceptionReportingStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncExecutionFailureExceptionReportingStage.java @@ -16,6 +16,7 @@ package software.amazon.awssdk.core.internal.http.pipeline.stages; import static software.amazon.awssdk.core.internal.http.pipeline.stages.utils.ExceptionReportingUtils.reportFailureToInterceptors; +import static software.amazon.awssdk.core.internal.http.pipeline.stages.utils.ExceptionReportingUtils.reportFailureToProgressListeners; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; @@ -42,16 +43,17 @@ public CompletableFuture execute(SdkHttpFullRequest input, RequestExecu CompletableFuture executeFuture = wrappedExecute.handle((o, t) -> { if (t != null) { Throwable toReport = t; - if (toReport instanceof CompletionException) { toReport = toReport.getCause(); } toReport = reportFailureToInterceptors(context, toReport); + // If Progress Listeners are attached to the request, update them with the throwable + reportFailureToProgressListeners(context.progressUpdater(), toReport); + throw CompletableFutureUtils.errorAsCompletionException(ThrowableUtils.asSdkException(toReport)); - } else { - return o; } + return o; }); return CompletableFutureUtils.forwardExceptionTo(executeFuture, wrappedExecute); } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/BeforeExecutionProgressReportingStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/BeforeExecutionProgressReportingStage.java new file mode 100644 index 000000000000..bda4633b0f42 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/BeforeExecutionProgressReportingStage.java @@ -0,0 +1,47 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.http.pipeline.RequestToRequestPipeline; +import software.amazon.awssdk.core.internal.progress.listener.DefaultProgressUpdater; +import software.amazon.awssdk.core.internal.progress.listener.NoOpProgressUpdater; +import software.amazon.awssdk.core.internal.progress.listener.ProgressUpdater; +import software.amazon.awssdk.core.internal.util.ProgressListenerUtils; +import software.amazon.awssdk.http.SdkHttpFullRequest; + +@SdkInternalApi +public class BeforeExecutionProgressReportingStage implements RequestToRequestPipeline { + + @Override + public SdkHttpFullRequest execute(SdkHttpFullRequest input, RequestExecutionContext context) throws Exception { + + if (ProgressListenerUtils.progressListenerAttached(context.originalRequest())) { + Long requestContentLength = + (context.requestProvider() != null && context.requestProvider().contentLength().isPresent()) ? + context.requestProvider().contentLength().get() : null; + + ProgressUpdater progressUpdater = new DefaultProgressUpdater(context.originalRequest(), requestContentLength); + progressUpdater.requestPrepared(input); + context.progressUpdater(progressUpdater); + } else { + context.progressUpdater(new NoOpProgressUpdater()); + } + + return input; + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ExecutionFailureExceptionReportingStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ExecutionFailureExceptionReportingStage.java index 322796d331ae..9ce35fcbf9b5 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ExecutionFailureExceptionReportingStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ExecutionFailureExceptionReportingStage.java @@ -16,6 +16,7 @@ package software.amazon.awssdk.core.internal.http.pipeline.stages; import static software.amazon.awssdk.core.internal.http.pipeline.stages.utils.ExceptionReportingUtils.reportFailureToInterceptors; +import static software.amazon.awssdk.core.internal.http.pipeline.stages.utils.ExceptionReportingUtils.reportFailureToProgressListeners; import static software.amazon.awssdk.core.internal.util.ThrowableUtils.failure; import software.amazon.awssdk.annotations.SdkInternalApi; @@ -37,6 +38,8 @@ public OutputT execute(SdkHttpFullRequest input, RequestExecutionContext context return wrapped.execute(input, context); } catch (Exception e) { Throwable throwable = reportFailureToInterceptors(context, e); + + reportFailureToProgressListeners(context.progressUpdater(), throwable); throw failure(throwable); } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HandleResponseStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HandleResponseStage.java index b22bbbf14bc4..a6c18d2a3f68 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HandleResponseStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HandleResponseStage.java @@ -15,6 +15,9 @@ package software.amazon.awssdk.core.internal.http.pipeline.stages; +import static software.amazon.awssdk.core.internal.util.ProgressListenerUtils.updateProgressListenersWithResponseStatus; +import static software.amazon.awssdk.core.internal.util.ProgressListenerUtils.wrapWithBytesReadTrackingStream; + import java.time.Duration; import java.util.concurrent.atomic.AtomicLong; import software.amazon.awssdk.annotations.SdkInternalApi; @@ -45,10 +48,11 @@ public HandleResponseStage(HttpResponseHandler> responseHandle @Override public Response execute(SdkHttpFullResponse httpResponse, RequestExecutionContext context) throws Exception { + + updateProgressListenersWithResponseStatus(context.progressUpdater(), httpResponse); SdkHttpFullResponse bytesReadTracking = trackBytesRead(httpResponse, context); Response response = responseHandler.handle(bytesReadTracking, context.executionAttributes()); - collectMetrics(context); return response; @@ -85,7 +89,11 @@ private SdkHttpFullResponse trackBytesRead(SdkHttpFullResponse httpFullResponse, private AbortableInputStream trackBytesRead(AbortableInputStream content, RequestExecutionContext context) { AtomicLong bytesRead = context.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESPONSE_BYTES_READ); - BytesReadTrackingInputStream bytesReadTrackedStream = new BytesReadTrackingInputStream(content, bytesRead); + + BytesReadTrackingInputStream bytesReadTrackedStream = + wrapWithBytesReadTrackingStream( + content, bytesRead, context.progressUpdater()); + return AbortableInputStream.create(bytesReadTrackedStream); } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/MakeAsyncHttpRequestStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/MakeAsyncHttpRequestStage.java index 5c443f07a9a5..d831736470cd 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/MakeAsyncHttpRequestStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/MakeAsyncHttpRequestStage.java @@ -17,6 +17,7 @@ import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.SDK_HTTP_EXECUTION_ATTRIBUTES; import static software.amazon.awssdk.core.internal.http.timers.TimerUtils.resolveTimeoutInMillis; +import static software.amazon.awssdk.core.internal.util.ProgressListenerUtils.updateProgressListenersWithResponseStatus; import static software.amazon.awssdk.http.Header.CONTENT_LENGTH; import java.nio.ByteBuffer; @@ -48,6 +49,8 @@ import software.amazon.awssdk.core.internal.http.timers.TimerUtils; import software.amazon.awssdk.core.internal.metrics.BytesReadTrackingPublisher; import software.amazon.awssdk.core.internal.util.MetricUtils; +import software.amazon.awssdk.core.internal.util.ProgressListenerUtils; +import software.amazon.awssdk.core.internal.util.ResponseProgressUpdaterInvoker; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpMethod; @@ -132,6 +135,12 @@ private CompletableFuture> executeHttpRequest(SdkHttpFullReque SdkHttpContentPublisher requestProvider = context.requestProvider() == null ? new SimpleHttpContentPublisher(request) : new SdkHttpContentPublisherAdapter(context.requestProvider()); + + AtomicLong bytesRead = context.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESPONSE_BYTES_READ); + requestProvider = ProgressListenerUtils.wrapWithByteTracking( + requestProvider, bytesRead, + context.progressUpdater()); + // Set content length if it hasn't been set already. SdkHttpFullRequest requestWithContentLength = getRequestWithContentLength(request, requestProvider); @@ -303,13 +312,20 @@ public void onHeaders(SdkHttpResponse headers) { long d = now - startTime; context.attemptMetricCollector().reportMetric(CoreMetric.TIME_TO_FIRST_BYTE, Duration.ofNanos(d)); super.onHeaders(headers); + + updateProgressListenersWithResponseStatus(context.progressUpdater(), headers); } @Override public void onStream(Publisher stream) { AtomicLong bytesReadCounter = context.executionAttributes() .getAttribute(SdkInternalExecutionAttribute.RESPONSE_BYTES_READ); - BytesReadTrackingPublisher bytesReadTrackingPublisher = new BytesReadTrackingPublisher(stream, bytesReadCounter); + + Publisher bytesReadTrackingPublisher = + new BytesReadTrackingPublisher(stream, + bytesReadCounter, + new ResponseProgressUpdaterInvoker(context.progressUpdater())); + super.onStream(bytesReadTrackingPublisher); } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/MakeHttpRequestStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/MakeHttpRequestStage.java index 32cd094a57de..74caae766094 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/MakeHttpRequestStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/MakeHttpRequestStage.java @@ -16,6 +16,7 @@ package software.amazon.awssdk.core.internal.http.pipeline.stages; import java.time.Duration; +import java.util.concurrent.atomic.AtomicLong; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; @@ -23,8 +24,12 @@ import software.amazon.awssdk.core.internal.http.InterruptMonitor; import software.amazon.awssdk.core.internal.http.RequestExecutionContext; import software.amazon.awssdk.core.internal.http.pipeline.RequestPipeline; +import software.amazon.awssdk.core.internal.metrics.BytesReadTrackingInputStream; import software.amazon.awssdk.core.internal.util.MetricUtils; +import software.amazon.awssdk.core.internal.util.ProgressListenerUtils; import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.http.AbortableInputStream; +import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.ExecutableHttpRequest; import software.amazon.awssdk.http.HttpExecuteRequest; import software.amazon.awssdk.http.HttpExecuteResponse; @@ -61,15 +66,29 @@ public Pair execute(SdkHttpFullRequest } private HttpExecuteResponse executeHttpRequest(SdkHttpFullRequest request, RequestExecutionContext context) throws Exception { + MetricCollector attemptMetricCollector = context.attemptMetricCollector(); MetricCollector httpMetricCollector = MetricUtils.createHttpMetricsCollector(context); + ContentStreamProvider contentStreamProvider = null; + if (request.contentStreamProvider().isPresent()) { + AtomicLong bytesRead = context.executionAttributes() + .getAttribute(SdkInternalExecutionAttribute.RESPONSE_BYTES_READ); + + BytesReadTrackingInputStream wrappedByteTracking = ProgressListenerUtils.wrapWithBytesReadTrackingStream( + AbortableInputStream.create(request.contentStreamProvider().get().newStream()), + bytesRead, + context.progressUpdater()); + + contentStreamProvider = ContentStreamProvider.fromInputStream(wrappedByteTracking); + } + ExecutableHttpRequest requestCallable = sdkHttpClient .prepareRequest(HttpExecuteRequest.builder() .request(request) .metricCollector(httpMetricCollector) - .contentStreamProvider(request.contentStreamProvider().orElse(null)) + .contentStreamProvider(contentStreamProvider) .build()); context.apiCallTimeoutTracker().abortable(requestCallable); diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/utils/ExceptionReportingUtils.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/utils/ExceptionReportingUtils.java index 06a4b2544fc7..841d621ecf65 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/utils/ExceptionReportingUtils.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/utils/ExceptionReportingUtils.java @@ -18,6 +18,7 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.internal.http.RequestExecutionContext; import software.amazon.awssdk.core.internal.interceptor.DefaultFailedExecutionContext; +import software.amazon.awssdk.core.internal.progress.listener.ProgressUpdater; import software.amazon.awssdk.utils.Logger; @SdkInternalApi @@ -28,7 +29,8 @@ private ExceptionReportingUtils() { } /** - * Report the failure to the execution interceptors. Swallow any exceptions thrown from the interceptor since + * Report the failure to the execution interceptors and progress listeners if present. + * Swallow any exceptions thrown from the interceptor since * we don't want to replace the execution failure. * * @param context The execution context. @@ -46,6 +48,21 @@ public static Throwable reportFailureToInterceptors(RequestExecutionContext cont return modifiedContext.exception(); } + /** + * Report the failure to the Progress Listeners if they are present + * + * @param progressUpdater The execution context. + * @param failure The execution failure. + */ + public static void reportFailureToProgressListeners(ProgressUpdater progressUpdater, Throwable failure) { + + try { + progressUpdater.attemptFailure(failure); + } catch (Exception exception) { + log.warn(() -> "Progess Listener update threw an error while invoking attemptFailure().", exception); + } + } + private static DefaultFailedExecutionContext runModifyException(RequestExecutionContext context, Throwable e) { DefaultFailedExecutionContext failedContext = DefaultFailedExecutionContext.builder() diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingInputStream.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingInputStream.java index c5c74cf15ec1..46248fd826e4 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingInputStream.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingInputStream.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.concurrent.atomic.AtomicLong; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.internal.util.ProgressUpdaterInvoker; import software.amazon.awssdk.core.io.SdkFilterInputStream; import software.amazon.awssdk.http.Abortable; import software.amazon.awssdk.http.AbortableInputStream; @@ -26,11 +27,14 @@ public final class BytesReadTrackingInputStream extends SdkFilterInputStream implements Abortable { private final Abortable abortableIs; private final AtomicLong bytesRead; + private final ProgressUpdaterInvoker progressUpdaterInvoker; - public BytesReadTrackingInputStream(AbortableInputStream in, AtomicLong bytesRead) { + public BytesReadTrackingInputStream(AbortableInputStream in, AtomicLong bytesRead, + ProgressUpdaterInvoker progressUpdaterInvoker) { super(in); this.abortableIs = in; this.bytesRead = bytesRead; + this.progressUpdaterInvoker = progressUpdaterInvoker; } public long bytesRead() { @@ -68,6 +72,10 @@ public int read(byte[] b) throws IOException { private void updateBytesRead(long read) { if (read > 0) { bytesRead.addAndGet(read); + + if (progressUpdaterInvoker != null) { + progressUpdaterInvoker.incrementBytesTransferred(read); + } } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingPublisher.java index dd8ef03b7312..07a68280261c 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingPublisher.java @@ -16,52 +16,75 @@ package software.amazon.awssdk.core.internal.metrics; import java.nio.ByteBuffer; +import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.internal.util.ProgressUpdaterInvoker; +import software.amazon.awssdk.http.async.SdkHttpContentPublisher; /** * Publisher that tracks how many bytes are published from the wrapped publisher to the downstream subscriber. + * If request contains Progress Listeners attached, the callbacks invoke methods to update and track request status + * by invoking progress updater methods with the bytes being transacted */ @SdkInternalApi -public final class BytesReadTrackingPublisher implements Publisher { +public final class BytesReadTrackingPublisher implements SdkHttpContentPublisher { private final Publisher upstream; private final AtomicLong bytesRead; + private final ProgressUpdaterInvoker progressUpdaterInvoker; - public BytesReadTrackingPublisher(Publisher upstream, AtomicLong bytesRead) { + public BytesReadTrackingPublisher(Publisher upstream, AtomicLong bytesRead, + ProgressUpdaterInvoker progressUpdaterInvoker) { this.upstream = upstream; this.bytesRead = bytesRead; + this.progressUpdaterInvoker = progressUpdaterInvoker; } @Override public void subscribe(Subscriber subscriber) { - upstream.subscribe(new BytesReadTracker(subscriber, bytesRead)); + upstream.subscribe(new BytesReadTracker(subscriber, bytesRead, progressUpdaterInvoker)); } public long bytesRead() { return bytesRead.get(); } + @Override + public Optional contentLength() { + return Optional.empty(); + } + private static final class BytesReadTracker implements Subscriber { private final Subscriber downstream; private final AtomicLong bytesRead; + private final ProgressUpdaterInvoker progressUpdaterInvoker; - private BytesReadTracker(Subscriber downstream, AtomicLong bytesRead) { + private BytesReadTracker(Subscriber downstream, + AtomicLong bytesRead, ProgressUpdaterInvoker progressUpdaterInvoker) { this.downstream = downstream; this.bytesRead = bytesRead; + this.progressUpdaterInvoker = progressUpdaterInvoker; } @Override public void onSubscribe(Subscription subscription) { downstream.onSubscribe(subscription); + if (progressUpdaterInvoker.progressUpdater() != null) { + progressUpdaterInvoker.resetBytes(); + } } @Override public void onNext(ByteBuffer byteBuffer) { - bytesRead.addAndGet(byteBuffer.remaining()); + long byteBufferSize = byteBuffer.remaining(); + bytesRead.addAndGet(byteBufferSize); downstream.onNext(byteBuffer); + if (progressUpdaterInvoker != null) { + progressUpdaterInvoker.incrementBytesTransferred(byteBufferSize); + } } @Override diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/ProgressListenerContext.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/ProgressListenerContext.java index 86af01d93395..6a821d1bcd0a 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/ProgressListenerContext.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/ProgressListenerContext.java @@ -15,6 +15,7 @@ package software.amazon.awssdk.core.internal.progress; +import java.util.Optional; import software.amazon.awssdk.annotations.Immutable; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.SdkRequest; @@ -89,8 +90,8 @@ public SdkHttpResponse httpResponse() { } @Override - public SdkResponse response() { - return response; + public Optional response() { + return Optional.of(response); } @Override diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/DefaultProgressUpdater.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/DefaultProgressUpdater.java new file mode 100644 index 000000000000..400f254ac994 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/DefaultProgressUpdater.java @@ -0,0 +1,175 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.progress.listener; + +import java.util.Collections; +import java.util.Optional; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.RequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.internal.progress.ProgressListenerContext; +import software.amazon.awssdk.core.internal.progress.ProgressListenerFailedContext; +import software.amazon.awssdk.core.internal.progress.snapshot.DefaultProgressSnapshot; +import software.amazon.awssdk.core.progress.listener.SdkExchangeProgress; +import software.amazon.awssdk.core.progress.snapshot.ProgressSnapshot; +import software.amazon.awssdk.http.SdkHttpRequest; + +/** + * ProgressUpdater exposes methods that invokes listener methods to update and store request progress state + */ +@SdkInternalApi +public class DefaultProgressUpdater implements ProgressUpdater { + private final DefaultSdkExchangeProgress requestBodyProgress; + private final DefaultSdkExchangeProgress responseBodyProgress; + private ProgressListenerContext context; + private final ProgressListenerInvoker listenerInvoker; + + public DefaultProgressUpdater(SdkRequest sdkRequest, + Long requestContentLength) { + DefaultProgressSnapshot.Builder uploadProgressSnapshotBuilder = DefaultProgressSnapshot.builder(); + uploadProgressSnapshotBuilder.transferredBytes(0L); + Optional.ofNullable(requestContentLength).ifPresent(uploadProgressSnapshotBuilder::totalBytes); + + ProgressSnapshot uploadProgressSnapshot = uploadProgressSnapshotBuilder.build(); + requestBodyProgress = new DefaultSdkExchangeProgress(uploadProgressSnapshot); + + DefaultProgressSnapshot.Builder downloadProgressSnapshotBuilder = DefaultProgressSnapshot.builder(); + downloadProgressSnapshotBuilder.transferredBytes(0L); + ProgressSnapshot downloadProgressSnapshot = downloadProgressSnapshotBuilder.build(); + responseBodyProgress = new DefaultSdkExchangeProgress(downloadProgressSnapshot); + + context = ProgressListenerContext.builder() + .request(sdkRequest) + .uploadProgressSnapshot(uploadProgressSnapshot) + .downloadProgressSnapshot(downloadProgressSnapshot) + .build(); + + listenerInvoker = new ProgressListenerInvoker(sdkRequest.overrideConfiguration() + .map(RequestOverrideConfiguration::progressListeners) + .orElse(Collections.emptyList())); + } + + @Override + public void updateRequestContentLength(Long requestContentLength) { + requestBodyProgress.updateAndGet(b -> b.totalBytes(requestContentLength)); + } + + @Override + public void updateResponseContentLength(Long responseContentLength) { + responseBodyProgress.updateAndGet(b -> b.totalBytes(responseContentLength)); + } + + public SdkExchangeProgress requestBodyProgress() { + return requestBodyProgress; + } + + public SdkExchangeProgress responseBodyProgress() { + return responseBodyProgress; + } + + @Override + public void requestPrepared(SdkHttpRequest httpRequest) { + listenerInvoker.requestPrepared(context.copy(b -> b.httpRequest(httpRequest))); + } + + @Override + public void requestHeaderSent() { + listenerInvoker.requestHeaderSent(context); + } + + @Override + public void resetBytesSent() { + requestBodyProgress.updateAndGet(b -> b.transferredBytes(0L)); + } + + @Override + public void resetBytesReceived() { + responseBodyProgress.updateAndGet(b -> b.transferredBytes(0L)); + } + + @Override + public void incrementBytesSent(long numBytes) { + long uploadBytes = requestBodyProgress.progressSnapshot().transferredBytes(); + + ProgressSnapshot snapshot = requestBodyProgress.updateAndGet(b -> b.transferredBytes(uploadBytes + numBytes)); + listenerInvoker.requestBytesSent(context.copy(b -> b.uploadProgressSnapshot(snapshot))); + } + + @Override + public void incrementBytesReceived(long numBytes) { + long downloadedBytes = responseBodyProgress.progressSnapshot().transferredBytes(); + + ProgressSnapshot snapshot = responseBodyProgress.updateAndGet(b -> b.transferredBytes(downloadedBytes + numBytes)); + listenerInvoker.responseBytesReceived(context.copy(b -> b.downloadProgressSnapshot(snapshot))); + } + + @Override + public void responseHeaderReceived() { + listenerInvoker.responseHeaderReceived(context); + } + + @Override + public void executionSuccess(SdkResponse response) { + + listenerInvoker.executionSuccess(context.copy(b -> b.response(response))); + } + + @Override + public void executionFailure(Throwable t) { + listenerInvoker.executionFailure(ProgressListenerFailedContext.builder() + .progressListenerContext( + context.copy( + b -> { + b.uploadProgressSnapshot( + requestBodyProgress.progressSnapshot()); + b.downloadProgressSnapshot( + responseBodyProgress.progressSnapshot()); + })) + .exception(t) + .build()); + } + + @Override + public void attemptFailure(Throwable t) { + listenerInvoker.attemptFailure(ProgressListenerFailedContext.builder() + .progressListenerContext( + context.copy( + b -> { + b.uploadProgressSnapshot( + requestBodyProgress.progressSnapshot()); + b.downloadProgressSnapshot( + responseBodyProgress.progressSnapshot()); + })) + .exception(t) + .build()); + } + + @Override + public void attemptFailureResponseBytesReceived(Throwable t) { + listenerInvoker.attemptFailureResponseBytesReceived(ProgressListenerFailedContext.builder() + .progressListenerContext( + context.copy( + b -> { + b.uploadProgressSnapshot( + requestBodyProgress.progressSnapshot()); + b.downloadProgressSnapshot( + responseBodyProgress.progressSnapshot()); + })) + .exception(t) + .build()); + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/DefaultSdkExchangeProgress.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/DefaultSdkExchangeProgress.java index a7af7a3ab69c..cdb071ddee42 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/DefaultSdkExchangeProgress.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/DefaultSdkExchangeProgress.java @@ -27,7 +27,7 @@ /** * An SDK-internal implementation of {@link SdkExchangeProgress}. This implementation acts as a thin wrapper around {@link * AtomicReference}, where calls to get the latest {@link #progressSnapshot()} simply return the latest reference, while {@link - * ProgressUpdater} is responsible for continuously updating the latest reference. + * DefaultProgressUpdater} is responsible for continuously updating the latest reference. * * @see SdkExchangeProgress */ diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/NoOpProgressUpdater.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/NoOpProgressUpdater.java new file mode 100644 index 000000000000..afe90597e5bb --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/NoOpProgressUpdater.java @@ -0,0 +1,45 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.progress.listener; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.http.SdkHttpRequest; + +@SdkInternalApi +public class NoOpProgressUpdater implements ProgressUpdater { + + @Override + public void requestPrepared(SdkHttpRequest httpRequest) { + } + + @Override + public void responseHeaderReceived() { + } + + @Override + public void executionSuccess(SdkResponse response) { + } + + @Override + public void executionFailure(Throwable t) { + } + + @Override + public void attemptFailure(Throwable t) { + } +} + diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/ProgressUpdater.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/ProgressUpdater.java index 3db626977b6a..ccbb568effdf 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/ProgressUpdater.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/progress/listener/ProgressUpdater.java @@ -15,143 +15,43 @@ package software.amazon.awssdk.core.internal.progress.listener; -import java.util.Collections; -import java.util.Optional; import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.core.RequestOverrideConfiguration; -import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SdkResponse; -import software.amazon.awssdk.core.internal.progress.ProgressListenerContext; -import software.amazon.awssdk.core.internal.progress.ProgressListenerFailedContext; -import software.amazon.awssdk.core.internal.progress.snapshot.DefaultProgressSnapshot; -import software.amazon.awssdk.core.progress.listener.SdkExchangeProgress; -import software.amazon.awssdk.core.progress.snapshot.ProgressSnapshot; +import software.amazon.awssdk.http.SdkHttpRequest; -/** - * ProgressUpdater exposes methods that invokes listener methods to update and store request progress state - */ @SdkInternalApi -public class ProgressUpdater { - private final DefaultSdkExchangeProgress requestBodyProgress; - private final DefaultSdkExchangeProgress responseBodyProgress; - private ProgressListenerContext context; - private final ProgressListenerInvoker listenerInvoker; - - public ProgressUpdater(SdkRequest sdkRequest, - Long requestContentLength) { - DefaultProgressSnapshot.Builder uploadProgressSnapshotBuilder = DefaultProgressSnapshot.builder(); - uploadProgressSnapshotBuilder.transferredBytes(0L); - Optional.ofNullable(requestContentLength).ifPresent(uploadProgressSnapshotBuilder::totalBytes); - - ProgressSnapshot uploadProgressSnapshot = uploadProgressSnapshotBuilder.build(); - requestBodyProgress = new DefaultSdkExchangeProgress(uploadProgressSnapshot); - - DefaultProgressSnapshot.Builder downloadProgressSnapshotBuilder = DefaultProgressSnapshot.builder(); - downloadProgressSnapshotBuilder.transferredBytes(0L); - ProgressSnapshot downloadProgressSnapshot = downloadProgressSnapshotBuilder.build(); - responseBodyProgress = new DefaultSdkExchangeProgress(downloadProgressSnapshot); - - context = ProgressListenerContext.builder() - .request(sdkRequest) - .uploadProgressSnapshot(uploadProgressSnapshot) - .downloadProgressSnapshot(downloadProgressSnapshot) - .build(); - - listenerInvoker = new ProgressListenerInvoker(sdkRequest.overrideConfiguration() - .map(RequestOverrideConfiguration::progressListeners) - .orElse(Collections.emptyList())); - } - - public void updateResponseContentLength(Long responseContentLength) { - responseBodyProgress.updateAndGet(b -> b.totalBytes(responseContentLength)); - } - - public SdkExchangeProgress requestBodyProgress() { - return requestBodyProgress; +public interface ProgressUpdater { + default void updateRequestContentLength(Long requestContentLength) { } - public SdkExchangeProgress responseBodyProgress() { - return responseBodyProgress; + default void updateResponseContentLength(Long responseContentLength) { } - public void requestPrepared() { - listenerInvoker.requestPrepared(context); - } + void requestPrepared(SdkHttpRequest httpRequest); - public void requestHeaderSent() { - listenerInvoker.requestHeaderSent(context); + default void requestHeaderSent() { } - public void resetBytesSent() { - requestBodyProgress.updateAndGet(b -> b.transferredBytes(0L)); + default void resetBytesSent() { } - public void resetBytesReceived() { - responseBodyProgress.updateAndGet(b -> b.transferredBytes(0L)); + default void resetBytesReceived() { } - public void incrementBytesSent(long numBytes) { - long uploadBytes = requestBodyProgress.progressSnapshot().transferredBytes(); - - ProgressSnapshot snapshot = requestBodyProgress.updateAndGet(b -> b.transferredBytes(uploadBytes + numBytes)); - listenerInvoker.requestBytesSent(context.copy(b -> b.uploadProgressSnapshot(snapshot))); + default void incrementBytesSent(long numBytes) { } - public void incrementBytesReceived(long numBytes) { - long downloadedBytes = responseBodyProgress.progressSnapshot().transferredBytes(); - - ProgressSnapshot snapshot = responseBodyProgress.updateAndGet(b -> b.transferredBytes(downloadedBytes + numBytes)); - listenerInvoker.responseBytesReceived(context.copy(b -> b.downloadProgressSnapshot(snapshot))); + default void incrementBytesReceived(long numBytes) { } - public void responseHeaderReceived() { - listenerInvoker.responseHeaderReceived(context); - } + void responseHeaderReceived(); - public void executionSuccess(SdkResponse response) { + void executionSuccess(SdkResponse response); - listenerInvoker.executionSuccess(context.copy(b -> b.response(response))); - } - - public void executionFailure(Throwable t) { - listenerInvoker.executionFailure(ProgressListenerFailedContext.builder() - .progressListenerContext( - context.copy( - b -> { - b.uploadProgressSnapshot( - requestBodyProgress.progressSnapshot()); - b.downloadProgressSnapshot( - responseBodyProgress.progressSnapshot()); - })) - .exception(t) - .build()); - } + void executionFailure(Throwable t); - public void attemptFailure(Throwable t) { - listenerInvoker.attemptFailure(ProgressListenerFailedContext.builder() - .progressListenerContext( - context.copy( - b -> { - b.uploadProgressSnapshot( - requestBodyProgress.progressSnapshot()); - b.downloadProgressSnapshot( - responseBodyProgress.progressSnapshot()); - })) - .exception(t) - .build()); - } + void attemptFailure(Throwable t); - public void attemptFailureResponseBytesReceived(Throwable t) { - listenerInvoker.attemptFailureResponseBytesReceived(ProgressListenerFailedContext.builder() - .progressListenerContext( - context.copy( - b -> { - b.uploadProgressSnapshot( - requestBodyProgress.progressSnapshot()); - b.downloadProgressSnapshot( - responseBodyProgress.progressSnapshot()); - })) - .exception(t) - .build()); + default void attemptFailureResponseBytesReceived(Throwable t){ } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/util/ProgressListenerUtils.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/util/ProgressListenerUtils.java new file mode 100644 index 000000000000..7900f87c2421 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/util/ProgressListenerUtils.java @@ -0,0 +1,72 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.util; + +import static software.amazon.awssdk.http.Header.CONTENT_LENGTH; + +import java.util.concurrent.atomic.AtomicLong; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.RequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.internal.metrics.BytesReadTrackingInputStream; +import software.amazon.awssdk.core.internal.metrics.BytesReadTrackingPublisher; +import software.amazon.awssdk.core.internal.progress.listener.ProgressUpdater; +import software.amazon.awssdk.http.AbortableInputStream; +import software.amazon.awssdk.http.SdkHttpHeaders; +import software.amazon.awssdk.http.async.SdkHttpContentPublisher; +import software.amazon.awssdk.utils.StringUtils; + +@SdkInternalApi +public final class ProgressListenerUtils { + + private ProgressListenerUtils() { + } + + public static SdkHttpContentPublisher wrapWithByteTracking( + SdkHttpContentPublisher requestProvider, AtomicLong bytesRead, ProgressUpdater progressUpdater) { + return new BytesReadTrackingPublisher(requestProvider, bytesRead, + new RequestProgressUpdaterInvoker(progressUpdater)); + } + + public static BytesReadTrackingInputStream wrapWithBytesReadTrackingStream( + AbortableInputStream content, AtomicLong bytesRead, ProgressUpdater progressUpdater) { + + return new BytesReadTrackingInputStream(content, + bytesRead, + new RequestProgressUpdaterInvoker(progressUpdater)); + } + + public static void updateProgressListenersWithResponseStatus(ProgressUpdater progressUpdater, + SdkHttpHeaders headers) { + progressUpdater.responseHeaderReceived(); + headers.firstMatchingHeader(CONTENT_LENGTH).ifPresent(value -> { + if (!StringUtils.isNotBlank(value)) { + progressUpdater.updateResponseContentLength(Long.parseLong(value)); + } + }); + } + + public static void updateProgressListenersWithSuccessResponse(SdkResponse response, + ProgressUpdater progressUpdater) { + progressUpdater.executionSuccess(response); + } + + public static boolean progressListenerAttached(SdkRequest request) { + return request.overrideConfiguration() + .map(RequestOverrideConfiguration::progressListeners).isPresent(); + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/util/ProgressUpdaterInvoker.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/util/ProgressUpdaterInvoker.java new file mode 100644 index 000000000000..0d47f66ed77b --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/util/ProgressUpdaterInvoker.java @@ -0,0 +1,28 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.util; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.internal.progress.listener.ProgressUpdater; + +@SdkInternalApi +public interface ProgressUpdaterInvoker { + void incrementBytesTransferred(long bytes); + + void resetBytes(); + + ProgressUpdater progressUpdater(); +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/util/RequestProgressUpdaterInvoker.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/util/RequestProgressUpdaterInvoker.java new file mode 100644 index 000000000000..5baae925a707 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/util/RequestProgressUpdaterInvoker.java @@ -0,0 +1,43 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.util; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.internal.progress.listener.ProgressUpdater; + +@SdkInternalApi +public class RequestProgressUpdaterInvoker implements ProgressUpdaterInvoker { + private final ProgressUpdater progressUpdater; + + public RequestProgressUpdaterInvoker(ProgressUpdater progressUpdater) { + this.progressUpdater = progressUpdater; + } + + @Override + public void incrementBytesTransferred(long bytes) { + progressUpdater.incrementBytesSent(bytes); + } + + @Override + public void resetBytes() { + progressUpdater.resetBytesSent(); + } + + @Override + public ProgressUpdater progressUpdater() { + return progressUpdater; + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/util/ResponseProgressUpdaterInvoker.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/util/ResponseProgressUpdaterInvoker.java new file mode 100644 index 000000000000..5b40786dfddd --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/util/ResponseProgressUpdaterInvoker.java @@ -0,0 +1,43 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.util; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.internal.progress.listener.ProgressUpdater; + +@SdkInternalApi +public class ResponseProgressUpdaterInvoker implements ProgressUpdaterInvoker { + private final ProgressUpdater deafultProgressUpdater; + + public ResponseProgressUpdaterInvoker(ProgressUpdater deafultProgressUpdater) { + this.deafultProgressUpdater = deafultProgressUpdater; + } + + @Override + public void incrementBytesTransferred(long bytes) { + deafultProgressUpdater.incrementBytesReceived(bytes); + } + + @Override + public void resetBytes() { + deafultProgressUpdater.resetBytesReceived(); + } + + @Override + public ProgressUpdater progressUpdater() { + return deafultProgressUpdater; + } +} \ No newline at end of file diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/progress/listener/ProgressListener.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/progress/listener/ProgressListener.java index 4a2e2a4df24d..a9ef5ceda092 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/progress/listener/ProgressListener.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/progress/listener/ProgressListener.java @@ -15,6 +15,7 @@ package software.amazon.awssdk.core.progress.listener; +import java.util.Optional; import software.amazon.awssdk.annotations.Immutable; import software.amazon.awssdk.annotations.SdkPreviewApi; import software.amazon.awssdk.annotations.SdkProtectedApi; @@ -189,8 +190,7 @@ default void responseBytesReceived(Context.ResponseBytesReceived context) { /** * For Expect: 100-continue embedded requests, the service returning anything other than 100 continue * indicates a request failure. This method captures the error in the payload - * After this, either executionFailure or requestHeaderSent will always be invoked depending on - * whether the error type is retryable or not + * After this it will either be an executionFailure or a retry the request. *

* Available context attributes: *

    @@ -423,7 +423,7 @@ public interface ExecutionSuccess extends ResponseBytesReceived { /** * The successful completion of a request submitted to the Sdk */ - SdkResponse response(); + Optional response(); } /** diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/client/handler/AsyncClientHandlerTransformerVerificationTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/client/handler/AsyncClientHandlerTransformerVerificationTest.java index ce7d7397f10b..249a454a3838 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/client/handler/AsyncClientHandlerTransformerVerificationTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/client/handler/AsyncClientHandlerTransformerVerificationTest.java @@ -46,7 +46,6 @@ import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkServiceException; import software.amazon.awssdk.core.http.HttpResponseHandler; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.protocol.VoidSdkResponse; import software.amazon.awssdk.core.retry.RetryPolicy; import software.amazon.awssdk.core.runtime.transform.Marshaller; diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/async/SimpleRequestProviderTckTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/async/SimpleRequestProviderTckTest.java index aad8819b0860..de19d7bdc983 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/async/SimpleRequestProviderTckTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/async/SimpleRequestProviderTckTest.java @@ -18,6 +18,7 @@ import java.io.ByteArrayInputStream; import java.net.URI; import java.nio.ByteBuffer; +import java.util.Optional; import org.reactivestreams.Publisher; import org.reactivestreams.tck.PublisherVerification; import org.reactivestreams.tck.TestEnvironment; diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AfterExecutionProgressReportingStageTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AfterExecutionProgressReportingStageTest.java new file mode 100644 index 000000000000..6a879352b789 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AfterExecutionProgressReportingStageTest.java @@ -0,0 +1,60 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import static software.amazon.awssdk.core.internal.util.ProgressListenerTestUtils.createSdkHttpRequest; +import static software.amazon.awssdk.core.internal.util.ProgressListenerTestUtils.createSdkResponseBuilder; +import static software.amazon.awssdk.core.internal.util.ProgressListenerTestUtils.progressListenerContext; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SdkRequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.progress.listener.DefaultProgressUpdater; +import software.amazon.awssdk.core.progress.listener.ProgressListener; + +class AfterExecutionProgressReportingStageTest { + + @Test + void afterExecutionProgressListener_calledFrom_ExecutionPipeline() throws Exception { + ProgressListener progressListener = Mockito.mock(ProgressListener.class); + + SdkRequestOverrideConfiguration config = SdkRequestOverrideConfiguration.builder() + .addProgressListener(progressListener) + .build(); + + SdkRequest request = createSdkHttpRequest(config).build(); + + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(request, null); + + RequestExecutionContext requestExecutionContext = progressListenerContext(false, request, + defaultProgressUpdater); + + SdkResponse response = createSdkResponseBuilder().build(); + + AfterExecutionProgressReportingStage afterExecutionUpdateProgressStage = new AfterExecutionProgressReportingStage(); + afterExecutionUpdateProgressStage.execute(response, requestExecutionContext); + + Mockito.verify(progressListener, Mockito.times(0)).requestPrepared(Mockito.any()); + Mockito.verify(progressListener, Mockito.times(0)).requestBytesSent(Mockito.any()); + Mockito.verify(progressListener, Mockito.times(0)).responseHeaderReceived(Mockito.any()); + Mockito.verify(progressListener, Mockito.times(0)).responseBytesReceived(Mockito.any()); + Mockito.verify(progressListener, Mockito.times(1)).executionSuccess(Mockito.any()); + Mockito.verify(progressListener, Mockito.times(0)).executionFailure(Mockito.any()); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/BeforeExecutionProgressReportingStageTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/BeforeExecutionProgressReportingStageTest.java new file mode 100644 index 000000000000..e34a10c65e11 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/BeforeExecutionProgressReportingStageTest.java @@ -0,0 +1,61 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import static software.amazon.awssdk.core.internal.util.ProgressListenerTestUtils.createHttpRequestBuilder; +import static software.amazon.awssdk.core.internal.util.ProgressListenerTestUtils.createSdkHttpRequest; +import static software.amazon.awssdk.core.internal.util.ProgressListenerTestUtils.progressListenerContext; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SdkRequestOverrideConfiguration; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.progress.listener.DefaultProgressUpdater; +import software.amazon.awssdk.core.progress.listener.ProgressListener; +import software.amazon.awssdk.http.SdkHttpFullRequest; + +class BeforeExecutionProgressReportingStageTest { + @Test + void beforeExecutionProgressListener_calledFrom_ExecutionPipeline() throws Exception { + ProgressListener progressListener = Mockito.mock(ProgressListener.class); + + SdkRequestOverrideConfiguration config = SdkRequestOverrideConfiguration.builder() + .addProgressListener(progressListener) + .build(); + + SdkHttpFullRequest requestBuilder = createHttpRequestBuilder().build(); + + SdkRequest request = createSdkHttpRequest(config).build(); + + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(request, null); + + RequestExecutionContext requestExecutionContext = progressListenerContext(false, request, + defaultProgressUpdater); + + BeforeExecutionProgressReportingStage beforeExecutionUpdateProgressStage = new BeforeExecutionProgressReportingStage(); + beforeExecutionUpdateProgressStage.execute(requestBuilder, requestExecutionContext); + + Mockito.verify(progressListener, Mockito.times(1)).requestPrepared(Mockito.any()); + Mockito.verify(progressListener, Mockito.times(0)).requestBytesSent(Mockito.any()); + Mockito.verify(progressListener, Mockito.times(0)).responseHeaderReceived(Mockito.any()); + Mockito.verify(progressListener, Mockito.times(0)).responseBytesReceived(Mockito.any()); + Mockito.verify(progressListener, Mockito.times(0)).executionSuccess(Mockito.any()); + Mockito.verify(progressListener, Mockito.times(0)).executionFailure(Mockito.any()); + Mockito.verify(progressListener, Mockito.times(0)).attemptFailure(Mockito.any()); + + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ExecutionFailureExceptionReportingStageTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ExecutionFailureExceptionReportingStageTest.java new file mode 100644 index 000000000000..32469ad22fe4 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ExecutionFailureExceptionReportingStageTest.java @@ -0,0 +1,97 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertThrows; +import static software.amazon.awssdk.core.internal.util.ProgressListenerTestUtils.createSdkHttpRequest; +import static software.amazon.awssdk.core.internal.util.ProgressListenerTestUtils.progressListenerContext; + +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import software.amazon.awssdk.core.Response; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SdkRequestOverrideConfiguration; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.http.pipeline.RequestPipeline; +import software.amazon.awssdk.core.internal.progress.listener.DefaultProgressUpdater; +import software.amazon.awssdk.core.progress.listener.ProgressListener; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import utils.ValidSdkObjects; + +class ExecutionFailureExceptionReportingStageTest { + + + @Test + void when_sync_executeThrowsException_attemptFailureInvoked() throws Exception { + + RequestPipeline> requestPipeline = Mockito.mock(RequestPipeline.class); + ProgressListener progressListener = Mockito.mock(ProgressListener.class); + + SdkRequestOverrideConfiguration config = SdkRequestOverrideConfiguration.builder() + .addProgressListener(progressListener) + .build(); + + SdkRequest request = createSdkHttpRequest(config).build(); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(request, null); + RequestExecutionContext requestExecutionContext = progressListenerContext(false, request, + defaultProgressUpdater); + + ExecutionFailureExceptionReportingStage executionFailureExceptionReportingStage = new ExecutionFailureExceptionReportingStage(requestPipeline); + when(requestPipeline.execute(any(), any())).thenThrow(new RuntimeException()); + assertThrows(RuntimeException.class, () -> executionFailureExceptionReportingStage.execute(ValidSdkObjects.sdkHttpFullRequest().build(), requestExecutionContext)); + + Mockito.verify(progressListener, Mockito.times(0)).requestPrepared(any()); + Mockito.verify(progressListener, Mockito.times(0)).requestBytesSent(any()); + Mockito.verify(progressListener, Mockito.times(0)).responseHeaderReceived(any()); + Mockito.verify(progressListener, Mockito.times(0)).responseBytesReceived(any()); + Mockito.verify(progressListener, Mockito.times(0)).executionSuccess(any()); + Mockito.verify(progressListener, Mockito.times(1)).attemptFailure(any()); + } + + @Test + void when_async_executeThrowsException_attemptFailureInvoked() throws Exception { + + RequestPipeline requestPipeline = Mockito.mock(RequestPipeline.class); + ProgressListener progressListener = Mockito.mock(ProgressListener.class); + CompletableFuture future = new CompletableFuture<>(); + + SdkRequestOverrideConfiguration config = SdkRequestOverrideConfiguration.builder() + .addProgressListener(progressListener) + .build(); + + SdkRequest request = createSdkHttpRequest(config).build(); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(request, null); + RequestExecutionContext requestExecutionContext = progressListenerContext(false, request, + defaultProgressUpdater); + + AsyncExecutionFailureExceptionReportingStage executionFailureExceptionReportingStage = new AsyncExecutionFailureExceptionReportingStage(requestPipeline); + when(requestPipeline.execute(any(), any())).thenReturn(future); + future.completeExceptionally(new RuntimeException()); + + executionFailureExceptionReportingStage.execute(ValidSdkObjects.sdkHttpFullRequest().build(), requestExecutionContext); + + Mockito.verify(progressListener, Mockito.times(0)).requestPrepared(any()); + Mockito.verify(progressListener, Mockito.times(0)).requestBytesSent(any()); + Mockito.verify(progressListener, Mockito.times(0)).responseHeaderReceived(any()); + Mockito.verify(progressListener, Mockito.times(0)).responseBytesReceived(any()); + Mockito.verify(progressListener, Mockito.times(0)).executionSuccess(any()); + Mockito.verify(progressListener, Mockito.times(1)).attemptFailure(any()); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingInputStreamTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingInputStreamTest.java index 213c78e96152..c73c9761cefa 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingInputStreamTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingInputStreamTest.java @@ -29,6 +29,15 @@ import java.util.concurrent.atomic.AtomicLong; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SdkRequestOverrideConfiguration; +import software.amazon.awssdk.core.http.NoopTestRequest; +import software.amazon.awssdk.core.internal.progress.listener.DefaultProgressUpdater; +import software.amazon.awssdk.core.internal.util.ResponseProgressUpdaterInvoker; +import software.amazon.awssdk.core.internal.util.ProgressUpdaterInvoker; +import software.amazon.awssdk.core.internal.util.RequestProgressUpdaterInvoker; +import software.amazon.awssdk.core.progress.listener.ProgressListener; import software.amazon.awssdk.http.Abortable; import software.amazon.awssdk.http.AbortableInputStream; @@ -150,21 +159,73 @@ public void readByteArrayRange_returnsPositive_updatesTotal() throws IOException assertThat(trackingInputStream.bytesRead()).isEqualTo(4); } + @Test + void readByteArrayRange_withProgressListener_invokesResponseBytesReceived() throws IOException { + + ProgressListener progressListener = mock(ProgressListener.class); + + SdkRequestOverrideConfiguration config = SdkRequestOverrideConfiguration.builder() + .addProgressListener(progressListener) + .build(); + + SdkRequest request = NoopTestRequest.builder() + .overrideConfiguration(config) + .build(); + + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(request, null); + + when(mockStream.read(any(byte[].class), eq(2), eq(2))).thenReturn(2); + + BytesReadTrackingInputStream trackingInputStream = newTrackingStreamWithProgressUpdater(new AtomicLong(0L), + new ResponseProgressUpdaterInvoker(defaultProgressUpdater)); + trackingInputStream.read(new byte[8], 2, 2); + + verify(progressListener, Mockito.times(1)).responseBytesReceived(any()); + } + + @Test + void writeByteArrayRange_withProgressListener_invokesRequestBytesSent() throws IOException { + + ProgressListener progressListener = mock(ProgressListener.class); + + SdkRequestOverrideConfiguration config = SdkRequestOverrideConfiguration.builder() + .addProgressListener(progressListener) + .build(); + + SdkRequest request = NoopTestRequest.builder() + .overrideConfiguration(config) + .build(); + + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(request, null); + + when(mockStream.read(any(byte[].class), eq(2), eq(2))).thenReturn(2); + + BytesReadTrackingInputStream trackingInputStream = newTrackingStreamWithProgressUpdater(new AtomicLong(0L), + new RequestProgressUpdaterInvoker(defaultProgressUpdater)); + trackingInputStream.read(new byte[8], 2, 2); + + verify(progressListener, Mockito.times(1)).requestBytesSent(any()); + } + @Test public void abort_abortsDelegate() { Abortable mockAbortable = mock(Abortable.class); AbortableInputStream abortableIs = AbortableInputStream.create(mockStream, mockAbortable); - BytesReadTrackingInputStream trackingInputStream = new BytesReadTrackingInputStream(abortableIs, new AtomicLong(0)); + BytesReadTrackingInputStream trackingInputStream = new BytesReadTrackingInputStream(abortableIs, new AtomicLong(0), null); trackingInputStream.abort(); verify(mockAbortable).abort(); } private BytesReadTrackingInputStream newTrackingStream(AtomicLong read) { - return new BytesReadTrackingInputStream(abortableStream, read); + return new BytesReadTrackingInputStream(abortableStream, read, null); } private BytesReadTrackingInputStream newTrackingStream() { return newTrackingStream(new AtomicLong(0)); } + + private BytesReadTrackingInputStream newTrackingStreamWithProgressUpdater(AtomicLong read, ProgressUpdaterInvoker progressUpdaterInvoker) { + return new BytesReadTrackingInputStream(abortableStream, read, progressUpdaterInvoker); + } } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingPublisherTckTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingPublisherTckTest.java index 2afffec30592..963502079714 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingPublisherTckTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingPublisherTckTest.java @@ -23,6 +23,8 @@ import org.reactivestreams.Publisher; import org.reactivestreams.tck.PublisherVerification; import org.reactivestreams.tck.TestEnvironment; +import software.amazon.awssdk.core.internal.progress.listener.NoOpProgressUpdater; +import software.amazon.awssdk.core.internal.util.RequestProgressUpdaterInvoker; /** * TCK verification class for {@link BytesReadTrackingPublisher}. @@ -34,7 +36,7 @@ public BytesReadTrackingPublisherTckTest() { @Override public Publisher createPublisher(long l) { - return new BytesReadTrackingPublisher(createUpstreamPublisher(l), new AtomicLong(0)); + return new BytesReadTrackingPublisher(createUpstreamPublisher(l), new AtomicLong(0), new RequestProgressUpdaterInvoker(new NoOpProgressUpdater())); } @Override diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingPublisherTest.java index 259b13d53329..68f37734f3db 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingPublisherTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesReadTrackingPublisherTest.java @@ -23,7 +23,16 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.Mockito; import org.reactivestreams.Publisher; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SdkRequestOverrideConfiguration; +import software.amazon.awssdk.core.http.NoopTestRequest; +import software.amazon.awssdk.core.internal.progress.listener.DefaultProgressUpdater; +import software.amazon.awssdk.core.internal.progress.listener.NoOpProgressUpdater; +import software.amazon.awssdk.core.internal.util.ResponseProgressUpdaterInvoker; +import software.amazon.awssdk.core.progress.listener.ProgressListener; /** * Functional tests for {@link BytesReadTrackingPublisher}. @@ -31,18 +40,19 @@ public class BytesReadTrackingPublisherTest { @Test - public void test_requestAll_calculatesCorrectTotal() { + public void requestAll_calculatesCorrectTotal() { long nElements = 1024; int elementSize = 4; Publisher upstreamPublisher = createUpstreamPublisher(nElements, elementSize); - BytesReadTrackingPublisher trackingPublisher = new BytesReadTrackingPublisher(upstreamPublisher, new AtomicLong(0)); + BytesReadTrackingPublisher trackingPublisher = new BytesReadTrackingPublisher(upstreamPublisher, new AtomicLong(0), + new ResponseProgressUpdaterInvoker(new NoOpProgressUpdater())); readFully(trackingPublisher); assertThat(trackingPublisher.bytesRead()).isEqualTo(nElements * elementSize); } @Test - public void test_requestAll_updatesInputCount() { + public void requestAll_updatesInputCount() { long nElements = 8; int elementSize = 2; @@ -50,7 +60,9 @@ public void test_requestAll_updatesInputCount() { AtomicLong bytesRead = new AtomicLong(baseBytesRead); Publisher upstreamPublisher = createUpstreamPublisher(nElements, elementSize); - BytesReadTrackingPublisher trackingPublisher = new BytesReadTrackingPublisher(upstreamPublisher, bytesRead); + BytesReadTrackingPublisher trackingPublisher = new BytesReadTrackingPublisher( + upstreamPublisher, bytesRead, new ResponseProgressUpdaterInvoker(new NoOpProgressUpdater())); + readFully(trackingPublisher); long expectedRead = baseBytesRead + nElements * elementSize; @@ -58,6 +70,34 @@ public void test_requestAll_updatesInputCount() { assertThat(trackingPublisher.bytesRead()).isEqualTo(expectedRead); } + @Test + void progressUpdater_invokes_incrementBytesReceived() { + int nElements = 8; + int elementSize = 2; + + long baseBytesRead = 1024; + AtomicLong bytesRead = new AtomicLong(baseBytesRead); + + ProgressListener progressListener = Mockito.mock(ProgressListener.class); + + SdkRequestOverrideConfiguration config = SdkRequestOverrideConfiguration.builder() + .addProgressListener(progressListener) + .build(); + + SdkRequest request = NoopTestRequest.builder() + .overrideConfiguration(config) + .build(); + + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(request, null); + + Publisher upstreamPublisher = createUpstreamPublisher(nElements, elementSize); + Publisher trackingPublisher = new BytesReadTrackingPublisher(upstreamPublisher, bytesRead, + new ResponseProgressUpdaterInvoker(defaultProgressUpdater)); + readFully(trackingPublisher); + + Mockito.verify(progressListener, Mockito.times(nElements)).responseBytesReceived(ArgumentMatchers.any()); + } + private Publisher createUpstreamPublisher(long elements, int elementSize) { return Flowable.fromIterable(Stream.generate(() -> ByteBuffer.wrap(new byte[elementSize])) .limit(elements) diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesSentTrackingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesSentTrackingPublisherTest.java new file mode 100644 index 000000000000..9d5500109250 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/metrics/BytesSentTrackingPublisherTest.java @@ -0,0 +1,93 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.metrics; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.reactivex.Flowable; +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SdkRequestOverrideConfiguration; +import software.amazon.awssdk.core.http.NoopTestRequest; +import software.amazon.awssdk.core.internal.progress.listener.DefaultProgressUpdater; +import software.amazon.awssdk.core.internal.util.RequestProgressUpdaterInvoker; +import software.amazon.awssdk.core.progress.listener.ProgressListener; + +public class BytesSentTrackingPublisherTest { + + @Test + public void validate_updatesBytesSent_invocation_tracksBytesSentAccurately() { + int nElements = 8; + int elementSize = 2; + + DefaultProgressUpdater defaultProgressUpdater = Mockito.mock(DefaultProgressUpdater.class); + + Publisher upstreamPublisher = createUpstreamPublisher(nElements, elementSize); + BytesReadTrackingPublisher trackingPublisher = new BytesReadTrackingPublisher(upstreamPublisher, new AtomicLong(0), + new RequestProgressUpdaterInvoker(defaultProgressUpdater)); + readFully(trackingPublisher); + + long expectedSent = nElements * elementSize; + + assertThat(trackingPublisher.bytesRead()).isEqualTo(expectedSent); + } + + @Test + public void progressUpdater_invokes_incrementBytesSent() { + int nElements = 8; + int elementSize = 2; + + ProgressListener progressListener = Mockito.mock(ProgressListener.class); + + SdkRequestOverrideConfiguration config = SdkRequestOverrideConfiguration.builder() + .addProgressListener(progressListener) + .build(); + + SdkRequest request = NoopTestRequest.builder() + .overrideConfiguration(config) + .build(); + + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(request, null); + + Publisher upstreamPublisher = createUpstreamPublisher(nElements, elementSize); + BytesReadTrackingPublisher trackingPublisher = new BytesReadTrackingPublisher(upstreamPublisher, new AtomicLong(0L), + new RequestProgressUpdaterInvoker(defaultProgressUpdater)); + readFully(trackingPublisher); + + long expectedSent = nElements * elementSize; + + assertThat(trackingPublisher.bytesRead()).isEqualTo(expectedSent); + Mockito.verify(progressListener, Mockito.times(nElements)).requestBytesSent(ArgumentMatchers.any()); + } + + private Publisher createUpstreamPublisher(long elements, int elementSize) { + return Flowable.fromIterable(Stream.generate(() -> ByteBuffer.wrap(new byte[elementSize])) + .limit(elements) + .collect(Collectors.toList())); + } + + private void readFully(Publisher publisher) { + Flowable.fromPublisher(publisher).toList().blockingGet(); + } +} + diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/progress/listener/ProgressUpdaterTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/progress/listener/ProgressUpdaterTest.java index 0cb887561449..86ddfcbc7ad6 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/progress/listener/ProgressUpdaterTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/progress/listener/ProgressUpdaterTest.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; +import java.net.URI; import java.util.Arrays; import java.util.stream.Stream; import org.junit.jupiter.api.Assertions; @@ -36,8 +37,10 @@ import software.amazon.awssdk.core.http.NoopTestRequest; import software.amazon.awssdk.core.progress.listener.ProgressListener; import software.amazon.awssdk.core.protocol.VoidSdkResponse; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; -public class ProgressUpdaterTest { +class ProgressUpdaterTest { private CaptureProgressListener captureProgressListener; private static final long BYTES_TRANSFERRED = 5L; private static final Throwable attemptFailure = new Throwable("AttemptFailureException"); @@ -60,7 +63,7 @@ private static Stream contentLength() { } @Test - public void requestPrepared_transferredBytes_equals_zero() { + void requestPrepared_transferredBytes_equals_zero() { CaptureProgressListener mockListener = Mockito.mock(CaptureProgressListener.class); @@ -73,10 +76,10 @@ public void requestPrepared_transferredBytes_equals_zero() { .overrideConfiguration(overrideConfig) .build(); - ProgressUpdater progressUpdater = new ProgressUpdater(sdkRequest, null); - progressUpdater.requestPrepared(); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(sdkRequest, null); + defaultProgressUpdater.requestPrepared(createHttpRequest()); - assertEquals(0.0, progressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); + assertEquals(0.0, defaultProgressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); assertTrue(captureProgressListener.requestPrepared()); assertFalse(captureProgressListener.requestHeaderSent()); assertFalse(captureProgressListener.responseHeaderReceived()); @@ -89,7 +92,7 @@ public void requestPrepared_transferredBytes_equals_zero() { } @Test - public void requestHeaderSent_transferredBytes_equals_zero() { + void requestHeaderSent_transferredBytes_equals_zero() { CaptureProgressListener mockListener = Mockito.mock(CaptureProgressListener.class); @@ -102,10 +105,10 @@ public void requestHeaderSent_transferredBytes_equals_zero() { .overrideConfiguration(overrideConfig) .build(); - ProgressUpdater progressUpdater = new ProgressUpdater(sdkRequest, null); - progressUpdater.requestHeaderSent(); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(sdkRequest, null); + defaultProgressUpdater.requestHeaderSent(); - assertEquals(0.0, progressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); + assertEquals(0.0, defaultProgressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); assertFalse(captureProgressListener.requestPrepared()); assertTrue(captureProgressListener.requestHeaderSent()); assertFalse(captureProgressListener.responseHeaderReceived()); @@ -117,7 +120,7 @@ public void requestHeaderSent_transferredBytes_equals_zero() { } @Test - public void requestBytesSent_transferredBytes() { + void requestBytesSent_transferredBytes() { CaptureProgressListener mockListener = Mockito.mock(CaptureProgressListener.class); @@ -130,13 +133,13 @@ public void requestBytesSent_transferredBytes() { .overrideConfiguration(overrideConfig) .build(); - ProgressUpdater progressUpdater = new ProgressUpdater(sdkRequest, null); - progressUpdater.incrementBytesSent(BYTES_TRANSFERRED); - assertEquals(BYTES_TRANSFERRED, progressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(sdkRequest, null); + defaultProgressUpdater.incrementBytesSent(BYTES_TRANSFERRED); + assertEquals(BYTES_TRANSFERRED, defaultProgressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); - progressUpdater.incrementBytesSent(BYTES_TRANSFERRED); + defaultProgressUpdater.incrementBytesSent(BYTES_TRANSFERRED); assertEquals(BYTES_TRANSFERRED + BYTES_TRANSFERRED, - progressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); + defaultProgressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); Mockito.verify(mockListener, never()).executionFailure(ArgumentMatchers.any(ProgressListener.Context.ExecutionFailure.class)); Mockito.verify(mockListener, never()).attemptFailure(ArgumentMatchers.any(ProgressListener.Context.ExecutionFailure.class)); @@ -146,7 +149,7 @@ public void requestBytesSent_transferredBytes() { } @Test - public void validate_resetBytesSent() { + void validate_resetBytesSent() { CaptureProgressListener mockListener = Mockito.mock(CaptureProgressListener.class); @@ -159,17 +162,17 @@ public void validate_resetBytesSent() { .overrideConfiguration(overrideConfig) .build(); - ProgressUpdater progressUpdater = new ProgressUpdater(sdkRequest, null); - progressUpdater.incrementBytesSent(BYTES_TRANSFERRED); - assertEquals(BYTES_TRANSFERRED, progressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(sdkRequest, null); + defaultProgressUpdater.incrementBytesSent(BYTES_TRANSFERRED); + assertEquals(BYTES_TRANSFERRED, defaultProgressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); - progressUpdater.resetBytesSent(); - assertEquals(0, progressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); + defaultProgressUpdater.resetBytesSent(); + assertEquals(0, defaultProgressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); } @Test - public void validate_resetBytesReceived() { + void validate_resetBytesReceived() { CaptureProgressListener mockListener = Mockito.mock(CaptureProgressListener.class); @@ -182,18 +185,18 @@ public void validate_resetBytesReceived() { .overrideConfiguration(overrideConfig) .build(); - ProgressUpdater progressUpdater = new ProgressUpdater(sdkRequest, null); - progressUpdater.incrementBytesReceived(BYTES_TRANSFERRED); - assertEquals(BYTES_TRANSFERRED, progressUpdater.responseBodyProgress().progressSnapshot().transferredBytes(), 0.0); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(sdkRequest, null); + defaultProgressUpdater.incrementBytesReceived(BYTES_TRANSFERRED); + assertEquals(BYTES_TRANSFERRED, defaultProgressUpdater.responseBodyProgress().progressSnapshot().transferredBytes(), 0.0); - progressUpdater.resetBytesReceived(); - assertEquals(0, progressUpdater.responseBodyProgress().progressSnapshot().transferredBytes(), 0.0); + defaultProgressUpdater.resetBytesReceived(); + assertEquals(0, defaultProgressUpdater.responseBodyProgress().progressSnapshot().transferredBytes(), 0.0); } @ParameterizedTest @MethodSource("contentLength") - public void ratioTransferred_upload_transferredBytes(long contentLength) { + void ratioTransferred_upload_transferredBytes(long contentLength) { CaptureProgressListener mockListener = Mockito.mock(CaptureProgressListener.class); @@ -206,15 +209,39 @@ public void ratioTransferred_upload_transferredBytes(long contentLength) { .overrideConfiguration(overrideConfig) .build(); - ProgressUpdater progressUpdater = new ProgressUpdater(sdkRequest, contentLength); - progressUpdater.incrementBytesSent(BYTES_TRANSFERRED); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(sdkRequest, null); + defaultProgressUpdater.updateRequestContentLength(contentLength); + defaultProgressUpdater.incrementBytesSent(BYTES_TRANSFERRED); assertEquals((double) BYTES_TRANSFERRED / contentLength, - progressUpdater.requestBodyProgress().progressSnapshot().ratioTransferred().getAsDouble(), 0.0); + defaultProgressUpdater.requestBodyProgress().progressSnapshot().ratioTransferred().getAsDouble(), 0.0); + + } + + @ParameterizedTest + @MethodSource("contentLength") + void ratioTransferred_download_transferredBytes(long contentLength) { + + CaptureProgressListener mockListener = Mockito.mock(CaptureProgressListener.class); + + SdkRequestOverrideConfiguration.Builder builder = SdkRequestOverrideConfiguration.builder(); + builder.progressListeners(Arrays.asList(mockListener, captureProgressListener)); + + SdkRequestOverrideConfiguration overrideConfig = builder.build(); + + SdkRequest sdkRequest = NoopTestRequest.builder() + .overrideConfiguration(overrideConfig) + .build(); + + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(sdkRequest, null); + defaultProgressUpdater.updateResponseContentLength(contentLength); + defaultProgressUpdater.incrementBytesReceived(BYTES_TRANSFERRED); + assertEquals((double) BYTES_TRANSFERRED / contentLength, + defaultProgressUpdater.responseBodyProgress().progressSnapshot().ratioTransferred().getAsDouble(), 0.0); } @Test - public void responseHeaderReceived_transferredBytes_equals_zero() { + void responseHeaderReceived_transferredBytes_equals_zero() { CaptureProgressListener mockListener = Mockito.mock(CaptureProgressListener.class); @@ -227,10 +254,10 @@ public void responseHeaderReceived_transferredBytes_equals_zero() { .overrideConfiguration(overrideConfig) .build(); - ProgressUpdater progressUpdater = new ProgressUpdater(sdkRequest, null); - progressUpdater.responseHeaderReceived(); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(sdkRequest, null); + defaultProgressUpdater.responseHeaderReceived(); - assertEquals(0.0, progressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); + assertEquals(0.0, defaultProgressUpdater.requestBodyProgress().progressSnapshot().transferredBytes(), 0.0); assertFalse(captureProgressListener.requestPrepared()); assertFalse(captureProgressListener.requestHeaderSent()); assertTrue(captureProgressListener.responseHeaderReceived()); @@ -242,7 +269,7 @@ public void responseHeaderReceived_transferredBytes_equals_zero() { } @Test - public void executionSuccess_transferredBytes_valid() { + void executionSuccess_transferredBytes_valid() { CaptureProgressListener mockListener = Mockito.mock(CaptureProgressListener.class); @@ -255,15 +282,15 @@ public void executionSuccess_transferredBytes_valid() { .overrideConfiguration(overrideConfig) .build(); - ProgressUpdater progressUpdater = new ProgressUpdater(sdkRequest, null); - progressUpdater.incrementBytesReceived(BYTES_TRANSFERRED); - assertEquals(BYTES_TRANSFERRED, progressUpdater.responseBodyProgress().progressSnapshot().transferredBytes(), 0.0); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(sdkRequest, null); + defaultProgressUpdater.incrementBytesReceived(BYTES_TRANSFERRED); + assertEquals(BYTES_TRANSFERRED, defaultProgressUpdater.responseBodyProgress().progressSnapshot().transferredBytes(), 0.0); - progressUpdater.incrementBytesReceived(BYTES_TRANSFERRED); + defaultProgressUpdater.incrementBytesReceived(BYTES_TRANSFERRED); assertEquals(BYTES_TRANSFERRED + BYTES_TRANSFERRED, - progressUpdater.responseBodyProgress().progressSnapshot().transferredBytes(), 0.0); + defaultProgressUpdater.responseBodyProgress().progressSnapshot().transferredBytes(), 0.0); - progressUpdater.executionSuccess(VoidSdkResponse.builder().sdkHttpResponse(null).build()); + defaultProgressUpdater.executionSuccess(VoidSdkResponse.builder().sdkHttpResponse(null).build()); Mockito.verify(mockListener, never()).executionFailure(ArgumentMatchers.any(ProgressListener.Context.ExecutionFailure.class)); Mockito.verify(mockListener, never()).attemptFailure(ArgumentMatchers.any(ProgressListener.Context.ExecutionFailure.class)); Mockito.verify(mockListener, never()).attemptFailureResponseBytesReceived(ArgumentMatchers.any(ProgressListener.Context.ExecutionFailure.class)); @@ -272,7 +299,7 @@ public void executionSuccess_transferredBytes_valid() { } @Test - public void attemptFailureResponseBytesReceived() { + void attemptFailureResponseBytesReceived() { CaptureProgressListener mockListener = Mockito.mock(CaptureProgressListener.class); @@ -285,10 +312,10 @@ public void attemptFailureResponseBytesReceived() { .overrideConfiguration(overrideConfig) .build(); - ProgressUpdater progressUpdater = new ProgressUpdater(sdkRequest, null); - progressUpdater.requestPrepared(); - progressUpdater.responseHeaderReceived(); - progressUpdater.attemptFailureResponseBytesReceived(attemptFailureResponseBytesReceived); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(sdkRequest, null); + defaultProgressUpdater.requestPrepared(createHttpRequest()); + defaultProgressUpdater.responseHeaderReceived(); + defaultProgressUpdater.attemptFailureResponseBytesReceived(attemptFailureResponseBytesReceived); Mockito.verify(mockListener, times(1)).requestPrepared(ArgumentMatchers.any(ProgressListener.Context.RequestPrepared.class)); Mockito.verify(mockListener, times(1)).responseHeaderReceived(ArgumentMatchers.any(ProgressListener.Context.ResponseHeaderReceived.class)); @@ -300,7 +327,7 @@ public void attemptFailureResponseBytesReceived() { } @Test - public void attemptFailure() { + void attemptFailure() { CaptureProgressListener mockListener = Mockito.mock(CaptureProgressListener.class); @@ -313,9 +340,9 @@ public void attemptFailure() { .overrideConfiguration(overrideConfig) .build(); - ProgressUpdater progressUpdater = new ProgressUpdater(sdkRequest, null); - progressUpdater.requestPrepared(); - progressUpdater.attemptFailure(attemptFailure); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(sdkRequest, null); + defaultProgressUpdater.requestPrepared(createHttpRequest()); + defaultProgressUpdater.attemptFailure(attemptFailure); Mockito.verify(mockListener, times(1)).requestPrepared(ArgumentMatchers.any(ProgressListener.Context.RequestPrepared.class)); Mockito.verify(mockListener, times(0)).responseHeaderReceived(ArgumentMatchers.any(ProgressListener.Context.ResponseHeaderReceived.class)); @@ -329,7 +356,7 @@ public void attemptFailure() { } @Test - public void executionFailure() { + void executionFailure() { CaptureProgressListener mockListener = Mockito.mock(CaptureProgressListener.class); @@ -342,9 +369,9 @@ public void executionFailure() { .overrideConfiguration(overrideConfig) .build(); - ProgressUpdater progressUpdater = new ProgressUpdater(sdkRequest, null); - progressUpdater.requestPrepared(); - progressUpdater.executionFailure(executionFailure); + DefaultProgressUpdater defaultProgressUpdater = new DefaultProgressUpdater(sdkRequest, null); + defaultProgressUpdater.requestPrepared(createHttpRequest()); + defaultProgressUpdater.executionFailure(executionFailure); Mockito.verify(mockListener, times(1)).requestPrepared(ArgumentMatchers.any(ProgressListener.Context.RequestPrepared.class)); @@ -357,4 +384,10 @@ public void executionFailure() { Assertions.assertEquals(captureProgressListener.exceptionCaught().getMessage(), executionFailure.getMessage()); } + + private SdkHttpFullRequest createHttpRequest() { + return SdkHttpFullRequest.builder().uri(URI.create("https://endpoint.host")) + .method(SdkHttpMethod.GET) + .build(); + } } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/progress/snapshot/DefaultProgressSnapshotTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/progress/snapshot/DefaultProgressSnapshotTest.java index b2d125f656d1..2d4404556f6f 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/progress/snapshot/DefaultProgressSnapshotTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/progress/snapshot/DefaultProgressSnapshotTest.java @@ -15,10 +15,8 @@ package software.amazon.awssdk.core.internal.progress.snapshot; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import java.time.Duration; import java.time.Instant; @@ -29,9 +27,8 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import software.amazon.awssdk.core.internal.progress.snapshot.DefaultProgressSnapshot; -public class DefaultProgressSnapshotTest { +class DefaultProgressSnapshotTest { private static Stream getArgumentsForInvalidParameterValidationTests() { return Stream.of(Arguments.of("transferredBytes (2) must not be greater than totalBytes (1)", @@ -112,7 +109,7 @@ private static Stream getArgumentsForBytesTest() { @ParameterizedTest @MethodSource("getArgumentsForInvalidParameterValidationTests") - public void test_invalid_arguments_shouldThrow(String expectedErrorMsg, DefaultProgressSnapshot.Builder builder, + void test_invalid_arguments_shouldThrow(String expectedErrorMsg, DefaultProgressSnapshot.Builder builder, Exception e) { assertThatThrownBy(builder::build) .isInstanceOf(e.getClass()) @@ -121,12 +118,12 @@ public void test_invalid_arguments_shouldThrow(String expectedErrorMsg, DefaultP @ParameterizedTest @MethodSource("getArgumentsForMissingParameterValidationTests") - public void test_missing_params_shouldReturnEmpty(boolean condition) { + void test_missing_params_shouldReturnEmpty(boolean condition) { Assertions.assertFalse(condition); } @Test - public void ratioTransferred() { + void ratioTransferred() { DefaultProgressSnapshot snapshot = DefaultProgressSnapshot.builder() .transferredBytes(1L) .totalBytes(5L) @@ -137,18 +134,18 @@ public void ratioTransferred() { @ParameterizedTest @MethodSource("getArgumentsForBytesTest") - public void test_estimatedBytesRemaining_and_totalBytes(long expectedBytes, long actualBytes) { + void test_estimatedBytesRemaining_and_totalBytes(long expectedBytes, long actualBytes) { Assertions.assertEquals(expectedBytes, actualBytes); } @ParameterizedTest @MethodSource("getArgumentsForTimeTest") - public void test_elapsedTime_and_estimatedTimeRemaining(long expected, long timeInMillis, long delta) { + void test_elapsedTime_and_estimatedTimeRemaining(long expected, long timeInMillis, long delta) { Assertions.assertEquals(expected, timeInMillis, delta); } @Test - public void averageBytesPer() { + void averageBytesPer() { DefaultProgressSnapshot snapshot = DefaultProgressSnapshot.builder() .transferredBytes(100L) .startTime(Instant.now().minusMillis(100)) diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/util/ProgressListenerTestUtils.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/util/ProgressListenerTestUtils.java new file mode 100644 index 000000000000..59d693c742c5 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/util/ProgressListenerTestUtils.java @@ -0,0 +1,78 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.util; + +import java.net.URI; +import java.util.ArrayList; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SdkRequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.http.ExecutionContext; +import software.amazon.awssdk.core.http.NoopTestRequest; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptorChain; +import software.amazon.awssdk.core.interceptor.InterceptorContext; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.progress.listener.DefaultProgressUpdater; +import software.amazon.awssdk.core.protocol.VoidSdkResponse; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; + +public final class ProgressListenerTestUtils { + + public static final AsyncRequestBody ASYNC_REQUEST_BODY = AsyncRequestBody.fromString("TestBody"); + public static final RequestBody REQUEST_BODY = RequestBody.fromString("TestBody"); + + private ProgressListenerTestUtils() { + } + + public static SdkResponse.Builder createSdkResponseBuilder() { + return VoidSdkResponse.builder(); + } + + public static SdkRequest.Builder createSdkHttpRequest(SdkRequestOverrideConfiguration config) { + return NoopTestRequest.builder() + .overrideConfiguration(config); + } + + public static RequestExecutionContext progressListenerContext(boolean isAsyncStreaming, SdkRequest sdkRequest, + DefaultProgressUpdater defaultProgressUpdater) { + + RequestExecutionContext.Builder builder = + RequestExecutionContext.builder(). + executionContext(ExecutionContext.builder() + .interceptorContext(InterceptorContext.builder() + .request(sdkRequest) + .build()) + .interceptorChain(new ExecutionInterceptorChain(new ArrayList<>())) + .build()) + .originalRequest(sdkRequest); + if (isAsyncStreaming) { + builder.requestProvider(ASYNC_REQUEST_BODY); + } + + RequestExecutionContext context = builder.build(); + context.progressUpdater(defaultProgressUpdater); + return context; + } + + public static SdkHttpFullRequest.Builder createHttpRequestBuilder() { + return SdkHttpFullRequest.builder().uri(URI.create("https://endpoint.host")) + .method(SdkHttpMethod.GET) + .contentStreamProvider(REQUEST_BODY.contentStreamProvider()); + } +}