diff --git a/sdk-api/src/test/java/dev/restate/sdk/DeferredTest.java b/sdk-api/src/test/java/dev/restate/sdk/DeferredTest.java index 68b6b505..b9ea5035 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/DeferredTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/DeferredTest.java @@ -134,7 +134,7 @@ protected TestInvocationBuilder awaitOnAlreadyResolvedAwaitables() { protected TestInvocationBuilder awaitWithTimeout() { return testDefinitionForService( - "AwaitOnAlreadyResolvedAwaitables", + "AwaitWithTimeout", Serde.VOID, JsonSerdes.STRING, (ctx, unused) -> { diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/AckStateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/AckStateMachine.java new file mode 100644 index 00000000..c31f3f47 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/AckStateMachine.java @@ -0,0 +1,56 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core; + +/** State machine tracking acks */ +class AckStateMachine extends BaseSuspendableCallbackStateMachine { + + interface AckCallback extends SuspendableCallback { + void onAck(); + } + + private int lastAcknowledgedEntry = -1; + + /** -1 means no side effect waiting to be acked. */ + private int lastEntryToAck = -1; + + void waitLastAck(AckCallback callback) { + if (lastEntryIsAcked()) { + callback.onAck(); + } else { + this.setCallback(callback); + } + } + + void tryHandleAck(int entryIndex) { + this.lastAcknowledgedEntry = Math.max(entryIndex, this.lastAcknowledgedEntry); + if (lastEntryIsAcked()) { + this.consumeCallback(AckCallback::onAck); + } + } + + void registerEntryToAck(int entryIndex) { + this.lastEntryToAck = Math.max(entryIndex, this.lastEntryToAck); + } + + private boolean lastEntryIsAcked() { + return this.lastEntryToAck <= this.lastAcknowledgedEntry; + } + + public int getLastEntryToAck() { + return lastEntryToAck; + } + + @Override + void abort(Throwable cause) { + super.abort(cause); + // We can't do anything else if the input stream is closed, so we just fail the callback, if any + this.tryFailCallback(); + } +} diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java index 7cbf56ff..35a516de 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java @@ -55,7 +55,7 @@ class InvocationStateMachine implements InvocationFlow.InvocationProcessor { // Buffering of messages and completions private final IncomingEntriesStateMachine incomingEntriesStateMachine; - private final SideEffectAckStateMachine sideEffectAckStateMachine; + private final AckStateMachine ackStateMachine; private final ReadyResultStateMachine readyResultStateMachine; // Flow sub/pub @@ -75,7 +75,7 @@ class InvocationStateMachine implements InvocationFlow.InvocationProcessor { this.incomingEntriesStateMachine = new IncomingEntriesStateMachine(); this.readyResultStateMachine = new ReadyResultStateMachine(); - this.sideEffectAckStateMachine = new SideEffectAckStateMachine(); + this.ackStateMachine = new AckStateMachine(); this.afterStartCallback = new CallbackHandle<>(); } @@ -142,8 +142,7 @@ public void onNext(InvocationFlow.InvocationInput invocationInput) { // runtime. this.readyResultStateMachine.offerCompletion((Protocol.CompletionMessage) msg); } else if (msg instanceof Protocol.EntryAckMessage) { - this.sideEffectAckStateMachine.tryHandleSideEffectAck( - ((Protocol.EntryAckMessage) msg).getEntryIndex()); + this.ackStateMachine.tryHandleAck(((Protocol.EntryAckMessage) msg).getEntryIndex()); } else { this.incomingEntriesStateMachine.offer(msg); } @@ -159,7 +158,7 @@ public void onError(Throwable throwable) { public void onComplete() { LOG.trace("Input publisher closed"); this.readyResultStateMachine.abort(AbortedExecutionException.INSTANCE); - this.sideEffectAckStateMachine.abort(AbortedExecutionException.INSTANCE); + this.ackStateMachine.abort(AbortedExecutionException.INSTANCE); } // --- Init routine to wait for the start message @@ -287,7 +286,7 @@ private void closeWithMessage(MessageLite closeMessage, Throwable cause) { // Unblock any eventual waiting callbacks this.afterStartCallback.consume(cb -> cb.onCancel(cause)); this.readyResultStateMachine.abort(cause); - this.sideEffectAckStateMachine.abort(cause); + this.ackStateMachine.abort(cause); this.incomingEntriesStateMachine.abort(cause); this.span.end(); } @@ -456,21 +455,21 @@ void exitSideEffectBlock( } // Write new entry - this.sideEffectAckStateMachine.registerExecutedSideEffect(this.currentJournalEntryIndex); + this.ackStateMachine.registerEntryToAck(this.currentJournalEntryIndex); this.writeEntry(sideEffectEntry); // Wait for entry to be acked Protocol.RunEntryMessage finalSideEffectEntry = sideEffectEntry; - this.sideEffectAckStateMachine.waitLastSideEffectAck( - new SideEffectAckStateMachine.SideEffectAckCallback() { + this.ackStateMachine.waitLastAck( + new AckStateMachine.AckCallback() { @Override - public void onLastSideEffectAck() { + public void onAck() { completeSideEffectCallbackWithEntry(finalSideEffectEntry, callback); } @Override public void onSuspend() { - suspend(List.of(sideEffectAckStateMachine.getLastExecutedSideEffect())); + suspend(List.of(ackStateMachine.getLastEntryToAck())); callback.onCancel(AbortedExecutionException.INSTANCE); } @@ -621,8 +620,7 @@ private void resolveCombinatorDeferred( + "This is a symptom of an SDK bug, please contact the developers."); } - writeCombinatorEntry(Collections.emptyList()); - callback.onSuccess(null); + writeCombinatorEntry(Collections.emptyList(), callback); return; } @@ -636,8 +634,7 @@ private void resolveCombinatorDeferred( // Try to resolve the combinator now if (rootDeferred.tryResolve(entryIndex)) { - writeCombinatorEntry(resolvedOrder); - callback.onSuccess(null); + writeCombinatorEntry(resolvedOrder, callback); return; } } else { @@ -667,8 +664,7 @@ public boolean onNewResult(Map> resultMap) { // Try to resolve the combinator now if (rootDeferred.tryResolve(entryIndex)) { - writeCombinatorEntry(resolvedOrder); - callback.onSuccess(null); + writeCombinatorEntry(resolvedOrder, callback); return true; } } @@ -694,12 +690,35 @@ public void onError(Throwable e) { } } - private void writeCombinatorEntry(List resolvedList) { + private void writeCombinatorEntry(List resolvedList, SyscallCallback callback) { // Create and write the entry Java.CombinatorAwaitableEntryMessage entry = Java.CombinatorAwaitableEntryMessage.newBuilder().addAllEntryIndex(resolvedList).build(); span.addEvent("Combinator"); + + // We register the combinator entry to wait for an ack + this.ackStateMachine.registerEntryToAck(this.currentJournalEntryIndex); writeEntry(entry); + + // Let's wait for the ack + this.ackStateMachine.waitLastAck( + new AckStateMachine.AckCallback() { + @Override + public void onAck() { + callback.onSuccess(null); + } + + @Override + public void onSuspend() { + suspend(List.of(ackStateMachine.getLastEntryToAck())); + callback.onCancel(AbortedExecutionException.INSTANCE); + } + + @Override + public void onError(Throwable e) { + callback.onCancel(e); + } + }); } // --- Internal callback diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java b/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java index cb9c72c0..ec92dff0 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java @@ -9,6 +9,7 @@ package dev.restate.sdk.core; import com.google.protobuf.MessageLite; +import dev.restate.generated.sdk.java.Java; import dev.restate.generated.service.protocol.Protocol; public class MessageHeader { @@ -82,6 +83,9 @@ public static MessageHeader fromMessage(MessageLite msg) { } else if (msg instanceof Protocol.RunEntryMessage) { return new MessageHeader( MessageType.RunEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize()); + } else if (msg instanceof Java.CombinatorAwaitableEntryMessage) { + return new MessageHeader( + MessageType.CombinatorAwaitableEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize()); } // Messages with no flags return new MessageHeader(MessageType.fromMessage(msg), 0, msg.getSerializedSize()); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/SideEffectAckStateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/SideEffectAckStateMachine.java deleted file mode 100644 index 3fd0ed0d..00000000 --- a/sdk-core/src/main/java/dev/restate/sdk/core/SideEffectAckStateMachine.java +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH -// -// This file is part of the Restate Java SDK, -// which is released under the MIT license. -// -// You can find a copy of the license in file LICENSE in the root -// directory of this repository or package, or at -// https://github.com/restatedev/sdk-java/blob/main/LICENSE -package dev.restate.sdk.core; - -/** State machine tracking side effects acks */ -class SideEffectAckStateMachine - extends BaseSuspendableCallbackStateMachine { - - interface SideEffectAckCallback extends SuspendableCallback { - void onLastSideEffectAck(); - } - - private int lastAcknowledgedEntry = -1; - - /** -1 means no side effect waiting to be acked. */ - private int lastExecutedSideEffect = -1; - - void waitLastSideEffectAck(SideEffectAckCallback callback) { - if (canExecuteSideEffect()) { - callback.onLastSideEffectAck(); - } else { - this.setCallback(callback); - } - } - - void tryHandleSideEffectAck(int entryIndex) { - this.lastAcknowledgedEntry = Math.max(entryIndex, this.lastAcknowledgedEntry); - if (canExecuteSideEffect()) { - this.consumeCallback(SideEffectAckCallback::onLastSideEffectAck); - } - } - - void registerExecutedSideEffect(int entryIndex) { - this.lastExecutedSideEffect = entryIndex; - } - - private boolean canExecuteSideEffect() { - return this.lastExecutedSideEffect <= this.lastAcknowledgedEntry; - } - - public int getLastExecutedSideEffect() { - return lastExecutedSideEffect; - } - - @Override - void abort(Throwable cause) { - super.abort(cause); - // We can't do anything else if the input stream is closed, so we just fail the callback, if any - this.tryFailCallback(); - } -} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/DeferredTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/DeferredTestSuite.java index 3e93b3b4..f012a4f3 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/DeferredTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/DeferredTestSuite.java @@ -16,6 +16,7 @@ import dev.restate.generated.sdk.java.Java; import dev.restate.generated.service.protocol.Protocol; +import dev.restate.generated.service.protocol.Protocol.Empty; import java.util.function.Supplier; import java.util.stream.Stream; @@ -42,7 +43,7 @@ protected Stream anyTestDefinitions( return Stream.of( testInvocation .get() - .withInput(startMessage(1), ProtoUtils.inputMessage()) + .withInput(startMessage(1), inputMessage()) .expectingOutput( invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), invokeMessage(GREETER_SERVICE_TARGET, "Till"), @@ -52,19 +53,30 @@ protected Stream anyTestDefinitions( .get() .withInput( startMessage(3), - ProtoUtils.inputMessage(), + inputMessage(), invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), - invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL")) + invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL"), + ackMessage(3)) .expectingOutput(combinatorsMessage(2), outputMessage("TILL"), END_MESSAGE) .named("Only one completion will generate the combinators message"), testInvocation .get() .withInput( startMessage(3), - ProtoUtils.inputMessage(), + inputMessage(), + invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), + invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL")) + .expectingOutput(combinatorsMessage(2), suspensionMessage(3)) + .named("Completed without ack will suspend"), + testInvocation + .get() + .withInput( + startMessage(3), + inputMessage(), invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), invokeMessage(GREETER_SERVICE_TARGET, "Till") - .setFailure(Util.toProtocolFailure(new IllegalStateException("My error")))) + .setFailure(Util.toProtocolFailure(new IllegalStateException("My error"))), + ackMessage(3)) .expectingOutput( combinatorsMessage(2), outputMessage(new IllegalStateException("My error")), @@ -74,9 +86,10 @@ protected Stream anyTestDefinitions( .get() .withInput( startMessage(3), - ProtoUtils.inputMessage(), + inputMessage(), invokeMessage(GREETER_SERVICE_TARGET, "Francesco", "FRANCESCO"), - invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL")) + invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL"), + ackMessage(3)) .assertingOutput( msgs -> { assertThat(msgs).hasSize(3); @@ -100,7 +113,7 @@ protected Stream anyTestDefinitions( .get() .withInput( startMessage(4), - ProtoUtils.inputMessage(), + inputMessage(), invokeMessage(GREETER_SERVICE_TARGET, "Francesco", "FRANCESCO"), invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL"), combinatorsMessage(2)) @@ -109,7 +122,7 @@ protected Stream anyTestDefinitions( testInvocation .get() .withInput( - startMessage(1), ProtoUtils.inputMessage(), completionMessage(1, "FRANCESCO")) + startMessage(1), inputMessage(), completionMessage(1, "FRANCESCO"), ackMessage(3)) .onlyUnbuffered() .expectingOutput( invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), @@ -128,7 +141,7 @@ public Stream definitions() { Stream.of( // --- Reverse await order this.reverseAwaitOrder() - .withInput(startMessage(1), ProtoUtils.inputMessage()) + .withInput(startMessage(1), inputMessage()) .expectingOutput( invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), invokeMessage(GREETER_SERVICE_TARGET, "Till"), @@ -137,7 +150,7 @@ public Stream definitions() { this.reverseAwaitOrder() .withInput( startMessage(1), - ProtoUtils.inputMessage(), + inputMessage(), completionMessage(1, "FRANCESCO"), completionMessage(2, "TILL")) .onlyUnbuffered() @@ -151,7 +164,7 @@ public Stream definitions() { this.reverseAwaitOrder() .withInput( startMessage(1), - ProtoUtils.inputMessage(), + inputMessage(), completionMessage(2, "TILL"), completionMessage(1, "FRANCESCO")) .onlyUnbuffered() @@ -163,7 +176,7 @@ public Stream definitions() { END_MESSAGE) .named("A2 and A1 completed later"), this.reverseAwaitOrder() - .withInput(startMessage(1), ProtoUtils.inputMessage(), completionMessage(2, "TILL")) + .withInput(startMessage(1), inputMessage(), completionMessage(2, "TILL")) .onlyUnbuffered() .expectingOutput( invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), @@ -172,8 +185,7 @@ public Stream definitions() { suspensionMessage(1)) .named("Only A2 completed"), this.reverseAwaitOrder() - .withInput( - startMessage(1), ProtoUtils.inputMessage(), completionMessage(1, "FRANCESCO")) + .withInput(startMessage(1), inputMessage(), completionMessage(1, "FRANCESCO")) .onlyUnbuffered() .expectingOutput( invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), @@ -183,8 +195,7 @@ public Stream definitions() { // --- Await twice the same executable this.awaitTwiceTheSameAwaitable() - .withInput( - startMessage(1), ProtoUtils.inputMessage(), completionMessage(1, "FRANCESCO")) + .withInput(startMessage(1), inputMessage(), completionMessage(1, "FRANCESCO")) .onlyUnbuffered() .expectingOutput( invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), @@ -193,7 +204,7 @@ public Stream definitions() { // --- All combinator this.awaitAll() - .withInput(startMessage(1), ProtoUtils.inputMessage()) + .withInput(startMessage(1), inputMessage()) .expectingOutput( invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), invokeMessage(GREETER_SERVICE_TARGET, "Till"), @@ -202,7 +213,7 @@ public Stream definitions() { this.awaitAll() .withInput( startMessage(3), - ProtoUtils.inputMessage(), + inputMessage(), invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL")) .expectingOutput(suspensionMessage(1)) @@ -210,9 +221,10 @@ public Stream definitions() { this.awaitAll() .withInput( startMessage(3), - ProtoUtils.inputMessage(), + inputMessage(), invokeMessage(GREETER_SERVICE_TARGET, "Francesco", "FRANCESCO"), - invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL")) + invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL"), + ackMessage(3)) .assertingOutput( msgs -> { assertThat(msgs).hasSize(3); @@ -231,7 +243,7 @@ public Stream definitions() { this.awaitAll() .withInput( startMessage(4), - ProtoUtils.inputMessage(), + inputMessage(), invokeMessage(GREETER_SERVICE_TARGET, "Francesco", "FRANCESCO"), invokeMessage(GREETER_SERVICE_TARGET, "Till", "TILL"), combinatorsMessage(1, 2)) @@ -240,9 +252,10 @@ public Stream definitions() { this.awaitAll() .withInput( startMessage(1), - ProtoUtils.inputMessage(), + inputMessage(), completionMessage(1, "FRANCESCO"), - completionMessage(2, "TILL")) + completionMessage(2, "TILL"), + ackMessage(3)) .onlyUnbuffered() .expectingOutput( invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), @@ -254,8 +267,9 @@ public Stream definitions() { this.awaitAll() .withInput( startMessage(1), - ProtoUtils.inputMessage(), - completionMessage(1, new IllegalStateException("My error"))) + inputMessage(), + completionMessage(1, new IllegalStateException("My error")), + ackMessage(3)) .onlyUnbuffered() .expectingOutput( invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), @@ -267,9 +281,10 @@ public Stream definitions() { this.awaitAll() .withInput( startMessage(1), - ProtoUtils.inputMessage(), + inputMessage(), completionMessage(1, "FRANCESCO"), - completionMessage(2, new IllegalStateException("My error"))) + completionMessage(2, new IllegalStateException("My error")), + ackMessage(3)) .onlyUnbuffered() .expectingOutput( invokeMessage(GREETER_SERVICE_TARGET, "Francesco"), @@ -283,7 +298,7 @@ public Stream definitions() { this.combineAnyWithAll() .withInput( startMessage(6), - ProtoUtils.inputMessage(), + inputMessage(), awakeable("1"), awakeable("2"), awakeable("3"), @@ -293,7 +308,7 @@ public Stream definitions() { this.combineAnyWithAll() .withInput( startMessage(6), - ProtoUtils.inputMessage(), + inputMessage(), awakeable("1"), awakeable("2"), awakeable("3"), @@ -306,7 +321,7 @@ public Stream definitions() { this.awaitAnyIndex() .withInput( startMessage(6), - ProtoUtils.inputMessage(), + inputMessage(), awakeable("1"), awakeable("2"), awakeable("3"), @@ -316,7 +331,7 @@ public Stream definitions() { this.awaitAnyIndex() .withInput( startMessage(6), - ProtoUtils.inputMessage(), + inputMessage(), awakeable("1"), awakeable("2"), awakeable("3"), @@ -328,7 +343,12 @@ public Stream definitions() { // --- Compose nested and resolved all should work this.awaitOnAlreadyResolvedAwaitables() .withInput( - startMessage(3), ProtoUtils.inputMessage(), awakeable("1"), awakeable("2")) + startMessage(3), + inputMessage(), + awakeable("1"), + awakeable("2"), + ackMessage(3), + ackMessage(4)) .assertingOutput( msgs -> { assertThat(msgs).hasSize(4); @@ -348,7 +368,10 @@ public Stream definitions() { // --- Await with timeout this.awaitWithTimeout() .withInput( - startMessage(1), ProtoUtils.inputMessage(), completionMessage(1, "FRANCESCO")) + startMessage(1), + inputMessage(), + completionMessage(1, "FRANCESCO"), + ackMessage(3)) .onlyUnbuffered() .assertingOutput( messages -> { @@ -366,10 +389,9 @@ public Stream definitions() { this.awaitWithTimeout() .withInput( startMessage(1), - ProtoUtils.inputMessage(), - Protocol.CompletionMessage.newBuilder() - .setEntryIndex(2) - .setEmpty(Protocol.Empty.getDefaultInstance())) + inputMessage(), + completionMessage(2).setEmpty(Empty.getDefaultInstance()), + ackMessage(3)) .onlyUnbuffered() .assertingOutput( messages -> {