From c7fc5850ffdf65e97a39b080d73f1f23346bf52a Mon Sep 17 00:00:00 2001 From: xu0o0 Date: Wed, 27 Mar 2024 06:03:33 +0800 Subject: [PATCH] Gracefully shutdown of the websocket client (#213) Resolves #163 Co-authored-by: Evan Bradley <11745660+evan-bradley@users.noreply.github.com> Co-authored-by: Srikanth Chekuri Co-authored-by: Tigran Najaryan <4194920+tigrannajaryan@users.noreply.github.com> --- client/internal/wsreceiver.go | 20 ++- client/internal/wssender.go | 33 ++++- client/wsclient.go | 87 ++++++++----- client/wsclient_test.go | 221 ++++++++++++++++++++++++++++++++++ 4 files changed, 326 insertions(+), 35 deletions(-) diff --git a/client/internal/wsreceiver.go b/client/internal/wsreceiver.go index b53b54d3..634830d6 100644 --- a/client/internal/wsreceiver.go +++ b/client/internal/wsreceiver.go @@ -17,6 +17,9 @@ type wsReceiver struct { sender *WSSender callbacks types.Callbacks processor receivedProcessor + + // Indicates that the receiver has fully stopped. + stopped chan struct{} } // NewWSReceiver creates a new Receiver that uses WebSocket to receive @@ -36,18 +39,32 @@ func NewWSReceiver( sender: sender, callbacks: callbacks, processor: newReceivedProcessor(logger, callbacks, sender, clientSyncedState, packagesStateProvider, capabilities), + stopped: make(chan struct{}), } return w } -// ReceiverLoop runs the receiver loop. To stop the receiver cancel the context. +// Start starts the receiver loop. +func (r *wsReceiver) Start(ctx context.Context) { + go r.ReceiverLoop(ctx) +} + +// IsStopped returns a channel that's closed when the receiver is stopped. +func (r *wsReceiver) IsStopped() <-chan struct{} { + return r.stopped +} + +// ReceiverLoop runs the receiver loop. +// To stop the receiver cancel the context and close the websocket connection func (r *wsReceiver) ReceiverLoop(ctx context.Context) { type receivedMessage struct { message *protobufs.ServerToAgent err error } + defer func() { close(r.stopped) }() + for { select { case <-ctx.Done(): @@ -55,6 +72,7 @@ func (r *wsReceiver) ReceiverLoop(ctx context.Context) { default: result := make(chan receivedMessage, 1) + // To stop this goroutine, close the websocket connection go func() { var message protobufs.ServerToAgent err := r.receiveMessage(&message) diff --git a/client/internal/wssender.go b/client/internal/wssender.go index 40ac937b..4bff484a 100644 --- a/client/internal/wssender.go +++ b/client/internal/wssender.go @@ -2,6 +2,7 @@ package internal import ( "context" + "time" "github.com/gorilla/websocket" "google.golang.org/protobuf/proto" @@ -11,13 +12,19 @@ import ( "github.com/open-telemetry/opamp-go/protobufs" ) +const ( + defaultSendCloseMessageTimeout = 5 * time.Second +) + // WSSender implements the WebSocket client's sending portion of OpAMP protocol. type WSSender struct { SenderCommon conn *websocket.Conn logger types.Logger + // Indicates that the sender has fully stopped. stopped chan struct{} + err error } // NewSender creates a new Sender that uses WebSocket to send @@ -37,15 +44,22 @@ func (s *WSSender) Start(ctx context.Context, conn *websocket.Conn) error { // Run the sender in the background. s.stopped = make(chan struct{}) + s.err = nil go s.run(ctx) return err } -// WaitToStop blocks until the sender is stopped. To stop the sender cancel the context -// that was passed to Start(). -func (s *WSSender) WaitToStop() { - <-s.stopped +// IsStopped returns a channel that's closed when the sender is stopped. +func (s *WSSender) IsStopped() <-chan struct{} { + return s.stopped +} + +// StoppingErr returns an error if there was a problem with stopping the sender. +// If stopping was successful will return nil. +// StoppingErr() can be called only after IsStopped() is signalled. +func (s *WSSender) StoppingErr() error { + return s.err } func (s *WSSender) run(ctx context.Context) { @@ -56,6 +70,9 @@ out: s.sendNextMessage(ctx) case <-ctx.Done(): + if err := s.sendCloseMessage(); err != nil && err != websocket.ErrCloseSent { + s.err = err + } break out } } @@ -63,6 +80,14 @@ out: close(s.stopped) } +func (s *WSSender) sendCloseMessage() error { + return s.conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Normal closure"), + time.Now().Add(defaultSendCloseMessageTimeout), + ) +} + func (s *WSSender) sendNextMessage(ctx context.Context) error { msgToSend := s.nextMessage.PopPending() if msgToSend != nil && !proto.Equal(msgToSend, &protobufs.AgentToServer{}) { diff --git a/client/wsclient.go b/client/wsclient.go index 3fbf21f4..b017a44e 100644 --- a/client/wsclient.go +++ b/client/wsclient.go @@ -18,6 +18,10 @@ import ( "github.com/open-telemetry/opamp-go/protobufs" ) +const ( + defaultShutdownTimeout = 5 * time.Second +) + // wsClient is an OpAMP Client implementation for WebSocket transport. // See specification: https://github.com/open-telemetry/opamp-spec/blob/main/specification.md#websocket-transport type wsClient struct { @@ -40,6 +44,10 @@ type wsClient struct { // last non-nil internal error that was encountered in the conn retry loop, // currently used only for testing. lastInternalErr atomic.Pointer[error] + + // Network connection timeout used for the WebSocket closing handshake. + // This field is currently only modified during testing. + connShutdownTimeout time.Duration } // NewWebSocket creates a new OpAMP Client that uses WebSocket transport. @@ -50,8 +58,9 @@ func NewWebSocket(logger types.Logger) *wsClient { sender := internal.NewSender(logger) w := &wsClient{ - common: internal.NewClientCommon(logger, sender), - sender: sender, + common: internal.NewClientCommon(logger, sender), + sender: sender, + connShutdownTimeout: defaultShutdownTimeout, } return w } @@ -85,15 +94,6 @@ func (c *wsClient) Start(ctx context.Context, settings types.StartSettings) erro } func (c *wsClient) Stop(ctx context.Context) error { - // Close connection if any. - c.connMutex.RLock() - conn := c.conn - c.connMutex.RUnlock() - - if conn != nil { - _ = conn.Close() - } - return c.common.Stop(ctx) } @@ -232,19 +232,25 @@ func (c *wsClient) ensureConnected(ctx context.Context) error { // runOneCycle performs the following actions: // 1. connect (try until succeeds). // 2. send first status report. -// 3. receive and process messages until error happens. +// 3. start the sender to wait for scheduled messages and send them to the server. +// 4. start the receiver to receive and process messages until an error happens. +// 5. wait until both the sender and receiver are stopped. // -// If it encounters an error it closes the connection and returns. -// Will stop and return if Stop() is called (ctx is cancelled, isStopping is set). +// runOneCycle will close the connection it created before it return. +// +// When Stop() is called (ctx is cancelled, isStopping is set), wsClient will shutdown gracefully: +// 1. sender will be cancelled by the ctx, send the close message to server and return the error via sender.Err(). +// 2. runOneCycle will handle that error and wait for the close message from server until timeout. func (c *wsClient) runOneCycle(ctx context.Context) { if err := c.ensureConnected(ctx); err != nil { // Can't connect, so can't move forward. This currently happens when we // are being stopped. return } + // Close the underlying connection. + defer c.conn.Close() if c.common.IsStopping() { - _ = c.conn.Close() return } @@ -256,15 +262,14 @@ func (c *wsClient) runOneCycle(ctx context.Context) { } // Create a cancellable context for background processors. - procCtx, procCancel := context.WithCancel(ctx) + senderCtx, stopSender := context.WithCancel(ctx) + defer stopSender() // Connected successfully. Start the sender. This will also send the first // status report. - if err := c.sender.Start(procCtx, c.conn); err != nil { - c.common.Logger.Errorf(procCtx, "Failed to send first status report: %v", err) + if err := c.sender.Start(senderCtx, c.conn); err != nil { + c.common.Logger.Errorf(senderCtx, "Failed to send first status report: %v", err) // We could not send the report, the only thing we can do is start over. - _ = c.conn.Close() - procCancel() return } @@ -278,19 +283,41 @@ func (c *wsClient) runOneCycle(ctx context.Context) { c.common.PackagesStateProvider, c.common.Capabilities, ) - r.ReceiverLoop(ctx) - - // Stop the background processors. - procCancel() - // If we exited receiverLoop it means there is a connection error, we cannot - // read messages anymore. We need to start over. + // When the wsclient is closed, the context passed to runOneCycle will be canceled. + // The receiver should keep running and processing messages + // until it received a Close message from the server which means the server has no more messages. + receiverCtx, stopReceiver := context.WithCancel(context.Background()) + defer stopReceiver() + r.Start(receiverCtx) + + select { + case <-c.sender.IsStopped(): + // sender will send close message to initiate the close handshake + if err := c.sender.StoppingErr(); err != nil { + c.common.Logger.Debugf(ctx, "Error stopping the sender: %v", err) + + stopReceiver() + <-r.IsStopped() + break + } - // Close the connection to unblock the WSSender as well. - _ = c.conn.Close() + c.common.Logger.Debugf(ctx, "Waiting for receiver to stop.") + select { + case <-r.IsStopped(): + c.common.Logger.Debugf(ctx, "Receiver stopped.") + case <-time.After(c.connShutdownTimeout): + c.common.Logger.Debugf(ctx, "Timeout waiting for receiver to stop.") + stopReceiver() + <-r.IsStopped() + } + case <-r.IsStopped(): + // If we exited receiverLoop it means there is a connection error, we cannot + // read messages anymore. We need to start over. - // Wait for WSSender to stop. - c.sender.WaitToStop() + stopSender() + <-c.sender.IsStopped() + } } func (c *wsClient) runUntilStopped(ctx context.Context) { diff --git a/client/wsclient_test.go b/client/wsclient_test.go index d0b8b4e6..839463a4 100644 --- a/client/wsclient_test.go +++ b/client/wsclient_test.go @@ -9,9 +9,11 @@ import ( "strings" "sync/atomic" "testing" + "time" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "github.com/open-telemetry/opamp-go/client/internal" @@ -262,3 +264,222 @@ func TestRedirectWS(t *testing.T) { }) } } + +func TestHandlesStopBeforeStart(t *testing.T) { + client := NewWebSocket(nil) + require.Error(t, client.Stop(context.Background())) +} + +func TestPerformsClosingHandshake(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + closed := make(chan struct{}) + acked := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }) + + { + defhandler := client.conn.CloseHandler() + client.conn.SetCloseHandler(func(code int, msg string) error { + close(acked) + return defhandler(code, msg) + }) + } + + defHandler := wsConn.CloseHandler() + + wsConn.SetCloseHandler(func(code int, _ string) error { + require.Equal(t, websocket.CloseNormalClosure, code, "Client sent non-normal closing code") + + err := defHandler(code, "") + closed <- struct{}{} + return err + }) + + client.Stop(context.Background()) + + select { + case <-closed: + select { + case <-acked: + case <-time.After(2 * time.Second): + require.Fail(t, "Close connection without waiting for a close message from server") + } + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never closed") + } +} + +func TestHandlesSlowCloseMessageFromServer(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + closed := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + client.connShutdownTimeout = 100 * time.Millisecond + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + defHandler := wsConn.CloseHandler() + + wsConn.SetCloseHandler(func(code int, _ string) error { + require.Equal(t, websocket.CloseNormalClosure, code, "Client sent non-normal closing code") + + time.Sleep(200 * time.Millisecond) + err := defHandler(code, "") + closed <- struct{}{} + return err + }) + + client.Stop(context.Background()) + + select { + case <-closed: + case <-time.After(1 * time.Second): + require.Fail(t, "Connection never closed") + } +} + +func TestHandlesNoCloseMessageFromServer(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + closed := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + client.connShutdownTimeout = 100 * time.Millisecond + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + wsConn.SetCloseHandler(func(code int, _ string) error { + // Don't send close message + return nil + }) + + go func() { + client.Stop(context.Background()) + closed <- struct{}{} + }() + + select { + case <-closed: + case <-time.After(1 * time.Second): + require.Fail(t, "Connection never closed") + } +} + +func TestHandlesConnectionError(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + // Write an invalid message to the connection. The client + // will take this as an error and reconnect to the server. + writer, err := wsConn.NextWriter(websocket.BinaryMessage) + require.NoError(t, err) + n, err := writer.Write([]byte{99, 1, 2, 3, 4, 5}) + require.NoError(t, err) + require.Equal(t, 6, n) + err = writer.Close() + require.NoError(t, err) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never re-established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + err = client.Stop(context.Background()) + require.NoError(t, err) +}