Skip to content

Commit 00eab7b

Browse files
committed
Example chanegs for closing websocket with an error code
1 parent 13a798c commit 00eab7b

File tree

3 files changed

+107
-7
lines changed

3 files changed

+107
-7
lines changed

context.go

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func (blipCtx *Context) GetCancelCtx() context.Context {
147147
return context.TODO()
148148
}
149149

150-
// DialOptions is used by DialConfig to oepn a BLIP connection.
150+
// DialOptions is used by DialConfig to open a BLIP connection.
151151
type DialOptions struct {
152152
URL string
153153
HTTPClient *http.Client
@@ -197,7 +197,8 @@ func (blipCtx *Context) DialConfig(opts *DialOptions) (*Sender, error) {
197197
incrReceiverGoroutines()
198198
defer decrReceiverGoroutines()
199199

200-
err := sender.receiver.receiveLoop()
200+
var handlersStopped atomic.Bool
201+
err := sender.receiver.receiveLoop(&handlersStopped)
201202
if err != nil {
202203
if isCloseError(err) {
203204
// lower log level for close
@@ -224,22 +225,26 @@ type BlipWebsocketServer struct {
224225
blipCtx *Context
225226
ctx context.Context // Cancellable context to trigger server stop
226227
PostHandshakeCallback func(err error)
228+
websockets map[*websocket.Conn]struct{}
229+
handlersStopped atomic.Bool
227230
}
228231

229232
var _ http.Handler = &BlipWebsocketServer{}
230233

231234
// Creates an HTTP handler that accepts WebSocket connections and dispatches BLIP messages
232235
// to the Context.
233236
func (blipCtx *Context) WebSocketServer() *BlipWebsocketServer {
234-
return &BlipWebsocketServer{blipCtx: blipCtx}
237+
return &BlipWebsocketServer{blipCtx: blipCtx, websockets: make(map[*websocket.Conn]struct{})}
235238
}
236239

237240
func (bwss *BlipWebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
238241
ws, err := bwss.handshake(w, r)
239242
if err != nil {
240243
return
241244
}
245+
bwss.websockets[ws] = struct{}{}
242246
bwss.handle(ws)
247+
delete(bwss.websockets, ws)
243248
}
244249

245250
func (bwss *BlipWebsocketServer) handshake(w http.ResponseWriter, r *http.Request) (conn *websocket.Conn, err error) {
@@ -275,9 +280,12 @@ func (bwss *BlipWebsocketServer) handshake(w http.ResponseWriter, r *http.Reques
275280
func (bwss *BlipWebsocketServer) handle(ws *websocket.Conn) {
276281
bwss.blipCtx.log("Start BLIP/Websocket handler")
277282
sender := bwss.blipCtx.start(ws)
278-
err := sender.receiver.receiveLoop()
283+
err := sender.receiver.receiveLoop(&bwss.handlersStopped)
279284
sender.Stop()
280-
if err != nil && !isCloseError(err) {
285+
// if handlerStopped is true, it means the handler was stopped by StopHandler
286+
if bwss.handlersStopped.Load() {
287+
return
288+
} else if err != nil && !isCloseError(err) {
281289
bwss.blipCtx.log("BLIP/Websocket Handler exited with error: %v", err)
282290
if bwss.blipCtx.FatalErrorHandler != nil {
283291
bwss.blipCtx.FatalErrorHandler(err)
@@ -286,6 +294,23 @@ func (bwss *BlipWebsocketServer) handle(ws *websocket.Conn) {
286294
ws.Close(websocket.StatusNormalClosure, "")
287295
}
288296

297+
func (bwss *BlipWebsocketServer) StopHandlers(status websocket.StatusCode) error {
298+
bwss.handlersStopped.Store(true)
299+
fmt.Printf("Closing websocket connection with status: %v\n", status)
300+
var errs []error
301+
for ws := range bwss.websockets {
302+
fmt.Printf("Closing websocket connection\n")
303+
err := ws.Close(status, "")
304+
if err != nil {
305+
errs = append(errs, err)
306+
}
307+
}
308+
if len(errs) > 0 {
309+
return fmt.Errorf("errors closing websockets: %w", errors.Join(errs...))
310+
}
311+
return nil
312+
}
313+
289314
//////// DISPATCHING MESSAGES:
290315

291316
func (blipCtx *Context) dispatchRequest(request *Message, sender *Sender) {

context_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,78 @@ func assertHandlerNoError(t *testing.T, server *BlipWebsocketServer, wg *sync.Wa
728728
}
729729
}
730730

731+
// TestWebSocketServerStopHandler tests stopping the handler with a specific error code.
732+
func TestWebSocketServerStopHandler(t *testing.T) {
733+
734+
opts := ContextOptions{
735+
ProtocolIds: []string{BlipTestAppProtocolId},
736+
}
737+
blipContextEchoServer, err := NewContext(opts)
738+
require.NoError(t, err)
739+
740+
receivedRequests := sync.WaitGroup{}
741+
742+
// ----------------- Setup Echo Server that will be closed via cancellation context -------------------------
743+
744+
// Create a blip profile handler to respond to echo requests
745+
dispatchEcho := func(request *Message) {
746+
defer receivedRequests.Done()
747+
body, err := request.Body()
748+
require.NoError(t, err)
749+
require.Equal(t, "application/octet-stream", request.Properties["Content-Type"])
750+
if response := request.Response(); response != nil {
751+
response.SetBody(body)
752+
response.Properties["Content-Type"] = request.Properties["Content-Type"]
753+
}
754+
}
755+
756+
// Blip setup
757+
blipContextEchoServer.HandlerForProfile["BLIPTest/EchoData"] = dispatchEcho
758+
blipContextEchoServer.LogMessages = true
759+
blipContextEchoServer.LogFrames = true
760+
761+
// Websocket Server
762+
server := blipContextEchoServer.WebSocketServer()
763+
764+
// HTTP Handler wrapping websocket server
765+
http.Handle("/TestServerContextClose", server)
766+
listener, err := net.Listen("tcp", ":0")
767+
require.NoError(t, err)
768+
defer listener.Close()
769+
go func() {
770+
_ = http.Serve(listener, nil)
771+
}()
772+
773+
// ----------------- Setup Echo Client ----------------------------------------
774+
blipContextEchoClient, err := NewContext(defaultContextOptions)
775+
require.NoError(t, err)
776+
port := listener.Addr().(*net.TCPAddr).Port
777+
destUrl := fmt.Sprintf("ws://localhost:%d/TestServerContextClose", port)
778+
sender, err := blipContextEchoClient.Dial(destUrl)
779+
require.NoError(t, err)
780+
781+
// Create echo request
782+
echoResponseBody := []byte("hello")
783+
echoRequest := NewRequest()
784+
echoRequest.SetProfile("BLIPTest/EchoData")
785+
echoRequest.Properties["Content-Type"] = "application/octet-stream"
786+
echoRequest.SetBody(echoResponseBody)
787+
receivedRequests.Add(1)
788+
require.True(t, sender.Send(echoRequest))
789+
790+
// Read the echo response. Closed connection will result in empty response, as EOF message
791+
// isn't currently returned by blip client
792+
response := echoRequest.Response()
793+
responseBody, err := response.Body()
794+
require.NoError(t, err)
795+
require.Equal(t, echoResponseBody, responseBody)
796+
797+
fmt.Printf("Closing connection\n")
798+
server.StopHandler(websocket.StatusAbnormalClosure)
799+
//fmt.Printf("sender=%+v\n", sender.conn)
800+
require.True(t, false)
801+
}
802+
731803
// Wait for the WaitGroup, or return an error if the wg.Wait() doesn't return within timeout
732804
// TODO: this code is duplicated with code in Sync Gateway utilities_testing.go. Should be refactored to common repo.
733805
func WaitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) error {

receiver.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ type receiver struct {
5050
pendingRequests msgStreamerMap // Unfinished REQ messages being assembled
5151
pendingResponses msgStreamerMap // Unfinished RES messages being assembled
5252
maxPendingResponseNumber MessageNumber // Largest RES # I've seen
53+
stopped atomic.Bool // True if I've been stopped by the caller
5354
}
5455

5556
func newReceiver(context *Context, conn *websocket.Conn) *receiver {
@@ -64,7 +65,7 @@ func newReceiver(context *Context, conn *websocket.Conn) *receiver {
6465
}
6566
}
6667

67-
func (r *receiver) receiveLoop() error {
68+
func (r *receiver) receiveLoop(handlerStopped *atomic.Bool) error {
6869
defer atomic.AddInt32(&r.activeGoroutines, -1)
6970
atomic.AddInt32(&r.activeGoroutines, 1)
7071
go r.parseLoop()
@@ -75,7 +76,9 @@ func (r *receiver) receiveLoop() error {
7576
// Receive the next raw WebSocket frame:
7677
_, frame, err := r.conn.Read(r.context.GetCancelCtx())
7778
if err != nil {
78-
if isCloseError(err) {
79+
if handlerStopped.Load() {
80+
return nil
81+
} else if isCloseError(err) {
7982
// lower log level for close
8083
r.context.logFrame("receiveLoop stopped: %v", err)
8184
} else if parseErr := errorFromChannel(r.parseError); parseErr != nil {

0 commit comments

Comments
 (0)