Skip to content

Commit 5db84e5

Browse files
authored
finalize context correctly for stream requests (#1873)
1 parent 8bf6755 commit 5db84e5

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

pkg/wshutil/wshrpc.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ type rpcData struct {
186186
Command string
187187
Route string
188188
ResCh chan *RpcMessage
189-
Ctx context.Context
189+
Handler *RpcRequestHandler
190190
}
191191

192192
func validateServerImpl(serverImpl ServerImpl) {
@@ -405,21 +405,21 @@ func (w *WshRpc) SetServerImpl(serverImpl ServerImpl) {
405405
w.ServerImpl = serverImpl
406406
}
407407

408-
func (w *WshRpc) registerRpc(ctx context.Context, command string, route string, reqId string) chan *RpcMessage {
408+
func (w *WshRpc) registerRpc(handler *RpcRequestHandler, command string, route string, reqId string) chan *RpcMessage {
409409
w.Lock.Lock()
410410
defer w.Lock.Unlock()
411411
rpcCh := make(chan *RpcMessage, RespChSize)
412412
w.RpcMap[reqId] = &rpcData{
413+
Handler: handler,
413414
Command: command,
414415
Route: route,
415416
ResCh: rpcCh,
416-
Ctx: ctx,
417417
}
418418
go func() {
419419
defer func() {
420420
panichandler.PanicHandler("registerRpc:timeout", recover())
421421
}()
422-
<-ctx.Done()
422+
<-handler.ctx.Done()
423423
w.retrySendTimeout(reqId)
424424
}()
425425
return rpcCh
@@ -447,6 +447,7 @@ func (w *WshRpc) unregisterRpc(reqId string, err error) {
447447
}
448448
delete(w.RpcMap, reqId)
449449
close(rd.ResCh)
450+
rd.Handler.callContextCancelFn()
450451
}
451452

452453
// no response
@@ -541,16 +542,19 @@ func (handler *RpcRequestHandler) NextResponse() (any, error) {
541542
}
542543

543544
func (handler *RpcRequestHandler) finalize() {
544-
cancelFnPtr := handler.ctxCancelFn.Load()
545-
if cancelFnPtr != nil && *cancelFnPtr != nil {
546-
(*cancelFnPtr)()
547-
handler.ctxCancelFn.Store(nil)
548-
}
545+
handler.callContextCancelFn()
549546
if handler.reqId != "" {
550547
handler.w.unregisterRpc(handler.reqId, nil)
551548
}
552549
}
553550

551+
func (handler *RpcRequestHandler) callContextCancelFn() {
552+
cancelFnPtr := handler.ctxCancelFn.Swap(nil)
553+
if cancelFnPtr != nil && *cancelFnPtr != nil {
554+
(*cancelFnPtr)()
555+
}
556+
}
557+
554558
type RpcResponseHandler struct {
555559
w *WshRpc
556560
ctx context.Context
@@ -710,7 +714,7 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp
710714
if err != nil {
711715
return nil, err
712716
}
713-
handler.respCh = w.registerRpc(handler.ctx, command, opts.Route, handler.reqId)
717+
handler.respCh = w.registerRpc(handler, command, opts.Route, handler.reqId)
714718
w.OutputCh <- barr
715719
return handler, nil
716720
}

0 commit comments

Comments
 (0)