Skip to content

Commit

Permalink
Implement Shared handlers annotation for virtual objects (#288)
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper authored Apr 18, 2024
1 parent 5bb5164 commit 82fc68e
Show file tree
Hide file tree
Showing 23 changed files with 201 additions and 72 deletions.
5 changes: 4 additions & 1 deletion examples/src/main/java/my/restate/sdk/examples/Counter.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
package my.restate.sdk.examples;

import dev.restate.sdk.ObjectContext;
import dev.restate.sdk.SharedObjectContext;
import dev.restate.sdk.annotation.Handler;
import dev.restate.sdk.annotation.Shared;
import dev.restate.sdk.annotation.VirtualObject;
import dev.restate.sdk.common.CoreSerdes;
import dev.restate.sdk.common.StateKey;
Expand All @@ -36,8 +38,9 @@ public void add(ObjectContext ctx, Long request) {
ctx.set(TOTAL, newValue);
}

@Shared
@Handler
public Long get(ObjectContext ctx) {
public Long get(SharedObjectContext ctx) {
return ctx.get(TOTAL).orElse(0L);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
package my.restate.sdk.examples

import dev.restate.sdk.annotation.Handler
import dev.restate.sdk.annotation.Shared
import dev.restate.sdk.annotation.VirtualObject
import dev.restate.sdk.http.vertx.RestateHttpEndpointBuilder
import dev.restate.sdk.kotlin.KtStateKey
import dev.restate.sdk.kotlin.ObjectContext
import dev.restate.sdk.kotlin.SharedObjectContext
import kotlinx.serialization.Serializable
import org.apache.logging.log4j.LogManager
import org.apache.logging.log4j.Logger
Expand Down Expand Up @@ -40,7 +42,8 @@ class CounterKt {
}

@Handler
suspend fun get(ctx: ObjectContext): Long {
@Shared
suspend fun get(ctx: SharedObjectContext): Long {
return ctx.get(TOTAL) ?: 0L
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import dev.restate.sdk.Context;
import dev.restate.sdk.ObjectContext;
import dev.restate.sdk.SharedObjectContext;
import dev.restate.sdk.annotation.Exclusive;
import dev.restate.sdk.annotation.Shared;
import dev.restate.sdk.annotation.Workflow;
Expand Down Expand Up @@ -233,6 +234,8 @@ private void validateMethodSignature(
case SHARED:
if (serviceType == ServiceType.WORKFLOW) {
validateFirstParameterType(WorkflowSharedContext.class, element);
} else if (serviceType == ServiceType.VIRTUAL_OBJECT) {
validateFirstParameterType(SharedObjectContext.class, element);
} else {
messager.printMessage(
Diagnostic.Kind.ERROR,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class {{generatedClassSimpleName}} implements dev.restate.sdk.common.Bind
public {{generatedClassSimpleName}}({{originalClassFqcn}} bindableService, dev.restate.sdk.Service.Options options) {
this.service = dev.restate.sdk.Service.{{#if isObject}}virtualObject{{else}}service{{/if}}(SERVICE_NAME)
{{#handlers}}
.with(
.{{#if isShared}}withShared{{else if isExclusive}}withExclusive{{else}}with{{/if}}(
dev.restate.sdk.Service.HandlerSignature.of("{{name}}", {{{inputSerdeDecl}}}, {{{outputSerdeDecl}}}),
(ctx, req) -> {
{{#if outputEmpty}}
Expand Down
14 changes: 11 additions & 3 deletions sdk-api-gen/src/test/java/dev/restate/sdk/CodegenTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
import static dev.restate.sdk.core.TestDefinitions.testInvocation;

import com.google.protobuf.ByteString;
import dev.restate.sdk.annotation.Exclusive;
import dev.restate.sdk.annotation.Handler;
import dev.restate.sdk.annotation.*;
import dev.restate.sdk.annotation.Service;
import dev.restate.sdk.annotation.VirtualObject;
import dev.restate.sdk.common.CoreSerdes;
import dev.restate.sdk.common.Target;
import dev.restate.sdk.core.ProtoUtils;
Expand All @@ -39,6 +37,12 @@ static class ObjectGreeter {
String greet(ObjectContext context, String request) {
return request;
}

@Handler
@Shared
String sharedGreet(SharedObjectContext context, String request) {
return request;
}
}

@VirtualObject
Expand Down Expand Up @@ -113,6 +117,10 @@ public Stream<TestDefinitions.TestDefinition> definitions() {
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
.expectingOutput(outputMessage("Francesco"), END_MESSAGE),
testInvocation(ObjectGreeter::new, "sharedGreet")
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
.expectingOutput(outputMessage("Francesco"), END_MESSAGE),
testInvocation(ObjectGreeterImplementedFromInterface::new, "greet")
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import dev.restate.sdk.gen.model.PayloadType
import dev.restate.sdk.gen.model.Service
import dev.restate.sdk.kotlin.Context
import dev.restate.sdk.kotlin.ObjectContext
import dev.restate.sdk.kotlin.SharedObjectContext
import java.util.regex.Pattern
import kotlin.reflect.KClass

Expand Down Expand Up @@ -128,7 +129,7 @@ class KElementConverter(private val logger: KSPLogger, private val builtIns: KSB
}

val isAnnotatedWithShared =
function.isAnnotationPresent(dev.restate.sdk.annotation.Service::class)
function.isAnnotationPresent(dev.restate.sdk.annotation.Shared::class)
val isAnnotatedWithExclusive =
function.isAnnotationPresent(dev.restate.sdk.annotation.Exclusive::class)

Expand Down Expand Up @@ -190,8 +191,13 @@ class KElementConverter(private val logger: KSPLogger, private val builtIns: KSB
}
when (handlerType) {
HandlerType.SHARED ->
logger.error(
"The annotation @Shared is not supported by the service type $serviceType", function)
if (serviceType == ServiceType.VIRTUAL_OBJECT) {
validateFirstParameterType(SharedObjectContext::class, function)
} else {
logger.error(
"The annotation @Shared is not supported by the service type $serviceType",
function)
}
HandlerType.EXCLUSIVE ->
if (serviceType == ServiceType.VIRTUAL_OBJECT) {
validateFirstParameterType(ObjectContext::class, function)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class {{generatedClassSimpleName}}(

val service: dev.restate.sdk.kotlin.Service = dev.restate.sdk.kotlin.Service.{{#if isObject}}virtualObject{{else}}service{{/if}}(SERVICE_NAME, options) {
{{#handlers}}
handler(dev.restate.sdk.kotlin.Service.HandlerSignature("{{name}}", {{{inputSerdeDecl}}}, {{{outputSerdeDecl}}})) { ctx, req ->
{{#if isShared}}sharedHandler{{else if isExclusive}}exclusiveHandler{{else}}handler{{/if}}(dev.restate.sdk.kotlin.Service.HandlerSignature("{{name}}", {{{inputSerdeDecl}}}, {{{outputSerdeDecl}}})) { ctx, req ->
{{#if inputEmpty}}bindableService.{{name}}(ctx){{else}}bindableService.{{name}}(ctx, req){{/if}}
}
{{/handlers}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
package dev.restate.sdk.kotlin

import com.google.protobuf.ByteString
import dev.restate.sdk.annotation.Exclusive
import dev.restate.sdk.annotation.Handler
import dev.restate.sdk.annotation.*
import dev.restate.sdk.annotation.Service
import dev.restate.sdk.annotation.VirtualObject
import dev.restate.sdk.common.CoreSerdes
import dev.restate.sdk.common.Target
import dev.restate.sdk.core.ProtoUtils.*
Expand All @@ -36,6 +34,12 @@ class CodegenTest : TestDefinitions.TestSuite {
suspend fun greet(context: ObjectContext, request: String): String {
return request
}

@Handler
@Shared
suspend fun sharedGreet(context: SharedObjectContext, request: String): String {
return request
}
}

@VirtualObject
Expand Down Expand Up @@ -104,6 +108,10 @@ class CodegenTest : TestDefinitions.TestSuite {
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
.expectingOutput(outputMessage("Francesco"), END_MESSAGE),
testInvocation({ ObjectGreeter() }, "sharedGreet")
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
.expectingOutput(outputMessage("Francesco"), END_MESSAGE),
testInvocation({ ObjectGreeterImplementedFromInterface() }, "greet")
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
Expand Down
30 changes: 21 additions & 9 deletions sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Service.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
package dev.restate.sdk.kotlin

import com.google.protobuf.ByteString
import dev.restate.sdk.common.BindableService
import dev.restate.sdk.common.Serde
import dev.restate.sdk.common.ServiceType
import dev.restate.sdk.common.TerminalException
import dev.restate.sdk.common.*
import dev.restate.sdk.common.syscalls.*
import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.CoroutineScope
Expand Down Expand Up @@ -65,18 +62,31 @@ private constructor(
class VirtualObjectBuilder internal constructor(private val name: String) {
private val handlers: MutableMap<String, Handler<*, *, ObjectContext>> = mutableMapOf()

fun <REQ, RES> handler(
fun <REQ, RES> sharedHandler(
sig: HandlerSignature<REQ, RES>,
runner: suspend (ObjectContext, REQ) -> RES
): VirtualObjectBuilder {
handlers[sig.name] = Handler(sig, runner)
handlers[sig.name] = Handler(sig, HandlerType.SHARED, runner)
return this
}

inline fun <reified REQ, reified RES> handler(
inline fun <reified REQ, reified RES> sharedHandler(
name: String,
noinline runner: suspend (ObjectContext, REQ) -> RES
) = this.handler(HandlerSignature(name, KtSerdes.json(), KtSerdes.json()), runner)
) = this.sharedHandler(HandlerSignature(name, KtSerdes.json(), KtSerdes.json()), runner)

fun <REQ, RES> exclusiveHandler(
sig: HandlerSignature<REQ, RES>,
runner: suspend (ObjectContext, REQ) -> RES
): VirtualObjectBuilder {
handlers[sig.name] = Handler(sig, HandlerType.EXCLUSIVE, runner)
return this
}

inline fun <reified REQ, reified RES> exclusiveHandler(
name: String,
noinline runner: suspend (ObjectContext, REQ) -> RES
) = this.exclusiveHandler(HandlerSignature(name, KtSerdes.json(), KtSerdes.json()), runner)

fun build(options: Options) = Service(this.name, true, this.handlers, options)
}
Expand All @@ -88,7 +98,7 @@ private constructor(
sig: HandlerSignature<REQ, RES>,
runner: suspend (Context, REQ) -> RES
): ServiceBuilder {
handlers[sig.name] = Handler(sig, runner)
handlers[sig.name] = Handler(sig, HandlerType.SHARED, runner)
return this
}

Expand All @@ -102,6 +112,7 @@ private constructor(

class Handler<REQ, RES, CTX : Context>(
private val handlerSignature: HandlerSignature<REQ, RES>,
private val handlerType: HandlerType,
private val runner: suspend (CTX, REQ) -> RES,
) : InvocationHandler<Options> {

Expand All @@ -112,6 +123,7 @@ private constructor(
fun toHandlerDefinition() =
HandlerDefinition(
handlerSignature.name,
handlerType,
handlerSignature.requestSerde.schema(),
handlerSignature.responseSerde.schema(),
this)
Expand Down
9 changes: 8 additions & 1 deletion sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ suspend inline fun <reified T : Any> Context.awakeable(): Awakeable<T> {
* This interface extends [Context] adding access to the virtual object instance key-value state
* storage.
*/
sealed interface ObjectContext : Context {
sealed interface SharedObjectContext : Context {

/** @return the key of this object */
fun key(): String
Expand All @@ -267,6 +267,13 @@ sealed interface ObjectContext : Context {
* @return the immutable collection of known state keys.
*/
suspend fun stateKeys(): Collection<String>
}

/**
* This interface extends [Context] adding access to the virtual object instance key-value state
* storage.
*/
sealed interface ObjectContext : SharedObjectContext {

/**
* Sets the given value under the given key, serializing the value using the [StateKey.serde].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class KotlinCoroutinesTests : TestRunner() {
): TestInvocationBuilder {
return TestDefinitions.testInvocation(
Service.virtualObject(name, Service.Options(Dispatchers.Unconfined)) {
handler("run", runner)
exclusiveHandler("run", runner)
},
"run")
}
Expand Down
28 changes: 1 addition & 27 deletions sdk-api/src/main/java/dev/restate/sdk/ObjectContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
package dev.restate.sdk;

import dev.restate.sdk.common.*;
import java.util.Collection;
import java.util.Optional;
import org.jspecify.annotations.NonNull;

/**
Expand All @@ -22,31 +20,7 @@
*
* @see Context
*/
public interface ObjectContext extends Context {

/**
* @return the key of this object
*/
String key();

/**
* Gets the state stored under key, deserializing the raw value using the {@link Serde} in the
* {@link StateKey}.
*
* @param key identifying the state to get and its type.
* @return an {@link Optional} containing the stored state deserialized or an empty {@link
* Optional} if not set yet.
* @throws RuntimeException when the state cannot be deserialized.
*/
<T> Optional<T> get(StateKey<T> key);

/**
* Gets all the known state keys for this virtual object instance.
*
* @return the immutable collection of known state keys.
*/
Collection<String> stateKeys();

public interface ObjectContext extends SharedObjectContext {
/**
* Clears the state stored under key.
*
Expand Down
16 changes: 13 additions & 3 deletions sdk-api/src/main/java/dev/restate/sdk/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,15 @@ public static class VirtualObjectBuilder extends AbstractServiceBuilder {
super(name);
}

public <REQ, RES> VirtualObjectBuilder with(
public <REQ, RES> VirtualObjectBuilder withShared(
HandlerSignature<REQ, RES> sig, BiFunction<SharedObjectContext, REQ, RES> runner) {
this.handlers.put(sig.getName(), new Handler<>(sig, HandlerType.SHARED, runner));
return this;
}

public <REQ, RES> VirtualObjectBuilder withExclusive(
HandlerSignature<REQ, RES> sig, BiFunction<ObjectContext, REQ, RES> runner) {
this.handlers.put(sig.getName(), new Handler<>(sig, runner));
this.handlers.put(sig.getName(), new Handler<>(sig, HandlerType.EXCLUSIVE, runner));
return this;
}

Expand All @@ -90,7 +96,7 @@ public static class ServiceBuilder extends AbstractServiceBuilder {

public <REQ, RES> ServiceBuilder with(
HandlerSignature<REQ, RES> sig, BiFunction<Context, REQ, RES> runner) {
this.handlers.put(sig.getName(), new Handler<>(sig, runner));
this.handlers.put(sig.getName(), new Handler<>(sig, HandlerType.SHARED, runner));
return this;
}

Expand All @@ -102,14 +108,17 @@ public Service build(Service.Options options) {
@SuppressWarnings("unchecked")
public static class Handler<REQ, RES> implements InvocationHandler<Service.Options> {
private final HandlerSignature<REQ, RES> handlerSignature;
private final HandlerType handlerType;
private final BiFunction<Context, REQ, RES> runner;

private static final Logger LOG = LogManager.getLogger(Handler.class);

public Handler(
HandlerSignature<REQ, RES> handlerSignature,
HandlerType handlerType,
BiFunction<? extends Context, REQ, RES> runner) {
this.handlerSignature = handlerSignature;
this.handlerType = handlerType;
this.runner = (BiFunction<Context, REQ, RES>) runner;
}

Expand All @@ -124,6 +133,7 @@ public BiFunction<Context, REQ, RES> getRunner() {
public HandlerDefinition<Service.Options> toHandlerDefinition() {
return new HandlerDefinition<>(
this.handlerSignature.name,
this.handlerType,
this.handlerSignature.requestSerde.schema(),
this.handlerSignature.responseSerde.schema(),
this);
Expand Down
Loading

0 comments on commit 82fc68e

Please sign in to comment.