@@ -30,6 +30,8 @@ import (
30
30
"os"
31
31
"path"
32
32
"strings"
33
+ "sync"
34
+ "sync/atomic"
33
35
"time"
34
36
35
37
"github.com/gravitational/trace"
@@ -167,7 +169,7 @@ func (s *SSHMultiplexerService) writeArtifacts(ctx context.Context, proxyHost st
167
169
168
170
func (s * SSHMultiplexerService ) setup (ctx context.Context ) (
169
171
_ * authclient.Client ,
170
- _ * proxyclient. Client ,
172
+ _ * cyclingHostDialClient ,
171
173
proxyHost string ,
172
174
_ * libclient.TSHConfig ,
173
175
_ error ,
@@ -228,7 +230,7 @@ func (s *SSHMultiplexerService) setup(ctx context.Context) (
228
230
}
229
231
230
232
// Create Proxy and Auth clients
231
- proxyClient , err := proxyclient . NewClient ( ctx , proxyclient.ClientConfig {
233
+ proxyClient := newCyclingHostDialClient ( 100 , proxyclient.ClientConfig {
232
234
ProxyAddress : proxyAddr ,
233
235
TLSRoutingEnabled : proxyPing .Proxy .TLSRoutingEnabled ,
234
236
TLSConfigFunc : func (cluster string ) (* tls.Config , error ) {
@@ -253,21 +255,12 @@ func (s *SSHMultiplexerService) setup(ctx context.Context) (
253
255
SSHConfig : sshConfig ,
254
256
InsecureSkipVerify : s .botCfg .Insecure ,
255
257
ALPNConnUpgradeRequired : connUpgradeRequired ,
256
-
257
- // Here we use a special dial context that will create a new connection
258
- // after the cycleCount has been reached. This prevents too many SSH
259
- // connections from sharing the same upstream connection.
260
- DialContext : newDialCycling (100 ),
261
258
})
262
- if err != nil {
263
- return nil , nil , "" , nil , trace .Wrap (err )
264
- }
265
259
266
260
authClient , err := clientForFacade (
267
261
ctx , s .log , s .botCfg , s .identity , s .resolver ,
268
262
)
269
263
if err != nil {
270
- _ = proxyClient .Close ()
271
264
return nil , nil , "" , nil , trace .Wrap (err )
272
265
}
273
266
@@ -371,12 +364,11 @@ func (s *SSHMultiplexerService) Run(ctx context.Context) (err error) {
371
364
)
372
365
defer func () { tracing .EndSpan (span , err ) }()
373
366
374
- authClient , proxyClient , proxyHost , tshConfig , err := s .setup (ctx )
367
+ authClient , hostDialer , proxyHost , tshConfig , err := s .setup (ctx )
375
368
if err != nil {
376
369
return trace .Wrap (err )
377
370
}
378
371
defer authClient .Close ()
379
- defer proxyClient .Close ()
380
372
381
373
dest := s .cfg .Destination .(* config.DestinationDirectory )
382
374
l , err := createListener (
@@ -416,7 +408,7 @@ func (s *SSHMultiplexerService) Run(ctx context.Context) (err error) {
416
408
go func () {
417
409
inflightConnectionsGauge .Inc ()
418
410
err := s .handleConn (
419
- egCtx , tshConfig , authClient , proxyClient , proxyHost , downstream ,
411
+ egCtx , tshConfig , authClient , hostDialer , proxyHost , downstream ,
420
412
)
421
413
inflightConnectionsGauge .Dec ()
422
414
status := "OK"
@@ -454,7 +446,7 @@ func (s *SSHMultiplexerService) handleConn(
454
446
ctx context.Context ,
455
447
tshConfig * libclient.TSHConfig ,
456
448
authClient * authclient.Client ,
457
- proxyClient * proxyclient. Client ,
449
+ hostDialer * cyclingHostDialClient ,
458
450
proxyHost string ,
459
451
downstream net.Conn ,
460
452
) (err error ) {
@@ -530,7 +522,7 @@ func (s *SSHMultiplexerService) handleConn(
530
522
target = net .JoinHostPort (node .GetName (), "0" )
531
523
}
532
524
533
- upstream , _ , err := proxyClient .DialHost (ctx , target , clusterName , nil )
525
+ upstream , _ , err := hostDialer .DialHost (ctx , target , clusterName , nil )
534
526
if err != nil {
535
527
return trace .Wrap (err )
536
528
}
@@ -544,7 +536,7 @@ func (s *SSHMultiplexerService) handleConn(
544
536
// if the connection is being resumed, it means that
545
537
// we didn't need the agent in the first place
546
538
var noAgent agent.ExtendedAgent
547
- conn , _ , err := proxyClient .DialHost (
539
+ conn , _ , err := hostDialer .DialHost (
548
540
ctx , net .JoinHostPort (hostID , "0" ), clusterName , noAgent ,
549
541
)
550
542
return conn , err
@@ -576,15 +568,68 @@ func (s *SSHMultiplexerService) String() string {
576
568
577
569
// cyclingHostDialClient
578
570
type cyclingHostDialClient struct {
579
- max int
571
+ max int32
580
572
config proxyclient.ClientConfig
581
- inner * proxyclient.Client
573
+
574
+ mu sync.Mutex
575
+ started int32
576
+ currentClt * refCountProxyClient
577
+ }
578
+
579
+ type refCountProxyClient struct {
580
+ clt * proxyclient.Client
581
+ refCount atomic.Int32
582
+ }
583
+
584
+ type refCountConn struct {
585
+ net.Conn
586
+ parent atomic.Pointer [refCountProxyClient ]
582
587
}
583
588
584
- func newCyclingHostDialClient (max int32 , config * proxyclient.ClientConfig ) * cyclingHostDialClient {
589
+ func (r * refCountConn ) Close () error {
590
+ err := r .Conn .Close ()
591
+ // Swap operation ensures only one of the conns closes the underlying
592
+ // client.
593
+ if parent := r .parent .Swap (nil ); parent != nil {
594
+ if parent .refCount .Add (- 1 ) <= 0 {
595
+ go parent .clt .Close ()
596
+ }
597
+ }
598
+ return trace .Wrap (err )
599
+ }
600
+
601
+ func newCyclingHostDialClient (max int32 , config proxyclient.ClientConfig ) * cyclingHostDialClient {
585
602
return & cyclingHostDialClient {max : max , config : config }
586
603
}
587
604
588
605
func (s * cyclingHostDialClient ) DialHost (ctx context.Context , target string , cluster string , keyring agent.ExtendedAgent ) (net.Conn , proxyclient.ClusterDetails , error ) {
589
- return s .inner .DialHost (ctx , target , cluster , keyring )
606
+ s .mu .Lock ()
607
+ if s .currentClt == nil {
608
+ clt , err := proxyclient .NewClient (ctx , s .config )
609
+ if err != nil {
610
+ s .mu .Unlock ()
611
+ return nil , proxyclient.ClusterDetails {}, trace .Wrap (err )
612
+ }
613
+ s .currentClt = & refCountProxyClient {clt : clt }
614
+ s .started = 0
615
+ }
616
+
617
+ currentClt := s .currentClt
618
+ s .started ++
619
+ if s .started > s .max {
620
+ s .currentClt = nil
621
+ }
622
+ s .mu .Unlock ()
623
+
624
+ innerConn , details , err := currentClt .clt .DialHost (ctx , target , cluster , keyring )
625
+ if err != nil {
626
+ return nil , details , trace .Wrap (err )
627
+ }
628
+ currentClt .refCount .Add (1 )
629
+
630
+ wrappedConn := & refCountConn {
631
+ Conn : innerConn ,
632
+ }
633
+ wrappedConn .parent .Store (currentClt )
634
+ return wrappedConn , details , nil
590
635
}
0 commit comments