diff --git a/context.go b/context.go index 5b795a0..3be538e 100644 --- a/context.go +++ b/context.go @@ -147,7 +147,7 @@ func (blipCtx *Context) GetCancelCtx() context.Context { return context.TODO() } -// DialOptions is used by DialConfig to oepn a BLIP connection. +// DialOptions is used by DialConfig to open a BLIP connection. type DialOptions struct { URL string HTTPClient *http.Client @@ -197,7 +197,8 @@ func (blipCtx *Context) DialConfig(opts *DialOptions) (*Sender, error) { incrReceiverGoroutines() defer decrReceiverGoroutines() - err := sender.receiver.receiveLoop() + var handlersStopped atomic.Bool + err := sender.receiver.receiveLoop(&handlersStopped) if err != nil { if isCloseError(err) { // lower log level for close @@ -224,6 +225,8 @@ type BlipWebsocketServer struct { blipCtx *Context ctx context.Context // Cancellable context to trigger server stop PostHandshakeCallback func(err error) + websockets map[*websocket.Conn]struct{} + handlersStopped atomic.Bool } var _ http.Handler = &BlipWebsocketServer{} @@ -231,7 +234,7 @@ var _ http.Handler = &BlipWebsocketServer{} // Creates an HTTP handler that accepts WebSocket connections and dispatches BLIP messages // to the Context. func (blipCtx *Context) WebSocketServer() *BlipWebsocketServer { - return &BlipWebsocketServer{blipCtx: blipCtx} + return &BlipWebsocketServer{blipCtx: blipCtx, websockets: make(map[*websocket.Conn]struct{})} } func (bwss *BlipWebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -239,7 +242,9 @@ func (bwss *BlipWebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Reques if err != nil { return } + bwss.websockets[ws] = struct{}{} bwss.handle(ws) + delete(bwss.websockets, ws) } func (bwss *BlipWebsocketServer) handshake(w http.ResponseWriter, r *http.Request) (conn *websocket.Conn, err error) { @@ -275,9 +280,12 @@ func (bwss *BlipWebsocketServer) handshake(w http.ResponseWriter, r *http.Reques func (bwss *BlipWebsocketServer) handle(ws *websocket.Conn) { bwss.blipCtx.log("Start BLIP/Websocket handler") sender := bwss.blipCtx.start(ws) - err := sender.receiver.receiveLoop() + err := sender.receiver.receiveLoop(&bwss.handlersStopped) sender.Stop() - if err != nil && !isCloseError(err) { + // if handlerStopped is true, it means the handler was stopped by StopHandler + if bwss.handlersStopped.Load() { + return + } else if err != nil && !isCloseError(err) { bwss.blipCtx.log("BLIP/Websocket Handler exited with error: %v", err) if bwss.blipCtx.FatalErrorHandler != nil { bwss.blipCtx.FatalErrorHandler(err) @@ -286,6 +294,23 @@ func (bwss *BlipWebsocketServer) handle(ws *websocket.Conn) { ws.Close(websocket.StatusNormalClosure, "") } +func (bwss *BlipWebsocketServer) StopHandlers(status websocket.StatusCode) error { + bwss.handlersStopped.Store(true) + fmt.Printf("Closing websocket connection with status: %v\n", status) + var errs []error + for ws := range bwss.websockets { + fmt.Printf("Closing websocket connection\n") + err := ws.Close(status, "") + if err != nil { + errs = append(errs, err) + } + } + if len(errs) > 0 { + return fmt.Errorf("errors closing websockets: %w", errors.Join(errs...)) + } + return nil +} + //////// DISPATCHING MESSAGES: func (blipCtx *Context) dispatchRequest(request *Message, sender *Sender) { diff --git a/context_test.go b/context_test.go index 5009233..2eb1108 100644 --- a/context_test.go +++ b/context_test.go @@ -728,6 +728,78 @@ func assertHandlerNoError(t *testing.T, server *BlipWebsocketServer, wg *sync.Wa } } +// TestWebSocketServerStopHandler tests stopping the handler with a specific error code. +func TestWebSocketServerStopHandler(t *testing.T) { + + opts := ContextOptions{ + ProtocolIds: []string{BlipTestAppProtocolId}, + } + blipContextEchoServer, err := NewContext(opts) + require.NoError(t, err) + + receivedRequests := sync.WaitGroup{} + + // ----------------- Setup Echo Server that will be closed via cancellation context ------------------------- + + // Create a blip profile handler to respond to echo requests + dispatchEcho := func(request *Message) { + defer receivedRequests.Done() + body, err := request.Body() + require.NoError(t, err) + require.Equal(t, "application/octet-stream", request.Properties["Content-Type"]) + if response := request.Response(); response != nil { + response.SetBody(body) + response.Properties["Content-Type"] = request.Properties["Content-Type"] + } + } + + // Blip setup + blipContextEchoServer.HandlerForProfile["BLIPTest/EchoData"] = dispatchEcho + blipContextEchoServer.LogMessages = true + blipContextEchoServer.LogFrames = true + + // Websocket Server + server := blipContextEchoServer.WebSocketServer() + + // HTTP Handler wrapping websocket server + http.Handle("/TestServerContextClose", server) + listener, err := net.Listen("tcp", ":0") + require.NoError(t, err) + defer listener.Close() + go func() { + _ = http.Serve(listener, nil) + }() + + // ----------------- Setup Echo Client ---------------------------------------- + blipContextEchoClient, err := NewContext(defaultContextOptions) + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + destUrl := fmt.Sprintf("ws://localhost:%d/TestServerContextClose", port) + sender, err := blipContextEchoClient.Dial(destUrl) + require.NoError(t, err) + + // Create echo request + echoResponseBody := []byte("hello") + echoRequest := NewRequest() + echoRequest.SetProfile("BLIPTest/EchoData") + echoRequest.Properties["Content-Type"] = "application/octet-stream" + echoRequest.SetBody(echoResponseBody) + receivedRequests.Add(1) + require.True(t, sender.Send(echoRequest)) + + // Read the echo response. Closed connection will result in empty response, as EOF message + // isn't currently returned by blip client + response := echoRequest.Response() + responseBody, err := response.Body() + require.NoError(t, err) + require.Equal(t, echoResponseBody, responseBody) + + fmt.Printf("Closing connection\n") + server.StopHandler(websocket.StatusAbnormalClosure) + //fmt.Printf("sender=%+v\n", sender.conn) + require.True(t, false) +} + // Wait for the WaitGroup, or return an error if the wg.Wait() doesn't return within timeout // TODO: this code is duplicated with code in Sync Gateway utilities_testing.go. Should be refactored to common repo. func WaitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) error { diff --git a/receiver.go b/receiver.go index 4c50053..9b7d17d 100644 --- a/receiver.go +++ b/receiver.go @@ -50,6 +50,7 @@ type receiver struct { pendingRequests msgStreamerMap // Unfinished REQ messages being assembled pendingResponses msgStreamerMap // Unfinished RES messages being assembled maxPendingResponseNumber MessageNumber // Largest RES # I've seen + stopped atomic.Bool // True if I've been stopped by the caller } func newReceiver(context *Context, conn *websocket.Conn) *receiver { @@ -64,7 +65,7 @@ func newReceiver(context *Context, conn *websocket.Conn) *receiver { } } -func (r *receiver) receiveLoop() error { +func (r *receiver) receiveLoop(handlerStopped *atomic.Bool) error { defer atomic.AddInt32(&r.activeGoroutines, -1) atomic.AddInt32(&r.activeGoroutines, 1) go r.parseLoop() @@ -75,7 +76,9 @@ func (r *receiver) receiveLoop() error { // Receive the next raw WebSocket frame: _, frame, err := r.conn.Read(r.context.GetCancelCtx()) if err != nil { - if isCloseError(err) { + if handlerStopped.Load() { + return nil + } else if isCloseError(err) { // lower log level for close r.context.logFrame("receiveLoop stopped: %v", err) } else if parseErr := errorFromChannel(r.parseError); parseErr != nil {