Skip to content

[v17] Clean up some request handling in the reverse tunnel agent #53281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 30 additions & 56 deletions lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,6 @@ type agent struct {
doneConnecting chan struct{}
// hbChannel is the channel heartbeats are sent over.
hbChannel *tracessh.Channel
// hbRequests are requests going over the heartbeat channel.
hbRequests <-chan *ssh.Request
// discoveryC receives new discovery channels.
discoveryC <-chan ssh.NewChannel
// transportC receives new tranport channels.
Expand Down Expand Up @@ -335,9 +333,11 @@ func (a *agent) Start(ctx context.Context) error {
a.drainWG.Add(1)
a.wg.Add(1)
go func() {
if err := a.handleDrainChannels(); err != nil {
drainWGDone := sync.OnceFunc(a.drainWG.Done)
if err := a.handleDrainChannels(drainWGDone); err != nil {
a.log.WithError(err).Debug("Failed to handle drainable channels.")
}
drainWGDone()
a.wg.Done()
a.Stop()
}()
Expand Down Expand Up @@ -407,9 +407,9 @@ func (a *agent) sendFirstHeartbeat(ctx context.Context) error {
return trace.Wrap(err)
}
sshutils.DiscardChannelData(channel)
go ssh.DiscardRequests(requests)

a.hbChannel = channel
a.hbRequests = requests

// Send the first ping right away.
if _, err := a.hbChannel.SendRequest(ctx, "ping", false, nil); err != nil {
Expand Down Expand Up @@ -450,9 +450,8 @@ func (a *agent) Stop() error {
func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.Request) error {
for {
select {
case r := <-requests:
// The request will be nil when the request channel is closing.
if r == nil {
case r, ok := <-requests:
if !ok {
return trace.Errorf("global request channel is closing")
}

Expand Down Expand Up @@ -497,59 +496,29 @@ func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.R
}
}

func (a *agent) isDraining() bool {
return a.drainCtx.Err() != nil
}

// signalDraining will signal one time when the draining context is canceled.
func (a *agent) signalDraining() <-chan struct{} {
c := make(chan struct{})
a.wg.Add(1)
go func() {
<-a.drainCtx.Done()
close(c)
a.wg.Done()
}()

return c
}

// handleDrainChannels handles channels that should be stopped when the agent is draining.
func (a *agent) handleDrainChannels() error {
func (a *agent) handleDrainChannels(drainWGDone func()) error {
ticker := time.NewTicker(a.keepAlive)
defer ticker.Stop()

// once ensures drainWG.Done() is called one more time
// after no more transports will be created.
once := &sync.Once{}
drainWGDone := func() {
once.Do(func() {
a.drainWG.Done()
})
}
defer drainWGDone()
drainSignal := a.signalDraining()
drainCtxDone := a.drainCtx.Done()

for {
if a.isDraining() {
drainWGDone()
}

select {
case <-a.ctx.Done():
return nil
// Signal once when the drain context is canceled to ensure we unblock
// to call drainWG.Done().
case <-drainSignal:
continue
// Handle closed heartbeat channel.
case req := <-a.hbRequests:
if req == nil {
return trace.ConnectionProblem(nil, "heartbeat: connection closed")
}
case <-drainCtxDone:
// we synchronously do this here rather than using
// [context.AfterFunc] so we don't accidentally increase drainWG
// from 0 while something else might already be waiting
drainWGDone()
// don't re-enter this case of the select
drainCtxDone = nil
// for good measure
ticker.Stop()
// Send ping over heartbeat channel.
case <-ticker.C:
if a.isDraining() {
if a.drainCtx.Err() != nil {
continue
}
bytes, _ := a.clock.Now().UTC().MarshalText()
Expand All @@ -560,11 +529,16 @@ func (a *agent) handleDrainChannels() error {
}
a.log.Debugf("Ping -> %v.", a.client.RemoteAddr())
// Handle transport requests.
case nch := <-a.transportC:
if nch == nil {
continue
case nch, ok := <-a.transportC:
if !ok {
return trace.ConnectionProblem(nil, "transport: connection closed")
}
if a.isDraining() {

// once drainWGDone is called we can't add to the drain waitgroup so
// we have to reject transport requests beforehand; it gets called
// in this loop after drainCtx is done, so checking for the context
// error here is a stronger condition
if a.drainCtx.Err() != nil {
err := nch.Reject(ssh.ConnectionFailed, "agent connection is draining")
if err != nil {
a.log.WithError(err).Warningf("Failed to reject transport channel.")
Expand Down Expand Up @@ -597,9 +571,9 @@ func (a *agent) handleChannels() error {
case <-a.ctx.Done():
return nil
// new discovery request channel
case nch := <-a.discoveryC:
if nch == nil {
continue
case nch, ok := <-a.discoveryC:
if !ok {
return nil
}
a.log.Debugf("Discovery request channel opened: %v.", nch.ChannelType())
ch, req, err := nch.Accept()
Expand Down
Loading