diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SdkPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SdkPublisher.java index 58c5dec433c6..5563b716dae6 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SdkPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SdkPublisher.java @@ -20,10 +20,12 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; +import java.util.function.Supplier; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkPublicApi; +import software.amazon.awssdk.utils.async.AddingTrailingDataSubscriber; import software.amazon.awssdk.utils.async.BufferingSubscriber; import software.amazon.awssdk.utils.async.EventListeningSubscriber; import software.amazon.awssdk.utils.async.FilteringSubscriber; @@ -118,6 +120,18 @@ default SdkPublisher limit(int limit) { return subscriber -> subscribe(new LimitingSubscriber<>(subscriber, limit)); } + + /** + * Creates a new publisher that emits trailing events provided by {@code trailingDataSupplier} in addition to the + * published events. + * + * @param trailingDataSupplier supplier to provide the trailing data + * @return New publisher that will publish additional events + */ + default SdkPublisher addTrailingData(Supplier> trailingDataSupplier) { + return subscriber -> subscribe(new AddingTrailingDataSubscriber(subscriber, trailingDataSupplier)); + } + /** * Add a callback that will be invoked after this publisher invokes {@link Subscriber#onComplete()}. * diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SdkPublishersTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SdkPublishersTest.java index 592873971934..c71816e1ff27 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SdkPublishersTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SdkPublishersTest.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -141,6 +142,23 @@ public void flatMapIterableHandlesError() { .hasCause(exception); } + @Test + public void addTrailingData_handlesCorrectly() { + FakeSdkPublisher fakePublisher = new FakeSdkPublisher<>(); + + FakeStringSubscriber fakeSubscriber = new FakeStringSubscriber(); + fakePublisher.addTrailingData(() -> Arrays.asList("two", "three")) + .subscribe(fakeSubscriber); + + fakePublisher.publish("one"); + fakePublisher.complete(); + + assertThat(fakeSubscriber.recordedEvents()).containsExactly("one", "two", "three"); + assertThat(fakeSubscriber.isComplete()).isTrue(); + assertThat(fakeSubscriber.isError()).isFalse(); + } + + private final static class FakeByteBufferSubscriber implements Subscriber { private final List recordedEvents = new ArrayList<>(); diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriber.java new file mode 100644 index 000000000000..cd8b8c25eb27 --- /dev/null +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriber.java @@ -0,0 +1,171 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.util.Iterator; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; + +/** + * Allows to send trailing data before invoking onComplete on the downstream subscriber. + * trailingDataIterable will be created when the upstream subscriber has called onComplete. + */ +@SdkProtectedApi +public class AddingTrailingDataSubscriber extends DelegatingSubscriber { + private static final Logger log = Logger.loggerFor(AddingTrailingDataSubscriber.class); + + /** + * The subscription to the upstream subscriber. + */ + private Subscription upstreamSubscription; + + /** + * The amount of unfulfilled demand the downstream subscriber has opened against us. + */ + private final AtomicLong downstreamDemand = new AtomicLong(0); + + /** + * Whether the upstream subscriber has called onComplete on us. + */ + private volatile boolean onCompleteCalledByUpstream = false; + + /** + * Whether the upstream subscriber has called onError on us. + */ + private volatile boolean onErrorCalledByUpstream = false; + + /** + * Whether we have called onComplete on the downstream subscriber. + */ + private volatile boolean onCompleteCalledOnDownstream = false; + + private final Supplier> trailingDataIterableSupplier; + private Iterator trailingDataIterator; + + public AddingTrailingDataSubscriber(Subscriber subscriber, + Supplier> trailingDataIterableSupplier) { + super(Validate.paramNotNull(subscriber, "subscriber")); + this.trailingDataIterableSupplier = Validate.paramNotNull(trailingDataIterableSupplier, "trailingDataIterableSupplier"); + } + + @Override + public void onSubscribe(Subscription subscription) { + + if (upstreamSubscription != null) { + log.warn(() -> "Received duplicate subscription, cancelling the duplicate.", new IllegalStateException()); + subscription.cancel(); + return; + } + + upstreamSubscription = subscription; + + subscriber.onSubscribe(new Subscription() { + + @Override + public void request(long l) { + if (onErrorCalledByUpstream || onCompleteCalledOnDownstream) { + return; + } + + addDownstreamDemand(l); + + if (onCompleteCalledByUpstream) { + sendTrailingDataAndCompleteIfNeeded(); + return; + } + upstreamSubscription.request(l); + } + + @Override + public void cancel() { + upstreamSubscription.cancel(); + } + }); + } + + @Override + public void onError(Throwable throwable) { + onErrorCalledByUpstream = true; + subscriber.onError(throwable); + } + + @Override + public void onNext(T t) { + Validate.paramNotNull(t, "item"); + downstreamDemand.decrementAndGet(); + subscriber.onNext(t); + } + + @Override + public void onComplete() { + onCompleteCalledByUpstream = true; + sendTrailingDataAndCompleteIfNeeded(); + } + + private void addDownstreamDemand(long l) { + + if (l > 0) { + downstreamDemand.getAndUpdate(current -> { + long newValue = current + l; + return newValue >= 0 ? newValue : Long.MAX_VALUE; + }); + } else { + upstreamSubscription.cancel(); + onError(new IllegalArgumentException("Demand must not be negative")); + } + } + + private synchronized void sendTrailingDataAndCompleteIfNeeded() { + if (onCompleteCalledOnDownstream) { + return; + } + + if (trailingDataIterator == null) { + Iterable supplier = trailingDataIterableSupplier.get(); + if (supplier == null) { + completeDownstreamSubscriber(); + return; + } + + trailingDataIterator = supplier.iterator(); + } + + sendTrailingDataIfNeeded(); + + if (!trailingDataIterator.hasNext()) { + completeDownstreamSubscriber(); + } + } + + private void sendTrailingDataIfNeeded() { + long demand = downstreamDemand.get(); + + while (trailingDataIterator.hasNext() && demand > 0) { + subscriber.onNext(trailingDataIterator.next()); + demand = downstreamDemand.decrementAndGet(); + } + } + + private void completeDownstreamSubscriber() { + subscriber.onComplete(); + onCompleteCalledOnDownstream = true; + } +} diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTckTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTckTest.java new file mode 100644 index 000000000000..7eced2270c3a --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTckTest.java @@ -0,0 +1,75 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.util.Arrays; +import java.util.concurrent.CompletableFuture; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.reactivestreams.tck.SubscriberWhiteboxVerification; +import org.reactivestreams.tck.TestEnvironment; + +public class AddingTrailingDataSubscriberTckTest extends SubscriberWhiteboxVerification { + protected AddingTrailingDataSubscriberTckTest() { + super(new TestEnvironment()); + } + + @Override + public Subscriber createSubscriber(WhiteboxSubscriberProbe probe) { + Subscriber foo = new SequentialSubscriber<>(s -> {}, new CompletableFuture<>()); + + return new AddingTrailingDataSubscriber(foo, () -> Arrays.asList(0, 1, 2)) { + @Override + public void onError(Throwable throwable) { + super.onError(throwable); + probe.registerOnError(throwable); + } + + @Override + public void onSubscribe(Subscription subscription) { + super.onSubscribe(subscription); + probe.registerOnSubscribe(new SubscriberPuppet() { + @Override + public void triggerRequest(long elements) { + subscription.request(elements); + } + + @Override + public void signalCancel() { + subscription.cancel(); + } + }); + } + + @Override + public void onNext(Integer nextItem) { + super.onNext(nextItem); + probe.registerOnNext(nextItem); + } + + @Override + public void onComplete() { + super.onComplete(); + probe.registerOnComplete(); + } + }; + } + + @Override + public Integer createElement(int i) { + return i; + } +} diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTest.java new file mode 100644 index 000000000000..b4a72c459bb2 --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTest.java @@ -0,0 +1,99 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.google.common.collect.Lists; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscriber; + +public class AddingTrailingDataSubscriberTest { + + @Test + void trailingDataSupplierNull_shouldThrowException() { + SequentialSubscriber downstreamSubscriber = new SequentialSubscriber(i -> {}, new CompletableFuture()); + assertThatThrownBy(() -> new AddingTrailingDataSubscriber<>(downstreamSubscriber, null)) + .hasMessageContaining("must not be null"); + } + + @Test + void subscriberNull_shouldThrowException() { + assertThatThrownBy(() -> new AddingTrailingDataSubscriber<>(null, () -> Arrays.asList(1, 2))) + .hasMessageContaining("must not be null"); + } + + @Test + void trailingDataHasItems_shouldSendAdditionalData() { + List result = new ArrayList<>(); + CompletableFuture future = new CompletableFuture(); + SequentialSubscriber downstreamSubscriber = new SequentialSubscriber(i -> result.add(i), future); + + Subscriber subscriber = new AddingTrailingDataSubscriber<>(downstreamSubscriber, + () -> Arrays.asList(Integer.MAX_VALUE, + Integer.MIN_VALUE)); + + publishData(subscriber); + + future.join(); + + assertThat(result).containsExactly(0, 1, 2, Integer.MAX_VALUE, Integer.MIN_VALUE); + } + + @Test + void trailingDataEmpty_shouldNotSendAdditionalData() { + List result = new ArrayList<>(); + CompletableFuture future = new CompletableFuture(); + SequentialSubscriber downstreamSubscriber = new SequentialSubscriber(i -> result.add(i), future); + + Subscriber subscriber = new AddingTrailingDataSubscriber<>(downstreamSubscriber, () -> new ArrayList<>()); + + publishData(subscriber); + + future.join(); + + assertThat(result).containsExactly(0, 1, 2); + } + + @Test + void trailingDataNull_shouldCompleteNormally() { + List result = new ArrayList<>(); + CompletableFuture future = new CompletableFuture(); + SequentialSubscriber downstreamSubscriber = new SequentialSubscriber(i -> result.add(i), future); + + Subscriber subscriber = new AddingTrailingDataSubscriber<>(downstreamSubscriber, () -> null); + + publishData(subscriber); + + future.join(); + + assertThat(result).containsExactly(0, 1, 2); + } + + private void publishData(Subscriber subscriber) { + SimplePublisher simplePublisher = new SimplePublisher<>(); + simplePublisher.subscribe(subscriber); + for (int i = 0; i < 3; i++) { + simplePublisher.send(i); + } + simplePublisher.complete(); + } +}