Skip to content

Commit a6caf7d

Browse files
authored
Clean up some request handling in the reverse tunnel agent (#53200) (#53281)
1 parent e9d1155 commit a6caf7d

File tree

1 file changed

+30
-56
lines changed

1 file changed

+30
-56
lines changed

lib/reversetunnel/agent.go

+30-56
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,6 @@ type agent struct {
171171
doneConnecting chan struct{}
172172
// hbChannel is the channel heartbeats are sent over.
173173
hbChannel *tracessh.Channel
174-
// hbRequests are requests going over the heartbeat channel.
175-
hbRequests <-chan *ssh.Request
176174
// discoveryC receives new discovery channels.
177175
discoveryC <-chan ssh.NewChannel
178176
// transportC receives new tranport channels.
@@ -335,9 +333,11 @@ func (a *agent) Start(ctx context.Context) error {
335333
a.drainWG.Add(1)
336334
a.wg.Add(1)
337335
go func() {
338-
if err := a.handleDrainChannels(); err != nil {
336+
drainWGDone := sync.OnceFunc(a.drainWG.Done)
337+
if err := a.handleDrainChannels(drainWGDone); err != nil {
339338
a.log.WithError(err).Debug("Failed to handle drainable channels.")
340339
}
340+
drainWGDone()
341341
a.wg.Done()
342342
a.Stop()
343343
}()
@@ -407,9 +407,9 @@ func (a *agent) sendFirstHeartbeat(ctx context.Context) error {
407407
return trace.Wrap(err)
408408
}
409409
sshutils.DiscardChannelData(channel)
410+
go ssh.DiscardRequests(requests)
410411

411412
a.hbChannel = channel
412-
a.hbRequests = requests
413413

414414
// Send the first ping right away.
415415
if _, err := a.hbChannel.SendRequest(ctx, "ping", false, nil); err != nil {
@@ -450,9 +450,8 @@ func (a *agent) Stop() error {
450450
func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.Request) error {
451451
for {
452452
select {
453-
case r := <-requests:
454-
// The request will be nil when the request channel is closing.
455-
if r == nil {
453+
case r, ok := <-requests:
454+
if !ok {
456455
return trace.Errorf("global request channel is closing")
457456
}
458457

@@ -497,59 +496,29 @@ func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.R
497496
}
498497
}
499498

500-
func (a *agent) isDraining() bool {
501-
return a.drainCtx.Err() != nil
502-
}
503-
504-
// signalDraining will signal one time when the draining context is canceled.
505-
func (a *agent) signalDraining() <-chan struct{} {
506-
c := make(chan struct{})
507-
a.wg.Add(1)
508-
go func() {
509-
<-a.drainCtx.Done()
510-
close(c)
511-
a.wg.Done()
512-
}()
513-
514-
return c
515-
}
516-
517499
// handleDrainChannels handles channels that should be stopped when the agent is draining.
518-
func (a *agent) handleDrainChannels() error {
500+
func (a *agent) handleDrainChannels(drainWGDone func()) error {
519501
ticker := time.NewTicker(a.keepAlive)
520502
defer ticker.Stop()
521503

522-
// once ensures drainWG.Done() is called one more time
523-
// after no more transports will be created.
524-
once := &sync.Once{}
525-
drainWGDone := func() {
526-
once.Do(func() {
527-
a.drainWG.Done()
528-
})
529-
}
530-
defer drainWGDone()
531-
drainSignal := a.signalDraining()
504+
drainCtxDone := a.drainCtx.Done()
532505

533506
for {
534-
if a.isDraining() {
535-
drainWGDone()
536-
}
537-
538507
select {
539508
case <-a.ctx.Done():
540509
return nil
541-
// Signal once when the drain context is canceled to ensure we unblock
542-
// to call drainWG.Done().
543-
case <-drainSignal:
544-
continue
545-
// Handle closed heartbeat channel.
546-
case req := <-a.hbRequests:
547-
if req == nil {
548-
return trace.ConnectionProblem(nil, "heartbeat: connection closed")
549-
}
510+
case <-drainCtxDone:
511+
// we synchronously do this here rather than using
512+
// [context.AfterFunc] so we don't accidentally increase drainWG
513+
// from 0 while something else might already be waiting
514+
drainWGDone()
515+
// don't re-enter this case of the select
516+
drainCtxDone = nil
517+
// for good measure
518+
ticker.Stop()
550519
// Send ping over heartbeat channel.
551520
case <-ticker.C:
552-
if a.isDraining() {
521+
if a.drainCtx.Err() != nil {
553522
continue
554523
}
555524
bytes, _ := a.clock.Now().UTC().MarshalText()
@@ -560,11 +529,16 @@ func (a *agent) handleDrainChannels() error {
560529
}
561530
a.log.Debugf("Ping -> %v.", a.client.RemoteAddr())
562531
// Handle transport requests.
563-
case nch := <-a.transportC:
564-
if nch == nil {
565-
continue
532+
case nch, ok := <-a.transportC:
533+
if !ok {
534+
return trace.ConnectionProblem(nil, "transport: connection closed")
566535
}
567-
if a.isDraining() {
536+
537+
// once drainWGDone is called we can't add to the drain waitgroup so
538+
// we have to reject transport requests beforehand; it gets called
539+
// in this loop after drainCtx is done, so checking for the context
540+
// error here is a stronger condition
541+
if a.drainCtx.Err() != nil {
568542
err := nch.Reject(ssh.ConnectionFailed, "agent connection is draining")
569543
if err != nil {
570544
a.log.WithError(err).Warningf("Failed to reject transport channel.")
@@ -597,9 +571,9 @@ func (a *agent) handleChannels() error {
597571
case <-a.ctx.Done():
598572
return nil
599573
// new discovery request channel
600-
case nch := <-a.discoveryC:
601-
if nch == nil {
602-
continue
574+
case nch, ok := <-a.discoveryC:
575+
if !ok {
576+
return nil
603577
}
604578
a.log.Debugf("Discovery request channel opened: %v.", nch.ChannelType())
605579
ch, req, err := nch.Accept()

0 commit comments

Comments
 (0)