From e46edce5fa8d347acec54c5905f2031eeb2c389f Mon Sep 17 00:00:00 2001 From: "Guillaume L." Date: Tue, 24 Sep 2024 11:20:56 +0200 Subject: [PATCH] Use already pre existing reponse header to do websocket handshake (#5324) --- .../core/http/impl/ServerWebSocketImpl.java | 3 +- .../io/vertx/core/http/WebSocketTest.java | 28 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/main/java/io/vertx/core/http/impl/ServerWebSocketImpl.java b/src/main/java/io/vertx/core/http/impl/ServerWebSocketImpl.java index 842d8b19620..da50efbd838 100644 --- a/src/main/java/io/vertx/core/http/impl/ServerWebSocketImpl.java +++ b/src/main/java/io/vertx/core/http/impl/ServerWebSocketImpl.java @@ -14,6 +14,7 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker; import io.vertx.core.AsyncResult; @@ -189,7 +190,7 @@ private void doHandshake() { Channel channel = conn.channel(); Http1xServerResponse response = request.response(); try { - handshaker.handshake(channel, request.nettyRequest()); + handshaker.handshake(channel, request.nettyRequest(), (HttpHeaders) response.headers(), channel.newPromise()); } catch (Exception e) { response.setStatusCode(BAD_REQUEST.code()).end(); throw e; diff --git a/src/test/java/io/vertx/core/http/WebSocketTest.java b/src/test/java/io/vertx/core/http/WebSocketTest.java index 964b0f3506f..b2a8d856804 100644 --- a/src/test/java/io/vertx/core/http/WebSocketTest.java +++ b/src/test/java/io/vertx/core/http/WebSocketTest.java @@ -3940,4 +3940,32 @@ public void testClientWebSocketExceptionHandlerIsCalled() { await(); } + @Test + public void testCustomResponseHeadersBeforeUpgrade() { + String path = "/some/path"; + String message = "here is some text data"; + String headerKey = "custom"; + String headerValue = "value"; + server = vertx.createHttpServer(new HttpServerOptions().setPort(DEFAULT_HTTP_PORT)).requestHandler(req -> { + req.response().headers().set(headerKey, headerValue); + req.toWebSocket() + .onComplete(event -> { + ServerWebSocket serverWebSocket = event.result(); + serverWebSocket.accept(); + serverWebSocket.writeFinalTextFrame(message); + }); + }); + server.listen(onSuccess(s -> { + client = vertx.createWebSocketClient(); + client.connect(DEFAULT_HTTP_PORT, HttpTestBase.DEFAULT_HTTP_HOST, path, onSuccess(ws -> { + assertTrue(ws.headers().contains(headerKey)); + assertEquals(headerValue, ws.headers().get(headerKey)); + ws.handler(buff -> { + assertEquals(message, buff.toString("UTF-8")); + testComplete(); + }); + })); + })); + await(); + } }