Skip to content

Commit 666cc3e

Browse files
committed
Fixed NPE with SslMode.TUNNEL Usage
Motivation: A NPE was identified when utilizing `SslMode.TUNNEL`. The issue arises when `ConnectionContext#isMariaDb` is invoked from `SslBridgeHandler#isTls13Enabled`, leading to an NPE due to the `ConnectionContext` not being initialized at that time. Modification: Do not invoke `ConnectionContext#isMariaDb` when it is not initialized. Result: This change addresses the NPE issue, ensuring stability when `SslMode.TUNNEL` is selected. resolves GoogleCloudPlatform/cloud-sql-jdbc-socket-factory#1828
1 parent 6982acc commit 666cc3e

File tree

4 files changed

+314
-2
lines changed

4 files changed

+314
-2
lines changed

pom.xml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
<mbr.version>0.3.0.RELEASE</mbr.version>
8080
<jsr305.version>3.0.2</jsr305.version>
8181
<java-annotations.version>24.1.0</java-annotations.version>
82+
<bouncy-castle.version>1.77</bouncy-castle.version>
8283
</properties>
8384

8485
<dependencyManagement>
@@ -117,6 +118,12 @@
117118
<version>${java-annotations.version}</version>
118119
<scope>provided</scope>
119120
</dependency>
121+
<dependency>
122+
<groupId>org.bouncycastle</groupId>
123+
<artifactId>bcpkix-jdk18on</artifactId>
124+
<version>${bouncy-castle.version}</version>
125+
<scope>test</scope>
126+
</dependency>
120127
</dependencies>
121128
</dependencyManagement>
122129

@@ -240,6 +247,11 @@
240247
<artifactId>jackson-annotations</artifactId>
241248
<scope>test</scope>
242249
</dependency>
250+
<dependency>
251+
<groupId>org.bouncycastle</groupId>
252+
<artifactId>bcpkix-jdk18on</artifactId>
253+
<scope>test</scope>
254+
</dependency>
243255
</dependencies>
244256

245257
<build>

