From 4657365e3373b690e78d888cff107cb87a9f87fa Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 16 Apr 2024 11:22:41 +0200 Subject: [PATCH] Support named entries, and allow users to set side effects name. This commit adds support in the state machine to propagate additional entry info about "failing entries" in the EndMessage, and allows the user to set the name for side effect entries. --- .../dev/restate/sdk/kotlin/ContextImpl.kt | 7 +- .../main/kotlin/dev/restate/sdk/kotlin/api.kt | 13 ++- .../dev/restate/sdk/kotlin/SideEffectTest.kt | 11 ++ .../main/java/dev/restate/sdk/Context.java | 21 +++- .../java/dev/restate/sdk/ContextImpl.java | 3 +- .../java/dev/restate/sdk/SideEffectTest.java | 32 ++++++ .../restate/sdk/common/syscalls/Syscalls.java | 2 +- .../java/dev/restate/sdk/core/Entries.java | 57 ++++++++++ .../sdk/core/ExecutorSwitchingSyscalls.java | 4 +- .../sdk/core/InvocationStateMachine.java | 106 +++++++++++------- .../dev/restate/sdk/core/MessageHeader.java | 34 +----- .../dev/restate/sdk/core/MessageType.java | 43 +++++++ .../restate/sdk/core/ProtocolException.java | 16 --- .../dev/restate/sdk/core/SyscallsImpl.java | 4 +- .../main/java/dev/restate/sdk/core/Util.java | 38 +++++++ .../main/sdk-proto/dev/restate/sdk/java.proto | 6 + .../restate/sdk/core/SideEffectTestSuite.java | 35 +++++- .../workflow/impl/WorkflowContextImpl.java | 5 +- 18 files changed, 333 insertions(+), 104 deletions(-) diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt index f56ad3ea..9c98bcaa 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt @@ -130,11 +130,16 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls) } } - override suspend fun runBlock(serde: Serde, block: suspend () -> T): T { + override suspend fun runBlock( + serde: Serde, + name: String, + block: suspend () -> T + ): T { val exitResult = suspendCancellableCoroutine { cont: CancellableContinuation> -> syscalls.enterSideEffectBlock( + name, object : EnterSideEffectSyscallCallback { override fun onSuccess(t: ByteString?) { val deferred: CompletableDeferred = CompletableDeferred() diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt index f5d38328..a5d50eba 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt @@ -106,6 +106,9 @@ sealed interface Context { * suspension point) without re-executing the closure. Use this feature if you want to perform * non-deterministic operations. * + * You can name this closure using the `name` parameter. This name will be available in the + * observability tools. + * *

The closure should tolerate retries, that is Restate might re-execute the closure multiple * times until it records a result. * @@ -138,11 +141,12 @@ sealed interface Context { * To propagate failures to the run call-site, make sure to wrap them in [TerminalException]. * * @param serde the type tag of the return value, used to serialize/deserialize it. + * @param name the name of the side effect. * @param block closure to execute. * @param T type of the return value. * @return value of the runBlock operation. */ - suspend fun runBlock(serde: Serde, block: suspend () -> T): T + suspend fun runBlock(serde: Serde, name: String = "", block: suspend () -> T): T /** * Create an [Awakeable], addressable through [Awakeable.id]. @@ -221,8 +225,11 @@ sealed interface Context { * @param T type of the return value. * @return value of the runBlock operation. */ -suspend inline fun Context.runBlock(noinline block: suspend () -> T): T { - return this.runBlock(KtSerdes.json(), block) +suspend inline fun Context.runBlock( + name: String = "", + noinline block: suspend () -> T +): T { + return this.runBlock(KtSerdes.json(), name, block) } /** diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt index a3d6f70b..a3afc544 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt +++ b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt @@ -26,6 +26,12 @@ class SideEffectTest : SideEffectTestSuite() { "Hello $result" } + override fun namedSideEffect(name: String, sideEffectOutput: String): TestInvocationBuilder = + testDefinitionForService("SideEffect") { ctx, _: Unit -> + val result = ctx.runBlock(name) { sideEffectOutput } + "Hello $result" + } + override fun consecutiveSideEffect(sideEffectOutput: String): TestInvocationBuilder = testDefinitionForService("ConsecutiveSideEffect") { ctx, _: Unit -> val firstResult = ctx.runBlock { sideEffectOutput } @@ -54,4 +60,9 @@ class SideEffectTest : SideEffectTestSuite() { ctx.runBlock { ctx.send(GREETER_SERVICE_TARGET, KtSerdes.json(), "something") } throw IllegalStateException("This point should not be reached") } + + override fun failingSideEffect(name: String, reason: String): TestInvocationBuilder = + testDefinitionForService("FailingSideEffect") { ctx, _: Unit -> + ctx.runBlock(name) { throw IllegalStateException(reason) } + } } diff --git a/sdk-api/src/main/java/dev/restate/sdk/Context.java b/sdk-api/src/main/java/dev/restate/sdk/Context.java index 171b5d3e..dd01b59b 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Context.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Context.java @@ -100,6 +100,9 @@ default void sleep(Duration duration) { * suspension point) without re-executing the closure. Use this feature if you want to perform * non-deterministic operations. * + *

You can name this closure using the {@code name} parameter. This name will be available in + * the observability tools. + * *

The closure should tolerate retries, that is Restate might re-execute the closure multiple * times until it records a result. * @@ -133,16 +136,18 @@ default void sleep(Duration duration) { * To propagate run failures to the call-site, make sure to wrap them in {@link * TerminalException}. * + * @param name name of the side effect. * @param serde the type tag of the return value, used to serialize/deserialize it. * @param action closure to execute. * @param type of the return value. * @return value of the run operation. */ - T run(Serde serde, ThrowingSupplier action) throws TerminalException; + T run(String name, Serde serde, ThrowingSupplier action) throws TerminalException; - /** Like {@link #run(Serde, ThrowingSupplier)}, but without returning a value. */ - default void run(ThrowingRunnable runnable) throws TerminalException { + /** Like {@link #run(String, Serde, ThrowingSupplier)}, but without returning a value. */ + default void run(String name, ThrowingRunnable runnable) throws TerminalException { run( + name, CoreSerdes.VOID, () -> { runnable.run(); @@ -150,6 +155,16 @@ default void run(ThrowingRunnable runnable) throws TerminalException { }); } + /** Like {@link #run(String, Serde, ThrowingSupplier)}, but without a name. */ + default T run(Serde serde, ThrowingSupplier action) throws TerminalException { + return run(null, serde, action); + } + + /** Like {@link #run(String, ThrowingRunnable)}, but without a name. */ + default void run(ThrowingRunnable runnable) throws TerminalException { + run(null, runnable); + } + /** * Create an {@link Awakeable}, addressable through {@link Awakeable#id()}. * diff --git a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java index 1704f8f6..89d1d084 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java @@ -110,9 +110,10 @@ public void send(Target target, Serde inputSerde, T parameter, Duration d } @Override - public T run(Serde serde, ThrowingSupplier action) { + public T run(String name, Serde serde, ThrowingSupplier action) { CompletableFuture> enterFut = new CompletableFuture<>(); syscalls.enterSideEffectBlock( + name, new EnterSideEffectSyscallCallback() { @Override public void onNotExecuted() { diff --git a/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java b/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java index 38c30210..a34c7e6e 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java @@ -18,6 +18,7 @@ public class SideEffectTest extends SideEffectTestSuite { + @Override protected TestInvocationBuilder sideEffect(String sideEffectOutput) { return testDefinitionForService( "SideEffect", @@ -29,6 +30,19 @@ protected TestInvocationBuilder sideEffect(String sideEffectOutput) { }); } + @Override + protected TestInvocationBuilder namedSideEffect(String name, String sideEffectOutput) { + return testDefinitionForService( + "SideEffect", + CoreSerdes.VOID, + CoreSerdes.JSON_STRING, + (ctx, unused) -> { + String result = ctx.run(name, CoreSerdes.JSON_STRING, () -> sideEffectOutput); + return "Hello " + result; + }); + } + + @Override protected TestInvocationBuilder consecutiveSideEffect(String sideEffectOutput) { return testDefinitionForService( "ConsecutiveSideEffect", @@ -42,6 +56,7 @@ protected TestInvocationBuilder consecutiveSideEffect(String sideEffectOutput) { }); } + @Override protected TestInvocationBuilder checkContextSwitching() { return testDefinitionForService( "CheckContextSwitching", @@ -65,6 +80,7 @@ protected TestInvocationBuilder checkContextSwitching() { }); } + @Override protected TestInvocationBuilder sideEffectGuard() { return testDefinitionForService( "SideEffectGuard", @@ -75,4 +91,20 @@ protected TestInvocationBuilder sideEffectGuard() { throw new IllegalStateException("This point should not be reached"); }); } + + @Override + protected TestInvocationBuilder failingSideEffect(String name, String reason) { + return testDefinitionForService( + "FailingSideEffect", + CoreSerdes.VOID, + CoreSerdes.JSON_STRING, + (ctx, unused) -> { + ctx.run( + name, + () -> { + throw new IllegalStateException(reason); + }); + return null; + }); + } } diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java index a4d5a627..495481b6 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java @@ -68,7 +68,7 @@ void send( @Nullable Duration delay, SyscallCallback requestCallback); - void enterSideEffectBlock(EnterSideEffectSyscallCallback callback); + void enterSideEffectBlock(@Nullable String name, EnterSideEffectSyscallCallback callback); void exitSideEffectBlock(ByteString toWrite, ExitSideEffectSyscallCallback callback); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java b/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java index 3e0a193c..2e7f7706 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java @@ -26,6 +26,8 @@ final class Entries { private Entries() {} abstract static class JournalEntry { + abstract String getName(E expected); + void checkEntryHeader(E expected, MessageLite actual) throws ProtocolException {} abstract void trace(E expected, Span span); @@ -57,6 +59,11 @@ static final class OutputEntry extends JournalEntry { private OutputEntry() {} + @Override + String getName(OutputEntryMessage expected) { + return expected.getName(); + } + @Override public void trace(OutputEntryMessage expected, Span span) { span.addEvent("Output"); @@ -81,6 +88,11 @@ public boolean hasResult(GetStateEntryMessage actual) { return actual.getResultCase() != GetStateEntryMessage.ResultCase.RESULT_NOT_SET; } + @Override + String getName(GetStateEntryMessage expected) { + return expected.getName(); + } + @Override void checkEntryHeader(GetStateEntryMessage expected, MessageLite actual) throws ProtocolException { @@ -163,6 +175,11 @@ public boolean hasResult(GetStateKeysEntryMessage actual) { return actual.getResultCase() != GetStateKeysEntryMessage.ResultCase.RESULT_NOT_SET; } + @Override + String getName(GetStateKeysEntryMessage expected) { + return expected.getName(); + } + @Override void checkEntryHeader(GetStateKeysEntryMessage expected, MessageLite actual) throws ProtocolException { @@ -232,6 +249,11 @@ public void trace(ClearStateEntryMessage expected, Span span) { "ClearState", Attributes.of(Tracing.RESTATE_STATE_KEY, expected.getKey().toString())); } + @Override + String getName(ClearStateEntryMessage expected) { + return expected.getName(); + } + @Override void checkEntryHeader(ClearStateEntryMessage expected, MessageLite actual) throws ProtocolException { @@ -256,6 +278,11 @@ public void trace(ClearAllStateEntryMessage expected, Span span) { span.addEvent("ClearAllState"); } + @Override + String getName(ClearAllStateEntryMessage expected) { + return expected.getName(); + } + @Override void checkEntryHeader(ClearAllStateEntryMessage expected, MessageLite actual) throws ProtocolException { @@ -281,6 +308,11 @@ public void trace(SetStateEntryMessage expected, Span span) { "SetState", Attributes.of(Tracing.RESTATE_STATE_KEY, expected.getKey().toString())); } + @Override + String getName(SetStateEntryMessage expected) { + return expected.getName(); + } + @Override void checkEntryHeader(SetStateEntryMessage expected, MessageLite actual) throws ProtocolException { @@ -305,6 +337,11 @@ static final class SleepEntry extends CompletableJournalEntry syscalls.enterSideEffectBlock(callback)); + public void enterSideEffectBlock(String name, EnterSideEffectSyscallCallback callback) { + syscallsExecutor.execute(() -> syscalls.enterSideEffectBlock(name, callback)); } @Override diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java index 59dc3dd9..f954e345 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java @@ -21,7 +21,6 @@ import io.opentelemetry.api.trace.Span; import java.util.*; import java.util.concurrent.Flow; -import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; @@ -43,12 +42,15 @@ class InvocationStateMachine implements InvocationFlow.InvocationProcessor { // Obtained after WAITING_START private ByteString id; + private String debugId; private String key; private int entriesToReplay; private UserStateStore userStateStore; - // Index tracking progress in the journal - private int currentJournalIndex; + // Those values track the progress in the journal + private int currentJournalEntryIndex = -1; + private String currentJournalEntryName = null; + private MessageType currentJournalEntryType = null; // Buffering of messages and completions private final IncomingEntriesStateMachine incomingEntriesStateMachine; @@ -176,6 +178,7 @@ void onStartMessage(MessageLite msg) { // Unpack the StartMessage Protocol.StartMessage startMessage = (Protocol.StartMessage) msg; this.id = startMessage.getId(); + this.debugId = startMessage.getDebugId(); InvocationId invocationId = new InvocationIdImpl(startMessage.getDebugId()); this.key = startMessage.getKey(); this.entriesToReplay = startMessage.getKnownEntries(); @@ -212,8 +215,9 @@ void onStartMessage(MessageLite msg) { this.inputSubscription.request(Long.MAX_VALUE); // Now wait input entry + this.nextJournalEntry(null, MessageType.InputEntryMessage); this.readEntry( - (i, inputMsg) -> { + inputMsg -> { if (!(inputMsg instanceof Protocol.InputEntryMessage)) { throw ProtocolException.unexpectedMessage(Protocol.InputEntryMessage.class, inputMsg); } @@ -251,17 +255,13 @@ void suspend(Collection suspensionIndexes) { void fail(Throwable cause) { LOG.warn("Invocation failed", cause); - Protocol.ErrorMessage msg; - if (cause instanceof ProtocolException) { - msg = ((ProtocolException) cause).toErrorMessage(); - } else { - msg = - Protocol.ErrorMessage.newBuilder() - .setCode(TerminalException.INTERNAL_SERVER_ERROR_CODE) - .setMessage(cause.toString()) - .build(); - } - this.closeWithMessage(msg, cause); + this.closeWithMessage( + Util.toErrorMessage( + cause, + this.currentJournalEntryIndex, + this.currentJournalEntryName, + this.currentJournalEntryType), + cause); } private void closeWithMessage(MessageLite closeMessage, Throwable cause) { @@ -295,12 +295,15 @@ void processCompletableJournalEntry( Entries.CompletableJournalEntry journalEntry, SyscallCallback> callback) { checkInsideSideEffectGuard(); + this.nextJournalEntry( + journalEntry.getName(expectedEntryMessage), MessageType.fromMessage(expectedEntryMessage)); + if (this.invocationState == InvocationState.CLOSED) { callback.onCancel(AbortedExecutionException.INSTANCE); } else if (this.invocationState == InvocationState.REPLAYING) { // Retrieve the entry this.readEntry( - (entryIndex, actualEntryMessage) -> { + actualEntryMessage -> { journalEntry.checkEntryHeader(expectedEntryMessage, actualEntryMessage); if (journalEntry.hasResult((E) actualEntryMessage)) { @@ -308,17 +311,19 @@ void processCompletableJournalEntry( journalEntry.updateUserStateStoreWithEntry( (E) actualEntryMessage, this.userStateStore); Result readyResultInternal = journalEntry.parseEntryResult((E) actualEntryMessage); - callback.onSuccess(DeferredResults.completedSingle(entryIndex, readyResultInternal)); + callback.onSuccess( + DeferredResults.completedSingle( + this.currentJournalEntryIndex, readyResultInternal)); } else { // Entry is not completed yet this.readyResultStateMachine.offerCompletionParser( - entryIndex, + this.currentJournalEntryIndex, completionMessage -> { journalEntry.updateUserStateStorageWithCompletion( (E) actualEntryMessage, completionMessage, this.userStateStore); return journalEntry.parseCompletionResult(completionMessage); }); - callback.onSuccess(DeferredResults.single(entryIndex)); + callback.onSuccess(DeferredResults.single(this.currentJournalEntryIndex)); } }, callback::onCancel); @@ -331,9 +336,6 @@ void processCompletableJournalEntry( journalEntry.trace(entryToWrite, span); } - // Retrieve the index - int entryIndex = this.currentJournalIndex; - // Write out the input entry this.writeEntry(entryToWrite); @@ -341,11 +343,11 @@ void processCompletableJournalEntry( // Complete it with the result, as we already have it callback.onSuccess( DeferredResults.completedSingle( - entryIndex, journalEntry.parseEntryResult(entryToWrite))); + this.currentJournalEntryIndex, journalEntry.parseEntryResult(entryToWrite))); } else { // Register the completion parser this.readyResultStateMachine.offerCompletionParser( - entryIndex, + this.currentJournalEntryIndex, completionMessage -> { journalEntry.updateUserStateStorageWithCompletion( entryToWrite, completionMessage, this.userStateStore); @@ -353,7 +355,7 @@ void processCompletableJournalEntry( }); // Call the onSuccess - callback.onSuccess(DeferredResults.single(entryIndex)); + callback.onSuccess(DeferredResults.single(this.currentJournalEntryIndex)); } } else { throw new IllegalStateException( @@ -367,12 +369,15 @@ void processJournalEntry( Entries.JournalEntry journalEntry, SyscallCallback callback) { checkInsideSideEffectGuard(); + this.nextJournalEntry( + journalEntry.getName(expectedEntryMessage), MessageType.fromMessage(expectedEntryMessage)); + if (this.invocationState == InvocationState.CLOSED) { callback.onCancel(AbortedExecutionException.INSTANCE); } else if (this.invocationState == InvocationState.REPLAYING) { // Retrieve the entry this.readEntry( - (entryIndex, actualEntryMessage) -> { + actualEntryMessage -> { journalEntry.checkEntryHeader(expectedEntryMessage, actualEntryMessage); journalEntry.updateUserStateStoreWithEntry((E) actualEntryMessage, this.userStateStore); callback.onSuccess(null); @@ -397,14 +402,16 @@ void processJournalEntry( } } - void enterSideEffectBlock(EnterSideEffectSyscallCallback callback) { + void enterSideEffectBlock(String name, EnterSideEffectSyscallCallback callback) { checkInsideSideEffectGuard(); + this.nextJournalEntry(name, MessageType.SideEffectEntryMessage); + if (this.invocationState == InvocationState.CLOSED) { callback.onCancel(AbortedExecutionException.INSTANCE); } else if (this.invocationState == InvocationState.REPLAYING) { // Retrieve the entry this.readEntry( - (entryIndex, msg) -> { + msg -> { Util.assertEntryClass(Java.SideEffectEntryMessage.class, msg); // We have a result already, complete the callback @@ -437,16 +444,22 @@ void exitSideEffectBlock( span.addEvent("Exit SideEffect"); } + // For side effects, let's write out the name too, if available + if (this.currentJournalEntryName != null) { + sideEffectEntry = sideEffectEntry.toBuilder().setName(this.currentJournalEntryName).build(); + } + // Write new entry - this.sideEffectAckStateMachine.registerExecutedSideEffect(this.currentJournalIndex); + this.sideEffectAckStateMachine.registerExecutedSideEffect(this.currentJournalEntryIndex); this.writeEntry(sideEffectEntry); // Wait for entry to be acked + Java.SideEffectEntryMessage finalSideEffectEntry = sideEffectEntry; this.sideEffectAckStateMachine.waitLastSideEffectAck( new SideEffectAckStateMachine.SideEffectAckCallback() { @Override public void onLastSideEffectAck() { - completeSideEffectCallbackWithEntry(sideEffectEntry, callback); + completeSideEffectCallbackWithEntry(finalSideEffectEntry, callback); } @Override @@ -570,10 +583,12 @@ private void resolveCombinatorDeferred( // Calling .await() on a combinator deferred within a side effect is not allowed // as resolving it creates or read a journal entry. checkInsideSideEffectGuard(); + this.nextJournalEntry(null, MessageType.CombinatorAwaitableEntryMessage); + if (Objects.equals(this.invocationState, InvocationState.REPLAYING)) { // Retrieve the CombinatorAwaitableEntryMessage this.readEntry( - (entryIndex, actualMsg) -> { + actualMsg -> { Util.assertEntryClass(Java.CombinatorAwaitableEntryMessage.class, actualMsg); if (!rootDeferred.tryResolve( @@ -688,16 +703,14 @@ private void transitionState(InvocationState newInvocationState) { // Cannot move out of the closed state return; } - LOG.debug("Transitioning {} to {}", this, newInvocationState); + LOG.debug("Transitioning state machine to {}", newInvocationState); this.invocationState = newInvocationState; this.loggingContextSetter.set( RestateEndpoint.LoggingContextSetter.INVOCATION_STATUS_KEY, newInvocationState.toString()); } - private void incrementCurrentIndex() { - this.currentJournalIndex++; - - if (currentJournalIndex >= entriesToReplay + private void tryTransitionProcessing() { + if (currentJournalEntryIndex == entriesToReplay - 1 && this.invocationState == InvocationState.REPLAYING) { if (!this.incomingEntriesStateMachine.isEmpty()) { throw new IllegalStateException("Entries queue should be empty at this point"); @@ -706,19 +719,31 @@ private void incrementCurrentIndex() { } } + private void nextJournalEntry(String entryName, MessageType entryType) { + this.currentJournalEntryIndex++; + this.currentJournalEntryName = entryName; + this.currentJournalEntryType = entryType; + + LOG.debug( + "Current journal entry [{}]({}): {}", + this.currentJournalEntryIndex, + this.currentJournalEntryName, + this.currentJournalEntryType); + } + private void checkInsideSideEffectGuard() { if (this.insideSideEffect) { throw ProtocolException.invalidSideEffectCall(); } } - void readEntry(BiConsumer msgCallback, Consumer errorCallback) { + void readEntry(Consumer msgCallback, Consumer errorCallback) { this.incomingEntriesStateMachine.read( new IncomingEntriesStateMachine.OnEntryCallback() { @Override public void onEntry(MessageLite msg) { - incrementCurrentIndex(); - msgCallback.accept(currentJournalIndex - 1, msg); + tryTransitionProcessing(); + msgCallback.accept(msg); } @Override @@ -737,11 +762,10 @@ public void onError(Throwable e) { private void writeEntry(MessageLite message) { LOG.trace("Writing to output message {} {}", message.getClass(), message); Objects.requireNonNull(this.outputSubscriber).onNext(message); - this.incrementCurrentIndex(); } @Override public String toString() { - return "InvocationStateMachine{id=" + id + '}'; + return "InvocationStateMachine[" + debugId + ']'; } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java b/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java index 99cc9e63..cf2b9ae7 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java @@ -55,19 +55,7 @@ public static MessageHeader parse(long encoded) throws ProtocolException { } public static MessageHeader fromMessage(MessageLite msg) { - if (msg instanceof Protocol.SuspensionMessage) { - return new MessageHeader(MessageType.SuspensionMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.ErrorMessage) { - return new MessageHeader(MessageType.ErrorMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.EndMessage) { - return new MessageHeader(MessageType.EndMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.EntryAckMessage) { - return new MessageHeader(MessageType.EntryAckMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.InputEntryMessage) { - return new MessageHeader(MessageType.InputEntryMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.OutputEntryMessage) { - return new MessageHeader(MessageType.OutputEntryMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.GetStateEntryMessage) { + if (msg instanceof Protocol.GetStateEntryMessage) { return new MessageHeader( MessageType.GetStateEntryMessage, ((Protocol.GetStateEntryMessage) msg).getResultCase() @@ -75,12 +63,6 @@ public static MessageHeader fromMessage(MessageLite msg) { ? DONE_FLAG : 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.SetStateEntryMessage) { - return new MessageHeader(MessageType.SetStateEntryMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.ClearStateEntryMessage) { - return new MessageHeader(MessageType.ClearStateEntryMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.ClearAllStateEntryMessage) { - return new MessageHeader(MessageType.ClearAllStateEntryMessage, 0, msg.getSerializedSize()); } else if (msg instanceof Protocol.GetStateKeysEntryMessage) { return new MessageHeader( MessageType.GetStateKeysEntryMessage, @@ -105,9 +87,6 @@ public static MessageHeader fromMessage(MessageLite msg) { ? DONE_FLAG : 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.BackgroundInvokeEntryMessage) { - return new MessageHeader( - MessageType.BackgroundInvokeEntryMessage, 0, msg.getSerializedSize()); } else if (msg instanceof Protocol.AwakeableEntryMessage) { return new MessageHeader( MessageType.AwakeableEntryMessage, @@ -116,19 +95,12 @@ public static MessageHeader fromMessage(MessageLite msg) { ? DONE_FLAG : 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.CompleteAwakeableEntryMessage) { - return new MessageHeader( - MessageType.CompleteAwakeableEntryMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Java.CombinatorAwaitableEntryMessage) { - return new MessageHeader( - MessageType.CombinatorAwaitableEntryMessage, 0, msg.getSerializedSize()); } else if (msg instanceof Java.SideEffectEntryMessage) { return new MessageHeader( MessageType.SideEffectEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize()); - } else if (msg instanceof Protocol.CompletionMessage) { - throw new IllegalArgumentException("SDK should never send a CompletionMessage"); } - throw new IllegalStateException(); + // Messages with no flags + return new MessageHeader(MessageType.fromMessage(msg), 0, msg.getSerializedSize()); } public static void checkProtocolVersion(MessageHeader header) { diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java b/sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java index 2b79cd6b..4f013a05 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java @@ -201,4 +201,47 @@ public static MessageType decode(short value) throws ProtocolException { } throw ProtocolException.unknownMessageType(value); } + + public static MessageType fromMessage(MessageLite msg) { + if (msg instanceof Protocol.SuspensionMessage) { + return MessageType.SuspensionMessage; + } else if (msg instanceof Protocol.ErrorMessage) { + return MessageType.ErrorMessage; + } else if (msg instanceof Protocol.EndMessage) { + return MessageType.EndMessage; + } else if (msg instanceof Protocol.EntryAckMessage) { + return MessageType.EntryAckMessage; + } else if (msg instanceof Protocol.InputEntryMessage) { + return MessageType.InputEntryMessage; + } else if (msg instanceof Protocol.OutputEntryMessage) { + return MessageType.OutputEntryMessage; + } else if (msg instanceof Protocol.GetStateEntryMessage) { + return MessageType.GetStateEntryMessage; + } else if (msg instanceof Protocol.SetStateEntryMessage) { + return MessageType.SetStateEntryMessage; + } else if (msg instanceof Protocol.ClearStateEntryMessage) { + return MessageType.ClearStateEntryMessage; + } else if (msg instanceof Protocol.ClearAllStateEntryMessage) { + return MessageType.ClearAllStateEntryMessage; + } else if (msg instanceof Protocol.GetStateKeysEntryMessage) { + return MessageType.GetStateKeysEntryMessage; + } else if (msg instanceof Protocol.SleepEntryMessage) { + return MessageType.SleepEntryMessage; + } else if (msg instanceof Protocol.InvokeEntryMessage) { + return MessageType.InvokeEntryMessage; + } else if (msg instanceof Protocol.BackgroundInvokeEntryMessage) { + return MessageType.BackgroundInvokeEntryMessage; + } else if (msg instanceof Protocol.AwakeableEntryMessage) { + return MessageType.AwakeableEntryMessage; + } else if (msg instanceof Protocol.CompleteAwakeableEntryMessage) { + return MessageType.CompleteAwakeableEntryMessage; + } else if (msg instanceof Java.CombinatorAwaitableEntryMessage) { + return MessageType.CombinatorAwaitableEntryMessage; + } else if (msg instanceof Java.SideEffectEntryMessage) { + return MessageType.SideEffectEntryMessage; + } else if (msg instanceof Protocol.CompletionMessage) { + throw new IllegalArgumentException("SDK should never send a CompletionMessage"); + } + throw new IllegalStateException(); + } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java b/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java index 25d5089f..588eeba8 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java @@ -11,8 +11,6 @@ import com.google.protobuf.MessageLite; import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.common.TerminalException; -import java.io.PrintWriter; -import java.io.StringWriter; public class ProtocolException extends RuntimeException { @@ -42,20 +40,6 @@ public int getCode() { return code; } - public Protocol.ErrorMessage toErrorMessage() { - // Convert stacktrace to string - StringWriter sw = new StringWriter(); - PrintWriter pw = new PrintWriter(sw); - pw.println("Stacktrace:"); - this.printStackTrace(pw); - - return Protocol.ErrorMessage.newBuilder() - .setCode(code) - .setMessage(this.toString()) - .setDescription(sw.toString()) - .build(); - } - static ProtocolException unexpectedMessage( Class expected, MessageLite actual) { return new ProtocolException( diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java index fe261225..80036c2e 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java @@ -219,11 +219,11 @@ public void send( } @Override - public void enterSideEffectBlock(EnterSideEffectSyscallCallback callback) { + public void enterSideEffectBlock(String name, EnterSideEffectSyscallCallback callback) { wrapAndPropagateExceptions( () -> { LOG.trace("enterSideEffectBlock"); - this.stateMachine.enterSideEffectBlock(callback); + this.stateMachine.enterSideEffectBlock(name, callback); }, callback); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/Util.java b/sdk-core/src/main/java/dev/restate/sdk/core/Util.java index 0acebcf1..2d668398 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/Util.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/Util.java @@ -13,9 +13,12 @@ import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.common.AbortedExecutionException; import dev.restate.sdk.common.TerminalException; +import java.io.PrintWriter; +import java.io.StringWriter; import java.util.Objects; import java.util.Optional; import java.util.function.Predicate; +import org.jspecify.annotations.Nullable; public final class Util { private Util() {} @@ -76,6 +79,41 @@ static Protocol.Failure toProtocolFailure(Throwable throwable) { return toProtocolFailure(TerminalException.INTERNAL_SERVER_ERROR_CODE, throwable.toString()); } + static Protocol.ErrorMessage toErrorMessage( + Throwable throwable, + int currentJournalIndex, + @Nullable String currentJournalEntryName, + @Nullable MessageType currentJournalEntryType) { + Protocol.ErrorMessage.Builder msg = + Protocol.ErrorMessage.newBuilder().setMessage(throwable.toString()); + + if (throwable instanceof ProtocolException) { + msg.setCode(((ProtocolException) throwable).getCode()); + } else { + msg.setCode(TerminalException.INTERNAL_SERVER_ERROR_CODE); + } + + // Convert stacktrace to string + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + pw.println("Stacktrace:"); + throwable.printStackTrace(pw); + msg.setDescription(sw.toString()); + + // Add journal entry info + if (currentJournalIndex >= 0) { + msg.setRelatedEntryIndex(currentJournalIndex); + } + if (currentJournalEntryName != null) { + msg.setRelatedEntryName(currentJournalEntryName); + } + if (currentJournalEntryType != null) { + msg.setRelatedEntryType(currentJournalEntryType.encode()); + } + + return msg.build(); + } + static TerminalException toRestateException(Protocol.Failure failure) { return new TerminalException(failure.getCode(), failure.getMessage()); } diff --git a/sdk-core/src/main/sdk-proto/dev/restate/sdk/java.proto b/sdk-core/src/main/sdk-proto/dev/restate/sdk/java.proto index 0e31615e..71e2a0fe 100644 --- a/sdk-core/src/main/sdk-proto/dev/restate/sdk/java.proto +++ b/sdk-core/src/main/sdk-proto/dev/restate/sdk/java.proto @@ -18,6 +18,9 @@ option java_package = "dev.restate.generated.sdk.java"; // Type: 0xFC00 + 0 message CombinatorAwaitableEntryMessage { repeated uint32 entry_index = 1; + + // Entry name + string name = 12; } // Type: 0xFC00 + 1 @@ -27,4 +30,7 @@ message SideEffectEntryMessage { bytes value = 14; dev.restate.service.protocol.Failure failure = 15; }; + + // Entry name + string name = 12; } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java index 48e58f3f..e475db22 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java @@ -8,26 +8,33 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import static dev.restate.sdk.core.AssertUtils.containsOnlyExactErrorMessage; +import static dev.restate.sdk.core.AssertUtils.*; import static dev.restate.sdk.core.ProtoUtils.*; import static dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.STRING; import static org.assertj.core.api.InstanceOfAssertFactories.type; import dev.restate.generated.sdk.java.Java; +import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.common.CoreSerdes; +import dev.restate.sdk.common.TerminalException; import java.util.stream.Stream; public abstract class SideEffectTestSuite implements TestDefinitions.TestSuite { protected abstract TestInvocationBuilder sideEffect(String sideEffectOutput); + protected abstract TestInvocationBuilder namedSideEffect(String name, String sideEffectOutput); + protected abstract TestInvocationBuilder consecutiveSideEffect(String sideEffectOutput); protected abstract TestInvocationBuilder checkContextSwitching(); protected abstract TestInvocationBuilder sideEffectGuard(); + protected abstract TestInvocationBuilder failingSideEffect(String name, String reason); + @Override public Stream definitions() { return Stream.of( @@ -46,6 +53,13 @@ public Stream definitions() { outputMessage("Hello Francesco"), END_MESSAGE) .named("Without optimization and with acks returns"), + this.namedSideEffect("get-my-name", "Francesco") + .withInput(startMessage(1), inputMessage("Till")) + .expectingOutput( + Java.SideEffectEntryMessage.newBuilder() + .setName("get-my-name") + .setValue(CoreSerdes.JSON_STRING.serializeToByteString("Francesco")), + suspensionMessage(1)), this.consecutiveSideEffect("Francesco") .withInput(startMessage(1), inputMessage("Till")) .expectingOutput( @@ -74,6 +88,25 @@ public Stream definitions() { outputMessage("Hello FRANCESCO"), END_MESSAGE) .named("With optimization and ack on first and second side effect will resume"), + this.failingSideEffect("my-side-effect", "some failure") + .withInput(startMessage(1), inputMessage()) + .onlyUnbuffered() + .assertingOutput( + containsOnly( + errorMessage( + errorMessage -> + assertThat(errorMessage) + .returns( + TerminalException.INTERNAL_SERVER_ERROR_CODE, + Protocol.ErrorMessage::getCode) + .returns(1, Protocol.ErrorMessage::getRelatedEntryIndex) + .returns( + (int) MessageType.SideEffectEntryMessage.encode(), + Protocol.ErrorMessage::getRelatedEntryType) + .returns( + "my-side-effect", Protocol.ErrorMessage::getRelatedEntryName) + .extracting(Protocol.ErrorMessage::getMessage, STRING) + .contains("some failure")))), // --- Other tests this.checkContextSwitching() diff --git a/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/impl/WorkflowContextImpl.java b/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/impl/WorkflowContextImpl.java index 6d67b694..7ea9523b 100644 --- a/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/impl/WorkflowContextImpl.java +++ b/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/impl/WorkflowContextImpl.java @@ -209,8 +209,9 @@ public void send(Target target, Serde inputSerde, T parameter, Duration d } @Override - public T run(Serde serde, ThrowingSupplier action) throws TerminalException { - return ctx.run(serde, action); + public T run(String name, Serde serde, ThrowingSupplier action) + throws TerminalException { + return ctx.run(name, serde, action); } @Override