diff --git a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WebSocketHandler.java b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WebSocketHandler.java index a23d490794f..d4f5e726b98 100644 --- a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WebSocketHandler.java +++ b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WebSocketHandler.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, 2023 Oracle and/or its affiliates. + * Copyright (c) 2022, 2024 Oracle and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -70,6 +70,7 @@ class WebSocketHandler extends SimpleChannelInboundHandler { private volatile Connection connection; private final WebSocketEngine.UpgradeInfo upgradeInfo; private final BufferedEmittingPublisher emitter; + private final TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse(); WebSocketHandler(ChannelHandlerContext ctx, String path, FullHttpRequest upgradeRequest, @@ -140,6 +141,10 @@ public WebSocketEngine getWebSocketEngine() { this.upgradeInfo = upgrade(ctx); } + TyrusUpgradeResponse upgradeResponse() { + return upgradeResponse; + } + @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { LOGGER.log(Level.SEVERE, "WS handler ERROR ", cause); @@ -195,9 +200,7 @@ WebSocketEngine.UpgradeInfo upgrade(ChannelHandlerContext ctx) { upgradeRequest.headers().forEach(e -> requestContext.getHeaders().put(e.getKey(), List.of(e.getValue()))); // Use Tyrus to process a WebSocket upgrade request - final TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse(); - final WebSocketEngine.UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse); - + WebSocketEngine.UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse); upgradeResponse.getHeaders().forEach(this.upgradeResponseHeaders::add); return upgradeInfo; } diff --git a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WebSocketUpgradeCodec.java b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WebSocketUpgradeCodec.java index d5fac2e8918..367c3658bfe 100644 --- a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WebSocketUpgradeCodec.java +++ b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WebSocketUpgradeCodec.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. + * Copyright (c) 2022, 2024 Oracle and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,23 +15,38 @@ */ package io.helidon.webserver.websocket; +import java.nio.charset.Charset; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; +import io.helidon.common.http.Http; import io.helidon.webserver.ForwardingHandler; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.EmptyHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpServerUpgradeHandler; +import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; +import org.glassfish.tyrus.core.TyrusUpgradeResponse; class WebSocketUpgradeCodec implements HttpServerUpgradeHandler.UpgradeCodec { - private static final Logger LOGGER = Logger.getLogger(WebSocketUpgradeCodec.class.getName()); + private static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept"; + private static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; + private final WebSocketRouting webSocketRouting; private String path; private WebSocketHandler wsHandler; @@ -52,14 +67,43 @@ public boolean prepareUpgradeResponse(ChannelHandlerContext ctx, HttpHeaders upgradeResponseHeaders) { try { path = upgradeRequest.uri(); - upgradeResponseHeaders.remove("upgrade"); - upgradeResponseHeaders.remove("connection"); - this.wsHandler = new WebSocketHandler(ctx, path, upgradeRequest, upgradeResponseHeaders, webSocketRouting); - return true; + upgradeResponseHeaders.remove(Http.Header.UPGRADE); + upgradeResponseHeaders.remove(Http.Header.CONNECTION); + wsHandler = new WebSocketHandler(ctx, path, upgradeRequest, upgradeResponseHeaders, webSocketRouting); + + // if not 101 code, create and write to channel a custom user response of + // type text/plain using reason as payload and return false back to Netty + TyrusUpgradeResponse upgradeResponse = wsHandler.upgradeResponse(); + if (upgradeResponse.getStatus() != Http.Status.SWITCHING_PROTOCOLS_101.code()) { + // prepare headers for failed response + Map> upgradeHeaders = upgradeResponse.getHeaders(); + upgradeHeaders.remove(Http.Header.UPGRADE); + upgradeHeaders.remove(Http.Header.CONNECTION); + upgradeHeaders.remove(SEC_WEBSOCKET_ACCEPT); + upgradeHeaders.remove(SEC_WEBSOCKET_PROTOCOL); + HttpHeaders headers = new DefaultHttpHeaders(); + upgradeHeaders.forEach(headers::add); + + // set payload as text/plain with reason phrase + headers.add(Http.Header.CONTENT_TYPE, "text/plain"); + String reasonPhrase = upgradeResponse.getReasonPhrase() == null ? "" + : upgradeResponse.getReasonPhrase(); + HttpResponse httpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, + HttpResponseStatus.valueOf(upgradeResponse.getStatus()), + Unpooled.wrappedBuffer(reasonPhrase.getBytes(Charset.defaultCharset())), + headers, + EmptyHttpHeaders.INSTANCE); // trailing headers + + // write, flush and later close connection + ChannelFuture writeComplete = ctx.writeAndFlush(httpResponse); + writeComplete.addListener(ChannelFutureListener.CLOSE); + return false; + } } catch (Throwable cause) { LOGGER.log(Level.SEVERE, "Error during upgrade to WebSocket", cause); return false; } + return true; } @Override diff --git a/webserver/websocket/src/test/java/io/helidon/webserver/websocket/test/EchoEndpoint.java b/webserver/websocket/src/test/java/io/helidon/webserver/websocket/test/EchoEndpoint.java index 2ee68c32af1..b7f54c4017c 100644 --- a/webserver/websocket/src/test/java/io/helidon/webserver/websocket/test/EchoEndpoint.java +++ b/webserver/websocket/src/test/java/io/helidon/webserver/websocket/test/EchoEndpoint.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. + * Copyright (c) 2022, 2024 Oracle and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package io.helidon.webserver.websocket.test; import java.io.IOException; +import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.Logger; @@ -28,6 +29,7 @@ import jakarta.websocket.server.HandshakeRequest; import jakarta.websocket.server.ServerEndpoint; import jakarta.websocket.server.ServerEndpointConfig; +import org.glassfish.tyrus.core.TyrusUpgradeResponse; import static io.helidon.webserver.websocket.test.UppercaseCodec.isDecoded; @@ -86,6 +88,19 @@ public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, LOGGER.info("ServerConfigurator called during handshake"); super.modifyHandshake(sec, request, response); EchoEndpoint.modifyHandshakeCalled.set(true); + + // if not user Helidon, fail to authenticate, return reason and user header + String user = getUserFromParams(request); + if (!user.equals("Helidon") && response instanceof TyrusUpgradeResponse tyrusResponse) { + tyrusResponse.setStatus(401); + tyrusResponse.setReasonPhrase("Failed to authenticate"); + tyrusResponse.getHeaders().put("Endpoint", List.of("EchoEndpoint")); + } + } + + private String getUserFromParams(HandshakeRequest request) { + List values = request.getParameterMap().get("user"); + return values != null && !values.isEmpty() ? values.get(0) : ""; } } diff --git a/webserver/websocket/src/test/java/io/helidon/webserver/websocket/test/HandshakeFailureTest.java b/webserver/websocket/src/test/java/io/helidon/webserver/websocket/test/HandshakeFailureTest.java new file mode 100644 index 00000000000..3c11483452f --- /dev/null +++ b/webserver/websocket/src/test/java/io/helidon/webserver/websocket/test/HandshakeFailureTest.java @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2022, 2024 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.webserver.websocket.test; + +import java.net.URI; +import java.util.List; +import java.util.Map; + +import io.helidon.common.http.Http; + +import jakarta.websocket.DeploymentException; +import jakarta.websocket.HandshakeResponse; +import jakarta.websocket.server.HandshakeRequest; +import jakarta.websocket.server.ServerEndpointConfig; +import org.glassfish.tyrus.client.auth.AuthenticationException; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.fail; + +class HandshakeFailureTest extends TyrusSupportBaseTest { + + @BeforeAll + static void startServer() throws Exception { + webServer(true, EchoEndpoint.class); + } + + /** + * Should fail because user is not Helidon. See server handshake at + * {@link EchoEndpoint.ServerConfigurator#modifyHandshake(ServerEndpointConfig, HandshakeRequest, HandshakeResponse)}. + */ + @Test + void testEchoSingleUpgradeFail() { + URI uri = URI.create("ws://localhost:" + webServer().port() + "/tyrus/echo?user=Unknown"); + EchoClient echoClient = new EchoClient(uri); + try { + echoClient.echo("One"); + } catch (Exception e) { + assertThat(e, instanceOf(DeploymentException.class)); + assertThat(e.getCause(), instanceOf(AuthenticationException.class)); + AuthenticationException ae = (AuthenticationException) e.getCause(); + assertThat(ae.getHttpStatusCode(), is(401)); + assertThat(ae.getMessage(), is("Authentication failed.")); + return; + } + fail("Exception not thrown"); + } + + /** + * Should fail because user is not Helidon. See server handshake at + * {@link EchoEndpoint.ServerConfigurator#modifyHandshake(ServerEndpointConfig, HandshakeRequest, HandshakeResponse)}. + */ + @Test + void testEchoSingleUpgradeFailRaw() throws Exception { + String response = SocketHttpClient.sendAndReceive("/tyrus/echo?user=Unknown", + Http.Method.GET, + List.of("Connection:Upgrade", + "Upgrade:websocket", + "Sec-WebSocket-Key:0SBbaRkS/idPrmvImDNHBA==", + "Sec-WebSocket-Version:13"), + webServer()); + + assertThat(SocketHttpClient.statusFromResponse(response), + is(Http.Status.UNAUTHORIZED_401)); + assertThat(SocketHttpClient.entityFromResponse(response, false), + is("Failed to authenticate\n")); + Map headers = SocketHttpClient.headersFromResponse(response); + assertThat(headers.get("Endpoint"), is("EchoEndpoint")); + assertFalse(headers.containsKey("Connection") || headers.containsKey("connection")); + assertFalse(headers.containsKey("Upgrade") || headers.containsKey("upgrade")); + } +} diff --git a/webserver/websocket/src/test/java/io/helidon/webserver/websocket/test/SocketHttpClient.java b/webserver/websocket/src/test/java/io/helidon/webserver/websocket/test/SocketHttpClient.java new file mode 100644 index 00000000000..537a97f2ac5 --- /dev/null +++ b/webserver/websocket/src/test/java/io/helidon/webserver/websocket/test/SocketHttpClient.java @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2024 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.webserver.websocket.test; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import io.helidon.common.http.Http; +import io.helidon.webserver.WebServer; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.CoreMatchers.startsWith; +import static org.hamcrest.MatcherAssert.assertThat; + +/** + * A raw HTTP client to test WebSocket failed upgrades. Similar to SocketHttpClient + * in webserver, but simpler. + */ +public class SocketHttpClient implements AutoCloseable { + + private static final Logger LOGGER = Logger.getLogger(SocketHttpClient.class.getName()); + static final String EOL = "\r\n"; + private static final Pattern FIRST_LINE_PATTERN = Pattern.compile("HTTP/\\d+\\.\\d+ (\\d\\d\\d) (.*)"); + + private final Socket socket; + private final BufferedReader socketReader; + + /** + * Creates the instance linked with the provided webserver. + * + * @param webServer the webserver to link this client with + * @throws IOException in case of an error + */ + public SocketHttpClient(WebServer webServer) throws IOException { + socket = new Socket("localhost", webServer.port()); + socket.setSoTimeout(10000); + socketReader = new BufferedReader(new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)); + } + + /** + * A helper method that sends the given payload at the given path with the provided method to the webserver. + * + * @param path the path to access + * @param method the http method + * @param headers HTTP request headers + * @param webServer the webserver where to send the payload + * @return the exact string returned by webserver (including {@code HTTP/1.1 200 OK} line for instance) + * @throws Exception in case of an error + */ + public static String sendAndReceive(String path, + Http.RequestMethod method, + Iterable headers, + WebServer webServer) throws Exception { + try (SocketHttpClient s = new SocketHttpClient(webServer)) { + s.request(method, path, headers); + return s.receive(); + } + } + + /** + * Find headers in response and parse them. + * + * @param response full HTTP response + * @return headers map + */ + public static Map headersFromResponse(String response) { + assertThat(response, notNullValue()); + int index = response.indexOf("\n\n"); + if (index < 0) { + throw new AssertionError("Missing end of headers in response!"); + } + String hdrsPart = response.substring(0, index); + String[] lines = hdrsPart.split("\\n"); + if (lines.length <= 1) { + return Collections.emptyMap(); + } + Map result = new HashMap<>(lines.length - 1); + boolean first = true; + for (String line : lines) { + if (first) { + first = false; + continue; + } + int i = line.indexOf(':'); + if (i < 0) { + throw new AssertionError("Header without semicolon - " + line); + } + result.put(line.substring(0, i).trim(), line.substring(i + 1).trim()); + } + return result; + } + + /** + * Find the status line and return response HTTP status. + * + * @param response full HTTP response + * @return status + */ + public static Http.ResponseStatus statusFromResponse(String response) { + // response should start with HTTP/1.1 000 reasonPhrase\n + int eol = response.indexOf('\n'); + assertThat("There must be at least a line end after first line: " + response, eol > -1); + String firstLine = response.substring(0, eol).trim(); + + Matcher matcher = FIRST_LINE_PATTERN.matcher(firstLine); + assertThat("Status line must match the patter of 'HTTP/0.0 000 ReasonPhrase', but is: " + response, + matcher.matches()); + + int statusCode = Integer.parseInt(matcher.group(1)); + String phrase = matcher.group(2); + + return Http.ResponseStatus.create(statusCode, phrase); + } + + /** + * Get entity from response. + * + * @param response response with initial line, headers, and entity + * @param validateHeaderFormat whether to validate headers are correctly formatted + * @return entity string + */ + public static String entityFromResponse(String response, boolean validateHeaderFormat) { + assertThat(response, notNullValue()); + int index = response.indexOf("\n\n"); + if (index < 0) { + throw new AssertionError("Missing end of headers in response!"); + } + if (validateHeaderFormat) { + String headers = response.substring(0, index); + String[] lines = headers.split("\\n"); + assertThat(lines[0], startsWith("HTTP/")); + for (int i = 1; i < lines.length; i++) { + assertThat(lines[i], containsString(":")); + } + } + return response.substring(index + 2); + } + + /** + * Read the data from the socket. If socket is closed, an empty string is returned. + * + * @return the read data + * @throws IOException in case of an IO error + */ + public String receive() throws IOException { + StringBuilder sb = new StringBuilder(); + String t; + boolean ending = false; + int contentLength = -1; + while ((t = socketReader.readLine()) != null) { + LOGGER.finest("Received: " + t); + + if (t.toLowerCase().startsWith("content-length")) { + int k = t.indexOf(':'); + contentLength = Integer.parseInt(t.substring(k + 1).trim()); + } + + sb.append(t) + .append("\n"); + + if ("".equalsIgnoreCase(t) && contentLength >= 0) { + char[] content = new char[contentLength]; + socketReader.read(content); + sb.append(content); + break; + } + if (ending && "".equalsIgnoreCase(t)) { + break; + } + if (!ending && ("0".equalsIgnoreCase(t))) { + ending = true; + } + } + return sb.toString(); + } + + /** + * Sends a request to the webserver. + * + * @param path the path to access + * @param method the http method + * @param headers the headers (e.g., {@code Content-Type: application/json}) + * @throws IOException in case of an IO error + */ + public void request(Http.RequestMethod method, String path, Iterable headers) throws IOException { + request(method.name(), path, "HTTP/1.1", "localhost", headers); + } + + /** + * Send raw data to the server. + * + * @param method HTTP Method + * @param path path + * @param protocol protocol + * @param host host header value (if null, host header is not sent) + * @param headers headers (if null, additional headers are not sent) + * + * @throws IOException in case of an IO error + */ + public void request(String method, String path, String protocol, String host, Iterable headers) + throws IOException { + List usedHeaders = new LinkedList<>(); + if (headers != null) { + headers.forEach(usedHeaders::add); + } + if (host != null) { + usedHeaders.add(0, "Host: " + host); + } + PrintWriter pw = new PrintWriter(new OutputStreamWriter(socket.getOutputStream(), StandardCharsets.UTF_8)); + pw.print(method); + pw.print(" "); + pw.print(path); + pw.print(" "); + pw.print(protocol); + pw.print(EOL); + + for (String header : usedHeaders) { + pw.print(header); + pw.print(EOL); + } + + pw.print(EOL); + pw.print(EOL); + pw.flush(); + } + + @Override + public void close() throws Exception { + socket.close(); + } +}