Skip to content

Commit

Permalink
Add AdditionalDataSubscriber to allow users to send additional data t…
Browse files Browse the repository at this point in the history
…o the downstream subscriber
  • Loading branch information
zoewangg committed Aug 29, 2023
1 parent 0326bf1 commit 2756b1c
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.utils.async;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkProtectedApi;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Validate;

/**
* Allows to send trailing data before invoking onComplete on the downstream subscriber.
* If the trailingDataSupplier returns null, this class will invoke onComplete directly
*/
@SdkProtectedApi
public class AddingTrailingDataSubscriber<T> extends DelegatingSubscriber<T, T> {
private static final Logger log = Logger.loggerFor(AddingTrailingDataSubscriber.class);

/**
* The subscription to the upstream subscriber.
*/
private Subscription upstreamSubscription;

/**
* The amount of unfulfilled demand the downstream subscriber has opened against us.
*/
private final AtomicLong downstreamDemand = new AtomicLong(0);

/**
* Whether the upstream subscriber has called onComplete on us.
*/
private volatile boolean onCompleteCalledByUpstream = false;

/**
* Whether the upstream subscriber has called onError on us.
*/
private volatile boolean onErrorCalledByUpstream = false;

/**
* Whether we have called onComplete on the downstream subscriber.
*/
private AtomicBoolean onCompleteCalledOnDownstream = new AtomicBoolean(false);

private final Supplier<T> trailingDataSupplier;
private volatile T trailingData;

public AddingTrailingDataSubscriber(Subscriber<? super T> subscriber,
Supplier<T> trailingDataSupplier) {
super(Validate.paramNotNull(subscriber, "subscriber"));
this.trailingDataSupplier = Validate.paramNotNull(trailingDataSupplier, "trailingDataSupplier");
}

@Override
public void onSubscribe(Subscription subscription) {

if (upstreamSubscription != null) {
log.warn(() -> "Received duplicate subscription, cancelling the duplicate.", new IllegalStateException());
subscription.cancel();
return;
}

upstreamSubscription = subscription;

subscriber.onSubscribe(new Subscription() {

@Override
public void request(long l) {
if (onErrorCalledByUpstream) {
return;
}

if (onCompleteCalledByUpstream) {
sendTrailingDataIfNeededAndComplete();
return;
}

addDownstreamDemand(l);
upstreamSubscription.request(l);
}

@Override
public void cancel() {
upstreamSubscription.cancel();
}
});
}

@Override
public void onError(Throwable throwable) {
onErrorCalledByUpstream = true;
subscriber.onError(throwable);
}

@Override
public void onNext(T t) {
Validate.paramNotNull(t, "item");
downstreamDemand.decrementAndGet();
subscriber.onNext(t);
}

@Override
public void onComplete() {
onCompleteCalledByUpstream = true;

trailingData = trailingDataSupplier.get();
if (trailingData == null || downstreamDemand.get() > 0) {
sendTrailingDataIfNeededAndComplete();
}
}

private void addDownstreamDemand(long l) {

if (l > 0) {
downstreamDemand.getAndUpdate(current -> {
long newValue = current + l;
return newValue >= 0 ? newValue : Long.MAX_VALUE;
});
} else {
upstreamSubscription.cancel();
onError(new IllegalArgumentException("Demand must not be negative"));
}
}

private void sendTrailingDataIfNeededAndComplete() {
if (onCompleteCalledOnDownstream.compareAndSet(false, true)) {
if (trailingData != null) {
subscriber.onNext(trailingData);
}
subscriber.onComplete();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.utils.async;

import java.util.concurrent.CompletableFuture;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.reactivestreams.tck.SubscriberWhiteboxVerification;
import org.reactivestreams.tck.TestEnvironment;

public class AddingTrailingDataSubscriberTckTest extends SubscriberWhiteboxVerification<Integer> {
protected AddingTrailingDataSubscriberTckTest() {
super(new TestEnvironment());
}

@Override
public Subscriber<Integer> createSubscriber(WhiteboxSubscriberProbe<Integer> probe) {
Subscriber<Integer> foo = new SequentialSubscriber<>(s -> {}, new CompletableFuture<>());

return new AddingTrailingDataSubscriber<Integer>(foo, () -> Integer.MIN_VALUE) {
@Override
public void onError(Throwable throwable) {
super.onError(throwable);
probe.registerOnError(throwable);
}

@Override
public void onSubscribe(Subscription subscription) {
super.onSubscribe(subscription);
probe.registerOnSubscribe(new SubscriberPuppet() {
@Override
public void triggerRequest(long elements) {
subscription.request(elements);
}

@Override
public void signalCancel() {
subscription.cancel();
}
});
}

@Override
public void onNext(Integer nextItem) {
super.onNext(nextItem);
probe.registerOnNext(nextItem);
}

@Override
public void onComplete() {
super.onComplete();
probe.registerOnComplete();
}
};
}

@Override
public Integer createElement(int i) {
return i;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.utils.async;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Subscriber;

public class AddingTrailingDataSubscriberTest {

@Test
void trailingDataSupplierNull_shouldThrowException() {
SequentialSubscriber<Integer> downstreamSubscriber = new SequentialSubscriber<Integer>(i -> {}, new CompletableFuture());
assertThatThrownBy(() -> new AddingTrailingDataSubscriber<>(downstreamSubscriber, null))
.hasMessageContaining("must not be null");
}

@Test
void subscriberNull_shouldThrowException() {
assertThatThrownBy(() -> new AddingTrailingDataSubscriber<>(null, () -> 1))
.hasMessageContaining("must not be null");
}

@Test
void trailingDataNotNull_shouldNotSendAdditionalData() {
List<Integer> result = new ArrayList<>();
CompletableFuture future = new CompletableFuture();
SequentialSubscriber<Integer> downstreamSubscriber = new SequentialSubscriber<Integer>(i -> result.add(i), future);

Subscriber<Integer> subscriber = new AddingTrailingDataSubscriber<>(downstreamSubscriber, () -> Integer.MAX_VALUE);

publishData(subscriber);

future.join();

assertThat(result).containsExactly(0, 1, 2, Integer.MAX_VALUE);
}

@Test
void trailingDataNull_shouldNotSendAdditionalData() {
List<Integer> result = new ArrayList<>();
CompletableFuture future = new CompletableFuture();
SequentialSubscriber<Integer> downstreamSubscriber = new SequentialSubscriber<Integer>(i -> result.add(i), future);

Subscriber<Integer> subscriber = new AddingTrailingDataSubscriber<>(downstreamSubscriber, () -> null);

publishData(subscriber);

future.join();

assertThat(result).containsExactly(0, 1, 2);
}

private void publishData(Subscriber<Integer> subscriber) {
SimplePublisher<Integer> simplePublisher = new SimplePublisher<>();
simplePublisher.subscribe(subscriber);
for (int i = 0; i < 3; i++) {
simplePublisher.send(i);
}
simplePublisher.complete();
}
}

0 comments on commit 2756b1c

Please sign in to comment.