diff --git a/pkg/revdial/revdial.go b/pkg/revdial/revdial.go index 8486b81c926..7ccd232020a 100644 --- a/pkg/revdial/revdial.go +++ b/pkg/revdial/revdial.go @@ -439,8 +439,11 @@ func (ln *Listener) grabConn(path string) { return } + c := wsconnadapter.New(wsConn) + c.Ping() + select { - case ln.connc <- wsconnadapter.New(wsConn): + case ln.connc <- c: case <-ln.donec: } } diff --git a/pkg/wsconnadapter/wsconnadapter.go b/pkg/wsconnadapter/wsconnadapter.go index e43c481ea97..c95915eb571 100644 --- a/pkg/wsconnadapter/wsconnadapter.go +++ b/pkg/wsconnadapter/wsconnadapter.go @@ -31,6 +31,9 @@ type Adapter struct { reader io.Reader stopPingCh chan struct{} pongCh chan bool + pingOnce sync.Once + closeOnce sync.Once + closeErr error Logger *log.Entry CreatedAt time.Time } @@ -72,53 +75,49 @@ func New(conn *websocket.Conn, options ...Option) *Adapter { } func (a *Adapter) Ping() chan bool { - if a.pongCh != nil { - a.Logger.Debug("pong channel is not null") + a.pingOnce.Do(func() { + a.stopPingCh = make(chan struct{}) + a.pongCh = make(chan bool) - return a.pongCh - } - - a.stopPingCh = make(chan struct{}) - a.pongCh = make(chan bool) - - timeout := time.AfterFunc(pongTimeout, func() { - a.Logger.Debug("close connection due pong timeout") - - _ = a.Close() - }) + timeout := time.AfterFunc(pongTimeout, func() { + a.Logger.Debug("close connection due pong timeout") - a.conn.SetPongHandler(func(_ string) error { - timeout.Reset(pongTimeout) - a.Logger.Trace("pong timeout") + _ = a.Close() + }) - // non-blocking channel write - select { - case a.pongCh <- true: - a.Logger.Trace("write true to pong channel") - default: - } + a.conn.SetPongHandler(func(_ string) error { + timeout.Reset(pongTimeout) + a.Logger.Trace("pong timeout") - return nil - }) + // non-blocking channel write + select { + case a.pongCh <- true: + a.Logger.Trace("write true to pong channel") + default: + } - // ping loop - go func() { - ticker := time.NewTicker(pingInterval) - defer ticker.Stop() + return nil + }) - for { - select { - case <-ticker.C: - if err := a.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil { - a.Logger.WithError(err).Error("failed to write ping message") + // ping loop + go func() { + ticker := time.NewTicker(pingInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := a.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil { + a.Logger.WithError(err).Error("failed to write ping message") + } + case <-a.stopPingCh: + a.Logger.Debug("stop ping message received") + + return } - case <-a.stopPingCh: - a.Logger.Debug("stop ping message received") - - return } - } - }() + }() + }) return a.pongCh } @@ -182,19 +181,16 @@ func (a *Adapter) Write(b []byte) (int, error) { } func (a *Adapter) Close() error { - select { - case <-a.stopPingCh: - a.Logger.Debug("stop ping message received") - default: + a.closeOnce.Do(func() { if a.stopPingCh != nil { - a.stopPingCh <- struct{}{} close(a.stopPingCh) - a.Logger.Debug("stop ping channel closed") } - } - return a.conn.Close() + a.closeErr = a.conn.Close() + }) + + return a.closeErr } func (a *Adapter) LocalAddr() net.Addr { diff --git a/tests/ssh_test.go b/tests/ssh_test.go index ed835740d2b..f39eeeb1d88 100644 --- a/tests/ssh_test.go +++ b/tests/ssh_test.go @@ -981,7 +981,7 @@ func testSSHWithVersion(t *testing.T, connectionVersion int) { // interrupted, so each attempt runs in a goroutine with a // per-attempt timeout. On timeout we abort immediately // (the deferred sess/conn Close unblocks the reader). - deadline := time.Now().Add(5 * time.Second) + deadline := time.Now().Add(30 * time.Second) var lastOutput string matched := false for attempt := 0; !matched && time.Now().Before(deadline); attempt++ { @@ -1002,7 +1002,7 @@ func testSSHWithVersion(t *testing.T, connectionVersion int) { require.NoError(t, r.err) lastOutput = r.output matched = (r.output == expected) - case <-time.After(2 * time.Second): + case <-time.After(10 * time.Second): require.Fail(t, "timeout reading stty output", "marker=%s expected=%s", marker, expected) }