Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package io.quarkus.grpc.server.interceptors;

import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.time.Duration;
import java.util.logging.LogRecord;

import jakarta.enterprise.context.ApplicationScoped;

import org.assertj.core.api.Condition;
import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.asset.StringAsset;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.examples.helloworld.Greeter;
import io.grpc.examples.helloworld.GreeterBean;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.quarkus.grpc.GlobalInterceptor;
import io.quarkus.grpc.GrpcClient;
import io.quarkus.grpc.server.services.HelloService;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.mutiny.Uni;

public class ClosingCallInInterceptorTest {

private static final Metadata.Key<String> CLOSE_REASON_KEY = Metadata.Key.of("CUSTOM_CLOSE_REASON",
Metadata.ASCII_STRING_MARSHALLER);
private static final String STATED_REASON_TO_CLOSE = "Because I want to close it.";

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest().setArchiveProducer(
() -> ShrinkWrap.create(JavaArchive.class)
.addPackage(GreeterGrpc.class.getPackage())
.addClasses(MyClosingCallInterceptor.class, GreeterBean.class, HelloRequest.class, HelloService.class)
.addAsResource(new StringAsset("quarkus.grpc.server.use-separate-server=false" + System.lineSeparator()),
"application.properties"))
.setLogRecordPredicate(
record -> record.getMessage() != null && record.getMessage().contains("Closing gRPC call due to an error"))
.assertLogRecords(logRecords -> {
if (!logRecords.isEmpty()) {
for (LogRecord logRecord : logRecords) {
if (logRecord.getThrown() instanceof IllegalStateException ise
&& ise.getMessage().contains("Already closed")) {
Assertions.fail("Log contains message with 'java.lang.IllegalStateException: Already closed'");
}
}
}
});

@GrpcClient
Greeter greeter;

@Test
void test() {
Uni<HelloReply> result = greeter.sayHello(HelloRequest.newBuilder().setName("ServiceA").build());
assertThatThrownBy(() -> result.await().atMost(Duration.ofSeconds(4)))
.isInstanceOf(StatusRuntimeException.class)
.has(new Condition<Throwable>(t -> {
if (t instanceof StatusRuntimeException statusRuntimeException) {
var trailers = statusRuntimeException.getTrailers();
if (trailers != null) {
return STATED_REASON_TO_CLOSE.equals(trailers.get(CLOSE_REASON_KEY));
}
}
return false;
}, "Checking close reason returned in metadata"))
.hasMessageContaining("UNAUTHENTICATED");
}

@ApplicationScoped
@GlobalInterceptor
public static class MyClosingCallInterceptor implements ServerInterceptor {

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
var metadata = new Metadata();
metadata.put(CLOSE_REASON_KEY, STATED_REASON_TO_CLOSE);
call.close(Status.UNAUTHENTICATED, metadata);
return new ServerCall.Listener<>() {
};
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.function.Function;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.spi.Prioritized;
import jakarta.inject.Inject;

import org.jboss.logging.Logger;

import io.grpc.ForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusException;
import io.quarkus.grpc.ExceptionHandlerProvider;
import io.quarkus.grpc.GlobalInterceptor;
Expand Down Expand Up @@ -54,15 +56,27 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
}
}

private <ReqT, RespT> Supplier<ServerCall.Listener<ReqT>> nextCall(ServerCall<ReqT, RespT> call,
private <ReqT, RespT> Function<Runnable, ServerCall.Listener<ReqT>> nextCall(ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
// Must be sure to call next.startCall on the right context
io.grpc.Context current = io.grpc.Context.current();
return () -> {
return onClose -> {
io.grpc.Context previous = current.attach();
try {
return next.startCall(call, headers);
var forwardingCall = new ForwardingServerCall<ReqT, RespT>() {
@Override
protected ServerCall<ReqT, RespT> delegate() {
return call;
}

@Override
public void close(Status status, Metadata trailers) {
onClose.run();
super.close(status, trailers);
}
};
return next.startCall(forwardingCall, headers);
} finally {
current.detach(previous);
}
Expand All @@ -77,31 +91,35 @@ public int getPriority() {
static class ListenedOnDuplicatedContext<ReqT, RespT> extends ServerCall.Listener<ReqT> {

private final Context context;
private final Supplier<ServerCall.Listener<ReqT>> supplier;
private final Function<Runnable, ServerCall.Listener<ReqT>> listenerCreator;
private final ExceptionHandlerProvider ehp;
private final ServerCall<ReqT, RespT> call;
private ServerCall.Listener<ReqT> delegate;
private volatile ServerCall.Listener<ReqT> delegate;

private final AtomicBoolean closed = new AtomicBoolean();

public ListenedOnDuplicatedContext(
ExceptionHandlerProvider ehp,
ServerCall<ReqT, RespT> call, Supplier<ServerCall.Listener<ReqT>> supplier, Context context) {
ServerCall<ReqT, RespT> call, Function<Runnable, ServerCall.Listener<ReqT>> listenerCreator, Context context) {
this.ehp = ehp;
this.context = context;
this.supplier = supplier;
this.listenerCreator = listenerCreator;
this.call = call;
}

private synchronized ServerCall.Listener<ReqT> getDelegate() {
if (delegate == null) {
try {
delegate = supplier.get();
} catch (Throwable t) {
// If the interceptor supplier throws an exception, catch it, and close the call.
log.warn("Unable to retrieve gRPC Server call listener, see the cause below.");
close(t);
return null;
private ServerCall.Listener<ReqT> getDelegate() {
if (delegate == null && !closed.get()) {
synchronized (this) {
if (delegate == null && !closed.get()) {
try {
delegate = listenerCreator.apply(() -> closed.set(true));
} catch (Throwable t) {
// If the interceptor supplier throws an exception, catch it, and close the call.
log.warn("Unable to retrieve gRPC Server call listener, see the cause below.");
close(t);
return null;
}
}
}
}
return delegate;
Expand Down
Loading