src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public final class ConnectionContext implements CodecContext {
5757
*/
5858
private volatile short serverStatuses = ServerStatuses.AUTO_COMMIT;
5959

60+
@Nullable
6061
private volatile Capability capability = null;
6162

6263
ConnectionContext(ZeroDateOption zeroDateOption, @Nullable Path localInfilePath,

src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,10 @@ static MySqlSslContextSpec forClient(MySqlSslConfiguration ssl, ConnectionContex
220220
.applicationProtocolConfig(null);
221221
String[] tlsProtocols = ssl.getTlsVersion();
222222

223-
if (tlsProtocols.length > 0) {
224-
builder.protocols(tlsProtocols);
223+
if (tlsProtocols.length > 0 || ssl.getSslMode() == SslMode.TUNNEL) {
224+
if (tlsProtocols.length > 0) {
225+
builder.protocols(tlsProtocols);
226+
}
225227
} else if (isTls13Enabled(context)) {
226228
builder.protocols(TLS_PROTOCOLS);
227229
} else {
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
/*
2+
* Copyright 2024 asyncer.io projects
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.asyncer.r2dbc.mysql;
18+
19+
20+
import io.asyncer.r2dbc.mysql.constant.SslMode;
21+
import io.netty.bootstrap.Bootstrap;
22+
import io.netty.bootstrap.ServerBootstrap;
23+
import io.netty.buffer.Unpooled;
24+
import io.netty.channel.Channel;
25+
import io.netty.channel.ChannelFuture;
26+
import io.netty.channel.ChannelFutureListener;
27+
import io.netty.channel.ChannelHandlerContext;
28+
import io.netty.channel.ChannelInboundHandlerAdapter;
29+
import io.netty.channel.ChannelInitializer;
30+
import io.netty.channel.ChannelOption;
31+
import io.netty.channel.nio.NioEventLoopGroup;
32+
import io.netty.channel.socket.SocketChannel;
33+
import io.netty.channel.socket.nio.NioServerSocketChannel;
34+
import io.netty.handler.ssl.SslContext;
35+
import io.netty.handler.ssl.SslContextBuilder;
36+
import io.netty.handler.ssl.util.SelfSignedCertificate;
37+
import org.junit.After;
38+
import org.junit.Before;
39+
import org.junit.Test;
40+
41+
import javax.net.ssl.SSLException;
42+
import java.net.InetSocketAddress;
43+
import java.security.cert.CertificateException;
44+
import java.time.Duration;
45+
46+
import static org.assertj.core.api.Assertions.assertThat;
47+
48+
public class SslTunnelIntegrationTest {
49+
50+
private SelfSignedCertificate server;
51+
52+
private SelfSignedCertificate client;
53+
54+
private SslTunnelServer sslTunnelServer;
55+
56+
@Before
57+
public void setUp() throws CertificateException, SSLException {
58+
server = new SelfSignedCertificate();
59+
client = new SelfSignedCertificate();
60+
final SslContext sslContext = SslContextBuilder.forServer(server.key(), server.cert()).build();
61+
sslTunnelServer = new SslTunnelServer("localhost", 3306, sslContext);
62+
sslTunnelServer.setUp();
63+
}
64+
65+
@After
66+
public void tearDown() throws InterruptedException {
67+
server.delete();
68+
client.delete();
69+
sslTunnelServer.tearDown();
70+
}
71+
72+
@Test
73+
public void sslTunnelConnectionTest() {
74+
final String password = System.getProperty("test.mysql.password");
75+
assertThat(password).withFailMessage("Property test.mysql.password must exists and not be empty")
76+
.isNotNull()
77+
.isNotEmpty();
78+
79+
final MySqlConnectionConfiguration configuration = MySqlConnectionConfiguration
80+
.builder()
81+
.host("localhost")
82+
.port(sslTunnelServer.getLocalPort())
83+
.connectTimeout(Duration.ofSeconds(3))
84+
.user("root")
85+
.password(password)
86+
.database("r2dbc")
87+
.createDatabaseIfNotExist(true)
88+
.sslMode(SslMode.TUNNEL)
89+
.sslKey(client.privateKey().getAbsolutePath())
90+
.sslCert(client.certificate().getAbsolutePath())
91+
.sslCa(server.certificate().getAbsolutePath())
92+
.build();
93+
94+
final MySqlConnectionFactory connectionFactory = MySqlConnectionFactory.from(configuration);
95+
96+
final MySqlConnection connection = connectionFactory.create().block();
97+
connection.createStatement("SELECT 3").execute()
98+
.flatMap(it -> it.map((row, rowMetadata) -> row.get(0)))
99+
.doOnNext(it -> assertThat(it).isEqualTo(3L))
100+
.blockLast();
101+
102+
connection.close().block();
103+
}
104+
105+
private static class SslTunnelServer {
106+
107+
private final String remoteHost;
108+
109+
private final int remotePort;
110+
111+
private final SslContext sslContext;
112+
113+
private volatile ChannelFuture channelFuture;
114+
115+
116+
private SslTunnelServer(String remoteHost, int remotePort, SslContext sslContext) {
117+
this.remoteHost = remoteHost;
118+
this.remotePort = remotePort;
119+
this.sslContext = sslContext;
120+
}
121+
122+
void setUp() {
123+
// Configure the server.
124+
try {
125+
ServerBootstrap b = new ServerBootstrap();
126+
b.localAddress(0)
127+
.group(new NioEventLoopGroup())
128+
.channel(NioServerSocketChannel.class)
129+
.childHandler(new ProxyInitializer(remoteHost, remotePort, sslContext))
130+
.childOption(ChannelOption.AUTO_READ, false);
131+
132+
133+
// Start the server.
134+
channelFuture = b.bind().sync();
135+
136+
} catch (InterruptedException e) {
137+
e.printStackTrace();
138+
}
139+
}
140+
141+
void tearDown() throws InterruptedException {
142+
channelFuture.channel().close().sync();
143+
}
144+
145+
int getLocalPort() {
146+
return ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
147+
}
148+
149+
}
150+
151+
152+
private static class ProxyInitializer extends ChannelInitializer<SocketChannel> {
153+
154+
private final String remoteHost;
155+
156+
private final int remotePort;
157+
158+
private final SslContext sslContext;
159+
160+
ProxyInitializer(String remoteHost, int remotePort, SslContext sslContext) {
161+
this.remoteHost = remoteHost;
162+
this.remotePort = remotePort;
163+
this.sslContext = sslContext;
164+
}
165+
166+
@Override
167+
public void initChannel(SocketChannel ch) {
168+
ch.pipeline().addLast(sslContext.newHandler(ch.alloc()));
169+
ch.pipeline().addLast(new ProxyFrontendHandler(remoteHost, remotePort));
170+
}
171+
}
172+
173+
private static class ProxyFrontendHandler extends ChannelInboundHandlerAdapter {
174+
175+
private final String remoteHost;
176+
private final int remotePort;
177+
178+
// As we use inboundChannel.eventLoop() when building the Bootstrap this does not need to be volatile as
179+
// the outboundChannel will use the same EventLoop (and therefore Thread) as the inboundChannel.
180+
private Channel outboundChannel;
181+
182+
public ProxyFrontendHandler(String remoteHost, int remotePort) {
183+
this.remoteHost = remoteHost;
184+
this.remotePort = remotePort;
185+
}
186+
187+
@Override
188+
public void channelActive(ChannelHandlerContext ctx) {
189+
final Channel inboundChannel = ctx.channel();
190+
191+
// Start the connection attempt.
192+
Bootstrap b = new Bootstrap();
193+
b.group(inboundChannel.eventLoop())
194+
.channel(ctx.channel().getClass())
195+
.handler(new ProxyBackendHandler(inboundChannel))
196+
.option(ChannelOption.AUTO_READ, false);
197+
ChannelFuture f = b.connect(remoteHost, remotePort);
198+
outboundChannel = f.channel();
199+
f.addListener(new ChannelFutureListener() {
200+
@Override
201+
public void operationComplete(ChannelFuture future) {
202+
if (future.isSuccess()) {
203+
// connection complete start to read first data
204+
inboundChannel.read();
205+
} else {
206+
// Close the connection if the connection attempt has failed.
207+
inboundChannel.close();
208+
}
209+
}
210+
});
211+
}
212+
213+
@Override
214+
public void channelRead(final ChannelHandlerContext ctx, Object msg) {
215+
if (outboundChannel.isActive()) {
216+
outboundChannel.writeAndFlush(msg).addListener(new ChannelFutureListener() {
217+
@Override
218+
public void operationComplete(ChannelFuture future) {
219+
if (future.isSuccess()) {
220+
// was able to flush out data, start to read the next chunk
221+
ctx.channel().read();
222+
} else {
223+
future.channel().close();
224+
}
225+
}
226+
});
227+
}
228+
}
229+
230+
@Override
231+
public void channelInactive(ChannelHandlerContext ctx) {
232+
if (outboundChannel != null) {
233+
closeOnFlush(outboundChannel);
234+
}
235+
}
236+
237+
@Override
238+
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
239+
cause.printStackTrace();
240+
closeOnFlush(ctx.channel());
241+
}
242+
243+
/**
244+
* Closes the specified channel after all queued write requests are flushed.
245+
*/
246+
static void closeOnFlush(Channel ch) {
247+
if (ch.isActive()) {
248+
ch.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
249+
}
250+
}
251+
}
252+
253+
private static class ProxyBackendHandler extends ChannelInboundHandlerAdapter {
254+
255+
private final Channel inboundChannel;
256+
257+
public ProxyBackendHandler(Channel inboundChannel) {
258+
this.inboundChannel = inboundChannel;
259+
}
260+
261+
@Override
262+
public void channelActive(ChannelHandlerContext ctx) {
263+
if (!inboundChannel.isActive()) {
264+
ProxyFrontendHandler.closeOnFlush(ctx.channel());
265+
} else {
266+
ctx.read();
267+
}
268+
}
269+
270+
@Override
271+
public void channelRead(final ChannelHandlerContext ctx, Object msg) {
272+
inboundChannel.writeAndFlush(msg).addListener(new ChannelFutureListener() {
273+
@Override
274+
public void operationComplete(ChannelFuture future) {
275+
if (future.isSuccess()) {
276+
ctx.channel().read();
277+
} else {
278+
future.channel().close();
279+
}
280+
}
281+
});
282+
}
283+
284+
@Override
285+
public void channelInactive(ChannelHandlerContext ctx) {
286+
ProxyFrontendHandler.closeOnFlush(inboundChannel);
287+
}
288+
289+
@Override
290+
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
291+
cause.printStackTrace();
292+
ProxyFrontendHandler.closeOnFlush(ctx.channel());
293+
}
294+
}
295+
296+
297+
}

0 commit comments

Comments
 (0)