From e1a63897a8eeaa76d18fcd5daf5fc9111b8c8bb6 Mon Sep 17 00:00:00 2001 From: Francesco Guardiani Date: Mon, 24 Jul 2023 17:27:09 +0200 Subject: [PATCH] New failure propagation (#100) * Implement ErrorMessage header * Refactor ProtocolException to hold code numbers outside the gRPC Status.Code range * Write out the ErrorMessage when there is a failure cause. * * Now when closing the state machine, we invoke syscalls.fail(cause) on every failure with Code UNKNOWN. In other words, a failure is non-terminal when code is UNKNOWN, otherwise is terminal. * Now GrpcServerCallListenerAdaptor won't erase the cause anymore when getting an exception from the user code, because we need it for syscalls.fail(cause) * Make sure we don't write non-terminal failures for side effects --- .../impl/GrpcServerCallListenerAdaptor.java | 3 +- .../sdk/core/impl/InvocationStateMachine.java | 16 ++- .../restate/sdk/core/impl/MessageHeader.java | 4 +- .../restate/sdk/core/impl/MessageType.java | 8 ++ .../sdk/core/impl/ProtocolException.java | 56 +++++++--- .../sdk/core/impl/RestateServerCall.java | 12 +- .../restate/sdk/core/impl/SyscallsImpl.java | 16 ++- .../java/dev/restate/sdk/core/impl/Util.java | 6 + .../restate/sdk/core/impl/AssertUtils.java | 51 +++++++++ .../restate/sdk/core/impl/CoreTestRunner.java | 4 +- .../sdk/core/impl/GetAndSetStateTest.java | 3 +- .../dev/restate/sdk/core/impl/ProtoUtils.java | 17 ++- .../restate/sdk/core/impl/SideEffectTest.java | 4 +- .../core/impl/StateMachineFailuresTest.java | 19 +++- .../sdk/core/impl/UserFailuresTest.java | 103 +++++++++++++++++- .../sdk/lambda/LambdaRestateServer.java | 2 +- .../sdk/vertx/HttpRequestFlowAdapter.java | 1 + .../sdk/vertx/HttpResponseFlowAdapter.java | 4 +- .../sdk/vertx/RequestHttpServerHandler.java | 2 +- 19 files changed, 281 insertions(+), 50 deletions(-) create mode 100644 sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/AssertUtils.java diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/GrpcServerCallListenerAdaptor.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/GrpcServerCallListenerAdaptor.java index 6540b776..a203f426 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/GrpcServerCallListenerAdaptor.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/GrpcServerCallListenerAdaptor.java @@ -2,6 +2,7 @@ import io.grpc.Metadata; import io.grpc.ServerCall; +import io.grpc.Status; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -67,7 +68,7 @@ private void closeWithException(Throwable e) { serverCall.close(Util.SUSPENDED_STATUS, new Metadata()); } else { LOG.warn("Error when processing the invocation", e); - serverCall.close(Util.toGrpcStatusErasingCause(e), new Metadata()); + serverCall.close(Status.fromThrowable(e), new Metadata()); } } } diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InvocationStateMachine.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InvocationStateMachine.java index a18e9a8a..817b4ed7 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InvocationStateMachine.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InvocationStateMachine.java @@ -2,6 +2,7 @@ import com.google.protobuf.ByteString; import com.google.protobuf.MessageLite; +import com.google.rpc.Code; import dev.restate.generated.sdk.java.Java; import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.core.InvocationId; @@ -13,7 +14,6 @@ import dev.restate.sdk.core.impl.ReadyResults.ReadyResultInternal; import dev.restate.sdk.core.syscalls.DeferredResult; import dev.restate.sdk.core.syscalls.SyscallCallback; -import io.grpc.Status; import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.trace.Span; import java.util.*; @@ -221,7 +221,16 @@ void fail(Throwable cause) { this.inputSubscription.cancel(); } if (this.outputSubscriber != null) { - this.outputSubscriber.onError(cause); + if (cause instanceof ProtocolException) { + this.outputSubscriber.onNext(((ProtocolException) cause).toErrorMessage()); + } else if (cause != null) { + this.outputSubscriber.onNext( + Protocol.ErrorMessage.newBuilder() + .setCode(Code.UNKNOWN_VALUE) + .setMessage(cause.toString()) + .build()); + } + this.outputSubscriber.onComplete(); this.outputSubscriber = null; } this.insideSideEffect = false; @@ -640,8 +649,7 @@ private void incrementCurrentIndex() { private void checkInsideSideEffectGuard() { if (this.insideSideEffect) { - throw new ProtocolException( - "A syscall was invoked from within a side effect closure.", null, Status.Code.UNKNOWN); + throw ProtocolException.invalidSideEffectCall(); } } diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/MessageHeader.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/MessageHeader.java index 9814fc35..77ddd968 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/MessageHeader.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/MessageHeader.java @@ -55,6 +55,8 @@ public static MessageHeader parse(long encoded) throws ProtocolException { public static MessageHeader fromMessage(MessageLite msg) { if (msg instanceof Protocol.SuspensionMessage) { return new MessageHeader(MessageType.SuspensionMessage, (short) 0, msg.getSerializedSize()); + } else if (msg instanceof Protocol.ErrorMessage) { + return new MessageHeader(MessageType.ErrorMessage, (short) 0, msg.getSerializedSize()); } else if (msg instanceof Protocol.PollInputStreamEntryMessage) { return new MessageHeader( MessageType.PollInputStreamEntryMessage, (short) 0, msg.getSerializedSize()); @@ -108,8 +110,6 @@ public static MessageHeader fromMessage(MessageLite msg) { } else if (msg instanceof Java.SideEffectEntryMessage) { return new MessageHeader( MessageType.SideEffectEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize()); - } else if (msg instanceof Protocol.StartMessage) { - throw new IllegalArgumentException("SDK should never send a StartMessage"); } else if (msg instanceof Protocol.CompletionMessage) { throw new IllegalArgumentException("SDK should never send a CompletionMessage"); } diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/MessageType.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/MessageType.java index b0212d6d..9e742a08 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/MessageType.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/MessageType.java @@ -9,6 +9,7 @@ public enum MessageType { StartMessage, CompletionMessage, SuspensionMessage, + ErrorMessage, // IO PollInputStreamEntryMessage, @@ -33,6 +34,7 @@ public enum MessageType { public static final short START_MESSAGE_TYPE = 0x0000; public static final short COMPLETION_MESSAGE_TYPE = 0x0001; public static final short SUSPENSION_MESSAGE_TYPE = 0x0002; + public static final short ERROR_MESSAGE_TYPE = 0x0003; public static final short POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE = 0x0400; public static final short OUTPUT_STREAM_ENTRY_MESSAGE_TYPE = 0x0401; public static final short GET_STATE_ENTRY_MESSAGE_TYPE = 0x0800; @@ -54,6 +56,8 @@ public Parser messageParser() { return Protocol.CompletionMessage.parser(); case SuspensionMessage: return Protocol.SuspensionMessage.parser(); + case ErrorMessage: + return Protocol.ErrorMessage.parser(); case PollInputStreamEntryMessage: return Protocol.PollInputStreamEntryMessage.parser(); case OutputStreamEntryMessage: @@ -90,6 +94,8 @@ public short encode() { return COMPLETION_MESSAGE_TYPE; case SuspensionMessage: return SUSPENSION_MESSAGE_TYPE; + case ErrorMessage: + return ERROR_MESSAGE_TYPE; case PollInputStreamEntryMessage: return POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE; case OutputStreamEntryMessage: @@ -126,6 +132,8 @@ public static MessageType decode(short value) throws ProtocolException { return CompletionMessage; case SUSPENSION_MESSAGE_TYPE: return SuspensionMessage; + case ERROR_MESSAGE_TYPE: + return ErrorMessage; case POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE: return PollInputStreamEntryMessage; case OUTPUT_STREAM_ENTRY_MESSAGE_TYPE: diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ProtocolException.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ProtocolException.java index 6687e1f2..e7dc5d5a 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ProtocolException.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ProtocolException.java @@ -3,28 +3,47 @@ import com.google.protobuf.MessageLite; import dev.restate.generated.service.protocol.Protocol; import io.grpc.Status; +import java.io.PrintWriter; +import java.io.StringWriter; public class ProtocolException extends RuntimeException { + static final int JOURNAL_MISMATCH_CODE = 32; + static final int PROTOCOL_VIOLATION = 33; + static final ProtocolException CLOSED = new ProtocolException("Invocation closed"); - private final Status.Code grpcCode; + private final int failureCode; private ProtocolException(String message) { - this(message, Status.Code.INTERNAL); + this(message, Status.Code.INTERNAL.value()); } - private ProtocolException(String message, Status.Code grpcCode) { - this(message, null, grpcCode); + private ProtocolException(String message, int failureCode) { + this(message, null, failureCode); } - public ProtocolException(String message, Throwable cause, Status.Code grpcCode) { + public ProtocolException(String message, Throwable cause, int failureCode) { super(message, cause); - this.grpcCode = grpcCode; + this.failureCode = failureCode; + } + + public int getFailureCode() { + return failureCode; } - public Status.Code getGrpcCode() { - return grpcCode; + public Protocol.ErrorMessage toErrorMessage() { + // Convert stacktrace to string + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + pw.println("Stacktrace:"); + this.printStackTrace(pw); + + return Protocol.ErrorMessage.newBuilder() + .setCode(failureCode) + .setMessage(this.toString()) + .setDescription(sw.toString()) + .build(); } static ProtocolException unexpectedMessage( @@ -34,26 +53,37 @@ static ProtocolException unexpectedMessage( + expected.getCanonicalName() + "', Actual: '" + actual.getClass().getCanonicalName() - + "'"); + + "'", + PROTOCOL_VIOLATION); } static ProtocolException entryDoesNotMatch(MessageLite expected, MessageLite actual) { return new ProtocolException( - "Journal entry " + expected.getClass() + " does not match: " + expected + " != " + actual); + "Journal entry " + expected.getClass() + " does not match: " + expected + " != " + actual, + JOURNAL_MISMATCH_CODE); } static ProtocolException completionDoesNotMatch( String entry, Protocol.CompletionMessage.ResultCase actual) { return new ProtocolException( - "Completion for entry " + entry + " doesn't expect completion variant " + actual); + "Completion for entry " + entry + " doesn't expect completion variant " + actual, + JOURNAL_MISMATCH_CODE); } static ProtocolException unknownMessageType(short type) { - return new ProtocolException("MessageType " + Integer.toHexString(type) + " unknown"); + return new ProtocolException( + "MessageType " + Integer.toHexString(type) + " unknown", PROTOCOL_VIOLATION); } static ProtocolException methodNotFound(String svcName, String methodName) { return new ProtocolException( - "Cannot find method '" + svcName + "/" + methodName + "'", Status.Code.NOT_FOUND); + "Cannot find method '" + svcName + "/" + methodName + "'", Status.Code.NOT_FOUND.value()); + } + + static ProtocolException invalidSideEffectCall() { + return new ProtocolException( + "A syscall was invoked from within a side effect closure.", + null, + Status.Code.UNKNOWN.value()); } } diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateServerCall.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateServerCall.java index 1f63d37f..ba846065 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateServerCall.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateServerCall.java @@ -7,7 +7,6 @@ import io.grpc.ServerCall; import io.grpc.Status; import java.util.Objects; -import java.util.Optional; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -88,15 +87,10 @@ public void close(Status status, Metadata trailers) { // Let's cancel the listener first listener.onCancel(); - Optional protocolException = - Util.findCause(status.getCause(), t -> t instanceof ProtocolException); - if (protocolException.isPresent()) { - // If it's a protocol exception, we propagate the failure to syscalls, which will propagate - // it to the network layer - syscalls.fail(protocolException.get()); + if (status.getCode() == Status.Code.UNKNOWN) { + // If no cause, just propagate a generic runtime exception + syscalls.fail(status.getCause() != null ? status.getCause() : status.asRuntimeException()); } else { - // If not a protocol exception, then it's an exception coming from user which we write on - // the journal syscalls.writeOutput( status.asRuntimeException(), SyscallCallback.ofVoid( diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SyscallsImpl.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SyscallsImpl.java index 9eb1e443..a9695875 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SyscallsImpl.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SyscallsImpl.java @@ -1,5 +1,6 @@ package dev.restate.sdk.core.impl; +import static dev.restate.sdk.core.impl.Util.isTerminalException; import static dev.restate.sdk.core.impl.Util.toProtocolFailure; import com.google.protobuf.ByteString; @@ -197,10 +198,17 @@ public void exitSideEffectBlockWithException( Throwable toWrite, ExitSideEffectSyscallCallback callback) { LOG.trace("exitSideEffectBlock with failure"); - // If it's a protocol exception, don't write it - Optional protocolException = Util.findProtocolException(toWrite); - if (protocolException.isPresent()) { - throw protocolException.get(); + // If it's a non-terminal exception (such as a protocol exception), + // we don't write it but simply throw it + if (!(isTerminalException(toWrite))) { + // For safety wrt Syscalls API we do this check and wrapping, + // but with the current APIs the exception should always be RuntimeException + // because that's what can be thrown inside a lambda + if (toWrite instanceof RuntimeException) { + throw (RuntimeException) toWrite; + } else { + throw new RuntimeException(toWrite); + } } this.stateMachine.exitSideEffectBlock( diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Util.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Util.java index 6f6b6c3e..2e4e8074 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Util.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Util.java @@ -1,6 +1,7 @@ package dev.restate.sdk.core.impl; import com.google.protobuf.MessageLite; +import com.google.rpc.Code; import dev.restate.generated.sdk.java.Java; import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.core.SuspendedException; @@ -79,6 +80,11 @@ static Status toGrpcStatusErasingCause(Throwable throwable) { return Status.UNKNOWN.withDescription(throwable.getMessage()); } + static boolean isTerminalException(Throwable throwable) { + return throwable instanceof StatusRuntimeException + && ((StatusRuntimeException) throwable).getStatus().getCode().value() != Code.UNKNOWN_VALUE; + } + static void assertIsEntry(MessageLite msg) { if (!isEntry(msg)) { throw new IllegalStateException("Expected input to be entry"); diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/AssertUtils.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/AssertUtils.java new file mode 100644 index 00000000..ecbc92bf --- /dev/null +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/AssertUtils.java @@ -0,0 +1,51 @@ +package dev.restate.sdk.core.impl; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.STRING; +import static org.assertj.core.api.InstanceOfAssertFactories.type; + +import com.google.protobuf.MessageLite; +import com.google.rpc.Code; +import dev.restate.generated.service.protocol.Protocol; +import java.util.List; +import java.util.function.Consumer; + +public class AssertUtils { + + public static Consumer> containsOnly(Consumer consumer) { + return msgs -> assertThat(msgs).satisfiesExactly(consumer); + } + + public static Consumer> containsOnlyExactErrorMessage(Throwable e) { + return containsOnly(exactErrorMessage(e)); + } + + public static Consumer errorMessage( + Consumer consumer) { + return msg -> + assertThat(msg).asInstanceOf(type(Protocol.ErrorMessage.class)).satisfies(consumer); + } + + public static Consumer exactErrorMessage(Throwable e) { + return errorMessage( + msg -> + assertThat(msg) + .returns(e.toString(), Protocol.ErrorMessage::getMessage) + .returns(Code.UNKNOWN_VALUE, Protocol.ErrorMessage::getCode)); + } + + public static Consumer errorMessageStartingWith(String str) { + return errorMessage( + msg -> + assertThat(msg).extracting(Protocol.ErrorMessage::getMessage, STRING).startsWith(str)); + } + + public static Consumer protocolExceptionErrorMessage(int code) { + return errorMessage( + msg -> + assertThat(msg) + .returns(code, Protocol.ErrorMessage::getCode) + .extracting(Protocol.ErrorMessage::getMessage, STRING) + .startsWith(ProtocolException.class.getCanonicalName())); + } +} diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/CoreTestRunner.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/CoreTestRunner.java index 5f311b34..735cad6b 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/CoreTestRunner.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/CoreTestRunner.java @@ -361,7 +361,9 @@ public BiConsumer, Duration> getOutputAssert() { .last() .isNotNull() .isInstanceOfAny( - Protocol.OutputStreamEntryMessage.class, Protocol.SuspensionMessage.class); + Protocol.OutputStreamEntryMessage.class, + Protocol.SuspensionMessage.class, + Protocol.ErrorMessage.class); }; } } diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetAndSetStateTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetAndSetStateTest.java index fcb52cf3..e4af6cda 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetAndSetStateTest.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetAndSetStateTest.java @@ -1,5 +1,6 @@ package dev.restate.sdk.core.impl; +import static dev.restate.sdk.core.impl.AssertUtils.containsOnlyExactErrorMessage; import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; import static dev.restate.sdk.core.impl.ProtoUtils.*; @@ -89,6 +90,6 @@ Stream definitions() { testInvocation(new SetNullState(), GreeterGrpc.getGreetMethod()) .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) .usingAllThreadingModels() - .assertingFailure(NullPointerException.class)); + .assertingOutput(containsOnlyExactErrorMessage(new NullPointerException()))); } } diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/ProtoUtils.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/ProtoUtils.java index 52eb6c2d..965f683b 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/ProtoUtils.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/ProtoUtils.java @@ -1,5 +1,7 @@ package dev.restate.sdk.core.impl; +import static dev.restate.sdk.core.impl.Util.toProtocolFailure; + import com.google.protobuf.ByteString; import com.google.protobuf.Empty; import com.google.protobuf.MessageLite; @@ -10,6 +12,7 @@ import dev.restate.sdk.core.impl.testservices.GreetingRequest; import dev.restate.sdk.core.impl.testservices.GreetingResponse; import io.grpc.MethodDescriptor; +import io.grpc.Status; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -74,7 +77,7 @@ public static Protocol.CompletionMessage completionMessage( static Protocol.CompletionMessage completionMessage(int index, Throwable e) { return Protocol.CompletionMessage.newBuilder() .setEntryIndex(index) - .setFailure(Util.toProtocolFailure(e)) + .setFailure(toProtocolFailure(Status.INTERNAL.withDescription(e.getMessage()))) .build(); } @@ -94,9 +97,15 @@ static Protocol.OutputStreamEntryMessage outputMessage(MessageLiteOrBuilder valu .build(); } + static Protocol.OutputStreamEntryMessage outputMessage(Status s) { + return Protocol.OutputStreamEntryMessage.newBuilder() + .setFailure(Util.toProtocolFailure(s.asRuntimeException())) + .build(); + } + static Protocol.OutputStreamEntryMessage outputMessage(Throwable e) { return Protocol.OutputStreamEntryMessage.newBuilder() - .setFailure(Util.toProtocolFailure(e)) + .setFailure(toProtocolFailure(Status.INTERNAL.withDescription(e.getMessage()))) .build(); } @@ -144,7 +153,9 @@ static Protocol.InvokeEntryMessag static Protocol.InvokeEntryMessage invokeMessage( MethodDescriptor methodDescriptor, T parameter, Throwable e) { - return invokeMessage(methodDescriptor, parameter).setFailure(Util.toProtocolFailure(e)).build(); + return invokeMessage(methodDescriptor, parameter) + .setFailure(toProtocolFailure(Status.INTERNAL.withDescription(e.getMessage()))) + .build(); } static Protocol.AwakeableEntryMessage.Builder awakeable() { diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SideEffectTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SideEffectTest.java index beaa3394..463ba14a 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SideEffectTest.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SideEffectTest.java @@ -1,5 +1,6 @@ package dev.restate.sdk.core.impl; +import static dev.restate.sdk.core.impl.AssertUtils.containsOnlyExactErrorMessage; import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; import static dev.restate.sdk.core.impl.ProtoUtils.*; import static org.assertj.core.api.Assertions.assertThat; @@ -171,7 +172,8 @@ Stream definitions() { testInvocation(new SideEffectGuard(), GreeterGrpc.getGreetMethod()) .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) .usingAllThreadingModels() - .assertingFailure(ProtocolException.class), + .assertingOutput( + containsOnlyExactErrorMessage(ProtocolException.invalidSideEffectCall())), testInvocation(new SideEffectThenAwakeable(), GreeterGrpc.getGreetMethod()) .withInput( startMessage(2), diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/StateMachineFailuresTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/StateMachineFailuresTest.java index 10ded53f..1770da45 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/StateMachineFailuresTest.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/StateMachineFailuresTest.java @@ -1,5 +1,7 @@ package dev.restate.sdk.core.impl; +import static dev.restate.sdk.core.impl.AssertUtils.containsOnly; +import static dev.restate.sdk.core.impl.AssertUtils.errorMessageStartingWith; import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; import static dev.restate.sdk.core.impl.ProtoUtils.*; @@ -88,24 +90,33 @@ Stream definitions() { inputMessage(GreetingRequest.newBuilder().setName("Till")), getStateMessage("Something")) .usingAllThreadingModels() - .assertingFailure(ProtocolException.class), + .assertingOutput( + containsOnly( + AssertUtils.protocolExceptionErrorMessage( + ProtocolException.JOURNAL_MISMATCH_CODE))), testInvocation(new GetState(), GreeterGrpc.getGreetMethod()) .withInput( startMessage(2), inputMessage(GreetingRequest.newBuilder().setName("Till")), getStateMessage("STATE", "This is not an integer")) .usingAllThreadingModels() - .assertingFailure(NumberFormatException.class), + .assertingOutput( + containsOnly( + errorMessageStartingWith(NumberFormatException.class.getCanonicalName()))), testInvocation(new EndSideEffectSerializationFailure(), GreeterGrpc.getGreetMethod()) .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) .usingAllThreadingModels() - .assertingFailure(IllegalStateException.class), + .assertingOutput( + containsOnly( + errorMessageStartingWith(IllegalStateException.class.getCanonicalName()))), testInvocation(new EndSideEffectDeserializationFailure(), GreeterGrpc.getGreetMethod()) .withInput( startMessage(2), inputMessage(GreetingRequest.newBuilder().setName("Till")), Java.SideEffectEntryMessage.newBuilder()) .usingAllThreadingModels() - .assertingFailure(IllegalStateException.class)); + .assertingOutput( + containsOnly( + errorMessageStartingWith(IllegalStateException.class.getCanonicalName())))); } } diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/UserFailuresTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/UserFailuresTest.java index 61c21b25..c8b6b797 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/UserFailuresTest.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/UserFailuresTest.java @@ -1,29 +1,124 @@ package dev.restate.sdk.core.impl; +import static dev.restate.sdk.core.impl.AssertUtils.containsOnlyExactErrorMessage; import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; import static dev.restate.sdk.core.impl.ProtoUtils.*; +import dev.restate.generated.sdk.java.Java; +import dev.restate.sdk.blocking.RestateBlockingService; import dev.restate.sdk.core.impl.testservices.GreeterGrpc; import dev.restate.sdk.core.impl.testservices.GreetingRequest; import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; import io.grpc.stub.StreamObserver; import java.util.stream.Stream; class UserFailuresTest extends CoreTestRunner { - private static class FailingGreeter extends GreeterGrpc.GreeterImplBase { + private static final Status MY_ERROR = Status.INTERNAL.withDescription("my error"); + + private static class ThrowIllegalStateException extends GreeterGrpc.GreeterImplBase { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { throw new IllegalStateException("Whatever"); } } + private static class ResponseObserverOnErrorIllegalStateException + extends GreeterGrpc.GreeterImplBase { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + responseObserver.onError(new IllegalStateException("Whatever")); + } + } + + private static class SideEffectThrowIllegalStateException extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + restateContext() + .sideEffect( + () -> { + throw new IllegalStateException("Whatever"); + }); + } + } + + private static class ThrowUnknownStatusRuntimeException extends GreeterGrpc.GreeterImplBase { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + throw new StatusRuntimeException(Status.UNKNOWN.withDescription("Whatever")); + } + } + + private static class ThrowStatusRuntimeException extends GreeterGrpc.GreeterImplBase { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + throw new StatusRuntimeException(MY_ERROR); + } + } + + private static class ResponseObserverOnErrorStatusRuntimeException + extends GreeterGrpc.GreeterImplBase { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + responseObserver.onError(new StatusRuntimeException(MY_ERROR)); + } + } + + private static class SideEffectThrowStatusRuntimeException extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + restateContext() + .sideEffect( + () -> { + throw new StatusRuntimeException(MY_ERROR); + }); + } + } + @Override Stream definitions() { return Stream.of( - testInvocation(new FailingGreeter(), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) + // Cases returning ErrorMessage + testInvocation(new ThrowIllegalStateException(), GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .assertingOutput(containsOnlyExactErrorMessage(new IllegalStateException("Whatever"))), + testInvocation(new ThrowUnknownStatusRuntimeException(), GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .assertingOutput( + containsOnlyExactErrorMessage( + Status.UNKNOWN.withDescription("Whatever").asRuntimeException())), + testInvocation( + new ResponseObserverOnErrorIllegalStateException(), GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .assertingOutput(containsOnlyExactErrorMessage(new IllegalStateException("Whatever"))), + testInvocation(new SideEffectThrowIllegalStateException(), GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .assertingOutput(containsOnlyExactErrorMessage(new IllegalStateException("Whatever"))), + + // Cases completing the invocation with OutputStreamEntry.failure + testInvocation(new ThrowStatusRuntimeException(), GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .expectingOutput(outputMessage(MY_ERROR)), + testInvocation( + new ResponseObserverOnErrorStatusRuntimeException(), GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .expectingOutput(outputMessage(MY_ERROR)), + testInvocation(new SideEffectThrowStatusRuntimeException(), GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) .usingAllThreadingModels() - .expectingOutput(outputMessage(new IllegalStateException("Whatever")))); + .expectingOutput( + Java.SideEffectEntryMessage.newBuilder() + .setFailure(Util.toProtocolFailure(MY_ERROR)), + outputMessage(MY_ERROR))); } } diff --git a/sdk-lambda/src/main/java/dev/restate/sdk/lambda/LambdaRestateServer.java b/sdk-lambda/src/main/java/dev/restate/sdk/lambda/LambdaRestateServer.java index 4d0cce09..896188c0 100644 --- a/sdk-lambda/src/main/java/dev/restate/sdk/lambda/LambdaRestateServer.java +++ b/sdk-lambda/src/main/java/dev/restate/sdk/lambda/LambdaRestateServer.java @@ -110,7 +110,7 @@ private APIGatewayProxyResponseEvent handleInvoke(APIGatewayProxyRequestEvent in } catch (ProtocolException e) { LOG.warn("Error when resolving the grpc handler", e); return new APIGatewayProxyResponseEvent() - .withStatusCode(e.getGrpcCode() == Status.Code.NOT_FOUND ? 404 : 500); + .withStatusCode(e.getFailureCode() == Status.Code.NOT_FOUND.value() ? 404 : 500); } BufferedPublisher publisher = new BufferedPublisher(requestBody); diff --git a/sdk-vertx/src/main/java/dev/restate/sdk/vertx/HttpRequestFlowAdapter.java b/sdk-vertx/src/main/java/dev/restate/sdk/vertx/HttpRequestFlowAdapter.java index e9781bf5..0b30f583 100644 --- a/sdk-vertx/src/main/java/dev/restate/sdk/vertx/HttpRequestFlowAdapter.java +++ b/sdk-vertx/src/main/java/dev/restate/sdk/vertx/HttpRequestFlowAdapter.java @@ -72,6 +72,7 @@ private void handleIncomingBuffer(Buffer buffer) { } private void handleRequestFailure(Throwable e) { + LOG.trace("Request error", e); this.inputMessagesSubscriber.onError(e); } diff --git a/sdk-vertx/src/main/java/dev/restate/sdk/vertx/HttpResponseFlowAdapter.java b/sdk-vertx/src/main/java/dev/restate/sdk/vertx/HttpResponseFlowAdapter.java index 7f73a37b..4d8d55ef 100644 --- a/sdk-vertx/src/main/java/dev/restate/sdk/vertx/HttpResponseFlowAdapter.java +++ b/sdk-vertx/src/main/java/dev/restate/sdk/vertx/HttpResponseFlowAdapter.java @@ -76,7 +76,7 @@ private void propagatePublisherFailure(Throwable e) { pe -> // TODO which status codes we need to map here? httpServerResponse.setStatusCode( - pe.getGrpcCode() == Status.Code.NOT_FOUND ? 404 : 500), + pe.getFailureCode() == Status.Code.NOT_FOUND.value() ? 404 : 500), () -> httpServerResponse.setStatusCode(500)); } LOG.warn("Error from publisher", e); @@ -84,6 +84,7 @@ private void propagatePublisherFailure(Throwable e) { } private void endResponse() { + LOG.trace("Closing response"); if (!this.httpServerResponse.ended()) { this.httpServerResponse.end(); } @@ -91,6 +92,7 @@ private void endResponse() { } private void cancelSubscription() { + LOG.trace("Cancelling subscription"); if (this.outputSubscription != null) { Flow.Subscription outputSubscription = this.outputSubscription; this.outputSubscription = null; diff --git a/sdk-vertx/src/main/java/dev/restate/sdk/vertx/RequestHttpServerHandler.java b/sdk-vertx/src/main/java/dev/restate/sdk/vertx/RequestHttpServerHandler.java index abdad8f6..48e47d19 100644 --- a/sdk-vertx/src/main/java/dev/restate/sdk/vertx/RequestHttpServerHandler.java +++ b/sdk-vertx/src/main/java/dev/restate/sdk/vertx/RequestHttpServerHandler.java @@ -115,7 +115,7 @@ public void handle(HttpServerRequest request) { request .response() .setStatusCode( - e.getGrpcCode() == Status.Code.NOT_FOUND + e.getFailureCode() == Status.Code.NOT_FOUND.value() ? NOT_FOUND.code() : INTERNAL_SERVER_ERROR.code()) .end();