diff --git a/common/testing/http-junit5/src/main/java/io/helidon/common/testing/http/junit5/SocketHttpClient.java b/common/testing/http-junit5/src/main/java/io/helidon/common/testing/http/junit5/SocketHttpClient.java index 743bdf146c6..557aa1641ec 100644 --- a/common/testing/http-junit5/src/main/java/io/helidon/common/testing/http/junit5/SocketHttpClient.java +++ b/common/testing/http-junit5/src/main/java/io/helidon/common/testing/http/junit5/SocketHttpClient.java @@ -453,6 +453,22 @@ public void request(String method, String path, String protocol, String host, It } } + /** + * Write raw proxy protocol header before a request. + * + * @param header header to write + */ + public void writeProxyHeader(byte[] header) { + try { + if (socket == null) { + connect(); + } + socket.getOutputStream().write(header); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + /** * Disconnect from server socket. */ diff --git a/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2Connection.java b/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2Connection.java index c1947c521d5..172f1d4d8f9 100644 --- a/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2Connection.java +++ b/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2Connection.java @@ -64,6 +64,8 @@ import io.helidon.webserver.http2.spi.Http2SubProtocolSelector; import io.helidon.webserver.spi.ServerConnection; +import static io.helidon.http.HeaderNames.X_FORWARDED_FOR; +import static io.helidon.http.HeaderNames.X_FORWARDED_PORT; import static io.helidon.http.HeaderNames.X_HELIDON_CN; import static io.helidon.http.http2.Http2Util.PREFACE_LENGTH; import static java.lang.System.Logger.Level.DEBUG; @@ -614,6 +616,19 @@ private void doHeaders(Semaphore requestSemaphore) { ctx.remotePeer().tlsCertificates() .flatMap(TlsUtils::parseCn) .ifPresent(cn -> connectionHeaders.add(X_HELIDON_CN, cn)); + + // proxy protocol related headers X-Forwarded-For and X-Forwarded-Port + ctx.proxyProtocolData().ifPresent(proxyProtocolData -> { + String sourceAddress = proxyProtocolData.sourceAddress(); + if (!sourceAddress.isEmpty()) { + connectionHeaders.add(X_FORWARDED_FOR, sourceAddress); + } + int sourcePort = proxyProtocolData.sourcePort(); + if (sourcePort != -1) { + connectionHeaders.set(X_FORWARDED_PORT, sourcePort); + } + }); + initConnectionHeaders = false; } diff --git a/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2ServerRequest.java b/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2ServerRequest.java index cc556612d88..ddb84736946 100644 --- a/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2ServerRequest.java +++ b/webserver/http2/src/main/java/io/helidon/webserver/http2/Http2ServerRequest.java @@ -18,6 +18,7 @@ import java.io.InputStream; import java.util.Objects; +import java.util.Optional; import java.util.function.Supplier; import java.util.function.UnaryOperator; @@ -38,6 +39,7 @@ import io.helidon.http.media.ReadableEntity; import io.helidon.webserver.ConnectionContext; import io.helidon.webserver.ListenerContext; +import io.helidon.webserver.ProxyProtocolData; import io.helidon.webserver.http.HttpSecurity; import io.helidon.webserver.http.RoutingRequest; @@ -220,6 +222,11 @@ public void streamFilter(UnaryOperator filterFunction) { this.streamFilter = it -> filterFunction.apply(current.apply(it)); } + @Override + public Optional proxyProtocolData() { + return ctx.proxyProtocolData(); + } + private UriInfo createUriInfo() { return ctx.listenerContext().config().requestedUriDiscoveryContext().uriInfo(remotePeer().address().toString(), localPeer().address().toString(), diff --git a/webserver/tests/webserver/src/test/java/io/helidon/webserver/tests/ProxyProtocolTest.java b/webserver/tests/webserver/src/test/java/io/helidon/webserver/tests/ProxyProtocolTest.java new file mode 100644 index 00000000000..281d65579a0 --- /dev/null +++ b/webserver/tests/webserver/src/test/java/io/helidon/webserver/tests/ProxyProtocolTest.java @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2023 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.webserver.tests; + +import java.util.HexFormat; + +import io.helidon.common.testing.http.junit5.SocketHttpClient; +import io.helidon.http.HeaderNames; +import io.helidon.http.Method; +import io.helidon.http.Status; +import io.helidon.webserver.ProxyProtocolData; +import io.helidon.webserver.WebServerConfig; +import io.helidon.webserver.http.HttpRules; +import io.helidon.webserver.testing.junit5.ServerTest; +import io.helidon.webserver.testing.junit5.SetUpRoute; +import io.helidon.webserver.testing.junit5.SetUpServer; +import org.junit.jupiter.api.Test; + +import static java.nio.charset.StandardCharsets.US_ASCII; +import static org.hamcrest.CoreMatchers.startsWith; +import static org.hamcrest.MatcherAssert.assertThat; + +@ServerTest +class ProxyProtocolTest { + + static final String V2_PREFIX = "0D:0A:0D:0A:00:0D:0A:51:55:49:54:0A"; + + private final static HexFormat hexFormat = HexFormat.of().withUpperCase().withDelimiter(":"); + + private final SocketHttpClient socketHttpClient; + + ProxyProtocolTest(SocketHttpClient socketHttpClient) { + this.socketHttpClient = socketHttpClient; + } + + @SetUpServer + static void setupServer(WebServerConfig.Builder builder) { + builder.enableProxyProtocol(true); + } + + @SetUpRoute + static void routing(HttpRules routing) { + routing.get("/", (req, res) -> { + ProxyProtocolData data = req.proxyProtocolData().orElse(null); + if (data != null + && data.family() == ProxyProtocolData.Family.IPv4 + && data.protocol() == ProxyProtocolData.Protocol.TCP + && data.sourceAddress().equals("192.168.0.1") + && data.destAddress().equals("192.168.0.11") + && data.sourcePort() == 56324 + && data.destPort() == 443 + && "192.168.0.1".equals(req.headers().first(HeaderNames.X_FORWARDED_FOR).orElse(null)) + && "56324".equals(req.headers().first(HeaderNames.X_FORWARDED_PORT).orElse(null))) { + res.status(Status.OK_200).send(); + return; + } + res.status(Status.INTERNAL_SERVER_ERROR_500).send(); + }); + } + + /** + * V1 encoding in this test was manually verified with Wireshark. + */ + @Test + void testProxyProtocolV1IPv4() { + socketHttpClient.writeProxyHeader("PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n".getBytes(US_ASCII)); + String s = socketHttpClient.sendAndReceive(Method.GET, ""); + assertThat(s, startsWith("HTTP/1.1 200 OK")); + } + + /** + * V2 encoding in this test was manually verified with Wireshark. + */ + @Test + void testProxyProtocolV2IPv4() { + String header = V2_PREFIX + + ":20:11:00:0C" // version, family/protocol, length + + ":C0:A8:00:01" // 192.168.0.1 + + ":C0:A8:00:0B" // 192.168.0.11 + + ":DC:04" // 56324 + + ":01:BB"; // 443 + socketHttpClient.writeProxyHeader(hexFormat.parseHex(header)); + String s = socketHttpClient.sendAndReceive(Method.GET, ""); + assertThat(s, startsWith("HTTP/1.1 200 OK")); + } +} diff --git a/webserver/webserver/src/main/java/io/helidon/webserver/ConnectionContext.java b/webserver/webserver/src/main/java/io/helidon/webserver/ConnectionContext.java index bbca8281788..1c14c051e91 100644 --- a/webserver/webserver/src/main/java/io/helidon/webserver/ConnectionContext.java +++ b/webserver/webserver/src/main/java/io/helidon/webserver/ConnectionContext.java @@ -16,6 +16,7 @@ package io.helidon.webserver; +import java.util.Optional; import java.util.concurrent.ExecutorService; import io.helidon.common.buffers.DataReader; @@ -60,4 +61,14 @@ public interface ConnectionContext extends SocketContext { * @return rouer */ Router router(); + + /** + * Proxy protocol header data. + * + * @return protocol header data if proxy protocol is enabled on socket + * @see ListenerConfig#enableProxyProtocol() + */ + default Optional proxyProtocolData() { + return Optional.empty(); + } } diff --git a/webserver/webserver/src/main/java/io/helidon/webserver/ConnectionHandler.java b/webserver/webserver/src/main/java/io/helidon/webserver/ConnectionHandler.java index c4bb7ab11b0..e3c47720f99 100644 --- a/webserver/webserver/src/main/java/io/helidon/webserver/ConnectionHandler.java +++ b/webserver/webserver/src/main/java/io/helidon/webserver/ConnectionHandler.java @@ -22,6 +22,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.Semaphore; @@ -64,11 +65,13 @@ class ConnectionHandler implements InterruptableTask, ConnectionContext { private final String serverChannelId; private final Router router; private final Tls tls; + private final ListenerConfig listenerConfig; private ServerConnection connection; private HelidonSocket helidonSocket; private DataReader reader; private SocketWriter writer; + private ProxyProtocolData proxyProtocolData; ConnectionHandler(ListenerContext listenerContext, Semaphore connectionSemaphore, @@ -89,6 +92,7 @@ class ConnectionHandler implements InterruptableTask, ConnectionContext { this.serverChannelId = serverChannelId; this.router = router; this.tls = tls; + this.listenerConfig = listenerContext.config(); } @Override @@ -100,6 +104,12 @@ public boolean canInterrupt() { public final void run() { String channelId = "0x" + HexFormat.of().toHexDigits(System.identityHashCode(socket)); + // proxy protocol before SSL handshake + if (listenerConfig.enableProxyProtocol()) { + ProxyProtocolHandler handler = new ProxyProtocolHandler(socket, channelId); + proxyProtocolData = handler.get(); + } + // handle SSL and init helidonSocket, reader and writer try { if (tls.enabled()) { @@ -226,6 +236,11 @@ public Router router() { return router; } + @Override + public Optional proxyProtocolData() { + return Optional.ofNullable(proxyProtocolData); + } + private ServerConnection identifyConnection() { try { reader.ensureAvailable(); diff --git a/webserver/webserver/src/main/java/io/helidon/webserver/ListenerConfigBlueprint.java b/webserver/webserver/src/main/java/io/helidon/webserver/ListenerConfigBlueprint.java index 40056e5e0d2..df40279f690 100644 --- a/webserver/webserver/src/main/java/io/helidon/webserver/ListenerConfigBlueprint.java +++ b/webserver/webserver/src/main/java/io/helidon/webserver/ListenerConfigBlueprint.java @@ -327,6 +327,19 @@ interface ListenerConfigBlueprint { */ Optional listenerContext(); + /** + * Enable proxy protocol support for this socket. This protocol is supported by + * some load balancers/reverse proxies as a means to convey client information that + * would otherwise be lost. If enabled, the proxy protocol header must be present + * on every new connection established with your server. For more information, + * see + * the specification. Default is {@code false}. + * + * @return proxy support status + */ + @Option.Default("false") + boolean enableProxyProtocol(); + /** * Requested URI discovery context. * diff --git a/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolData.java b/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolData.java new file mode 100644 index 00000000000..dedcc71812e --- /dev/null +++ b/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolData.java @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2023 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.webserver; + +/** + * Proxy protocol data parsed by {@link ProxyProtocolHandler}. + */ +public interface ProxyProtocolData { + + /** + * Protocol family. + */ + enum Family { + /** + * Unknown family. + */ + UNKNOWN, + + /** + * IP version 4. + */ + IPv4, + + /** + * IP version 6. + */ + IPv6, + + /** + * Unix. + */ + UNIX; + + static Family fromString(String s) { + return switch (s) { + case "TCP4" -> IPv4; + case "TCP6" -> IPv6; + case "UNIX" -> UNIX; + case "UNKNOWN" -> UNKNOWN; + default -> throw new IllegalArgumentException("Unknown family " + s); + }; + } + } + + /** + * Protocol type. + */ + enum Protocol { + /** + * Unknown protocol. + */ + UNKNOWN, + + /** + * TCP streams protocol. + */ + TCP, + + /** + * UDP datagram protocol. + */ + UDP; + + static Protocol fromString(String s) { + return switch (s) { + case "TCP4", "TCP6" -> TCP; + case "UDP" -> UDP; + case "UNKNOWN" -> UNKNOWN; + default -> throw new IllegalArgumentException("Unknown protocol " + s); + }; + } + } + + /** + * Family from protocol header. + * + * @return family + */ + Family family(); + + /** + * Protocol from protocol header. + * + * @return protocol + */ + Protocol protocol(); + + /** + * Source address that is either IP4 or IP6 depending on {@link #family()}. + * + * @return source address or {@code ""} if not provided + */ + String sourceAddress(); + + /** + * Destination address that is either IP4 or IP46 depending on {@link #family()}. + * + * @return source address or (@code ""} if not provided + */ + String destAddress(); + + /** + * Source port number. + * + * @return source port. + */ + int sourcePort(); + + /** + * Destination port number. + * + * @return port number. + */ + int destPort(); +} + + diff --git a/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolHandler.java b/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolHandler.java new file mode 100644 index 00000000000..663d2d916ad --- /dev/null +++ b/webserver/webserver/src/main/java/io/helidon/webserver/ProxyProtocolHandler.java @@ -0,0 +1,319 @@ +/* + * Copyright (c) 2023 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.webserver; + +import java.io.IOException; +import java.io.InputStream; +import java.io.PushbackInputStream; +import java.io.UncheckedIOException; +import java.lang.System.Logger.Level; +import java.net.Inet6Address; +import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.function.Supplier; + +import io.helidon.http.DirectHandler; +import io.helidon.http.RequestException; +import io.helidon.webserver.ProxyProtocolData.Family; +import io.helidon.webserver.ProxyProtocolData.Protocol; + +class ProxyProtocolHandler implements Supplier { + private static final System.Logger LOGGER = System.getLogger(ProxyProtocolHandler.class.getName()); + + private static final int MAX_V1_FIELD_LENGTH = 40; + private static final int MAX_TLV_BYTES_TO_SKIP = 128 * 4; // 128 entries + + static final byte[] V1_PREFIX = { + (byte) 'P', + (byte) 'R', + (byte) 'O', + (byte) 'X', + (byte) 'Y', + }; + + static final byte[] V2_PREFIX_1 = { + (byte) 0x0D, + (byte) 0x0A, + (byte) 0x0D, + (byte) 0x0A, + (byte) 0x00, + }; + + static final byte[] V2_PREFIX_2 = { + (byte) 0x0D, + (byte) 0x0A, + (byte) 0x51, + (byte) 0x55, + (byte) 0x49, + (byte) 0x54, + (byte) 0x0A + }; + + static final RequestException BAD_PROTOCOL_EXCEPTION = RequestException.builder() + .type(DirectHandler.EventType.OTHER) + .message("Unable to parse proxy protocol header") + .build(); + + private final Socket socket; + private final String channelId; + + ProxyProtocolHandler(Socket socket, String channelId) { + this.socket = socket; + this.channelId = channelId; + } + + @Override + public ProxyProtocolData get() { + LOGGER.log(Level.DEBUG, "Reading proxy protocol data for channel %s", channelId); + + try { + byte[] prefix = new byte[V1_PREFIX.length]; + PushbackInputStream inputStream = new PushbackInputStream(socket.getInputStream(), 1); + int n = inputStream.read(prefix); + if (n < V1_PREFIX.length) { + throw BAD_PROTOCOL_EXCEPTION; + } + if (arrayEquals(prefix, V1_PREFIX, V1_PREFIX.length)) { + return handleV1Protocol(inputStream); + } else if (arrayEquals(prefix, V2_PREFIX_1, V2_PREFIX_1.length)) { + return handleV2Protocol(inputStream); + } else { + throw BAD_PROTOCOL_EXCEPTION; + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + static ProxyProtocolData handleV1Protocol(PushbackInputStream inputStream) throws IOException { + try { + int n; + byte[] buffer = new byte[MAX_V1_FIELD_LENGTH]; + + match(inputStream, (byte) ' '); + + // protocol and family + n = readUntil(inputStream, buffer, (byte) ' ', (byte) '\r'); + String familyProtocol = new String(buffer, 0, n, StandardCharsets.US_ASCII); + var family = Family.fromString(familyProtocol); + var protocol = Protocol.fromString(familyProtocol); + byte b = readNext(inputStream); + if (b == (byte) '\r') { + // special case for just UNKNOWN family + if (family == ProxyProtocolData.Family.UNKNOWN) { + return new ProxyProtocolDataImpl(Family.UNKNOWN, Protocol.UNKNOWN, + "", "", -1, -1); + } + } + + match(b, (byte) ' '); + + // source address + n = readUntil(inputStream, buffer, (byte) ' '); + var sourceAddress = new String(buffer, 0, n, StandardCharsets.US_ASCII); + match(inputStream, (byte) ' '); + + // destination address + n = readUntil(inputStream, buffer, (byte) ' '); + var destAddress = new String(buffer, 0, n, StandardCharsets.US_ASCII); + match(inputStream, (byte) ' '); + + // source port + n = readUntil(inputStream, buffer, (byte) ' '); + int sourcePort = Integer.parseInt(new String(buffer, 0, n, StandardCharsets.US_ASCII)); + match(inputStream, (byte) ' '); + + // destination port + n = readUntil(inputStream, buffer, (byte) '\r'); + int destPort = Integer.parseInt(new String(buffer, 0, n, StandardCharsets.US_ASCII)); + match(inputStream, (byte) '\r'); + match(inputStream, (byte) '\n'); + + return new ProxyProtocolDataImpl(family, protocol, sourceAddress, destAddress, sourcePort, destPort); + } catch (IllegalArgumentException e) { + throw BAD_PROTOCOL_EXCEPTION; + } + } + + static ProxyProtocolData handleV2Protocol(PushbackInputStream inputStream) throws IOException { + // match rest of prefix + match(inputStream, V2_PREFIX_2); + + // only accept version 2, ignore LOCAL/PROXY + int b = readNext(inputStream); + if (b >>> 4 != 0x02) { + throw BAD_PROTOCOL_EXCEPTION; + } + + // protocol and family + b = readNext(inputStream); + var family = switch (b >>> 4) { + case 0x1 -> Family.IPv4; + case 0x2 -> Family.IPv6; + case 0x3 -> Family.UNIX; + default -> Family.UNKNOWN; + }; + var protocol = switch (b & 0x0F) { + case 0x1 -> Protocol.TCP; + case 0x2 -> Protocol.UDP; + default -> Protocol.UNKNOWN; + }; + + // length + b = readNext(inputStream); + int headerLength = ((b << 8) & 0xFF00) | (readNext(inputStream) & 0xFF); + + // decode addresses and ports + String sourceAddress = ""; + String destAddress = ""; + int sourcePort = -1; + int destPort = -1; + switch (family) { + case IPv4 -> { + byte[] buffer = new byte[12]; + int n = inputStream.read(buffer, 0, buffer.length); + if (n < buffer.length) { + throw BAD_PROTOCOL_EXCEPTION; + } + sourceAddress = (buffer[0] & 0xFF) + + "." + (buffer[1] & 0xFF) + + "." + (buffer[2] & 0xFF) + + "." + (buffer[3] & 0xFF); + destAddress = (buffer[4] & 0xFF) + + "." + (buffer[5] & 0xFF) + + "." + (buffer[6] & 0xFF) + + "." + (buffer[7] & 0xFF); + sourcePort = buffer[9] & 0xFF + | ((buffer[8] << 8) & 0xFF00); + destPort = buffer[11] & 0xFF + | ((buffer[10] << 8) & 0xFF00); + headerLength -= buffer.length; + } + case IPv6 -> { + byte[] buffer = new byte[16]; + int n = inputStream.read(buffer, 0, buffer.length); + if (n < buffer.length) { + throw BAD_PROTOCOL_EXCEPTION; + } + sourceAddress = Inet6Address.getByAddress(buffer).getHostAddress(); + n = inputStream.read(buffer, 0, buffer.length); + if (n < buffer.length) { + throw BAD_PROTOCOL_EXCEPTION; + } + destAddress = Inet6Address.getByAddress(buffer).getHostAddress(); + n = inputStream.read(buffer, 0, 4); + if (n < 4) { + throw BAD_PROTOCOL_EXCEPTION; + } + sourcePort = buffer[1] & 0xFF + | ((buffer[0] << 8) & 0xFF00); + destPort = buffer[3] & 0xFF + | ((buffer[2] << 8) & 0xFF00); + headerLength -= 2 * buffer.length + 4; + } + case UNIX -> { + byte[] buffer = new byte[216]; + int n = inputStream.read(buffer, 0, buffer.length); + if (n < buffer.length) { + throw BAD_PROTOCOL_EXCEPTION; + } + sourceAddress = new String(buffer, 0, 108, StandardCharsets.US_ASCII); + destAddress = new String(buffer, 108, buffer.length, StandardCharsets.US_ASCII); + headerLength -= buffer.length; + } + default -> { + // falls through + } + } + + // skip any TLV vectors up to our max for security reasons + if (headerLength > MAX_TLV_BYTES_TO_SKIP) { + throw BAD_PROTOCOL_EXCEPTION; + } + while (headerLength > 0) { + headerLength -= (int) inputStream.skip(headerLength); + } + + return new ProxyProtocolDataImpl(family, protocol, sourceAddress, destAddress, + sourcePort, destPort); + } + + private static byte readNext(InputStream inputStream) throws IOException { + int b = inputStream.read(); + if (b < 0) { + throw BAD_PROTOCOL_EXCEPTION; + } + return (byte) b; + } + + private static void match(byte a, byte b) { + if (a != b) { + throw BAD_PROTOCOL_EXCEPTION; + } + } + + private static void match(PushbackInputStream inputStream, byte b) throws IOException { + if (inputStream.read() != b) { + throw BAD_PROTOCOL_EXCEPTION; + } + } + + private static void match(PushbackInputStream inputStream, byte... bs) throws IOException { + for (byte b : bs) { + int c = inputStream.read(); + if (((byte) c) != b) { + throw BAD_PROTOCOL_EXCEPTION; + } + } + } + + private static int readUntil(PushbackInputStream inputStream, byte[] buffer, byte... delims) throws IOException { + int n = 0; + do { + byte b = readNext(inputStream); + if (arrayContains(delims, b)) { + inputStream.unread(b); + return n; + } + buffer[n++] = b; + if (n >= buffer.length) { + throw BAD_PROTOCOL_EXCEPTION; + } + } while (true); + } + + private static boolean arrayEquals(byte[] array1, byte[] array2, int prefix) { + return Arrays.equals(array1, 0, prefix, array2, 0, prefix); + } + + private static boolean arrayContains(byte[] array, byte b) { + for (byte a : array) { + if (a == b) { + return true; + } + } + return false; + } + + record ProxyProtocolDataImpl(Family family, + Protocol protocol, + String sourceAddress, + String destAddress, + int sourcePort, + int destPort) implements ProxyProtocolData { + } +} diff --git a/webserver/webserver/src/main/java/io/helidon/webserver/http/ServerRequest.java b/webserver/webserver/src/main/java/io/helidon/webserver/http/ServerRequest.java index e9f6138e567..deb358e15ca 100644 --- a/webserver/webserver/src/main/java/io/helidon/webserver/http/ServerRequest.java +++ b/webserver/webserver/src/main/java/io/helidon/webserver/http/ServerRequest.java @@ -17,12 +17,14 @@ package io.helidon.webserver.http; import java.io.InputStream; +import java.util.Optional; import java.util.function.UnaryOperator; import io.helidon.common.context.Context; import io.helidon.http.RoutedPath; import io.helidon.http.media.ReadableEntity; import io.helidon.webserver.ListenerContext; +import io.helidon.webserver.ProxyProtocolData; /** * HTTP server request. @@ -110,4 +112,12 @@ public interface ServerRequest extends HttpRequest { * @param filterFunction the function to replace input stream of this request with a user provided one */ void streamFilter(UnaryOperator filterFunction); + + /** + * Access proxy protocol data for the connection on which this request was sent. + * + * @return proxy protocol data, if available + * @see io.helidon.webserver.ListenerConfig#enableProxyProtocol() + */ + Optional proxyProtocolData(); } diff --git a/webserver/webserver/src/main/java/io/helidon/webserver/http1/Http1Connection.java b/webserver/webserver/src/main/java/io/helidon/webserver/http1/Http1Connection.java index 4496805b7cb..cb8e5dc572e 100644 --- a/webserver/webserver/src/main/java/io/helidon/webserver/http1/Http1Connection.java +++ b/webserver/webserver/src/main/java/io/helidon/webserver/http1/Http1Connection.java @@ -47,11 +47,14 @@ import io.helidon.http.encoding.ContentEncodingContext; import io.helidon.webserver.CloseConnectionException; import io.helidon.webserver.ConnectionContext; +import io.helidon.webserver.ProxyProtocolData; import io.helidon.webserver.http.DirectTransportRequest; import io.helidon.webserver.http.HttpRouting; import io.helidon.webserver.http1.spi.Http1Upgrader; import io.helidon.webserver.spi.ServerConnection; +import static io.helidon.http.HeaderNames.X_FORWARDED_FOR; +import static io.helidon.http.HeaderNames.X_FORWARDED_PORT; import static io.helidon.http.HeaderNames.X_HELIDON_CN; import static java.lang.System.Logger.Level.TRACE; import static java.lang.System.Logger.Level.WARNING; @@ -128,6 +131,9 @@ public boolean canInterrupt() { public void handle(Semaphore requestSemaphore) throws InterruptedException { this.myThread = Thread.currentThread(); try { + // look for protocol data + ProxyProtocolData proxyProtocolData = ctx.proxyProtocolData().orElse(null); + // handle connection until an exception (or explicit connection close) while (canRun) { // prologue (first line of request) @@ -145,6 +151,18 @@ public void handle(Semaphore requestSemaphore) throws InterruptedException { .ifPresent(name -> headers.set(X_HELIDON_CN, name)); recvListener.headers(ctx, headers); + // proxy protocol related headers X-Forwarded-For and X-Forwarded-Port + if (proxyProtocolData != null) { + String sourceAddress = proxyProtocolData.sourceAddress(); + if (!sourceAddress.isEmpty()) { + headers.add(X_FORWARDED_FOR, sourceAddress); + } + int sourcePort = proxyProtocolData.sourcePort(); + if (sourcePort != -1) { + headers.add(X_FORWARDED_PORT, sourcePort); + } + } + if (canUpgrade) { if (headers.contains(HeaderNames.UPGRADE)) { Http1Upgrader upgrader = upgradeProviderMap.get(headers.get(HeaderNames.UPGRADE).get()); diff --git a/webserver/webserver/src/main/java/io/helidon/webserver/http1/Http1ServerRequest.java b/webserver/webserver/src/main/java/io/helidon/webserver/http1/Http1ServerRequest.java index 0058084cdc2..f0b03e811ee 100644 --- a/webserver/webserver/src/main/java/io/helidon/webserver/http1/Http1ServerRequest.java +++ b/webserver/webserver/src/main/java/io/helidon/webserver/http1/Http1ServerRequest.java @@ -16,6 +16,7 @@ package io.helidon.webserver.http1; +import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.function.Supplier; @@ -36,6 +37,7 @@ import io.helidon.http.encoding.ContentDecoder; import io.helidon.webserver.ConnectionContext; import io.helidon.webserver.ListenerContext; +import io.helidon.webserver.ProxyProtocolData; import io.helidon.webserver.http.HttpSecurity; import io.helidon.webserver.http.RoutingRequest; @@ -207,6 +209,11 @@ public UriInfo requestedUri() { return uriInfo.get(); } + @Override + public Optional proxyProtocolData() { + return ctx.proxyProtocolData(); + } + private UriInfo createUriInfo() { return ctx.listenerContext().config().requestedUriDiscoveryContext().uriInfo(remotePeer().address().toString(), localPeer().address().toString(), diff --git a/webserver/webserver/src/test/java/io/helidon/webserver/ProxyProtocolHandlerTest.java b/webserver/webserver/src/test/java/io/helidon/webserver/ProxyProtocolHandlerTest.java new file mode 100644 index 00000000000..a1dca086868 --- /dev/null +++ b/webserver/webserver/src/test/java/io/helidon/webserver/ProxyProtocolHandlerTest.java @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2023 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.webserver; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.PushbackInputStream; +import java.nio.charset.StandardCharsets; +import java.util.HexFormat; + +import io.helidon.http.RequestException; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class ProxyProtocolHandlerTest { + + static final String V2_PREFIX_2 = "0D:0A:51:55:49:54:0A:"; + + private static final HexFormat hexFormat = HexFormat.of().withUpperCase().withDelimiter(":"); + + @Test + void basicV1Test() throws IOException { + String header = " TCP4 192.168.0.1 192.168.0.11 56324 443\r\n"; // excludes PROXY prefix + ProxyProtocolData data = ProxyProtocolHandler.handleV1Protocol(new PushbackInputStream( + new ByteArrayInputStream(header.getBytes(StandardCharsets.US_ASCII)))); + assertThat(data.family(), is(ProxyProtocolData.Family.IPv4)); + assertThat(data.protocol(), is(ProxyProtocolData.Protocol.TCP)); + assertThat(data.sourceAddress(), is("192.168.0.1")); + assertThat(data.destAddress(), is("192.168.0.11")); + assertThat(data.sourcePort(), is(56324)); + assertThat(data.destPort(), is(443)); + } + + @Test + void unknownV1Test() throws IOException { + String header = " UNKNOWN\r\n"; // excludes PROXY prefix + ProxyProtocolData data = ProxyProtocolHandler.handleV1Protocol(new PushbackInputStream( + new ByteArrayInputStream(header.getBytes(StandardCharsets.US_ASCII)))); + assertThat(data.family(), is(ProxyProtocolData.Family.UNKNOWN)); + assertThat(data.protocol(), is(ProxyProtocolData.Protocol.UNKNOWN)); + assertThat(data.sourceAddress(), is("")); + assertThat(data.destAddress(), is("")); + assertThat(data.sourcePort(), is(-1)); + assertThat(data.destPort(), is(-1)); + } + + @Test + void badV1Test() { + String header1 = " MYPROTOCOL 192.168.0.1 192.168.0.11 56324 443\r\n"; + assertThrows(RequestException.class, () -> + ProxyProtocolHandler.handleV1Protocol(new PushbackInputStream( + new ByteArrayInputStream(header1.getBytes(StandardCharsets.US_ASCII))))); + String header2 = " TCP4 192.168.0.1 192.168.0.11 56324\r\n"; + assertThrows(RequestException.class, () -> + ProxyProtocolHandler.handleV1Protocol(new PushbackInputStream( + new ByteArrayInputStream(header2.getBytes(StandardCharsets.US_ASCII))))); + String header3 = " TCP4 192.168.0.1 192.168.0.11 56324 443"; + assertThrows(RequestException.class, () -> + ProxyProtocolHandler.handleV1Protocol(new PushbackInputStream( + new ByteArrayInputStream(header3.getBytes(StandardCharsets.US_ASCII))))); + String header4 = " TCP4 192.168.0.1 56324 443\r\n"; + assertThrows(RequestException.class, () -> + ProxyProtocolHandler.handleV1Protocol(new PushbackInputStream( + new ByteArrayInputStream(header4.getBytes(StandardCharsets.US_ASCII))))); + } + + @Test + void basicV2TestIPv4() throws IOException { + String header = V2_PREFIX_2 + + "20:11:00:0C:" // version, family/protocol, length + + "C0:A8:00:01:" // 192.168.0.1 + + "C0:A8:00:0B:" // 192.168.0.11 + + "DC:04:" // 56324 + + "01:BB"; // 443 + ProxyProtocolData data = ProxyProtocolHandler.handleV2Protocol(new PushbackInputStream( + new ByteArrayInputStream(hexFormat.parseHex(header)))); + assertThat(data.family(), is(ProxyProtocolData.Family.IPv4)); + assertThat(data.protocol(), is(ProxyProtocolData.Protocol.TCP)); + assertThat(data.sourceAddress(), is("192.168.0.1")); + assertThat(data.destAddress(), is("192.168.0.11")); + assertThat(data.sourcePort(), is(56324)); + assertThat(data.destPort(), is(443)); + } + + @Test + void basicV2TestIPv6() throws IOException { + String header = V2_PREFIX_2 + + "20:21:00:0C:" // version, family/protocol, length + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" // source + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" // dest + + "DC:04:" // 56324 + + "01:BB"; // 443 + ProxyProtocolData data = ProxyProtocolHandler.handleV2Protocol(new PushbackInputStream( + new ByteArrayInputStream(hexFormat.parseHex(header)))); + assertThat(data.family(), is(ProxyProtocolData.Family.IPv6)); + assertThat(data.protocol(), is(ProxyProtocolData.Protocol.TCP)); + assertThat(data.sourceAddress(), is("aaaa:bbbb:cccc:dddd:aaaa:bbbb:cccc:dddd")); + assertThat(data.destAddress(), is("aaaa:bbbb:cccc:dddd:aaaa:bbbb:cccc:dddd")); + assertThat(data.sourcePort(), is(56324)); + assertThat(data.destPort(), is(443)); + } + + @Test + void unknownV2Test() throws IOException { + String header = V2_PREFIX_2 + + "20:00:00:40:" // version, family/protocol, length=64 + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD"; + ProxyProtocolData data = ProxyProtocolHandler.handleV2Protocol(new PushbackInputStream( + new ByteArrayInputStream(hexFormat.parseHex(header)))); + assertThat(data.family(), is(ProxyProtocolData.Family.UNKNOWN)); + assertThat(data.protocol(), is(ProxyProtocolData.Protocol.UNKNOWN)); + assertThat(data.sourceAddress(), is("")); + assertThat(data.destAddress(), is("")); + assertThat(data.sourcePort(), is(-1)); + assertThat(data.destPort(), is(-1)); + } + + @Test + void badV2Test() { + String header1 = V2_PREFIX_2 + + "20:21:00:0C:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:" // bad source + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "DC:04:" + + "01:BB"; + assertThrows(RequestException.class, () -> + ProxyProtocolHandler.handleV2Protocol(new PushbackInputStream( + new ByteArrayInputStream(hexFormat.parseHex(header1))))); + + String header2 = V2_PREFIX_2 + + "20:21:0F:FF:" // bad length, over our limit + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "DC:04:" + + "01:BB"; + assertThrows(RequestException.class, () -> + ProxyProtocolHandler.handleV2Protocol(new PushbackInputStream( + new ByteArrayInputStream(hexFormat.parseHex(header2))))); + + String header3 = V2_PREFIX_2 + + "20:21:00:0C:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "AA:AA:BB:BB:CC:CC:DD:DD:" + + "DC:04"; // missing dest port + assertThrows(RequestException.class, () -> + ProxyProtocolHandler.handleV2Protocol(new PushbackInputStream( + new ByteArrayInputStream(hexFormat.parseHex(header3))))); + } +}