Skip to content

Commit 0be3dda

Browse files
committed
Replace low-level cycler with higher-level but broken cycler
1 parent f0be1c0 commit 0be3dda

File tree

3 files changed

+67
-200
lines changed

3 files changed

+67
-200
lines changed

api/client/proxy/client.go

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"context"
1919
"crypto/tls"
2020
"encoding/asn1"
21-
"io"
2221
"net"
2322
"slices"
2423
"sync/atomic"
@@ -74,23 +73,13 @@ type ClientConfig struct {
7473
// Used by proxy's web server to make calls on behalf of connected clients.
7574
PROXYHeaderGetter client.PROXYHeaderGetter
7675

77-
// DialContext allows a custom grpc.ClientConnInterface to be used by the
78-
// client. This allows for more customized behavior (e.g cycling the
79-
// underlying connection every X connections).
80-
DialContext func(ctx context.Context, target string, opts ...grpc.DialOption) (grpcClientConnInterfaceCloser, error)
81-
8276
// The below items are intended to be used by tests to connect without mTLS.
8377
// The gRPC transport credentials to use when establishing the connection to proxy.
8478
creds func(cluster string) (credentials.TransportCredentials, error)
8579
// The client credentials to use when establishing the connection to auth.
8680
clientCreds func(cluster string) (client.Credentials, error)
8781
}
8882

89-
type grpcClientConnInterfaceCloser = interface {
90-
grpc.ClientConnInterface
91-
io.Closer
92-
}
93-
9483
// CheckAndSetDefaults ensures required options are present and
9584
// sets the default value of any that are omitted.
9685
func (c *ClientConfig) CheckAndSetDefaults() error {
@@ -286,14 +275,7 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error
286275
return nil, trace.Wrap(err)
287276
}
288277

