From 3342f4c1f943a201e6bd64b998772579a89fff51 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 21 Mar 2025 00:31:14 +0100 Subject: [PATCH] Clean up some request handling in the reverse tunnel agent (#53200) --- lib/reversetunnel/agent.go | 86 +++++++++++++------------------------- 1 file changed, 30 insertions(+), 56 deletions(-) diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index da41578411067..e52e3982ad00b 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -171,8 +171,6 @@ type agent struct { doneConnecting chan struct{} // hbChannel is the channel heartbeats are sent over. hbChannel *tracessh.Channel - // hbRequests are requests going over the heartbeat channel. - hbRequests <-chan *ssh.Request // discoveryC receives new discovery channels. discoveryC <-chan ssh.NewChannel // transportC receives new tranport channels. @@ -335,9 +333,11 @@ func (a *agent) Start(ctx context.Context) error { a.drainWG.Add(1) a.wg.Add(1) go func() { - if err := a.handleDrainChannels(); err != nil { + drainWGDone := sync.OnceFunc(a.drainWG.Done) + if err := a.handleDrainChannels(drainWGDone); err != nil { a.log.WithError(err).Debug("Failed to handle drainable channels.") } + drainWGDone() a.wg.Done() a.Stop() }() @@ -407,9 +407,9 @@ func (a *agent) sendFirstHeartbeat(ctx context.Context) error { return trace.Wrap(err) } sshutils.DiscardChannelData(channel) + go ssh.DiscardRequests(requests) a.hbChannel = channel - a.hbRequests = requests // Send the first ping right away. if _, err := a.hbChannel.SendRequest(ctx, "ping", false, nil); err != nil { @@ -450,9 +450,8 @@ func (a *agent) Stop() error { func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.Request) error { for { select { - case r := <-requests: - // The request will be nil when the request channel is closing. - if r == nil { + case r, ok := <-requests: + if !ok { return trace.Errorf("global request channel is closing") } @@ -497,59 +496,29 @@ func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.R } } -func (a *agent) isDraining() bool { - return a.drainCtx.Err() != nil -} - -// signalDraining will signal one time when the draining context is canceled. -func (a *agent) signalDraining() <-chan struct{} { - c := make(chan struct{}) - a.wg.Add(1) - go func() { - <-a.drainCtx.Done() - close(c) - a.wg.Done() - }() - - return c -} - // handleDrainChannels handles channels that should be stopped when the agent is draining. -func (a *agent) handleDrainChannels() error { +func (a *agent) handleDrainChannels(drainWGDone func()) error { ticker := time.NewTicker(a.keepAlive) defer ticker.Stop() - // once ensures drainWG.Done() is called one more time - // after no more transports will be created. - once := &sync.Once{} - drainWGDone := func() { - once.Do(func() { - a.drainWG.Done() - }) - } - defer drainWGDone() - drainSignal := a.signalDraining() + drainCtxDone := a.drainCtx.Done() for { - if a.isDraining() { - drainWGDone() - } - select { case <-a.ctx.Done(): return nil - // Signal once when the drain context is canceled to ensure we unblock - // to call drainWG.Done(). - case <-drainSignal: - continue - // Handle closed heartbeat channel. - case req := <-a.hbRequests: - if req == nil { - return trace.ConnectionProblem(nil, "heartbeat: connection closed") - } + case <-drainCtxDone: + // we synchronously do this here rather than using + // [context.AfterFunc] so we don't accidentally increase drainWG + // from 0 while something else might already be waiting + drainWGDone() + // don't re-enter this case of the select + drainCtxDone = nil + // for good measure + ticker.Stop() // Send ping over heartbeat channel. case <-ticker.C: - if a.isDraining() { + if a.drainCtx.Err() != nil { continue } bytes, _ := a.clock.Now().UTC().MarshalText() @@ -560,11 +529,16 @@ func (a *agent) handleDrainChannels() error { } a.log.Debugf("Ping -> %v.", a.client.RemoteAddr()) // Handle transport requests. - case nch := <-a.transportC: - if nch == nil { - continue + case nch, ok := <-a.transportC: + if !ok { + return trace.ConnectionProblem(nil, "transport: connection closed") } - if a.isDraining() { + + // once drainWGDone is called we can't add to the drain waitgroup so + // we have to reject transport requests beforehand; it gets called + // in this loop after drainCtx is done, so checking for the context + // error here is a stronger condition + if a.drainCtx.Err() != nil { err := nch.Reject(ssh.ConnectionFailed, "agent connection is draining") if err != nil { a.log.WithError(err).Warningf("Failed to reject transport channel.") @@ -597,9 +571,9 @@ func (a *agent) handleChannels() error { case <-a.ctx.Done(): return nil // new discovery request channel - case nch := <-a.discoveryC: - if nch == nil { - continue + case nch, ok := <-a.discoveryC: + if !ok { + return nil } a.log.Debugf("Discovery request channel opened: %v.", nch.ChannelType()) ch, req, err := nch.Accept()