diff --git a/vertx-grpc-client/src/main/java/io/vertx/grpc/client/GrpcClientOptions.java b/vertx-grpc-client/src/main/java/io/vertx/grpc/client/GrpcClientOptions.java index 70e03dd..bb979a9 100644 --- a/vertx-grpc-client/src/main/java/io/vertx/grpc/client/GrpcClientOptions.java +++ b/vertx-grpc-client/src/main/java/io/vertx/grpc/client/GrpcClientOptions.java @@ -36,9 +36,15 @@ public class GrpcClientOptions { */ public static final TimeUnit DEFAULT_TIMEOUT_UNIT = TimeUnit.SECONDS; + /** + * The default maximum message size in bytes accepted from a server = {@code 256KB} + */ + public static final long DEFAULT_MAX_MESSAGE_SIZE = 256 * 1024; + private boolean scheduleDeadlineAutomatically; private int timeout; private TimeUnit timeoutUnit; + private long maxMessageSize; /** * Default constructor. @@ -47,6 +53,7 @@ public GrpcClientOptions() { scheduleDeadlineAutomatically = DEFAULT_SCHEDULE_DEADLINE_AUTOMATICALLY; timeout = DEFAULT_TIMEOUT; timeoutUnit = DEFAULT_TIMEOUT_UNIT; + this.maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE; } /** @@ -58,6 +65,7 @@ public GrpcClientOptions(GrpcClientOptions other) { scheduleDeadlineAutomatically = other.scheduleDeadlineAutomatically; timeout = other.timeout; timeoutUnit = other.timeoutUnit; + maxMessageSize = other.maxMessageSize; } /** @@ -127,4 +135,27 @@ public GrpcClientOptions setTimeoutUnit(TimeUnit timeoutUnit) { this.timeoutUnit = Objects.requireNonNull(timeoutUnit); return this; } + + /** + * @return the maximum message size in bytes accepted by the client + */ + public long getMaxMessageSize() { + return maxMessageSize; + } + + /** + * Set the maximum message size in bytes accepted from a server, the maximum value is {@code 0xFFFFFFFF} + * @param maxMessageSize the size + * @return a reference to this, so the API can be used fluently + */ + public GrpcClientOptions setMaxMessageSize(long maxMessageSize) { + if (maxMessageSize <= 0) { + throw new IllegalArgumentException("Max message size must be > 0"); + } + if (maxMessageSize > 0xFFFFFFFFL) { + throw new IllegalArgumentException("Max message size must be <= 0xFFFFFFFF"); + } + this.maxMessageSize = maxMessageSize; + return this; + } } diff --git a/vertx-grpc-client/src/main/java/io/vertx/grpc/client/impl/GrpcClientImpl.java b/vertx-grpc-client/src/main/java/io/vertx/grpc/client/impl/GrpcClientImpl.java index ee853b6..67faac5 100644 --- a/vertx-grpc-client/src/main/java/io/vertx/grpc/client/impl/GrpcClientImpl.java +++ b/vertx-grpc-client/src/main/java/io/vertx/grpc/client/impl/GrpcClientImpl.java @@ -40,6 +40,7 @@ public class GrpcClientImpl implements GrpcClient { private HttpClient client; private boolean closeClient; private final boolean scheduleDeadlineAutomatically; + private final long maxMessageSize; private final int timeout; private final TimeUnit timeoutUnit; @@ -51,6 +52,7 @@ protected GrpcClientImpl(Vertx vertx, GrpcClientOptions grpcOptions, HttpClient this.vertx = vertx; this.client = client; this.scheduleDeadlineAutomatically = grpcOptions.getScheduleDeadlineAutomatically(); + this.maxMessageSize = grpcOptions.getMaxMessageSize();; this.timeout = grpcOptions.getTimeout(); this.timeoutUnit = grpcOptions.getTimeoutUnit(); this.closeClient = close; @@ -59,7 +61,12 @@ protected GrpcClientImpl(Vertx vertx, GrpcClientOptions grpcOptions, HttpClient public Future> request(RequestOptions options) { return client.request(options) .map(httpRequest -> { - GrpcClientRequestImpl grpcRequest = new GrpcClientRequestImpl<>(httpRequest, scheduleDeadlineAutomatically, GrpcMessageEncoder.IDENTITY, GrpcMessageDecoder.IDENTITY); + GrpcClientRequestImpl grpcRequest = new GrpcClientRequestImpl<>( + httpRequest, + maxMessageSize, + scheduleDeadlineAutomatically, + GrpcMessageEncoder.IDENTITY, + GrpcMessageDecoder.IDENTITY); grpcRequest.init(); configureTimeout(grpcRequest); return grpcRequest; @@ -107,7 +114,12 @@ public Future> request(Address server, private Future> request(RequestOptions options, ServiceMethod method) { return client.request(options) .map(request -> { - GrpcClientRequestImpl call = new GrpcClientRequestImpl<>(request, scheduleDeadlineAutomatically, method.encoder(), method.decoder()); + GrpcClientRequestImpl call = new GrpcClientRequestImpl<>( + request, + maxMessageSize, + scheduleDeadlineAutomatically, + method.encoder(), + method.decoder()); call.init(); call.serviceName(method.serviceName()); call.methodName(method.methodName()); diff --git a/vertx-grpc-client/src/main/java/io/vertx/grpc/client/impl/GrpcClientRequestImpl.java b/vertx-grpc-client/src/main/java/io/vertx/grpc/client/impl/GrpcClientRequestImpl.java index dcd8dab..eef845b 100644 --- a/vertx-grpc-client/src/main/java/io/vertx/grpc/client/impl/GrpcClientRequestImpl.java +++ b/vertx-grpc-client/src/main/java/io/vertx/grpc/client/impl/GrpcClientRequestImpl.java @@ -48,6 +48,7 @@ public class GrpcClientRequestImpl extends GrpcWriteStreamBase messageEncoder, GrpcMessageDecoder messageDecoder) { @@ -76,8 +77,18 @@ public GrpcClientRequestImpl(HttpClientRequest httpRequest, } } if (format != null || status != null) { - GrpcClientResponseImpl grpcResponse = new GrpcClientResponseImpl<>(context, this, format, status, httpResponse, messageDecoder); + GrpcClientResponseImpl grpcResponse = new GrpcClientResponseImpl<>( + context, + this, + format, + maxMessageSize, + status, + httpResponse, + messageDecoder); grpcResponse.init(this); + grpcResponse.invalidMessageHandler(invalidMsg -> { + cancel(); + }); return Future.succeededFuture(grpcResponse); } httpResponse.request().reset(GrpcError.CANCELLED.http2ResetCode); diff --git a/vertx-grpc-client/src/main/java/io/vertx/grpc/client/impl/GrpcClientResponseImpl.java b/vertx-grpc-client/src/main/java/io/vertx/grpc/client/impl/GrpcClientResponseImpl.java index 452b012..9c1e65a 100644 --- a/vertx-grpc-client/src/main/java/io/vertx/grpc/client/impl/GrpcClientResponseImpl.java +++ b/vertx-grpc-client/src/main/java/io/vertx/grpc/client/impl/GrpcClientResponseImpl.java @@ -36,9 +36,10 @@ public class GrpcClientResponseImpl extends GrpcReadStreamBase request, WireFormat format, + long maxMessageSize, GrpcStatus status, HttpClientResponse httpResponse, GrpcMessageDecoder messageDecoder) { - super(context, httpResponse, httpResponse.headers().get("grpc-encoding"), format, messageDecoder); + super(context, httpResponse, httpResponse.headers().get("grpc-encoding"), format, maxMessageSize, messageDecoder); this.request = request; this.httpResponse = httpResponse; this.status = status; diff --git a/vertx-grpc-client/src/test/java/io/vertx/tests/client/ClientRequestTest.java b/vertx-grpc-client/src/test/java/io/vertx/tests/client/ClientRequestTest.java index 77c0352..d04ec95 100644 --- a/vertx-grpc-client/src/test/java/io/vertx/tests/client/ClientRequestTest.java +++ b/vertx-grpc-client/src/test/java/io/vertx/tests/client/ClientRequestTest.java @@ -17,6 +17,7 @@ import io.grpc.examples.streaming.Empty; import io.grpc.examples.streaming.Item; import io.grpc.examples.streaming.StreamingGrpc; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import io.vertx.core.http.HttpClientOptions; import io.vertx.core.http.HttpHeaders; @@ -36,7 +37,10 @@ import java.io.File; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Base64; +import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -699,4 +703,81 @@ public void testJsonMessageFormat(TestContext should) throws Exception { done.awaitSuccess(); } + + @Test + public void testDefaultMessageSizeOverflow(TestContext should) throws Exception { + + Async test = should.async(); + + Item item = Item.newBuilder().setValue("Asmoranomardicadaistinaculdacar").build(); + int itemLen = item.getSerializedSize(); + + StreamingGrpc.StreamingImplBase called = new StreamingGrpc.StreamingImplBase() { + @Override + public void source(Empty request, StreamObserver responseObserver) { + ServerCallStreamObserver callStreamObserver = (ServerCallStreamObserver) responseObserver; + callStreamObserver.setOnCancelHandler(() -> { + test.complete(); + }); + responseObserver.onNext(item); + } + }; + startServer(called); + + GrpcClient client = GrpcClient.client(vertx, new GrpcClientOptions().setMaxMessageSize(itemLen - 1)); + client.request(SocketAddress.inetSocketAddress(port, "localhost"), STREAMING_SOURCE) + .onComplete(should.asyncAssertSuccess(callRequest -> { + callRequest.response().onComplete(should.asyncAssertSuccess(callResponse -> { + callResponse.handler(msg -> { + should.fail(); + }); + })); + callRequest.end(Empty.getDefaultInstance()); + })); + + test.awaitSuccess(20_000); + } + + @Test + public void testInvalidMessageHandlerStream(TestContext should) throws Exception { + + Async test = should.async(); + + List items = Arrays.asList( + Item.newBuilder().setValue("msg1").build(), + Item.newBuilder().setValue("Asmoranomardicadaistinaculdacar").build(), + Item.newBuilder().setValue("msg3").build() + ); + + int itemLen = items.get(1).getSerializedSize(); + + StreamingGrpc.StreamingImplBase called = new StreamingGrpc.StreamingImplBase() { + @Override + public void source(Empty request, StreamObserver responseObserver) { + items.forEach(item -> responseObserver.onNext(item)); + responseObserver.onCompleted(); + } + }; + startServer(called); + + GrpcClient client = GrpcClient.client(vertx, new GrpcClientOptions().setMaxMessageSize(itemLen - 1)); + client.request(SocketAddress.inetSocketAddress(port, "localhost"), STREAMING_SOURCE) + .onComplete(should.asyncAssertSuccess(callRequest -> { + callRequest.response().onComplete(should.asyncAssertSuccess(callResponse -> { + List received = new ArrayList<>(); + callResponse.invalidMessageHandler(received::add); + callResponse.handler(received::add); + callResponse.endHandler(v -> { + should.assertEquals(Item.class, received.get(0).getClass()); + should.assertEquals(MessageSizeOverflowException.class, received.get(1).getClass()); + should.assertEquals(Item.class, received.get(2).getClass()); + should.assertEquals(3, received.size()); + test.complete(); + }); + })); + callRequest.end(Empty.getDefaultInstance()); + })); + + test.awaitSuccess(20_000); + } } diff --git a/vertx-grpc-common/src/main/java/io/vertx/grpc/common/GrpcMessageDecoder.java b/vertx-grpc-common/src/main/java/io/vertx/grpc/common/GrpcMessageDecoder.java index 8efd62e..27e6421 100644 --- a/vertx-grpc-common/src/main/java/io/vertx/grpc/common/GrpcMessageDecoder.java +++ b/vertx-grpc-common/src/main/java/io/vertx/grpc/common/GrpcMessageDecoder.java @@ -41,7 +41,7 @@ public T decode(GrpcMessage msg) throws CodecException { try { return parser.parseFrom(msg.payload().getBytes()); } catch (InvalidProtocolBufferException e) { - return null; + throw new CodecException(e); } } @Override diff --git a/vertx-grpc-common/src/main/java/io/vertx/grpc/common/GrpcReadStream.java b/vertx-grpc-common/src/main/java/io/vertx/grpc/common/GrpcReadStream.java index 1ca6e5c..d771cb4 100644 --- a/vertx-grpc-common/src/main/java/io/vertx/grpc/common/GrpcReadStream.java +++ b/vertx-grpc-common/src/main/java/io/vertx/grpc/common/GrpcReadStream.java @@ -1,6 +1,7 @@ package io.vertx.grpc.common; import io.vertx.codegen.annotations.Fluent; +import io.vertx.codegen.annotations.GenIgnore; import io.vertx.codegen.annotations.Nullable; import io.vertx.codegen.annotations.VertxGen; import io.vertx.core.Future; @@ -36,6 +37,19 @@ public interface GrpcReadStream extends ReadStream { @Fluent GrpcReadStream messageHandler(@Nullable Handler handler); + /** + * Set a message handler that is reported with invalid message errors. + * + *

Warning: setting this handler overwrite the default handler which takes appropriate measure + * when an invalid message is encountered such as cancelling the stream. This handler should be set + * when control over invalid messages is required.

+ * + * @param handler the invalid message handler + * @return a reference to this, so the API can be used fluently + */ + @GenIgnore(GenIgnore.PERMITTED_TYPE) + GrpcReadStream invalidMessageHandler(@Nullable Handler handler); + /** * Set a handler to be notified with gRPC errors. * diff --git a/vertx-grpc-common/src/main/java/io/vertx/grpc/common/InvalidMessageException.java b/vertx-grpc-common/src/main/java/io/vertx/grpc/common/InvalidMessageException.java new file mode 100644 index 0000000..9dd15b7 --- /dev/null +++ b/vertx-grpc-common/src/main/java/io/vertx/grpc/common/InvalidMessageException.java @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2011-2024 Contributors to the Eclipse Foundation + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 + * which is available at https://www.apache.org/licenses/LICENSE-2.0. + * + * SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + */ +package io.vertx.grpc.common; + +import io.vertx.core.VertxException; + +/** + * Signals an invalid message. + * + * @author Julien Viet + */ +public abstract class InvalidMessageException extends VertxException { + + InvalidMessageException() { + super((String) null, true); + } + + InvalidMessageException(Throwable cause) { + super(cause, true); + } +} diff --git a/vertx-grpc-common/src/main/java/io/vertx/grpc/common/InvalidMessagePayloadException.java b/vertx-grpc-common/src/main/java/io/vertx/grpc/common/InvalidMessagePayloadException.java new file mode 100644 index 0000000..6aa472e --- /dev/null +++ b/vertx-grpc-common/src/main/java/io/vertx/grpc/common/InvalidMessagePayloadException.java @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2011-2024 Contributors to the Eclipse Foundation + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 + * which is available at https://www.apache.org/licenses/LICENSE-2.0. + * + * SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + */ +package io.vertx.grpc.common; + +/** + * Signals a message with an invalid payload, i.e. that could not be decoded by the protobuf codec. + * + * @author Julien Viet + */ +public final class InvalidMessagePayloadException extends InvalidMessageException { + + private GrpcMessage message; + + public InvalidMessagePayloadException(GrpcMessage message, Throwable cause) { + super(cause); + this.message = message; + } + + /** + * @return the invalid message that could not be decoded. + */ + public GrpcMessage message() { + return message; + } +} diff --git a/vertx-grpc-common/src/main/java/io/vertx/grpc/common/MessageSizeOverflowException.java b/vertx-grpc-common/src/main/java/io/vertx/grpc/common/MessageSizeOverflowException.java new file mode 100644 index 0000000..b735cef --- /dev/null +++ b/vertx-grpc-common/src/main/java/io/vertx/grpc/common/MessageSizeOverflowException.java @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2011-2024 Contributors to the Eclipse Foundation + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 + * which is available at https://www.apache.org/licenses/LICENSE-2.0. + * + * SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + */ +package io.vertx.grpc.common; + +/** + * Signals a message that is longer than the maximum configured size. + * + * @author Julien Viet + */ +public final class MessageSizeOverflowException extends InvalidMessageException { + + private final long messageSize; + + public MessageSizeOverflowException(long messageSize) { + this.messageSize = messageSize; + } + + public long messageSize() { + return messageSize; + } +} diff --git a/vertx-grpc-common/src/main/java/io/vertx/grpc/common/impl/GrpcReadStreamBase.java b/vertx-grpc-common/src/main/java/io/vertx/grpc/common/impl/GrpcReadStreamBase.java index 64abbdf..84a3956 100644 --- a/vertx-grpc-common/src/main/java/io/vertx/grpc/common/impl/GrpcReadStreamBase.java +++ b/vertx-grpc-common/src/main/java/io/vertx/grpc/common/impl/GrpcReadStreamBase.java @@ -48,22 +48,31 @@ public Buffer payload() { protected final ContextInternal context; private final String encoding; + private final long maxMessageSize; private final WireFormat format; private final ReadStream stream; private final InboundMessageQueue queue; private Buffer buffer; + private long bytesToSkip; private Handler exceptionHandler; private Handler messageHandler; private Handler endHandler; + private Handler invalidMessageHandler; private GrpcMessage last; private final GrpcMessageDecoder messageDecoder; private final Promise end; private GrpcWriteStreamBase ws; - protected GrpcReadStreamBase(Context context, ReadStream stream, String encoding, WireFormat format, GrpcMessageDecoder messageDecoder) { + protected GrpcReadStreamBase(Context context, + ReadStream stream, + String encoding, + WireFormat format, + long maxMessageSize, + GrpcMessageDecoder messageDecoder) { ContextInternal ctx = (ContextInternal) context; this.context = ctx; this.encoding = encoding; + this.maxMessageSize = maxMessageSize; this.stream = stream; this.format = format; this.queue = new InboundMessageQueue<>(ctx.nettyEventLoop(), ctx, 8, 16) { @@ -160,6 +169,34 @@ public final S messageHandler(Handler handler) { return (S) this; } + @Override + public final S invalidMessageHandler(@Nullable Handler handler) { + invalidMessageHandler = handler; + return (S) this; + } + + @Override + public S handler(@Nullable Handler handler) { + if (handler != null) { + return messageHandler(msg -> { + T decoded; + try { + decoded = decodeMessage(msg); + } catch (CodecException e) { + Handler errorHandler = invalidMessageHandler; + if (errorHandler != null) { + InvalidMessagePayloadException impe = new InvalidMessagePayloadException(msg, e); + errorHandler.handle(impe); + } + return; + } + handler.handle(decoded); + }); + } else { + return messageHandler(null); + } + } + @Override public final S endHandler(Handler endHandler) { this.endHandler = endHandler; @@ -167,19 +204,49 @@ public final S endHandler(Handler endHandler) { } public void handle(Buffer chunk) { + if (bytesToSkip > 0L) { + int len = chunk.length(); + if (len <= bytesToSkip) { + bytesToSkip -= len; + return; + } + chunk = chunk.slice((int)bytesToSkip, len); + bytesToSkip = 0L; + } if (buffer == null) { buffer = chunk; } else { buffer.appendBuffer(chunk); } int idx = 0; - int len; - while (idx + 5 <= buffer.length() && (idx + 5 + (len = buffer.getInt(idx + 1)))<= buffer.length()) { + while (true) { + if (idx + 5 > buffer.length()) { + break; + } + long len = ((long)buffer.getInt(idx + 1)) & 0xFFFFFFFFL; + if (len > maxMessageSize) { + Handler handler = invalidMessageHandler; + if (handler != null) { + MessageSizeOverflowException msoe = new MessageSizeOverflowException(len); + context.dispatch(msoe, handler); + } + if (buffer.length() < (len + 5)) { + bytesToSkip = (len + 5) - buffer.length(); + buffer = null; + return; + } else { + buffer = buffer.slice((int)(len + 5), buffer.length()); + continue; + } + } + if (len > buffer.length() - (idx + 5)) { + break; + } boolean compressed = buffer.getByte(idx) == 1; if (compressed && encoding == null) { throw new UnsupportedOperationException("Handle me"); } - Buffer payload = buffer.slice(idx + 5, idx + 5 + len); + Buffer payload = buffer.slice(idx + 5, (int)(idx + 5 + len)); GrpcMessage message = GrpcMessage.message(compressed ? encoding : "identity", format, payload); queue.write(message); idx += 5 + len; @@ -207,7 +274,7 @@ protected void handleEnd() { } } - protected void handleMessage(GrpcMessage msg) { + private void handleMessage(GrpcMessage msg) { last = msg; Handler handler = messageHandler; if (handler != null) { diff --git a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcServerOptions.java b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcServerOptions.java index 2830c36..58784d2 100644 --- a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcServerOptions.java +++ b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcServerOptions.java @@ -38,9 +38,15 @@ public class GrpcServerOptions { */ public static final boolean DEFAULT_PROPAGATE_DEADLINE = false; + /** + * The default maximum message size in bytes accepted from a client = {@code 256KB} + */ + public static final long DEFAULT_MAX_MESSAGE_SIZE = 256 * 1024; + private boolean grpcWebEnabled; private boolean scheduleDeadlineAutomatically; private boolean deadlinePropagation; + private long maxMessageSize; /** * Default options. @@ -49,6 +55,7 @@ public GrpcServerOptions() { grpcWebEnabled = DEFAULT_GRPC_WEB_ENABLED; scheduleDeadlineAutomatically = DEFAULT_SCHEDULE_DEADLINE_AUTOMATICALLY; deadlinePropagation = DEFAULT_PROPAGATE_DEADLINE; + maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE; } /** @@ -58,6 +65,7 @@ public GrpcServerOptions(GrpcServerOptions other) { grpcWebEnabled = other.grpcWebEnabled; scheduleDeadlineAutomatically = other.scheduleDeadlineAutomatically; deadlinePropagation = other.deadlinePropagation; + maxMessageSize = other.maxMessageSize; } /** @@ -130,6 +138,30 @@ public GrpcServerOptions setDeadlinePropagation(boolean deadlinePropagation) { return this; } + + /** + * @return the maximum message size in bytes accepted by the server + */ + public long getMaxMessageSize() { + return maxMessageSize; + } + + /** + * Set the maximum message size in bytes accepted from a client, the maximum value is {@code 0xFFFFFFFF} + * @param maxMessageSize the size + * @return a reference to this, so the API can be used fluently + */ + public GrpcServerOptions setMaxMessageSize(long maxMessageSize) { + if (maxMessageSize <= 0) { + throw new IllegalArgumentException("Max message size must be > 0"); + } + if (maxMessageSize > 0xFFFFFFFFL) { + throw new IllegalArgumentException("Max message size must be <= 0xFFFFFFFF"); + } + this.maxMessageSize = maxMessageSize; + return this; + } + /** * @return a JSON representation of options */ diff --git a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerImpl.java b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerImpl.java index 9badef7..982a27a 100644 --- a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerImpl.java +++ b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerImpl.java @@ -143,6 +143,7 @@ private void handle(HttpServerRequest httpRequest, options.getScheduleDeadlineAutomatically(), protocol, format, + options.getMaxMessageSize(), httpRequest, messageDecoder, messageEncoder, @@ -152,6 +153,13 @@ private void handle(HttpServerRequest httpRequest, context.putLocal(GrpcLocal.CONTEXT_LOCAL_KEY, AccessMode.CONCURRENT, new GrpcLocal(deadline)); } grpcRequest.init(grpcRequest.response); + grpcRequest.invalidMessageHandler(invalidMsg -> { + if (invalidMsg instanceof MessageSizeOverflowException) { + grpcRequest.response().status(GrpcStatus.RESOURCE_EXHAUSTED).end(); + } else { + grpcRequest.response.cancel(); + } + }); context.dispatch(grpcRequest, handler); } diff --git a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerRequestImpl.java b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerRequestImpl.java index 22051b8..28e925d 100644 --- a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerRequestImpl.java +++ b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerRequestImpl.java @@ -83,11 +83,12 @@ public GrpcServerRequestImpl(io.vertx.core.internal.ContextInternal context, boolean scheduleDeadline, GrpcProtocol protocol, WireFormat format, + long maxMessageSize, HttpServerRequest httpRequest, GrpcMessageDecoder messageDecoder, GrpcMessageEncoder messageEncoder, GrpcMethodCall methodCall) { - super(context, httpRequest, httpRequest.headers().get("grpc-encoding"), format, messageDecoder); + super(context, httpRequest, httpRequest.headers().get("grpc-encoding"), format, maxMessageSize, messageDecoder); String timeoutHeader = httpRequest.getHeader("grpc-timeout"); long timeout = timeoutHeader != null ? parseTimeout(timeoutHeader) : 0L; @@ -153,7 +154,7 @@ public String methodName() { } @Override - public GrpcServerRequest handler(Handler handler) { + public GrpcServerRequestImpl handler(Handler handler) { if (handler != null) { return messageHandler(msg -> { Req decoded; diff --git a/vertx-grpc-server/src/test/java/io/vertx/tests/server/ServerRequestTest.java b/vertx-grpc-server/src/test/java/io/vertx/tests/server/ServerRequestTest.java index 27f6863..6a2d61f 100644 --- a/vertx-grpc-server/src/test/java/io/vertx/tests/server/ServerRequestTest.java +++ b/vertx-grpc-server/src/test/java/io/vertx/tests/server/ServerRequestTest.java @@ -21,6 +21,7 @@ import io.grpc.stub.StreamObserver; import io.vertx.core.MultiMap; import io.vertx.core.Timer; +import io.vertx.core.buffer.Buffer; import io.vertx.core.http.*; import io.vertx.core.internal.ContextInternal; import io.vertx.core.json.JsonObject; @@ -35,7 +36,10 @@ import java.io.File; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Base64; +import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -433,4 +437,115 @@ public void testJsonMessageFormat(TestContext should) throws Exception { super.testJsonMessageFormat(should, "application/grpc+json"); } + + @Test + public void testDefaultMessageSizeOverflow(TestContext should) { + + HelloRequest request = HelloRequest.newBuilder().setName("Asmoranomardicadaistinaculdacar").build(); + int requestLen = request.getSerializedSize(); + + startServer(GrpcServer.server(vertx, new GrpcServerOptions().setMaxMessageSize(requestLen - 1)) + .callHandler(GREETER_SAY_HELLO, call -> { + })); + + channel = ManagedChannelBuilder.forAddress("localhost", port) + .usePlaintext() + .build(); + + GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(channel); + + try { + stub.sayHello(request); + should.fail(); + } catch (StatusRuntimeException ignore) { + should.assertEquals(Status.RESOURCE_EXHAUSTED.getCode(), ignore.getStatus().getCode()); + } + } + + @Test + public void testInvalidMessageHandler(TestContext should) { + + HelloRequest request = HelloRequest.newBuilder().setName("Asmoranomardicadaistinaculdacar").build(); + int requestLen = request.getSerializedSize(); + + startServer(GrpcServer.server(vertx, new GrpcServerOptions().setMaxMessageSize(requestLen - 1)) + .callHandler(GREETER_SAY_HELLO, call -> { + AtomicInteger invalid = new AtomicInteger(); + call.handler(msg -> { + should.fail(); + }); + call.invalidMessageHandler(err -> { + should.assertEquals(0, invalid.getAndIncrement()); + }); + call.endHandler(v -> { + call.response().end(HelloReply.newBuilder().setMessage("Hola").build()); + }); + })); + + channel = ManagedChannelBuilder.forAddress("localhost", port) + .usePlaintext() + .build(); + + GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(channel); + + HelloReply resp = stub.sayHello(request); + should.assertEquals("Hola", resp.getMessage()); + } + + @Test + public void testInvalidMessageHandlerStream(TestContext should) { + + List messages = Arrays.asList( + Buffer.buffer(Item.newBuilder().setValue("msg1").build().toByteArray()), + Buffer.buffer(Item.newBuilder().setValue("msg2-invalid").build().toByteArray()), + Buffer.buffer(Item.newBuilder().setValue("msg3").build().toByteArray()), + Buffer.buffer(new byte[]{ 0,1,2,3,4,5,6,7 }), + Buffer.buffer(Item.newBuilder().setValue("msg5").build().toByteArray()) + ); + + int invalidLen = messages.get(1).length() - 1; + + startServer(GrpcServer.server(vertx, new GrpcServerOptions().setMaxMessageSize(invalidLen - 1)).callHandler(STREAMING_SINK, call -> { + List received = new ArrayList<>(); + call.invalidMessageHandler(received::add); + call.handler(received::add); + call.endHandler(v -> { + should.assertEquals(Item.class, received.get(0).getClass()); + should.assertEquals(MessageSizeOverflowException.class, received.get(1).getClass()); + should.assertEquals(Item.class, received.get(2).getClass()); + should.assertEquals(InvalidMessagePayloadException.class, received.get(3).getClass()); + should.assertEquals(Item.class, received.get(4).getClass()); + should.assertEquals(5, received.size()); + call.response().end(Empty.getDefaultInstance()); + }); + })); + + Async test = should.async(); + + HttpClient client = vertx.createHttpClient(new HttpClientOptions() + .setProtocolVersion(HttpVersion.HTTP_2) + .setHttp2ClearTextUpgrade(false) + ); + + client.request(HttpMethod.POST, 8080, "localhost", "/" + StreamingGrpc.SERVICE_NAME + "/Sink") + .onComplete(should.asyncAssertSuccess(request -> { + request.putHeader("grpc-encoding", "gzip"); + request.setChunked(true); + messages.forEach(msg -> { + Buffer buffer = Buffer.buffer(); + buffer.appendByte((byte) 0); // Uncompressed + buffer.appendInt(msg.length()); + buffer.appendBuffer(msg); + request.write(buffer); + }); + request.end(); + request.response().onComplete(should.asyncAssertSuccess(response -> { + response.end().onComplete(should.asyncAssertSuccess(v -> { + test.complete(); + })); + })); + })); + + test.awaitSuccess(20_000); + } }