289-
dialContext := cfg.DialContext
290-
if dialContext == nil {
291-
dialContext = func(ctx context.Context, target string, opts ...grpc.DialOption) (grpcClientConnInterfaceCloser, error) {
292-
return grpc.DialContext(ctx, target, opts...)
293-
}
294-
}
295-
296-
conn, err := dialContext(
278+
conn, err := grpc.DialContext(
297279
dialCtx,
298280
cfg.ProxyAddress,
299281
append([]grpc.DialOption{

lib/tbot/dial_cycling.go

Lines changed: 0 additions & 160 deletions
This file was deleted.

lib/tbot/service_ssh_multiplexer.go

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ import (
3030
"os"
3131
"path"
3232
"strings"
33+
"sync"
34+
"sync/atomic"
3335
"time"
3436

3537
"github.com/gravitational/trace"
@@ -167,7 +169,7 @@ func (s *SSHMultiplexerService) writeArtifacts(ctx context.Context, proxyHost st
167169

168170
func (s *SSHMultiplexerService) setup(ctx context.Context) (
169171
_ *authclient.Client,
170-
_ *proxyclient.Client,
172+
_ *cyclingHostDialClient,
171173
proxyHost string,
172174
_ *libclient.TSHConfig,
173175
_ error,
@@ -228,7 +230,7 @@ func (s *SSHMultiplexerService) setup(ctx context.Context) (
228230
}
229231

230232
// Create Proxy and Auth clients
231-
proxyClient, err := proxyclient.NewClient(ctx, proxyclient.ClientConfig{
233+
proxyClient := newCyclingHostDialClient(100, proxyclient.ClientConfig{
232234
ProxyAddress: proxyAddr,
233235
TLSRoutingEnabled: proxyPing.Proxy.TLSRoutingEnabled,
234236
TLSConfigFunc: func(cluster string) (*tls.Config, error) {
@@ -253,21 +255,12 @@ func (s *SSHMultiplexerService) setup(ctx context.Context) (
253255
SSHConfig: sshConfig,
254256
InsecureSkipVerify: s.botCfg.Insecure,
255257
ALPNConnUpgradeRequired: connUpgradeRequired,
256-
257-
// Here we use a special dial context that will create a new connection
258-
// after the cycleCount has been reached. This prevents too many SSH
259-
// connections from sharing the same upstream connection.
260-
DialContext: newDialCycling(100),
261258
})
262-
if err != nil {
263-
return nil, nil, "", nil, trace.Wrap(err)
264-
}
265259

266260
authClient, err := clientForFacade(
267261
ctx, s.log, s.botCfg, s.identity, s.resolver,
268262
)
269263
if err != nil {
270-
_ = proxyClient.Close()
271264
return nil, nil, "", nil, trace.Wrap(err)
272265
}
273266

@@ -371,12 +364,11 @@ func (s *SSHMultiplexerService) Run(ctx context.Context) (err error) {
371364
)
372365
defer func() { tracing.EndSpan(span, err) }()
373366

374-
authClient, proxyClient, proxyHost, tshConfig, err := s.setup(ctx)
367+
authClient, hostDialer, proxyHost, tshConfig, err := s.setup(ctx)
375368
if err != nil {
376369
return trace.Wrap(err)
377370
}
378371
defer authClient.Close()
379-
defer proxyClient.Close()
380372

381373
dest := s.cfg.Destination.(*config.DestinationDirectory)
382374
l, err := createListener(
@@ -416,7 +408,7 @@ func (s *SSHMultiplexerService) Run(ctx context.Context) (err error) {
416408
go func() {
417409
inflightConnectionsGauge.Inc()
418410
err := s.handleConn(
419-
egCtx, tshConfig, authClient, proxyClient, proxyHost, downstream,
411+
egCtx, tshConfig, authClient, hostDialer, proxyHost, downstream,
420412
)
421413
inflightConnectionsGauge.Dec()
422414
status := "OK"
@@ -454,7 +446,7 @@ func (s *SSHMultiplexerService) handleConn(
454446
ctx context.Context,
455447
tshConfig *libclient.TSHConfig,
456448
authClient *authclient.Client,
457-
proxyClient *proxyclient.Client,
449+
hostDialer *cyclingHostDialClient,
458450
proxyHost string,
459451
downstream net.Conn,
460452
) (err error) {
@@ -530,7 +522,7 @@ func (s *SSHMultiplexerService) handleConn(
530522
target = net.JoinHostPort(node.GetName(), "0")
531523
}
532524

533-
upstream, _, err := proxyClient.DialHost(ctx, target, clusterName, nil)
525+
upstream, _, err := hostDialer.DialHost(ctx, target, clusterName, nil)
534526
if err != nil {
535527
return trace.Wrap(err)
536528
}
@@ -544,7 +536,7 @@ func (s *SSHMultiplexerService) handleConn(
544536
// if the connection is being resumed, it means that
545537
// we didn't need the agent in the first place
546538
var noAgent agent.ExtendedAgent
547-
conn, _, err := proxyClient.DialHost(
539+
conn, _, err := hostDialer.DialHost(
548540
ctx, net.JoinHostPort(hostID, "0"), clusterName, noAgent,
549541
)
550542
return conn, err
@@ -576,15 +568,68 @@ func (s *SSHMultiplexerService) String() string {
576568

577569
// cyclingHostDialClient
578570
type cyclingHostDialClient struct {
579-
max int
571+
max int32
580572
config proxyclient.ClientConfig
581-
inner *proxyclient.Client
573+
574+
mu sync.Mutex
575+
started int32
576+
currentClt *refCountProxyClient
577+
}
578+
579+
type refCountProxyClient struct {
580+
clt *proxyclient.Client
581+
refCount atomic.Int32
582+
}
583+
584+
type refCountConn struct {
585+
net.Conn
586+
parent atomic.Pointer[refCountProxyClient]
582587
}
583588

584-
func newCyclingHostDialClient(max int32, config *proxyclient.ClientConfig) *cyclingHostDialClient {
589+
func (r *refCountConn) Close() error {
590+
err := r.Conn.Close()
591+
// Swap operation ensures only one of the conns closes the underlying
592+
// client.
593+
if parent := r.parent.Swap(nil); parent != nil {
594+
if parent.refCount.Add(-1) <= 0 {
595+
go parent.clt.Close()
596+
}
597+
}
598+
return trace.Wrap(err)
599+
}
600+
601+
func newCyclingHostDialClient(max int32, config proxyclient.ClientConfig) *cyclingHostDialClient {
585602
return &cyclingHostDialClient{max: max, config: config}
586603
}
587604

588605
func (s *cyclingHostDialClient) DialHost(ctx context.Context, target string, cluster string, keyring agent.ExtendedAgent) (net.Conn, proxyclient.ClusterDetails, error) {
589-
return s.inner.DialHost(ctx, target, cluster, keyring)
606+
s.mu.Lock()
607+
if s.currentClt == nil {
608+
clt, err := proxyclient.NewClient(ctx, s.config)
609+
if err != nil {
610+
s.mu.Unlock()
611+
return nil, proxyclient.ClusterDetails{}, trace.Wrap(err)
612+
}
613+
s.currentClt = &refCountProxyClient{clt: clt}
614+
s.started = 0
615+
}
616+
617+
currentClt := s.currentClt
618+
s.started++
619+
if s.started > s.max {
620+
s.currentClt = nil
621+
}
622+
s.mu.Unlock()
623+
624+
innerConn, details, err := currentClt.clt.DialHost(ctx, target, cluster, keyring)
625+
if err != nil {
626+
return nil, details, trace.Wrap(err)
627+
}
628+
currentClt.refCount.Add(1)
629+
630+
wrappedConn := &refCountConn{
631+
Conn: innerConn,
632+
}
633+
wrappedConn.parent.Store(currentClt)
634+
return wrappedConn, details, nil
590635
}

0 commit comments

Comments
 (0)