Skip to content

Commit

Permalink
fix: protocol client deadlock when send request after close
Browse files Browse the repository at this point in the history
Signed-off-by: ZhangJian He <shoothzj@gmail.com>
  • Loading branch information
shoothzj committed Sep 17, 2024
1 parent 2338d63 commit 5a2b1de
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 76 deletions.
154 changes: 78 additions & 76 deletions opcua/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package opcua

import (
"context"
"crypto/tls"
"encoding/binary"
"fmt"
Expand Down Expand Up @@ -34,7 +35,8 @@ type Client struct {
eventsChan chan *sendRequest
pendingQueue chan *sendRequest
buffer *buffer.Buffer
closeCh chan struct{}
ctx context.Context
ctxCancel context.CancelFunc
}

func (c *Client) Hello(message *MessageHello) (*MessageAcknowledge, error) {
Expand Down Expand Up @@ -67,95 +69,96 @@ func (c *Client) Send(buf *buffer.Buffer) (*buffer.Buffer, error) {
}

func (c *Client) sendAsync(buf *buffer.Buffer, callback func(*buffer.Buffer, error)) {
sr := &sendRequest{
buf: buf,
callback: callback,
select {
case <-c.ctx.Done():
callback(nil, ErrClientClosed)
default:
sr := &sendRequest{
buf: buf,
callback: callback,
}
c.eventsChan <- sr
}
c.eventsChan <- sr
}

func (c *Client) read() {
for {
select {
case req := <-c.pendingQueue:
n, err := c.conn.Read(c.buffer.WritableSlice())
if err != nil {
req.callback(nil, err)
c.closeCh <- struct{}{}
break
}
err = c.buffer.AdjustWriteCursor(n)
if err != nil {
req.callback(nil, err)
c.closeCh <- struct{}{}
break
}
if c.buffer.ReadableSize() < 8 {
continue
}
bytes := make([]byte, 8)
err = c.buffer.PeekExactly(bytes)
if err != nil {
req.callback(nil, err)
c.closeCh <- struct{}{}
break
}
length := int(binary.LittleEndian.Uint32(bytes[4:8]))
if c.buffer.ReadableSize() < length {
continue
}
// in case ddos attack
if length > c.buffer.Capacity() {
req.callback(nil, fmt.Errorf("response length %d is too large", length))
c.closeCh <- struct{}{}
break
}
data := make([]byte, length)
err = c.buffer.ReadExactly(data)
if err != nil {
req.callback(nil, err)
c.closeCh <- struct{}{}
break
}
c.buffer.Compact()
req.callback(buffer.NewBufferFromBytes(data), nil)
case <-c.closeCh:
return
for req := range c.pendingQueue {
n, err := c.conn.Read(c.buffer.WritableSlice())
if err != nil {
req.callback(nil, err)
c.close()
break
}
err = c.buffer.AdjustWriteCursor(n)
if err != nil {
req.callback(nil, err)
c.close()
break
}
if c.buffer.ReadableSize() < 8 {
continue
}
bytes := make([]byte, 8)
err = c.buffer.PeekExactly(bytes)
if err != nil {
req.callback(nil, err)
c.close()
break
}
length := int(binary.LittleEndian.Uint32(bytes[4:8]))
if c.buffer.ReadableSize() < length {
continue
}
// in case ddos attack
if length > c.buffer.Capacity() {
req.callback(nil, fmt.Errorf("response length %d is too large", length))
c.close()
break
}
data := make([]byte, length)
err = c.buffer.ReadExactly(data)
if err != nil {
req.callback(nil, err)
c.close()
break
}
c.buffer.Compact()
req.callback(buffer.NewBufferFromBytes(data), nil)
}
}

func (c *Client) write() {
for {
select {
case req := <-c.eventsChan:
bytes, err := req.buf.ReadNBytes(req.buf.ReadableSize())
if err != nil {
req.callback(nil, err)
c.closeCh <- struct{}{}
break
}
n, err := c.conn.Write(bytes)
if err != nil {
req.callback(nil, err)
c.closeCh <- struct{}{}
break
}
if n != len(bytes) {
req.callback(nil, fmt.Errorf("write %d bytes, but expect %d bytes", n, len(bytes)))
c.closeCh <- struct{}{}
break
}
c.pendingQueue <- req
case <-c.closeCh:
return
for req := range c.eventsChan {
bytes, err := req.buf.ReadNBytes(req.buf.ReadableSize())
if err != nil {
req.callback(nil, err)
c.close()
break
}
n, err := c.conn.Write(bytes)
if err != nil {
req.callback(nil, err)
c.close()
break
}
if n != len(bytes) {
req.callback(nil, fmt.Errorf("write %d bytes, but expect %d bytes", n, len(bytes)))
c.close()
break
}
c.pendingQueue <- req
}
}

func (c *Client) close() {
c.Close()
}

func (c *Client) Close() {
c.ctxCancel()
_ = c.conn.Close()
c.closeCh <- struct{}{}
close(c.eventsChan)
close(c.pendingQueue)
}

func NewClient(config *ClientConfig) (*Client, error) {
Expand All @@ -182,7 +185,6 @@ func NewClient(config *ClientConfig) (*Client, error) {
eventsChan: make(chan *sendRequest, config.SendQueueSize),
pendingQueue: make(chan *sendRequest, config.PendingQueueSize),
buffer: buffer.NewBuffer(config.BufferMax),
closeCh: make(chan struct{}),
}
go func() {
client.read()
Expand Down
7 changes: 7 additions & 0 deletions opcua/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package opcua

import "errors"

var (
ErrClientClosed = errors.New("client state is closed")
)

0 comments on commit 5a2b1de

Please sign in to comment.