diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/AdditionalDataSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/AdditionalDataSubscriber.java new file mode 100644 index 000000000000..41cca81a9a18 --- /dev/null +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/AdditionalDataSubscriber.java @@ -0,0 +1,142 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; + +/** + * Allows additional data to be sent before invoking onComplete on the downstream subscriber. + */ +@SdkProtectedApi +public class AdditionalDataSubscriber extends DelegatingSubscriber { + private static final Logger log = Logger.loggerFor(AdditionalDataSubscriber.class); + + /** + * The subscription to the upstream subscriber. + */ + private Subscription upstreamSubscription; + + /** + * The amount of unfulfilled demand the downstream subscriber has opened against us. + */ + private final AtomicLong downstreamDemand = new AtomicLong(0); + + /** + * Whether the upstream subscriber has called onComplete on us. + */ + private volatile boolean onCompleteCalledByUpstream = false; + + /** + * Whether the upstream subscriber has called onError on us. + */ + private volatile boolean onErrorCalledByUpstream = false; + + /** + * Whether we have called onComplete on the downstream subscriber. + */ + private AtomicBoolean onCompleteCalledOnDownstream = new AtomicBoolean(false); + + private final Supplier additionalDataSupplier; + + public AdditionalDataSubscriber(Subscriber subscriber, Supplier additionalDataSupplier) { + super(Validate.paramNotNull(subscriber, "subscriber")); + this.additionalDataSupplier = Validate.paramNotNull(additionalDataSupplier, "additionalDataSupplier"); + } + + @Override + public void onSubscribe(Subscription subscription) { + + if (upstreamSubscription != null) { + log.warn(() -> "Received duplicate subscription, cancelling the duplicate.", new IllegalStateException()); + subscription.cancel(); + return; + } + + upstreamSubscription = subscription; + + subscriber.onSubscribe(new Subscription() { + + @Override + public void request(long l) { + if (onErrorCalledByUpstream) { + return; + } + + if (onCompleteCalledByUpstream) { + sendAdditionalDataAndComplete(); + return; + } + + addDownstreamDemand(l); + upstreamSubscription.request(l); + } + + @Override + public void cancel() { + upstreamSubscription.cancel(); + } + }); + } + + @Override + public void onError(Throwable throwable) { + onErrorCalledByUpstream = true; + subscriber.onError(throwable); + } + + @Override + public void onNext(T t) { + Validate.paramNotNull(t, "item"); + downstreamDemand.decrementAndGet(); + subscriber.onNext(t); + } + + @Override + public void onComplete() { + onCompleteCalledByUpstream = true; + if (downstreamDemand.get() > 0) { + sendAdditionalDataAndComplete(); + } + } + + private void addDownstreamDemand(long l) { + + if (l > 0) { + downstreamDemand.getAndUpdate(current -> { + long newValue = current + l; + return newValue >= 0 ? newValue : Long.MAX_VALUE; + }); + } else { + upstreamSubscription.cancel(); + onError(new IllegalArgumentException("Demand must not be negative")); + } + } + + private void sendAdditionalDataAndComplete() { + if (onCompleteCalledOnDownstream.compareAndSet(false, true)) { + T additionalData = additionalDataSupplier.get(); + subscriber.onNext(additionalData); + subscriber.onComplete(); + } + } +} diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/AdditionalDataSubscriberTckTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/AdditionalDataSubscriberTckTest.java new file mode 100644 index 000000000000..33d343494aca --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/AdditionalDataSubscriberTckTest.java @@ -0,0 +1,74 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.util.concurrent.CompletableFuture; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.reactivestreams.tck.SubscriberWhiteboxVerification; +import org.reactivestreams.tck.TestEnvironment; + +public class AdditionalDataSubscriberTckTest extends SubscriberWhiteboxVerification { + protected AdditionalDataSubscriberTckTest() { + super(new TestEnvironment()); + } + + @Override + public Subscriber createSubscriber(WhiteboxSubscriberProbe probe) { + Subscriber foo = new SequentialSubscriber<>(s -> {}, new CompletableFuture<>()); + + return new AdditionalDataSubscriber(foo, () -> Integer.MIN_VALUE) { + @Override + public void onError(Throwable throwable) { + super.onError(throwable); + probe.registerOnError(throwable); + } + + @Override + public void onSubscribe(Subscription subscription) { + super.onSubscribe(subscription); + probe.registerOnSubscribe(new SubscriberPuppet() { + @Override + public void triggerRequest(long elements) { + subscription.request(elements); + } + + @Override + public void signalCancel() { + subscription.cancel(); + } + }); + } + + @Override + public void onNext(Integer nextItem) { + super.onNext(nextItem); + probe.registerOnNext(nextItem); + } + + @Override + public void onComplete() { + super.onComplete(); + probe.registerOnComplete(); + } + }; + } + + @Override + public Integer createElement(int i) { + return i; + } +} diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/AdditionalDataSubscriberTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/AdditionalDataSubscriberTest.java new file mode 100644 index 000000000000..7d577823c372 --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/AdditionalDataSubscriberTest.java @@ -0,0 +1,65 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscriber; + +public class AdditionalDataSubscriberTest { + + @Test + void additionalDataSupplierNull_shouldThrowException() { + SequentialSubscriber downstreamSubscriber = new SequentialSubscriber(i -> {}, new CompletableFuture()); + assertThatThrownBy(() -> new AdditionalDataSubscriber<>(downstreamSubscriber, null)) + .hasMessageContaining("must not be null"); + } + + @Test + void subscriberNull_shouldThrowException() { + assertThatThrownBy(() -> new AdditionalDataSubscriber<>(null, () -> 1)) + .hasMessageContaining("must not be null"); + } + + @Test + void additionalDataNotNull_shouldNotSendAdditionalData() { + List result = new ArrayList<>(); + CompletableFuture future = new CompletableFuture(); + SequentialSubscriber downstreamSubscriber = new SequentialSubscriber(i -> result.add(i), future); + + Subscriber subscriber = new AdditionalDataSubscriber<>(downstreamSubscriber, () -> Integer.MAX_VALUE); + + publishData(subscriber); + + future.join(); + + assertThat(result).containsExactly(0, 1, 2, Integer.MAX_VALUE); + } + + private void publishData(Subscriber subscriber) { + SimplePublisher simplePublisher = new SimplePublisher<>(); + simplePublisher.subscribe(subscriber); + for (int i = 0; i < 3; i++) { + simplePublisher.send(i); + } + simplePublisher.complete(); + } +}