diff --git a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java index 0da51bf47f7..238e8e9d4c9 100644 --- a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java +++ b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java @@ -31,11 +31,17 @@ import io.grpc.Status; import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.XdsTransportFactory; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; final class GrpcXdsTransportFactory implements XdsTransportFactory { private final CallCredentials callCredentials; + // The map of xDS server info to its corresponding gRPC xDS transport. + // This enables reusing and sharing the same underlying gRPC channel. + private static final Map xdsServerInfoToTransportMap = + new ConcurrentHashMap<>(); GrpcXdsTransportFactory(CallCredentials callCredentials) { this.callCredentials = callCredentials; @@ -43,12 +49,25 @@ final class GrpcXdsTransportFactory implements XdsTransportFactory { @Override public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { - return new GrpcXdsTransport(serverInfo, callCredentials); + return xdsServerInfoToTransportMap.compute( + serverInfo, + (info, transport) -> { + if (transport == null) { + transport = new GrpcXdsTransport(serverInfo, callCredentials); + } + ++transport.refCount; + return transport; + }); } @VisibleForTesting public XdsTransport createForTest(ManagedChannel channel) { - return new GrpcXdsTransport(channel, callCredentials); + return new GrpcXdsTransport(channel, callCredentials, null); + } + + @VisibleForTesting + static boolean hasTransport(Bootstrapper.ServerInfo serverInfo) { + return xdsServerInfoToTransportMap.containsKey(serverInfo); } @VisibleForTesting @@ -56,6 +75,9 @@ static class GrpcXdsTransport implements XdsTransport { private final ManagedChannel channel; private final CallCredentials callCredentials; + private final Bootstrapper.ServerInfo serverInfo; + // Must only be accessed within the provided atomic methods of ConcurrentHashMap. + private int refCount = 0; public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) { this(serverInfo, null); @@ -63,7 +85,7 @@ public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) { @VisibleForTesting public GrpcXdsTransport(ManagedChannel channel) { - this(channel, null); + this(channel, null, null); } public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials callCredentials) { @@ -73,12 +95,17 @@ public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials call .keepAliveTime(5, TimeUnit.MINUTES) .build(); this.callCredentials = callCredentials; + this.serverInfo = serverInfo; } @VisibleForTesting - public GrpcXdsTransport(ManagedChannel channel, CallCredentials callCredentials) { + public GrpcXdsTransport( + ManagedChannel channel, + CallCredentials callCredentials, + Bootstrapper.ServerInfo serverInfo) { this.channel = checkNotNull(channel, "channel"); this.callCredentials = callCredentials; + this.serverInfo = serverInfo; } @Override @@ -98,7 +125,19 @@ public StreamingCall createStreamingCall( @Override public void shutdown() { - channel.shutdown(); + if (serverInfo == null) { + channel.shutdown(); + return; + } + xdsServerInfoToTransportMap.computeIfPresent( + serverInfo, + (info, transport) -> { + if (--transport.refCount == 0) { // Prefix decrement and return the updated value. + transport.channel.shutdown(); + return null; // Remove mapping. + } + return transport; + }); } private class XdsStreamingCall implements diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java index 66e0d4b3198..e261624b6a4 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java @@ -118,6 +118,68 @@ public void callApis() throws Exception { xdsTransport.shutdown(); } + @Test + public void refCountedXdsTransport_sameXdsServerAddress_returnsExistingTransport() { + Bootstrapper.ServerInfo xdsServerInfo = + Bootstrapper.ServerInfo.create( + "localhost:" + server.getPort(), InsecureChannelCredentials.create()); + GrpcXdsTransportFactory xdsTransportFactory = new GrpcXdsTransportFactory(null); + // Verify calling create() for the first time creates a new GrpcXdsTransport instance. + // The ref count was previously 0 and now is 1. + XdsTransportFactory.XdsTransport transport1 = xdsTransportFactory.create(xdsServerInfo); + assertThat(GrpcXdsTransportFactory.hasTransport(xdsServerInfo)).isTrue(); + // Verify calling create() for the second time to the same xDS server address returns the same + // GrpcXdsTransport instance. The ref count was previously 1 and now is 2. + XdsTransportFactory.XdsTransport transport2 = xdsTransportFactory.create(xdsServerInfo); + assertThat(transport1).isSameInstanceAs(transport2); + assertThat(GrpcXdsTransportFactory.hasTransport(xdsServerInfo)).isTrue(); + // Verify calling shutdown() for the first time does not shut down the GrpcXdsTransport + // instance. The ref count was previously 2 and now is 1. + transport1.shutdown(); + assertThat(GrpcXdsTransportFactory.hasTransport(xdsServerInfo)).isTrue(); + // Verify calling shutdown() for the second time shuts down and cleans up the + // GrpcXdsTransport instance. The ref count was previously 1 and now is 0. + transport2.shutdown(); + assertThat(GrpcXdsTransportFactory.hasTransport(xdsServerInfo)).isFalse(); + } + + @Test + public void refCountedXdsTransport_differentXdsServerAddress_returnsDifferentTransport() + throws Exception { + // Create and start a second xDS serverĀ on a different port. + Server server2 = + Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(echoAdsService()) + .build() + .start(); + Bootstrapper.ServerInfo xdsServerInfo1 = + Bootstrapper.ServerInfo.create( + "localhost:" + server.getPort(), InsecureChannelCredentials.create()); + Bootstrapper.ServerInfo xdsServerInfo2 = + Bootstrapper.ServerInfo.create( + "localhost:" + server2.getPort(), InsecureChannelCredentials.create()); + GrpcXdsTransportFactory xdsTransportFactory = new GrpcXdsTransportFactory(null); + // Verify calling create() to the first xDS server creates a new GrpcXdsTransport instance. + // The ref count was previously 0 and now is 1. + XdsTransportFactory.XdsTransport transport1 = xdsTransportFactory.create(xdsServerInfo1); + assertThat(GrpcXdsTransportFactory.hasTransport(xdsServerInfo1)).isTrue(); + // Verify calling create() to the second xDS server creates a different GrpcXdsTransport + // instance. The ref count was previously 0 and now is 1. + XdsTransportFactory.XdsTransport transport2 = xdsTransportFactory.create(xdsServerInfo2); + assertThat(transport1).isNotSameInstanceAs(transport2); + assertThat(GrpcXdsTransportFactory.hasTransport(xdsServerInfo2)).isTrue(); + // Verify calling shutdown() shuts down and cleans up the GrpcXdsTransport instance for + // the first xDS server. The ref count was previously 1 and now is 0. + transport1.shutdown(); + assertThat(GrpcXdsTransportFactory.hasTransport(xdsServerInfo1)).isFalse(); + // Verify calling shutdown() shuts down and cleans up the GrpcXdsTransport instance for + // the second xDS server. The ref count was previously 1 and now is 0. + transport2.shutdown(); + assertThat(GrpcXdsTransportFactory.hasTransport(xdsServerInfo2)).isFalse(); + // Clean up the second xDS server. + server2.shutdown(); + } + private static class FakeEventHandler implements XdsTransportFactory.EventHandler { private final BlockingQueue respQ = new LinkedBlockingQueue<>();