Skip to content

Commit

Permalink
Now send acks on combinators (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper authored May 29, 2024
1 parent a3f7505 commit 957c7b6
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 114 deletions.
2 changes: 1 addition & 1 deletion sdk-api/src/test/java/dev/restate/sdk/DeferredTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ protected TestInvocationBuilder awaitOnAlreadyResolvedAwaitables() {

protected TestInvocationBuilder awaitWithTimeout() {
return testDefinitionForService(
"AwaitOnAlreadyResolvedAwaitables",
"AwaitWithTimeout",
Serde.VOID,
JsonSerdes.STRING,
(ctx, unused) -> {
Expand Down
56 changes: 56 additions & 0 deletions sdk-core/src/main/java/dev/restate/sdk/core/AckStateMachine.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate Java SDK,
// which is released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.core;

/** State machine tracking acks */
class AckStateMachine extends BaseSuspendableCallbackStateMachine<AckStateMachine.AckCallback> {

interface AckCallback extends SuspendableCallback {
void onAck();
}

private int lastAcknowledgedEntry = -1;

/** -1 means no side effect waiting to be acked. */
private int lastEntryToAck = -1;

void waitLastAck(AckCallback callback) {
if (lastEntryIsAcked()) {
callback.onAck();
} else {
this.setCallback(callback);
}
}

void tryHandleAck(int entryIndex) {
this.lastAcknowledgedEntry = Math.max(entryIndex, this.lastAcknowledgedEntry);
if (lastEntryIsAcked()) {
this.consumeCallback(AckCallback::onAck);
}
}

void registerEntryToAck(int entryIndex) {
this.lastEntryToAck = Math.max(entryIndex, this.lastEntryToAck);
}

private boolean lastEntryIsAcked() {
return this.lastEntryToAck <= this.lastAcknowledgedEntry;
}

public int getLastEntryToAck() {
return lastEntryToAck;
}

@Override
void abort(Throwable cause) {
super.abort(cause);
// We can't do anything else if the input stream is closed, so we just fail the callback, if any
this.tryFailCallback();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class InvocationStateMachine implements InvocationFlow.InvocationProcessor {

// Buffering of messages and completions
private final IncomingEntriesStateMachine incomingEntriesStateMachine;
private final SideEffectAckStateMachine sideEffectAckStateMachine;
private final AckStateMachine ackStateMachine;
private final ReadyResultStateMachine readyResultStateMachine;

// Flow sub/pub
Expand All @@ -75,7 +75,7 @@ class InvocationStateMachine implements InvocationFlow.InvocationProcessor {

this.incomingEntriesStateMachine = new IncomingEntriesStateMachine();
this.readyResultStateMachine = new ReadyResultStateMachine();
this.sideEffectAckStateMachine = new SideEffectAckStateMachine();
this.ackStateMachine = new AckStateMachine();

this.afterStartCallback = new CallbackHandle<>();
}
Expand Down Expand Up @@ -142,8 +142,7 @@ public void onNext(InvocationFlow.InvocationInput invocationInput) {
// runtime.
this.readyResultStateMachine.offerCompletion((Protocol.CompletionMessage) msg);
} else if (msg instanceof Protocol.EntryAckMessage) {
this.sideEffectAckStateMachine.tryHandleSideEffectAck(
((Protocol.EntryAckMessage) msg).getEntryIndex());
this.ackStateMachine.tryHandleAck(((Protocol.EntryAckMessage) msg).getEntryIndex());
} else {
this.incomingEntriesStateMachine.offer(msg);
}
Expand All @@ -159,7 +158,7 @@ public void onError(Throwable throwable) {
public void onComplete() {
LOG.trace("Input publisher closed");
this.readyResultStateMachine.abort(AbortedExecutionException.INSTANCE);
this.sideEffectAckStateMachine.abort(AbortedExecutionException.INSTANCE);
this.ackStateMachine.abort(AbortedExecutionException.INSTANCE);
}

// --- Init routine to wait for the start message
Expand Down Expand Up @@ -287,7 +286,7 @@ private void closeWithMessage(MessageLite closeMessage, Throwable cause) {
// Unblock any eventual waiting callbacks
this.afterStartCallback.consume(cb -> cb.onCancel(cause));
this.readyResultStateMachine.abort(cause);
this.sideEffectAckStateMachine.abort(cause);
this.ackStateMachine.abort(cause);
this.incomingEntriesStateMachine.abort(cause);
this.span.end();
}
Expand Down Expand Up @@ -456,21 +455,21 @@ void exitSideEffectBlock(
}

// Write new entry
this.sideEffectAckStateMachine.registerExecutedSideEffect(this.currentJournalEntryIndex);
this.ackStateMachine.registerEntryToAck(this.currentJournalEntryIndex);
this.writeEntry(sideEffectEntry);

// Wait for entry to be acked
Protocol.RunEntryMessage finalSideEffectEntry = sideEffectEntry;
this.sideEffectAckStateMachine.waitLastSideEffectAck(
new SideEffectAckStateMachine.SideEffectAckCallback() {
this.ackStateMachine.waitLastAck(
new AckStateMachine.AckCallback() {
@Override
public void onLastSideEffectAck() {
public void onAck() {
completeSideEffectCallbackWithEntry(finalSideEffectEntry, callback);
}

@Override
public void onSuspend() {
suspend(List.of(sideEffectAckStateMachine.getLastExecutedSideEffect()));
suspend(List.of(ackStateMachine.getLastEntryToAck()));
callback.onCancel(AbortedExecutionException.INSTANCE);
}

Expand Down Expand Up @@ -621,8 +620,7 @@ private void resolveCombinatorDeferred(
+ "This is a symptom of an SDK bug, please contact the developers.");
}

writeCombinatorEntry(Collections.emptyList());
callback.onSuccess(null);
writeCombinatorEntry(Collections.emptyList(), callback);
return;
}

Expand All @@ -636,8 +634,7 @@ private void resolveCombinatorDeferred(

// Try to resolve the combinator now
if (rootDeferred.tryResolve(entryIndex)) {
writeCombinatorEntry(resolvedOrder);
callback.onSuccess(null);
writeCombinatorEntry(resolvedOrder, callback);
return;
}
} else {
Expand Down Expand Up @@ -667,8 +664,7 @@ public boolean onNewResult(Map<Integer, Result<?>> resultMap) {

// Try to resolve the combinator now
if (rootDeferred.tryResolve(entryIndex)) {
writeCombinatorEntry(resolvedOrder);
callback.onSuccess(null);
writeCombinatorEntry(resolvedOrder, callback);
return true;
}
}
Expand All @@ -694,12 +690,35 @@ public void onError(Throwable e) {
}
}

private void writeCombinatorEntry(List<Integer> resolvedList) {
private void writeCombinatorEntry(List<Integer> resolvedList, SyscallCallback<Void> callback) {
// Create and write the entry
Java.CombinatorAwaitableEntryMessage entry =
Java.CombinatorAwaitableEntryMessage.newBuilder().addAllEntryIndex(resolvedList).build();
span.addEvent("Combinator");

// We register the combinator entry to wait for an ack
this.ackStateMachine.registerEntryToAck(this.currentJournalEntryIndex);
writeEntry(entry);

// Let's wait for the ack
this.ackStateMachine.waitLastAck(
new AckStateMachine.AckCallback() {
@Override
public void onAck() {
callback.onSuccess(null);
}

@Override
public void onSuspend() {
suspend(List.of(ackStateMachine.getLastEntryToAck()));
callback.onCancel(AbortedExecutionException.INSTANCE);
}

@Override
public void onError(Throwable e) {
callback.onCancel(e);
}
});
}

// --- Internal callback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package dev.restate.sdk.core;

import com.google.protobuf.MessageLite;
import dev.restate.generated.sdk.java.Java;
import dev.restate.generated.service.protocol.Protocol;

public class MessageHeader {
Expand Down Expand Up @@ -82,6 +83,9 @@ public static MessageHeader fromMessage(MessageLite msg) {
} else if (msg instanceof Protocol.RunEntryMessage) {
return new MessageHeader(
MessageType.RunEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize());
} else if (msg instanceof Java.CombinatorAwaitableEntryMessage) {
return new MessageHeader(
MessageType.CombinatorAwaitableEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize());
}
// Messages with no flags
return new MessageHeader(MessageType.fromMessage(msg), 0, msg.getSerializedSize());
Expand Down

This file was deleted.

Loading

0 comments on commit 957c7b6

Please sign in to comment.