Skip to content

Commit

Permalink
New failure propagation (#100)
Browse files Browse the repository at this point in the history
* Implement ErrorMessage header
* Refactor ProtocolException to hold code numbers outside the gRPC Status.Code range
* Write out the ErrorMessage when there is a failure cause.
* * Now when closing the state machine, we invoke syscalls.fail(cause) on every failure with Code UNKNOWN. In other words, a failure is non-terminal when code is UNKNOWN, otherwise is terminal.
* Now GrpcServerCallListenerAdaptor won't erase the cause anymore when getting an exception from the user code, because we need it for syscalls.fail(cause)
* Make sure we don't write non-terminal failures for side effects
  • Loading branch information
slinkydeveloper committed Jul 24, 2023
1 parent 0093ff2 commit e1a6389
Show file tree
Hide file tree
Showing 19 changed files with 281 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.Status;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand Down Expand Up @@ -67,7 +68,7 @@ private void closeWithException(Throwable e) {
serverCall.close(Util.SUSPENDED_STATUS, new Metadata());
} else {
LOG.warn("Error when processing the invocation", e);
serverCall.close(Util.toGrpcStatusErasingCause(e), new Metadata());
serverCall.close(Status.fromThrowable(e), new Metadata());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.protobuf.ByteString;
import com.google.protobuf.MessageLite;
import com.google.rpc.Code;
import dev.restate.generated.sdk.java.Java;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.sdk.core.InvocationId;
Expand All @@ -13,7 +14,6 @@
import dev.restate.sdk.core.impl.ReadyResults.ReadyResultInternal;
import dev.restate.sdk.core.syscalls.DeferredResult;
import dev.restate.sdk.core.syscalls.SyscallCallback;
import io.grpc.Status;
import io.opentelemetry.api.common.Attributes;
import io.opentelemetry.api.trace.Span;
import java.util.*;
Expand Down Expand Up @@ -221,7 +221,16 @@ void fail(Throwable cause) {
this.inputSubscription.cancel();
}
if (this.outputSubscriber != null) {
this.outputSubscriber.onError(cause);
if (cause instanceof ProtocolException) {
this.outputSubscriber.onNext(((ProtocolException) cause).toErrorMessage());
} else if (cause != null) {
this.outputSubscriber.onNext(
Protocol.ErrorMessage.newBuilder()
.setCode(Code.UNKNOWN_VALUE)
.setMessage(cause.toString())
.build());
}
this.outputSubscriber.onComplete();
this.outputSubscriber = null;
}
this.insideSideEffect = false;
Expand Down Expand Up @@ -640,8 +649,7 @@ private void incrementCurrentIndex() {

private void checkInsideSideEffectGuard() {
if (this.insideSideEffect) {
throw new ProtocolException(
"A syscall was invoked from within a side effect closure.", null, Status.Code.UNKNOWN);
throw ProtocolException.invalidSideEffectCall();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ public static MessageHeader parse(long encoded) throws ProtocolException {
public static MessageHeader fromMessage(MessageLite msg) {
if (msg instanceof Protocol.SuspensionMessage) {
return new MessageHeader(MessageType.SuspensionMessage, (short) 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.ErrorMessage) {
return new MessageHeader(MessageType.ErrorMessage, (short) 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.PollInputStreamEntryMessage) {
return new MessageHeader(
MessageType.PollInputStreamEntryMessage, (short) 0, msg.getSerializedSize());
Expand Down Expand Up @@ -108,8 +110,6 @@ public static MessageHeader fromMessage(MessageLite msg) {
} else if (msg instanceof Java.SideEffectEntryMessage) {
return new MessageHeader(
MessageType.SideEffectEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize());
} else if (msg instanceof Protocol.StartMessage) {
throw new IllegalArgumentException("SDK should never send a StartMessage");
} else if (msg instanceof Protocol.CompletionMessage) {
throw new IllegalArgumentException("SDK should never send a CompletionMessage");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public enum MessageType {
StartMessage,
CompletionMessage,
SuspensionMessage,
ErrorMessage,

// IO
PollInputStreamEntryMessage,
Expand All @@ -33,6 +34,7 @@ public enum MessageType {
public static final short START_MESSAGE_TYPE = 0x0000;
public static final short COMPLETION_MESSAGE_TYPE = 0x0001;
public static final short SUSPENSION_MESSAGE_TYPE = 0x0002;
public static final short ERROR_MESSAGE_TYPE = 0x0003;
public static final short POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE = 0x0400;
public static final short OUTPUT_STREAM_ENTRY_MESSAGE_TYPE = 0x0401;
public static final short GET_STATE_ENTRY_MESSAGE_TYPE = 0x0800;
Expand All @@ -54,6 +56,8 @@ public Parser<? extends MessageLite> messageParser() {
return Protocol.CompletionMessage.parser();
case SuspensionMessage:
return Protocol.SuspensionMessage.parser();
case ErrorMessage:
return Protocol.ErrorMessage.parser();
case PollInputStreamEntryMessage:
return Protocol.PollInputStreamEntryMessage.parser();
case OutputStreamEntryMessage:
Expand Down Expand Up @@ -90,6 +94,8 @@ public short encode() {
return COMPLETION_MESSAGE_TYPE;
case SuspensionMessage:
return SUSPENSION_MESSAGE_TYPE;
case ErrorMessage:
return ERROR_MESSAGE_TYPE;
case PollInputStreamEntryMessage:
return POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE;
case OutputStreamEntryMessage:
Expand Down Expand Up @@ -126,6 +132,8 @@ public static MessageType decode(short value) throws ProtocolException {
return CompletionMessage;
case SUSPENSION_MESSAGE_TYPE:
return SuspensionMessage;
case ERROR_MESSAGE_TYPE:
return ErrorMessage;
case POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE:
return PollInputStreamEntryMessage;
case OUTPUT_STREAM_ENTRY_MESSAGE_TYPE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,47 @@
import com.google.protobuf.MessageLite;
import dev.restate.generated.service.protocol.Protocol;
import io.grpc.Status;
import java.io.PrintWriter;
import java.io.StringWriter;

public class ProtocolException extends RuntimeException {

static final int JOURNAL_MISMATCH_CODE = 32;
static final int PROTOCOL_VIOLATION = 33;

static final ProtocolException CLOSED = new ProtocolException("Invocation closed");

private final Status.Code grpcCode;
private final int failureCode;

private ProtocolException(String message) {
this(message, Status.Code.INTERNAL);
this(message, Status.Code.INTERNAL.value());
}

private ProtocolException(String message, Status.Code grpcCode) {
this(message, null, grpcCode);
private ProtocolException(String message, int failureCode) {
this(message, null, failureCode);
}

public ProtocolException(String message, Throwable cause, Status.Code grpcCode) {
public ProtocolException(String message, Throwable cause, int failureCode) {
super(message, cause);
this.grpcCode = grpcCode;
this.failureCode = failureCode;
}

public int getFailureCode() {
return failureCode;
}

public Status.Code getGrpcCode() {
return grpcCode;
public Protocol.ErrorMessage toErrorMessage() {
// Convert stacktrace to string
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
pw.println("Stacktrace:");
this.printStackTrace(pw);

return Protocol.ErrorMessage.newBuilder()
.setCode(failureCode)
.setMessage(this.toString())
.setDescription(sw.toString())
.build();
}

static ProtocolException unexpectedMessage(
Expand All @@ -34,26 +53,37 @@ static ProtocolException unexpectedMessage(
+ expected.getCanonicalName()
+ "', Actual: '"
+ actual.getClass().getCanonicalName()
+ "'");
+ "'",
PROTOCOL_VIOLATION);
}

static ProtocolException entryDoesNotMatch(MessageLite expected, MessageLite actual) {
return new ProtocolException(
"Journal entry " + expected.getClass() + " does not match: " + expected + " != " + actual);
"Journal entry " + expected.getClass() + " does not match: " + expected + " != " + actual,
JOURNAL_MISMATCH_CODE);
}

static ProtocolException completionDoesNotMatch(
String entry, Protocol.CompletionMessage.ResultCase actual) {
return new ProtocolException(
"Completion for entry " + entry + " doesn't expect completion variant " + actual);
"Completion for entry " + entry + " doesn't expect completion variant " + actual,
JOURNAL_MISMATCH_CODE);
}

static ProtocolException unknownMessageType(short type) {
return new ProtocolException("MessageType " + Integer.toHexString(type) + " unknown");
return new ProtocolException(
"MessageType " + Integer.toHexString(type) + " unknown", PROTOCOL_VIOLATION);
}

static ProtocolException methodNotFound(String svcName, String methodName) {
return new ProtocolException(
"Cannot find method '" + svcName + "/" + methodName + "'", Status.Code.NOT_FOUND);
"Cannot find method '" + svcName + "/" + methodName + "'", Status.Code.NOT_FOUND.value());
}

static ProtocolException invalidSideEffectCall() {
return new ProtocolException(
"A syscall was invoked from within a side effect closure.",
null,
Status.Code.UNKNOWN.value());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import io.grpc.ServerCall;
import io.grpc.Status;
import java.util.Objects;
import java.util.Optional;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand Down Expand Up @@ -88,15 +87,10 @@ public void close(Status status, Metadata trailers) {
// Let's cancel the listener first
listener.onCancel();

Optional<Throwable> protocolException =
Util.findCause(status.getCause(), t -> t instanceof ProtocolException);
if (protocolException.isPresent()) {
// If it's a protocol exception, we propagate the failure to syscalls, which will propagate
// it to the network layer
syscalls.fail(protocolException.get());
if (status.getCode() == Status.Code.UNKNOWN) {
// If no cause, just propagate a generic runtime exception
syscalls.fail(status.getCause() != null ? status.getCause() : status.asRuntimeException());
} else {
// If not a protocol exception, then it's an exception coming from user which we write on
// the journal
syscalls.writeOutput(
status.asRuntimeException(),
SyscallCallback.ofVoid(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.restate.sdk.core.impl;

import static dev.restate.sdk.core.impl.Util.isTerminalException;
import static dev.restate.sdk.core.impl.Util.toProtocolFailure;

import com.google.protobuf.ByteString;
Expand Down Expand Up @@ -197,10 +198,17 @@ public void exitSideEffectBlockWithException(
Throwable toWrite, ExitSideEffectSyscallCallback<?> callback) {
LOG.trace("exitSideEffectBlock with failure");

// If it's a protocol exception, don't write it
Optional<ProtocolException> protocolException = Util.findProtocolException(toWrite);
if (protocolException.isPresent()) {
throw protocolException.get();
// If it's a non-terminal exception (such as a protocol exception),
// we don't write it but simply throw it
if (!(isTerminalException(toWrite))) {
// For safety wrt Syscalls API we do this check and wrapping,
// but with the current APIs the exception should always be RuntimeException
// because that's what can be thrown inside a lambda
if (toWrite instanceof RuntimeException) {
throw (RuntimeException) toWrite;
} else {
throw new RuntimeException(toWrite);
}
}

this.stateMachine.exitSideEffectBlock(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.restate.sdk.core.impl;

import com.google.protobuf.MessageLite;
import com.google.rpc.Code;
import dev.restate.generated.sdk.java.Java;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.sdk.core.SuspendedException;
Expand Down Expand Up @@ -79,6 +80,11 @@ static Status toGrpcStatusErasingCause(Throwable throwable) {
return Status.UNKNOWN.withDescription(throwable.getMessage());
}

static boolean isTerminalException(Throwable throwable) {
return throwable instanceof StatusRuntimeException
&& ((StatusRuntimeException) throwable).getStatus().getCode().value() != Code.UNKNOWN_VALUE;
}

static void assertIsEntry(MessageLite msg) {
if (!isEntry(msg)) {
throw new IllegalStateException("Expected input to be entry");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package dev.restate.sdk.core.impl;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.InstanceOfAssertFactories.STRING;
import static org.assertj.core.api.InstanceOfAssertFactories.type;

import com.google.protobuf.MessageLite;
import com.google.rpc.Code;
import dev.restate.generated.service.protocol.Protocol;
import java.util.List;
import java.util.function.Consumer;

public class AssertUtils {

public static Consumer<List<MessageLite>> containsOnly(Consumer<? super MessageLite> consumer) {
return msgs -> assertThat(msgs).satisfiesExactly(consumer);
}

public static Consumer<List<MessageLite>> containsOnlyExactErrorMessage(Throwable e) {
return containsOnly(exactErrorMessage(e));
}

public static Consumer<? super MessageLite> errorMessage(
Consumer<? super Protocol.ErrorMessage> consumer) {
return msg ->
assertThat(msg).asInstanceOf(type(Protocol.ErrorMessage.class)).satisfies(consumer);
}

public static Consumer<? super MessageLite> exactErrorMessage(Throwable e) {
return errorMessage(
msg ->
assertThat(msg)
.returns(e.toString(), Protocol.ErrorMessage::getMessage)
.returns(Code.UNKNOWN_VALUE, Protocol.ErrorMessage::getCode));
}

public static Consumer<? super MessageLite> errorMessageStartingWith(String str) {
return errorMessage(
msg ->
assertThat(msg).extracting(Protocol.ErrorMessage::getMessage, STRING).startsWith(str));
}

public static Consumer<? super MessageLite> protocolExceptionErrorMessage(int code) {
return errorMessage(
msg ->
assertThat(msg)
.returns(code, Protocol.ErrorMessage::getCode)
.extracting(Protocol.ErrorMessage::getMessage, STRING)
.startsWith(ProtocolException.class.getCanonicalName()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ public BiConsumer<FutureSubscriber<MessageLite>, Duration> getOutputAssert() {
.last()
.isNotNull()
.isInstanceOfAny(
Protocol.OutputStreamEntryMessage.class, Protocol.SuspensionMessage.class);
Protocol.OutputStreamEntryMessage.class,
Protocol.SuspensionMessage.class,
Protocol.ErrorMessage.class);
};
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.restate.sdk.core.impl;

import static dev.restate.sdk.core.impl.AssertUtils.containsOnlyExactErrorMessage;
import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation;
import static dev.restate.sdk.core.impl.ProtoUtils.*;

Expand Down Expand Up @@ -89,6 +90,6 @@ Stream<TestDefinition> definitions() {
testInvocation(new SetNullState(), GreeterGrpc.getGreetMethod())
.withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till")))
.usingAllThreadingModels()
.assertingFailure(NullPointerException.class));
.assertingOutput(containsOnlyExactErrorMessage(new NullPointerException())));
}
}
Loading

0 comments on commit e1a6389

Please sign in to comment.