From 2756b1ca609431fb37dff2d65cfe47fcd760f864 Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Tue, 29 Aug 2023 11:21:15 -0700 Subject: [PATCH 1/2] Add AdditionalDataSubscriber to allow users to send additional data to the downstream subscriber --- .../async/AddingTrailingDataSubscriber.java | 148 ++++++++++++++++++ .../AddingTrailingDataSubscriberTckTest.java | 74 +++++++++ .../AddingTrailingDataSubscriberTest.java | 80 ++++++++++ 3 files changed, 302 insertions(+) create mode 100644 utils/src/main/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriber.java create mode 100644 utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTckTest.java create mode 100644 utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTest.java diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriber.java new file mode 100644 index 000000000000..5646140255cc --- /dev/null +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriber.java @@ -0,0 +1,148 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; + +/** + * Allows to send trailing data before invoking onComplete on the downstream subscriber. + * If the trailingDataSupplier returns null, this class will invoke onComplete directly + */ +@SdkProtectedApi +public class AddingTrailingDataSubscriber extends DelegatingSubscriber { + private static final Logger log = Logger.loggerFor(AddingTrailingDataSubscriber.class); + + /** + * The subscription to the upstream subscriber. + */ + private Subscription upstreamSubscription; + + /** + * The amount of unfulfilled demand the downstream subscriber has opened against us. + */ + private final AtomicLong downstreamDemand = new AtomicLong(0); + + /** + * Whether the upstream subscriber has called onComplete on us. + */ + private volatile boolean onCompleteCalledByUpstream = false; + + /** + * Whether the upstream subscriber has called onError on us. + */ + private volatile boolean onErrorCalledByUpstream = false; + + /** + * Whether we have called onComplete on the downstream subscriber. + */ + private AtomicBoolean onCompleteCalledOnDownstream = new AtomicBoolean(false); + + private final Supplier trailingDataSupplier; + private volatile T trailingData; + + public AddingTrailingDataSubscriber(Subscriber subscriber, + Supplier trailingDataSupplier) { + super(Validate.paramNotNull(subscriber, "subscriber")); + this.trailingDataSupplier = Validate.paramNotNull(trailingDataSupplier, "trailingDataSupplier"); + } + + @Override + public void onSubscribe(Subscription subscription) { + + if (upstreamSubscription != null) { + log.warn(() -> "Received duplicate subscription, cancelling the duplicate.", new IllegalStateException()); + subscription.cancel(); + return; + } + + upstreamSubscription = subscription; + + subscriber.onSubscribe(new Subscription() { + + @Override + public void request(long l) { + if (onErrorCalledByUpstream) { + return; + } + + if (onCompleteCalledByUpstream) { + sendTrailingDataIfNeededAndComplete(); + return; + } + + addDownstreamDemand(l); + upstreamSubscription.request(l); + } + + @Override + public void cancel() { + upstreamSubscription.cancel(); + } + }); + } + + @Override + public void onError(Throwable throwable) { + onErrorCalledByUpstream = true; + subscriber.onError(throwable); + } + + @Override + public void onNext(T t) { + Validate.paramNotNull(t, "item"); + downstreamDemand.decrementAndGet(); + subscriber.onNext(t); + } + + @Override + public void onComplete() { + onCompleteCalledByUpstream = true; + + trailingData = trailingDataSupplier.get(); + if (trailingData == null || downstreamDemand.get() > 0) { + sendTrailingDataIfNeededAndComplete(); + } + } + + private void addDownstreamDemand(long l) { + + if (l > 0) { + downstreamDemand.getAndUpdate(current -> { + long newValue = current + l; + return newValue >= 0 ? newValue : Long.MAX_VALUE; + }); + } else { + upstreamSubscription.cancel(); + onError(new IllegalArgumentException("Demand must not be negative")); + } + } + + private void sendTrailingDataIfNeededAndComplete() { + if (onCompleteCalledOnDownstream.compareAndSet(false, true)) { + if (trailingData != null) { + subscriber.onNext(trailingData); + } + subscriber.onComplete(); + } + } +} diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTckTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTckTest.java new file mode 100644 index 000000000000..8cf9890b2290 --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTckTest.java @@ -0,0 +1,74 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.util.concurrent.CompletableFuture; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.reactivestreams.tck.SubscriberWhiteboxVerification; +import org.reactivestreams.tck.TestEnvironment; + +public class AddingTrailingDataSubscriberTckTest extends SubscriberWhiteboxVerification { + protected AddingTrailingDataSubscriberTckTest() { + super(new TestEnvironment()); + } + + @Override + public Subscriber createSubscriber(WhiteboxSubscriberProbe probe) { + Subscriber foo = new SequentialSubscriber<>(s -> {}, new CompletableFuture<>()); + + return new AddingTrailingDataSubscriber(foo, () -> Integer.MIN_VALUE) { + @Override + public void onError(Throwable throwable) { + super.onError(throwable); + probe.registerOnError(throwable); + } + + @Override + public void onSubscribe(Subscription subscription) { + super.onSubscribe(subscription); + probe.registerOnSubscribe(new SubscriberPuppet() { + @Override + public void triggerRequest(long elements) { + subscription.request(elements); + } + + @Override + public void signalCancel() { + subscription.cancel(); + } + }); + } + + @Override + public void onNext(Integer nextItem) { + super.onNext(nextItem); + probe.registerOnNext(nextItem); + } + + @Override + public void onComplete() { + super.onComplete(); + probe.registerOnComplete(); + } + }; + } + + @Override + public Integer createElement(int i) { + return i; + } +} diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTest.java new file mode 100644 index 000000000000..d4cdd6d5e9b8 --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTest.java @@ -0,0 +1,80 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscriber; + +public class AddingTrailingDataSubscriberTest { + + @Test + void trailingDataSupplierNull_shouldThrowException() { + SequentialSubscriber downstreamSubscriber = new SequentialSubscriber(i -> {}, new CompletableFuture()); + assertThatThrownBy(() -> new AddingTrailingDataSubscriber<>(downstreamSubscriber, null)) + .hasMessageContaining("must not be null"); + } + + @Test + void subscriberNull_shouldThrowException() { + assertThatThrownBy(() -> new AddingTrailingDataSubscriber<>(null, () -> 1)) + .hasMessageContaining("must not be null"); + } + + @Test + void trailingDataNotNull_shouldNotSendAdditionalData() { + List result = new ArrayList<>(); + CompletableFuture future = new CompletableFuture(); + SequentialSubscriber downstreamSubscriber = new SequentialSubscriber(i -> result.add(i), future); + + Subscriber subscriber = new AddingTrailingDataSubscriber<>(downstreamSubscriber, () -> Integer.MAX_VALUE); + + publishData(subscriber); + + future.join(); + + assertThat(result).containsExactly(0, 1, 2, Integer.MAX_VALUE); + } + + @Test + void trailingDataNull_shouldNotSendAdditionalData() { + List result = new ArrayList<>(); + CompletableFuture future = new CompletableFuture(); + SequentialSubscriber downstreamSubscriber = new SequentialSubscriber(i -> result.add(i), future); + + Subscriber subscriber = new AddingTrailingDataSubscriber<>(downstreamSubscriber, () -> null); + + publishData(subscriber); + + future.join(); + + assertThat(result).containsExactly(0, 1, 2); + } + + private void publishData(Subscriber subscriber) { + SimplePublisher simplePublisher = new SimplePublisher<>(); + simplePublisher.subscribe(subscriber); + for (int i = 0; i < 3; i++) { + simplePublisher.send(i); + } + simplePublisher.complete(); + } +} From 09a02e526cf4419d0f36c8051002261301aee3a8 Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Tue, 29 Aug 2023 18:24:09 -0700 Subject: [PATCH 2/2] Support iterable --- .../awssdk/core/async/SdkPublisher.java | 14 ++++ .../awssdk/core/async/SdkPublishersTest.java | 18 +++++ .../async/AddingTrailingDataSubscriber.java | 65 +++++++++++++------ .../AddingTrailingDataSubscriberTckTest.java | 3 +- .../AddingTrailingDataSubscriberTest.java | 29 +++++++-- 5 files changed, 102 insertions(+), 27 deletions(-) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SdkPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SdkPublisher.java index 58c5dec433c6..5563b716dae6 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SdkPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/SdkPublisher.java @@ -20,10 +20,12 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; +import java.util.function.Supplier; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkPublicApi; +import software.amazon.awssdk.utils.async.AddingTrailingDataSubscriber; import software.amazon.awssdk.utils.async.BufferingSubscriber; import software.amazon.awssdk.utils.async.EventListeningSubscriber; import software.amazon.awssdk.utils.async.FilteringSubscriber; @@ -118,6 +120,18 @@ default SdkPublisher limit(int limit) { return subscriber -> subscribe(new LimitingSubscriber<>(subscriber, limit)); } + + /** + * Creates a new publisher that emits trailing events provided by {@code trailingDataSupplier} in addition to the + * published events. + * + * @param trailingDataSupplier supplier to provide the trailing data + * @return New publisher that will publish additional events + */ + default SdkPublisher addTrailingData(Supplier> trailingDataSupplier) { + return subscriber -> subscribe(new AddingTrailingDataSubscriber(subscriber, trailingDataSupplier)); + } + /** * Add a callback that will be invoked after this publisher invokes {@link Subscriber#onComplete()}. * diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SdkPublishersTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SdkPublishersTest.java index 592873971934..c71816e1ff27 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SdkPublishersTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/SdkPublishersTest.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -141,6 +142,23 @@ public void flatMapIterableHandlesError() { .hasCause(exception); } + @Test + public void addTrailingData_handlesCorrectly() { + FakeSdkPublisher fakePublisher = new FakeSdkPublisher<>(); + + FakeStringSubscriber fakeSubscriber = new FakeStringSubscriber(); + fakePublisher.addTrailingData(() -> Arrays.asList("two", "three")) + .subscribe(fakeSubscriber); + + fakePublisher.publish("one"); + fakePublisher.complete(); + + assertThat(fakeSubscriber.recordedEvents()).containsExactly("one", "two", "three"); + assertThat(fakeSubscriber.isComplete()).isTrue(); + assertThat(fakeSubscriber.isError()).isFalse(); + } + + private final static class FakeByteBufferSubscriber implements Subscriber { private final List recordedEvents = new ArrayList<>(); diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriber.java index 5646140255cc..cd8b8c25eb27 100644 --- a/utils/src/main/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriber.java +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriber.java @@ -15,7 +15,7 @@ package software.amazon.awssdk.utils.async; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.Iterator; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import org.reactivestreams.Subscriber; @@ -26,7 +26,7 @@ /** * Allows to send trailing data before invoking onComplete on the downstream subscriber. - * If the trailingDataSupplier returns null, this class will invoke onComplete directly + * trailingDataIterable will be created when the upstream subscriber has called onComplete. */ @SdkProtectedApi public class AddingTrailingDataSubscriber extends DelegatingSubscriber { @@ -55,15 +55,15 @@ public class AddingTrailingDataSubscriber extends DelegatingSubscriber /** * Whether we have called onComplete on the downstream subscriber. */ - private AtomicBoolean onCompleteCalledOnDownstream = new AtomicBoolean(false); + private volatile boolean onCompleteCalledOnDownstream = false; - private final Supplier trailingDataSupplier; - private volatile T trailingData; + private final Supplier> trailingDataIterableSupplier; + private Iterator trailingDataIterator; public AddingTrailingDataSubscriber(Subscriber subscriber, - Supplier trailingDataSupplier) { + Supplier> trailingDataIterableSupplier) { super(Validate.paramNotNull(subscriber, "subscriber")); - this.trailingDataSupplier = Validate.paramNotNull(trailingDataSupplier, "trailingDataSupplier"); + this.trailingDataIterableSupplier = Validate.paramNotNull(trailingDataIterableSupplier, "trailingDataIterableSupplier"); } @Override @@ -81,16 +81,16 @@ public void onSubscribe(Subscription subscription) { @Override public void request(long l) { - if (onErrorCalledByUpstream) { + if (onErrorCalledByUpstream || onCompleteCalledOnDownstream) { return; } + addDownstreamDemand(l); + if (onCompleteCalledByUpstream) { - sendTrailingDataIfNeededAndComplete(); + sendTrailingDataAndCompleteIfNeeded(); return; } - - addDownstreamDemand(l); upstreamSubscription.request(l); } @@ -117,11 +117,7 @@ public void onNext(T t) { @Override public void onComplete() { onCompleteCalledByUpstream = true; - - trailingData = trailingDataSupplier.get(); - if (trailingData == null || downstreamDemand.get() > 0) { - sendTrailingDataIfNeededAndComplete(); - } + sendTrailingDataAndCompleteIfNeeded(); } private void addDownstreamDemand(long l) { @@ -137,12 +133,39 @@ private void addDownstreamDemand(long l) { } } - private void sendTrailingDataIfNeededAndComplete() { - if (onCompleteCalledOnDownstream.compareAndSet(false, true)) { - if (trailingData != null) { - subscriber.onNext(trailingData); + private synchronized void sendTrailingDataAndCompleteIfNeeded() { + if (onCompleteCalledOnDownstream) { + return; + } + + if (trailingDataIterator == null) { + Iterable supplier = trailingDataIterableSupplier.get(); + if (supplier == null) { + completeDownstreamSubscriber(); + return; } - subscriber.onComplete(); + + trailingDataIterator = supplier.iterator(); } + + sendTrailingDataIfNeeded(); + + if (!trailingDataIterator.hasNext()) { + completeDownstreamSubscriber(); + } + } + + private void sendTrailingDataIfNeeded() { + long demand = downstreamDemand.get(); + + while (trailingDataIterator.hasNext() && demand > 0) { + subscriber.onNext(trailingDataIterator.next()); + demand = downstreamDemand.decrementAndGet(); + } + } + + private void completeDownstreamSubscriber() { + subscriber.onComplete(); + onCompleteCalledOnDownstream = true; } } diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTckTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTckTest.java index 8cf9890b2290..7eced2270c3a 100644 --- a/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTckTest.java +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTckTest.java @@ -15,6 +15,7 @@ package software.amazon.awssdk.utils.async; +import java.util.Arrays; import java.util.concurrent.CompletableFuture; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -30,7 +31,7 @@ protected AddingTrailingDataSubscriberTckTest() { public Subscriber createSubscriber(WhiteboxSubscriberProbe probe) { Subscriber foo = new SequentialSubscriber<>(s -> {}, new CompletableFuture<>()); - return new AddingTrailingDataSubscriber(foo, () -> Integer.MIN_VALUE) { + return new AddingTrailingDataSubscriber(foo, () -> Arrays.asList(0, 1, 2)) { @Override public void onError(Throwable throwable) { super.onError(throwable); diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTest.java index d4cdd6d5e9b8..b4a72c459bb2 100644 --- a/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTest.java +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/AddingTrailingDataSubscriberTest.java @@ -18,7 +18,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import com.google.common.collect.Lists; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.Test; @@ -35,27 +37,44 @@ void trailingDataSupplierNull_shouldThrowException() { @Test void subscriberNull_shouldThrowException() { - assertThatThrownBy(() -> new AddingTrailingDataSubscriber<>(null, () -> 1)) + assertThatThrownBy(() -> new AddingTrailingDataSubscriber<>(null, () -> Arrays.asList(1, 2))) .hasMessageContaining("must not be null"); } @Test - void trailingDataNotNull_shouldNotSendAdditionalData() { + void trailingDataHasItems_shouldSendAdditionalData() { List result = new ArrayList<>(); CompletableFuture future = new CompletableFuture(); SequentialSubscriber downstreamSubscriber = new SequentialSubscriber(i -> result.add(i), future); - Subscriber subscriber = new AddingTrailingDataSubscriber<>(downstreamSubscriber, () -> Integer.MAX_VALUE); + Subscriber subscriber = new AddingTrailingDataSubscriber<>(downstreamSubscriber, + () -> Arrays.asList(Integer.MAX_VALUE, + Integer.MIN_VALUE)); publishData(subscriber); future.join(); - assertThat(result).containsExactly(0, 1, 2, Integer.MAX_VALUE); + assertThat(result).containsExactly(0, 1, 2, Integer.MAX_VALUE, Integer.MIN_VALUE); } @Test - void trailingDataNull_shouldNotSendAdditionalData() { + void trailingDataEmpty_shouldNotSendAdditionalData() { + List result = new ArrayList<>(); + CompletableFuture future = new CompletableFuture(); + SequentialSubscriber downstreamSubscriber = new SequentialSubscriber(i -> result.add(i), future); + + Subscriber subscriber = new AddingTrailingDataSubscriber<>(downstreamSubscriber, () -> new ArrayList<>()); + + publishData(subscriber); + + future.join(); + + assertThat(result).containsExactly(0, 1, 2); + } + + @Test + void trailingDataNull_shouldCompleteNormally() { List result = new ArrayList<>(); CompletableFuture future = new CompletableFuture(); SequentialSubscriber downstreamSubscriber = new SequentialSubscriber(i -> result.add(i), future);