diff --git a/sdk-api-gen/src/main/resources/templates/Client.hbs b/sdk-api-gen/src/main/resources/templates/Client.hbs index fd835a55..8d73c2cf 100644 --- a/sdk-api-gen/src/main/resources/templates/Client.hbs +++ b/sdk-api-gen/src/main/resources/templates/Client.hbs @@ -109,6 +109,21 @@ public class {{generatedClassSimpleName}} { {{outputSerdeFieldName}}, {{#if inputEmpty}}null{{else}}req{{/if}}, requestOptions); + } + + public {{#if outputEmpty}}java.util.concurrent.CompletableFuture{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{name}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return this.{{name}}Async( + {{^inputEmpty}}req, {{/inputEmpty}} + dev.restate.sdk.client.RequestOptions.DEFAULT); + } + + public {{#if outputEmpty}}java.util.concurrent.CompletableFuture{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{name}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) { + return this.ingressClient.callAsync( + {{#if isObject}}Target.virtualObject(COMPONENT_NAME, this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}}, + {{inputSerdeFieldName}}, + {{outputSerdeFieldName}}, + {{#if inputEmpty}}null{{else}}req{{/if}}, + requestOptions); }{{/handlers}} public Send send() { @@ -129,6 +144,20 @@ public class {{generatedClassSimpleName}} { {{inputSerdeFieldName}}, {{#if inputEmpty}}null{{else}}req{{/if}}, requestOptions); + } + + public java.util.concurrent.CompletableFuture {{name}}Async({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { + return this.{{name}}Async( + {{^inputEmpty}}req, {{/inputEmpty}} + dev.restate.sdk.client.RequestOptions.DEFAULT); + } + + public java.util.concurrent.CompletableFuture {{name}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) { + return IngressClient.this.ingressClient.sendAsync( + {{#if isObject}}Target.virtualObject(COMPONENT_NAME, IngressClient.this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}}, + {{inputSerdeFieldName}}, + {{#if inputEmpty}}null{{else}}req{{/if}}, + requestOptions); }{{/handlers}} } } diff --git a/sdk-api-kotlin-gen/build.gradle.kts b/sdk-api-kotlin-gen/build.gradle.kts index fe64cad9..94b3f3a3 100644 --- a/sdk-api-kotlin-gen/build.gradle.kts +++ b/sdk-api-kotlin-gen/build.gradle.kts @@ -21,6 +21,7 @@ dependencies { testImplementation(testingLibs.assertj) testImplementation(coreLibs.protobuf.java) testImplementation(coreLibs.log4j.core) + testImplementation(kotlinLibs.kotlinx.coroutines) // Import test suites from sdk-core testImplementation(project(":sdk-core", "testArchive")) diff --git a/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs b/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs index 3c36cc97..c71192d1 100644 --- a/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs +++ b/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs @@ -6,6 +6,7 @@ import dev.restate.sdk.common.StateKey import dev.restate.sdk.common.Serde import dev.restate.sdk.common.Target import kotlin.time.Duration +import kotlinx.coroutines.future.await object {{generatedClassSimpleName}} { @@ -73,12 +74,12 @@ object {{generatedClassSimpleName}} { {{#handlers}} suspend fun {{name}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.RequestOptions = dev.restate.sdk.client.RequestOptions.DEFAULT): {{{boxedOutputFqcn}}} { - return this.ingressClient.call( + return this.ingressClient.callAsync( {{#if isObject}}Target.virtualObject(COMPONENT_NAME, this.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}}, {{inputSerdeFieldName}}, {{outputSerdeFieldName}}, {{#if inputEmpty}}null{{else}}req{{/if}}, - requestOptions); + requestOptions).await(); }{{/handlers}} fun send(): Send { @@ -88,11 +89,11 @@ object {{generatedClassSimpleName}} { inner class Send { {{#handlers}} suspend fun {{name}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.RequestOptions = dev.restate.sdk.client.RequestOptions.DEFAULT): String { - return this@IngressClient.ingressClient.send( + return this@IngressClient.ingressClient.sendAsync( {{#if isObject}}Target.virtualObject(COMPONENT_NAME, this@IngressClient.key, "{{name}}"){{else}}Target.service(COMPONENT_NAME, "{{name}}"){{/if}}, {{inputSerdeFieldName}}, {{#if inputEmpty}}null{{else}}req{{/if}}, - requestOptions); + requestOptions).await(); }{{/handlers}} } } diff --git a/sdk-common/src/main/java/dev/restate/sdk/client/DefaultIngressClient.java b/sdk-common/src/main/java/dev/restate/sdk/client/DefaultIngressClient.java index ac157683..50289c54 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/client/DefaultIngressClient.java +++ b/sdk-common/src/main/java/dev/restate/sdk/client/DefaultIngressClient.java @@ -13,14 +13,15 @@ import com.fasterxml.jackson.core.JsonToken; import dev.restate.sdk.common.Serde; import dev.restate.sdk.common.Target; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; -import java.nio.charset.StandardCharsets; import java.util.Map; +import java.util.concurrent.CompletableFuture; public class DefaultIngressClient implements IngressClient { @@ -37,53 +38,58 @@ public DefaultIngressClient(HttpClient httpClient, String baseUri, Map Res call( + public CompletableFuture callAsync( Target target, Serde reqSerde, Serde resSerde, Req req, RequestOptions requestOptions) { HttpRequest request = prepareHttpRequest(target, false, reqSerde, req, requestOptions); - HttpResponse response; - try { - response = httpClient.send(request, HttpResponse.BodyHandlers.ofByteArray()); - } catch (IOException | InterruptedException e) { - throw new RuntimeException("Error when executing the request", e); - } - - if (response.statusCode() != 200) { - // Try to parse as string - String error = new String(response.body(), StandardCharsets.UTF_8); - throw new RuntimeException( - "Received non OK status code: " + response.statusCode() + ". Body: " + error); - } - - return resSerde.deserialize(response.body()); + return httpClient + .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) + .handle( + (response, throwable) -> { + if (throwable != null) { + throw new IngressException("Error when executing the request", throwable); + } + + if (response.statusCode() >= 300) { + handleNonSuccessResponse(response); + } + + try { + return resSerde.deserialize(response.body()); + } catch (Exception e) { + throw new IngressException( + "Cannot deserialize the response", response.statusCode(), response.body(), e); + } + }); } @Override - public String send(Target target, Serde reqSerde, Req req, RequestOptions options) { + public CompletableFuture sendAsync( + Target target, Serde reqSerde, Req req, RequestOptions options) { HttpRequest request = prepareHttpRequest(target, true, reqSerde, req, options); - HttpResponse response; - try { - response = httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); - } catch (IOException | InterruptedException e) { - throw new RuntimeException("Error when executing the request", e); - } - - try (InputStream in = response.body()) { - if (response.statusCode() >= 300) { - // Try to parse as string - String error = new String(in.readAllBytes(), StandardCharsets.UTF_8); - throw new RuntimeException( - "Received non OK status code: " + response.statusCode() + ". Body: " + error); - } - return deserializeInvocationId(in); - } catch (IOException e) { - throw new RuntimeException( - "Error when trying to read the response, when status code was " + response.statusCode(), - e); - } + return httpClient + .sendAsync(request, HttpResponse.BodyHandlers.ofByteArray()) + .handle( + (response, throwable) -> { + if (throwable != null) { + throw new IngressException("Error when executing the request", throwable); + } + + if (response.statusCode() >= 300) { + handleNonSuccessResponse(response); + } + + try { + return findStringFieldInJsonObject( + new ByteArrayInputStream(response.body()), "invocationId"); + } catch (Exception e) { + throw new IngressException( + "Cannot deserialize the response", response.statusCode(), response.body(), e); + } + }); } private URI toRequestURI(Target target, boolean isSend) { @@ -128,23 +134,43 @@ private HttpRequest prepareHttpRequest( return reqBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(reqSerde.serialize(req))).build(); } - private static String deserializeInvocationId(InputStream body) throws IOException { + private void handleNonSuccessResponse(HttpResponse response) { + if (response.headers().firstValue("content-type").orElse("").contains("application/json")) { + String errorMessage; + // Let's try to parse the message field + try { + errorMessage = + findStringFieldInJsonObject(new ByteArrayInputStream(response.body()), "message"); + } catch (Exception e) { + throw new IngressException( + "Can't decode error response from ingress", response.statusCode(), response.body(), e); + } + throw new IngressException(errorMessage, response.statusCode(), response.body()); + } + + // Fallback error + throw new IngressException( + "Received non success status code", response.statusCode(), response.body()); + } + + private static String findStringFieldInJsonObject(InputStream body, String fieldName) + throws IOException { try (JsonParser parser = JSON_FACTORY.createParser(body)) { if (parser.nextToken() != JsonToken.START_OBJECT) { throw new IllegalStateException( "Expecting token " + JsonToken.START_OBJECT + ", got " + parser.getCurrentToken()); } - String fieldName = parser.nextFieldName(); - if (fieldName == null || !fieldName.equalsIgnoreCase("invocationid")) { - throw new IllegalStateException( - "Expecting token \"invocationId\", got " + parser.getCurrentToken()); - } - String invocationId = parser.nextTextValue(); - if (invocationId == null) { - throw new IllegalStateException( - "Expecting token " + JsonToken.VALUE_STRING + ", got " + parser.getCurrentToken()); + for (String actualFieldName = parser.nextFieldName(); + actualFieldName != null; + actualFieldName = parser.nextFieldName()) { + if (actualFieldName.equalsIgnoreCase(fieldName)) { + return parser.nextTextValue(); + } else { + parser.nextValue(); + } } - return invocationId; + throw new IllegalStateException( + "Expecting field name \"" + fieldName + "\", got " + parser.getCurrentToken()); } } } diff --git a/sdk-common/src/main/java/dev/restate/sdk/client/IngressClient.java b/sdk-common/src/main/java/dev/restate/sdk/client/IngressClient.java index 339a0c44..9105bcae 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/client/IngressClient.java +++ b/sdk-common/src/main/java/dev/restate/sdk/client/IngressClient.java @@ -13,18 +13,57 @@ import java.net.http.HttpClient; import java.util.Collections; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; public interface IngressClient { - Res call( + + CompletableFuture callAsync( Target target, Serde reqSerde, Serde resSerde, Req req, RequestOptions options); - default Res call(Target target, Serde reqSerde, Serde resSerde, Req req) { + default CompletableFuture callAsync( + Target target, Serde reqSerde, Serde resSerde, Req req) { + return callAsync(target, reqSerde, resSerde, req, RequestOptions.DEFAULT); + } + + default Res call( + Target target, Serde reqSerde, Serde resSerde, Req req, RequestOptions options) + throws IngressException { + try { + return callAsync(target, reqSerde, resSerde, req, options).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } + + default Res call(Target target, Serde reqSerde, Serde resSerde, Req req) + throws IngressException { return call(target, reqSerde, resSerde, req, RequestOptions.DEFAULT); } - String send(Target target, Serde reqSerde, Req req, RequestOptions options); + CompletableFuture sendAsync( + Target target, Serde reqSerde, Req req, RequestOptions options); + + default CompletableFuture sendAsync(Target target, Serde reqSerde, Req req) { + return sendAsync(target, reqSerde, req, RequestOptions.DEFAULT); + } + + default String send(Target target, Serde reqSerde, Req req, RequestOptions options) + throws IngressException { + try { + return sendAsync(target, reqSerde, req, options).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new RuntimeException(e.getCause()); + } + } - default String send(Target target, Serde reqSerde, Req req) { + default String send(Target target, Serde reqSerde, Req req) throws IngressException { return send(target, reqSerde, req, RequestOptions.DEFAULT); } diff --git a/sdk-common/src/main/java/dev/restate/sdk/client/IngressException.java b/sdk-common/src/main/java/dev/restate/sdk/client/IngressException.java new file mode 100644 index 00000000..4a134ad8 --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/client/IngressException.java @@ -0,0 +1,54 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.client; + +import java.nio.charset.StandardCharsets; +import org.jspecify.annotations.Nullable; + +public class IngressException extends RuntimeException { + + private final int statusCode; + private final byte[] responseBody; + + public IngressException(String message, Throwable cause) { + this(message, -1, null, cause); + } + + public IngressException(String message, int statusCode, byte[] responseBody) { + this(message, statusCode, responseBody, null); + } + + public IngressException(String message, int statusCode, byte[] responseBody, Throwable cause) { + super(message, cause); + this.statusCode = statusCode; + this.responseBody = responseBody; + } + + public int getStatusCode() { + return statusCode; + } + + public byte @Nullable [] getResponseBody() { + return responseBody; + } + + @Override + public String toString() { + return "IngressException{" + + "statusCode=" + + statusCode + + ", responseBody='" + + new String(responseBody, StandardCharsets.UTF_8) + + '\'' + + ", message='" + + this.getMessage() + + '\'' + + '}'; + } +}