@@ -171,8 +171,6 @@ type agent struct {
171
171
doneConnecting chan struct {}
172
172
// hbChannel is the channel heartbeats are sent over.
173
173
hbChannel * tracessh.Channel
174
- // hbRequests are requests going over the heartbeat channel.
175
- hbRequests <- chan * ssh.Request
176
174
// discoveryC receives new discovery channels.
177
175
discoveryC <- chan ssh.NewChannel
178
176
// transportC receives new tranport channels.
@@ -335,9 +333,11 @@ func (a *agent) Start(ctx context.Context) error {
335
333
a .drainWG .Add (1 )
336
334
a .wg .Add (1 )
337
335
go func () {
338
- if err := a .handleDrainChannels (); err != nil {
336
+ drainWGDone := sync .OnceFunc (a .drainWG .Done )
337
+ if err := a .handleDrainChannels (drainWGDone ); err != nil {
339
338
a .log .WithError (err ).Debug ("Failed to handle drainable channels." )
340
339
}
340
+ drainWGDone ()
341
341
a .wg .Done ()
342
342
a .Stop ()
343
343
}()
@@ -407,9 +407,9 @@ func (a *agent) sendFirstHeartbeat(ctx context.Context) error {
407
407
return trace .Wrap (err )
408
408
}
409
409
sshutils .DiscardChannelData (channel )
410
+ go ssh .DiscardRequests (requests )
410
411
411
412
a .hbChannel = channel
412
- a .hbRequests = requests
413
413
414
414
// Send the first ping right away.
415
415
if _ , err := a .hbChannel .SendRequest (ctx , "ping" , false , nil ); err != nil {
@@ -450,9 +450,8 @@ func (a *agent) Stop() error {
450
450
func (a * agent ) handleGlobalRequests (ctx context.Context , requests <- chan * ssh.Request ) error {
451
451
for {
452
452
select {
453
- case r := <- requests :
454
- // The request will be nil when the request channel is closing.
455
- if r == nil {
453
+ case r , ok := <- requests :
454
+ if ! ok {
456
455
return trace .Errorf ("global request channel is closing" )
457
456
}
458
457
@@ -497,59 +496,29 @@ func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.R
497
496
}
498
497
}
499
498
500
- func (a * agent ) isDraining () bool {
501
- return a .drainCtx .Err () != nil
502
- }
503
-
504
- // signalDraining will signal one time when the draining context is canceled.
505
- func (a * agent ) signalDraining () <- chan struct {} {
506
- c := make (chan struct {})
507
- a .wg .Add (1 )
508
- go func () {
509
- <- a .drainCtx .Done ()
510
- close (c )
511
- a .wg .Done ()
512
- }()
513
-
514
- return c
515
- }
516
-
517
499
// handleDrainChannels handles channels that should be stopped when the agent is draining.
518
- func (a * agent ) handleDrainChannels () error {
500
+ func (a * agent ) handleDrainChannels (drainWGDone func () ) error {
519
501
ticker := time .NewTicker (a .keepAlive )
520
502
defer ticker .Stop ()
521
503
522
- // once ensures drainWG.Done() is called one more time
523
- // after no more transports will be created.
524
- once := & sync.Once {}
525
- drainWGDone := func () {
526
- once .Do (func () {
527
- a .drainWG .Done ()
528
- })
529
- }
530
- defer drainWGDone ()
531
- drainSignal := a .signalDraining ()
504
+ drainCtxDone := a .drainCtx .Done ()
532
505
533
506
for {
534
- if a .isDraining () {
535
- drainWGDone ()
536
- }
537
-
538
507
select {
539
508
case <- a .ctx .Done ():
540
509
return nil
541
- // Signal once when the drain context is canceled to ensure we unblock
542
- // to call drainWG.Done().
543
- case <- drainSignal :
544
- continue
545
- // Handle closed heartbeat channel.
546
- case req := <- a . hbRequests :
547
- if req == nil {
548
- return trace . ConnectionProblem ( nil , "heartbeat: connection closed" )
549
- }
510
+ case <- drainCtxDone :
511
+ // we synchronously do this here rather than using
512
+ // [context.AfterFunc] so we don't accidentally increase drainWG
513
+ // from 0 while something else might already be waiting
514
+ drainWGDone ()
515
+ // don't re-enter this case of the select
516
+ drainCtxDone = nil
517
+ // for good measure
518
+ ticker . Stop ()
550
519
// Send ping over heartbeat channel.
551
520
case <- ticker .C :
552
- if a .isDraining () {
521
+ if a .drainCtx . Err () != nil {
553
522
continue
554
523
}
555
524
bytes , _ := a .clock .Now ().UTC ().MarshalText ()
@@ -560,11 +529,16 @@ func (a *agent) handleDrainChannels() error {
560
529
}
561
530
a .log .Debugf ("Ping -> %v." , a .client .RemoteAddr ())
562
531
// Handle transport requests.
563
- case nch := <- a .transportC :
564
- if nch == nil {
565
- continue
532
+ case nch , ok := <- a .transportC :
533
+ if ! ok {
534
+ return trace . ConnectionProblem ( nil , "transport: connection closed" )
566
535
}
567
- if a .isDraining () {
536
+
537
+ // once drainWGDone is called we can't add to the drain waitgroup so
538
+ // we have to reject transport requests beforehand; it gets called
539
+ // in this loop after drainCtx is done, so checking for the context
540
+ // error here is a stronger condition
541
+ if a .drainCtx .Err () != nil {
568
542
err := nch .Reject (ssh .ConnectionFailed , "agent connection is draining" )
569
543
if err != nil {
570
544
a .log .WithError (err ).Warningf ("Failed to reject transport channel." )
@@ -597,9 +571,9 @@ func (a *agent) handleChannels() error {
597
571
case <- a .ctx .Done ():
598
572
return nil
599
573
// new discovery request channel
600
- case nch := <- a .discoveryC :
601
- if nch == nil {
602
- continue
574
+ case nch , ok := <- a .discoveryC :
575
+ if ! ok {
576
+ return nil
603
577
}
604
578
a .log .Debugf ("Discovery request channel opened: %v." , nch .ChannelType ())
605
579
ch , req , err := nch .Accept ()
0 commit comments