diff --git a/pom.xml b/pom.xml
index 6e697487b..1f009e813 100644
--- a/pom.xml
+++ b/pom.xml
@@ -79,6 +79,7 @@
0.3.0.RELEASE
3.0.2
24.1.0
+ 1.77
@@ -117,6 +118,12 @@
${java-annotations.version}
provided
+
+ org.bouncycastle
+ bcpkix-jdk18on
+ ${bouncy-castle.version}
+ test
+
@@ -240,6 +247,11 @@
jackson-annotations
test
+
+ org.bouncycastle
+ bcpkix-jdk18on
+ test
+
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java b/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java
index 3039ea19b..f8f594fdc 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java
@@ -57,6 +57,7 @@ public final class ConnectionContext implements CodecContext {
*/
private volatile short serverStatuses = ServerStatuses.AUTO_COMMIT;
+ @Nullable
private volatile Capability capability = null;
ConnectionContext(ZeroDateOption zeroDateOption, @Nullable Path localInfilePath,
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java b/src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java
index 823a59187..952be3917 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java
@@ -220,8 +220,10 @@ static MySqlSslContextSpec forClient(MySqlSslConfiguration ssl, ConnectionContex
.applicationProtocolConfig(null);
String[] tlsProtocols = ssl.getTlsVersion();
- if (tlsProtocols.length > 0) {
- builder.protocols(tlsProtocols);
+ if (tlsProtocols.length > 0 || ssl.getSslMode() == SslMode.TUNNEL) {
+ if (tlsProtocols.length > 0) {
+ builder.protocols(tlsProtocols);
+ }
} else if (isTls13Enabled(context)) {
builder.protocols(TLS_PROTOCOLS);
} else {
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/SslTunnelIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/SslTunnelIntegrationTest.java
new file mode 100644
index 000000000..0ad39029e
--- /dev/null
+++ b/src/test/java/io/asyncer/r2dbc/mysql/SslTunnelIntegrationTest.java
@@ -0,0 +1,297 @@
+/*
+ * Copyright 2024 asyncer.io projects
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.asyncer.r2dbc.mysql;
+
+
+import io.asyncer.r2dbc.mysql.constant.SslMode;
+import io.netty.bootstrap.Bootstrap;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.SslContextBuilder;
+import io.netty.handler.ssl.util.SelfSignedCertificate;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import javax.net.ssl.SSLException;
+import java.net.InetSocketAddress;
+import java.security.cert.CertificateException;
+import java.time.Duration;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class SslTunnelIntegrationTest {
+
+ private SelfSignedCertificate server;
+
+ private SelfSignedCertificate client;
+
+ private SslTunnelServer sslTunnelServer;
+
+ @Before
+ public void setUp() throws CertificateException, SSLException {
+ server = new SelfSignedCertificate();
+ client = new SelfSignedCertificate();
+ final SslContext sslContext = SslContextBuilder.forServer(server.key(), server.cert()).build();
+ sslTunnelServer = new SslTunnelServer("localhost", 3306, sslContext);
+ sslTunnelServer.setUp();
+ }
+
+ @After
+ public void tearDown() throws InterruptedException {
+ server.delete();
+ client.delete();
+ sslTunnelServer.tearDown();
+ }
+
+ @Test
+ public void sslTunnelConnectionTest() {
+ final String password = System.getProperty("test.mysql.password");
+ assertThat(password).withFailMessage("Property test.mysql.password must exists and not be empty")
+ .isNotNull()
+ .isNotEmpty();
+
+ final MySqlConnectionConfiguration configuration = MySqlConnectionConfiguration
+ .builder()
+ .host("localhost")
+ .port(sslTunnelServer.getLocalPort())
+ .connectTimeout(Duration.ofSeconds(3))
+ .user("root")
+ .password(password)
+ .database("r2dbc")
+ .createDatabaseIfNotExist(true)
+ .sslMode(SslMode.TUNNEL)
+ .sslKey(client.privateKey().getAbsolutePath())
+ .sslCert(client.certificate().getAbsolutePath())
+ .sslCa(server.certificate().getAbsolutePath())
+ .build();
+
+ final MySqlConnectionFactory connectionFactory = MySqlConnectionFactory.from(configuration);
+
+ final MySqlConnection connection = connectionFactory.create().block();
+ connection.createStatement("SELECT 3").execute()
+ .flatMap(it -> it.map((row, rowMetadata) -> row.get(0)))
+ .doOnNext(it -> assertThat(it).isEqualTo(3L))
+ .blockLast();
+
+ connection.close().block();
+ }
+
+ private static class SslTunnelServer {
+
+ private final String remoteHost;
+
+ private final int remotePort;
+
+ private final SslContext sslContext;
+
+ private volatile ChannelFuture channelFuture;
+
+
+ private SslTunnelServer(String remoteHost, int remotePort, SslContext sslContext) {
+ this.remoteHost = remoteHost;
+ this.remotePort = remotePort;
+ this.sslContext = sslContext;
+ }
+
+ void setUp() {
+ // Configure the server.
+ try {
+ ServerBootstrap b = new ServerBootstrap();
+ b.localAddress(0)
+ .group(new NioEventLoopGroup())
+ .channel(NioServerSocketChannel.class)
+ .childHandler(new ProxyInitializer(remoteHost, remotePort, sslContext))
+ .childOption(ChannelOption.AUTO_READ, false);
+
+
+ // Start the server.
+ channelFuture = b.bind().sync();
+
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+
+ void tearDown() throws InterruptedException {
+ channelFuture.channel().close().sync();
+ }
+
+ int getLocalPort() {
+ return ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
+ }
+
+ }
+
+
+ private static class ProxyInitializer extends ChannelInitializer {
+
+ private final String remoteHost;
+
+ private final int remotePort;
+
+ private final SslContext sslContext;
+
+ ProxyInitializer(String remoteHost, int remotePort, SslContext sslContext) {
+ this.remoteHost = remoteHost;
+ this.remotePort = remotePort;
+ this.sslContext = sslContext;
+ }
+
+ @Override
+ public void initChannel(SocketChannel ch) {
+ ch.pipeline().addLast(sslContext.newHandler(ch.alloc()));
+ ch.pipeline().addLast(new ProxyFrontendHandler(remoteHost, remotePort));
+ }
+ }
+
+ private static class ProxyFrontendHandler extends ChannelInboundHandlerAdapter {
+
+ private final String remoteHost;
+ private final int remotePort;
+
+ // As we use inboundChannel.eventLoop() when building the Bootstrap this does not need to be volatile as
+ // the outboundChannel will use the same EventLoop (and therefore Thread) as the inboundChannel.
+ private Channel outboundChannel;
+
+ public ProxyFrontendHandler(String remoteHost, int remotePort) {
+ this.remoteHost = remoteHost;
+ this.remotePort = remotePort;
+ }
+
+ @Override
+ public void channelActive(ChannelHandlerContext ctx) {
+ final Channel inboundChannel = ctx.channel();
+
+ // Start the connection attempt.
+ Bootstrap b = new Bootstrap();
+ b.group(inboundChannel.eventLoop())
+ .channel(ctx.channel().getClass())
+ .handler(new ProxyBackendHandler(inboundChannel))
+ .option(ChannelOption.AUTO_READ, false);
+ ChannelFuture f = b.connect(remoteHost, remotePort);
+ outboundChannel = f.channel();
+ f.addListener(new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) {
+ if (future.isSuccess()) {
+ // connection complete start to read first data
+ inboundChannel.read();
+ } else {
+ // Close the connection if the connection attempt has failed.
+ inboundChannel.close();
+ }
+ }
+ });
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext ctx, Object msg) {
+ if (outboundChannel.isActive()) {
+ outboundChannel.writeAndFlush(msg).addListener(new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) {
+ if (future.isSuccess()) {
+ // was able to flush out data, start to read the next chunk
+ ctx.channel().read();
+ } else {
+ future.channel().close();
+ }
+ }
+ });
+ }
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) {
+ if (outboundChannel != null) {
+ closeOnFlush(outboundChannel);
+ }
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ cause.printStackTrace();
+ closeOnFlush(ctx.channel());
+ }
+
+ /**
+ * Closes the specified channel after all queued write requests are flushed.
+ */
+ static void closeOnFlush(Channel ch) {
+ if (ch.isActive()) {
+ ch.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
+ }
+ }
+ }
+
+ private static class ProxyBackendHandler extends ChannelInboundHandlerAdapter {
+
+ private final Channel inboundChannel;
+
+ public ProxyBackendHandler(Channel inboundChannel) {
+ this.inboundChannel = inboundChannel;
+ }
+
+ @Override
+ public void channelActive(ChannelHandlerContext ctx) {
+ if (!inboundChannel.isActive()) {
+ ProxyFrontendHandler.closeOnFlush(ctx.channel());
+ } else {
+ ctx.read();
+ }
+ }
+
+ @Override
+ public void channelRead(final ChannelHandlerContext ctx, Object msg) {
+ inboundChannel.writeAndFlush(msg).addListener(new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) {
+ if (future.isSuccess()) {
+ ctx.channel().read();
+ } else {
+ future.channel().close();
+ }
+ }
+ });
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) {
+ ProxyFrontendHandler.closeOnFlush(inboundChannel);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ cause.printStackTrace();
+ ProxyFrontendHandler.closeOnFlush(ctx.channel());
+ }
+ }
+
+
+}