Skip to content

Commit

Permalink
Adds support for users to control the outcome of a WebSocket upgrade …
Browse files Browse the repository at this point in the history
…request. If the user handler returns a non-101 code, the protocol upgrade fails and a response is written back based on the data returned by the handler, including the error code, headers and the reason for the failure. See issue 7953. Some new tests.
  • Loading branch information
spericas committed Apr 1, 2024
1 parent f661067 commit f0fc904
Show file tree
Hide file tree
Showing 5 changed files with 414 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, 2023 Oracle and/or its affiliates.
* Copyright (c) 2022, 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -70,6 +70,7 @@ class WebSocketHandler extends SimpleChannelInboundHandler<Object> {
private volatile Connection connection;
private final WebSocketEngine.UpgradeInfo upgradeInfo;
private final BufferedEmittingPublisher<ByteBuf> emitter;
private final TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse();

WebSocketHandler(ChannelHandlerContext ctx, String path,
FullHttpRequest upgradeRequest,
Expand Down Expand Up @@ -140,6 +141,10 @@ public WebSocketEngine getWebSocketEngine() {
this.upgradeInfo = upgrade(ctx);
}

TyrusUpgradeResponse upgradeResponse() {
return upgradeResponse;
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
LOGGER.log(Level.SEVERE, "WS handler ERROR ", cause);
Expand Down Expand Up @@ -195,9 +200,7 @@ WebSocketEngine.UpgradeInfo upgrade(ChannelHandlerContext ctx) {
upgradeRequest.headers().forEach(e -> requestContext.getHeaders().put(e.getKey(), List.of(e.getValue())));

// Use Tyrus to process a WebSocket upgrade request
final TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse();
final WebSocketEngine.UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse);

WebSocketEngine.UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse);
upgradeResponse.getHeaders().forEach(this.upgradeResponseHeaders::add);
return upgradeInfo;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,18 +15,30 @@
*/
package io.helidon.webserver.websocket;

import java.nio.charset.Charset;
import java.util.Collection;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

import io.helidon.common.http.Http;
import io.helidon.webserver.ForwardingHandler;

import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.EmptyHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpServerUpgradeHandler;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import org.glassfish.tyrus.core.TyrusUpgradeResponse;

class WebSocketUpgradeCodec implements HttpServerUpgradeHandler.UpgradeCodec {

Expand All @@ -52,14 +64,40 @@ public boolean prepareUpgradeResponse(ChannelHandlerContext ctx,
HttpHeaders upgradeResponseHeaders) {
try {
path = upgradeRequest.uri();
upgradeResponseHeaders.remove("upgrade");
upgradeResponseHeaders.remove("connection");
this.wsHandler = new WebSocketHandler(ctx, path, upgradeRequest, upgradeResponseHeaders, webSocketRouting);
return true;
upgradeResponseHeaders.remove(Http.Header.UPGRADE);
upgradeResponseHeaders.remove(Http.Header.CONNECTION);
wsHandler = new WebSocketHandler(ctx, path, upgradeRequest, upgradeResponseHeaders, webSocketRouting);

// if not 101 code, create and write to channel a custom user response of
// type text/plain using reason as payload and return false back to Netty
TyrusUpgradeResponse upgradeResponse = wsHandler.upgradeResponse();
if (upgradeResponse.getStatus() != Http.Status.SWITCHING_PROTOCOLS_101.code()) {
// prepare headers for failed response
upgradeResponse.getHeaders().remove(Http.Header.UPGRADE);
upgradeResponse.getHeaders().remove(Http.Header.CONNECTION);
upgradeResponse.getHeaders().remove("sec-websocket-accept");
HttpHeaders headers = new DefaultHttpHeaders();
upgradeResponse.getHeaders().forEach(headers::add);

// set payload as text/plain with reason phrase
headers.add(Http.Header.CONTENT_TYPE, "text/plain");
String reasonPhrase = upgradeResponse.getReasonPhrase() == null ? ""
: upgradeResponse.getReasonPhrase();
HttpResponse r = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
HttpResponseStatus.valueOf(upgradeResponse.getStatus()),
Unpooled.wrappedBuffer(reasonPhrase.getBytes(Charset.defaultCharset())),
headers, EmptyHttpHeaders.INSTANCE);

// write, flush and later close connection
ChannelFuture writeComplete = ctx.writeAndFlush(r);
writeComplete.addListener(ChannelFutureListener.CLOSE);
return false;
}
} catch (Throwable cause) {
LOGGER.log(Level.SEVERE, "Error during upgrade to WebSocket", cause);
return false;
}
return true;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@
package io.helidon.webserver.websocket.test;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Logger;

Expand All @@ -28,6 +29,7 @@
import jakarta.websocket.server.HandshakeRequest;
import jakarta.websocket.server.ServerEndpoint;
import jakarta.websocket.server.ServerEndpointConfig;
import org.glassfish.tyrus.core.TyrusUpgradeResponse;

import static io.helidon.webserver.websocket.test.UppercaseCodec.isDecoded;

Expand Down Expand Up @@ -86,6 +88,19 @@ public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request,
LOGGER.info("ServerConfigurator called during handshake");
super.modifyHandshake(sec, request, response);
EchoEndpoint.modifyHandshakeCalled.set(true);

// if not user Helidon, fail to authenticate, return reason and user header
String user = getUserFromParams(request);
if (!user.equals("Helidon") && response instanceof TyrusUpgradeResponse tyrusResponse) {
tyrusResponse.setStatus(401);
tyrusResponse.setReasonPhrase("Failed to authenticate");
tyrusResponse.getHeaders().put("Endpoint", List.of("EchoEndpoint"));
}
}

private String getUserFromParams(HandshakeRequest request) {
List<String> values = request.getParameterMap().get("user");
return values != null && !values.isEmpty() ? values.get(0) : "";
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) 2022, 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.helidon.webserver.websocket.test;

import java.net.URI;
import java.util.List;
import java.util.Map;

import io.helidon.common.http.Http;

import jakarta.websocket.DeploymentException;
import jakarta.websocket.HandshakeResponse;
import jakarta.websocket.server.HandshakeRequest;
import jakarta.websocket.server.ServerEndpointConfig;
import org.glassfish.tyrus.client.auth.AuthenticationException;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.fail;

class HandshakeFailureTest extends TyrusSupportBaseTest {

@BeforeAll
static void startServer() throws Exception {
webServer(true, EchoEndpoint.class);
}

/**
* Should fail because user is not Helidon. See server handshake at
* {@link EchoEndpoint.ServerConfigurator#modifyHandshake(ServerEndpointConfig, HandshakeRequest, HandshakeResponse)}.
*/
@Test
void testEchoSingleUpgradeFail() {
URI uri = URI.create("ws://localhost:" + webServer().port() + "/tyrus/echo?user=Unknown");
EchoClient echoClient = new EchoClient(uri);
try {
echoClient.echo("One");
} catch (Exception e) {
assertThat(e, instanceOf(DeploymentException.class));
assertThat(e.getCause(), instanceOf(AuthenticationException.class));
AuthenticationException ae = (AuthenticationException) e.getCause();
assertThat(ae.getHttpStatusCode(), is(401));
assertThat(ae.getMessage(), is("Authentication failed."));
return;
}
fail("Exception not thrown");
}

/**
* Should fail because user is not Helidon. See server handshake at
* {@link EchoEndpoint.ServerConfigurator#modifyHandshake(ServerEndpointConfig, HandshakeRequest, HandshakeResponse)}.
*/
@Test
void testEchoSingleUpgradeFailRaw() throws Exception {
String response = SocketHttpClient.sendAndReceive("/tyrus/echo?user=Unknown",
Http.Method.GET,
List.of("Connection:Upgrade",
"Upgrade:websocket",
"Sec-WebSocket-Key:0SBbaRkS/idPrmvImDNHBA==",
"Sec-WebSocket-Version:13"),
webServer());

assertThat(SocketHttpClient.statusFromResponse(response),
is(Http.Status.UNAUTHORIZED_401));
assertThat(SocketHttpClient.entityFromResponse(response, false),
is("Failed to authenticate\n"));
Map<String, String> headers = SocketHttpClient.headersFromResponse(response);
assertThat(headers.get("Endpoint"), is("EchoEndpoint"));
assertFalse(headers.containsKey("Connection") || headers.containsKey("connection"));
assertFalse(headers.containsKey("Upgrade") || headers.containsKey("upgrade"));
}
}
Loading

0 comments on commit f0fc904

Please sign in to comment.