From 5a2b1ded016d83741a02b9defd139b8a48688f4a Mon Sep 17 00:00:00 2001 From: ZhangJian He Date: Tue, 17 Sep 2024 10:24:51 +0800 Subject: [PATCH] fix: protocol client deadlock when send request after close Signed-off-by: ZhangJian He --- opcua/client.go | 154 ++++++++++++++++++++++++------------------------ opcua/error.go | 7 +++ 2 files changed, 85 insertions(+), 76 deletions(-) create mode 100644 opcua/error.go diff --git a/opcua/client.go b/opcua/client.go index 69b6df6..c4bf071 100644 --- a/opcua/client.go +++ b/opcua/client.go @@ -1,6 +1,7 @@ package opcua import ( + "context" "crypto/tls" "encoding/binary" "fmt" @@ -34,7 +35,8 @@ type Client struct { eventsChan chan *sendRequest pendingQueue chan *sendRequest buffer *buffer.Buffer - closeCh chan struct{} + ctx context.Context + ctxCancel context.CancelFunc } func (c *Client) Hello(message *MessageHello) (*MessageAcknowledge, error) { @@ -67,95 +69,96 @@ func (c *Client) Send(buf *buffer.Buffer) (*buffer.Buffer, error) { } func (c *Client) sendAsync(buf *buffer.Buffer, callback func(*buffer.Buffer, error)) { - sr := &sendRequest{ - buf: buf, - callback: callback, + select { + case <-c.ctx.Done(): + callback(nil, ErrClientClosed) + default: + sr := &sendRequest{ + buf: buf, + callback: callback, + } + c.eventsChan <- sr } - c.eventsChan <- sr } func (c *Client) read() { - for { - select { - case req := <-c.pendingQueue: - n, err := c.conn.Read(c.buffer.WritableSlice()) - if err != nil { - req.callback(nil, err) - c.closeCh <- struct{}{} - break - } - err = c.buffer.AdjustWriteCursor(n) - if err != nil { - req.callback(nil, err) - c.closeCh <- struct{}{} - break - } - if c.buffer.ReadableSize() < 8 { - continue - } - bytes := make([]byte, 8) - err = c.buffer.PeekExactly(bytes) - if err != nil { - req.callback(nil, err) - c.closeCh <- struct{}{} - break - } - length := int(binary.LittleEndian.Uint32(bytes[4:8])) - if c.buffer.ReadableSize() < length { - continue - } - // in case ddos attack - if length > c.buffer.Capacity() { - req.callback(nil, fmt.Errorf("response length %d is too large", length)) - c.closeCh <- struct{}{} - break - } - data := make([]byte, length) - err = c.buffer.ReadExactly(data) - if err != nil { - req.callback(nil, err) - c.closeCh <- struct{}{} - break - } - c.buffer.Compact() - req.callback(buffer.NewBufferFromBytes(data), nil) - case <-c.closeCh: - return + for req := range c.pendingQueue { + n, err := c.conn.Read(c.buffer.WritableSlice()) + if err != nil { + req.callback(nil, err) + c.close() + break + } + err = c.buffer.AdjustWriteCursor(n) + if err != nil { + req.callback(nil, err) + c.close() + break + } + if c.buffer.ReadableSize() < 8 { + continue + } + bytes := make([]byte, 8) + err = c.buffer.PeekExactly(bytes) + if err != nil { + req.callback(nil, err) + c.close() + break + } + length := int(binary.LittleEndian.Uint32(bytes[4:8])) + if c.buffer.ReadableSize() < length { + continue + } + // in case ddos attack + if length > c.buffer.Capacity() { + req.callback(nil, fmt.Errorf("response length %d is too large", length)) + c.close() + break } + data := make([]byte, length) + err = c.buffer.ReadExactly(data) + if err != nil { + req.callback(nil, err) + c.close() + break + } + c.buffer.Compact() + req.callback(buffer.NewBufferFromBytes(data), nil) } } func (c *Client) write() { - for { - select { - case req := <-c.eventsChan: - bytes, err := req.buf.ReadNBytes(req.buf.ReadableSize()) - if err != nil { - req.callback(nil, err) - c.closeCh <- struct{}{} - break - } - n, err := c.conn.Write(bytes) - if err != nil { - req.callback(nil, err) - c.closeCh <- struct{}{} - break - } - if n != len(bytes) { - req.callback(nil, fmt.Errorf("write %d bytes, but expect %d bytes", n, len(bytes))) - c.closeCh <- struct{}{} - break - } - c.pendingQueue <- req - case <-c.closeCh: - return + for req := range c.eventsChan { + bytes, err := req.buf.ReadNBytes(req.buf.ReadableSize()) + if err != nil { + req.callback(nil, err) + c.close() + break + } + n, err := c.conn.Write(bytes) + if err != nil { + req.callback(nil, err) + c.close() + break + } + if n != len(bytes) { + req.callback(nil, fmt.Errorf("write %d bytes, but expect %d bytes", n, len(bytes))) + c.close() + break } + c.pendingQueue <- req } } +func (c *Client) close() { + c.Close() +} + func (c *Client) Close() { + c.ctxCancel() _ = c.conn.Close() - c.closeCh <- struct{}{} + close(c.eventsChan) + close(c.pendingQueue) } func NewClient(config *ClientConfig) (*Client, error) { @@ -182,7 +185,6 @@ func NewClient(config *ClientConfig) (*Client, error) { eventsChan: make(chan *sendRequest, config.SendQueueSize), pendingQueue: make(chan *sendRequest, config.PendingQueueSize), buffer: buffer.NewBuffer(config.BufferMax), - closeCh: make(chan struct{}), } go func() { client.read() diff --git a/opcua/error.go b/opcua/error.go new file mode 100644 index 0000000..a460ca4 --- /dev/null +++ b/opcua/error.go @@ -0,0 +1,7 @@ +package opcua + +import "errors" + +var ( + ErrClientClosed = errors.New("client state is closed") +)