From 3af6e5d16565f3ba09bb758d3b53e9ccd8fdf0d3 Mon Sep 17 00:00:00 2001 From: ZhangJian He Date: Mon, 28 Oct 2024 20:26:52 +0800 Subject: [PATCH] feat: support config MaxResponseSize (#46) Signed-off-by: ZhangJian He --- opcua/enc/encoder.go | 16 ++++++++++++++-- opcua/enc/encoder_test.go | 14 +++++++------- opcua/enc/message_acknowledge_decode_test.go | 3 ++- opcua/enc/message_close_session_request_test.go | 8 ++++++++ opcua/enc/message_hello_decode_test.go | 3 ++- .../message_open_secure_channel_decode_test.go | 3 ++- opcua/secure_channel.go | 2 +- opcua/server.go | 4 ++++ 8 files changed, 40 insertions(+), 13 deletions(-) create mode 100644 opcua/enc/message_close_session_request_test.go diff --git a/opcua/enc/encoder.go b/opcua/enc/encoder.go index dbf52c0..63ccbb7 100644 --- a/opcua/enc/encoder.go +++ b/opcua/enc/encoder.go @@ -24,10 +24,13 @@ type FastEncoder interface { type DefaultEncoder struct { sequenceNumberGenerator func() uint32 + maxEncodedSize int } -func NewDefaultEncoder() *DefaultEncoder { - return &DefaultEncoder{} +func NewDefaultEncoder(maxEncodedSize int) *DefaultEncoder { + return &DefaultEncoder{ + maxEncodedSize: maxEncodedSize, + } } func (e *DefaultEncoder) Encode(v *uamsg.Message, chunkSize int) ([][]byte, error) { @@ -67,6 +70,7 @@ func (e *DefaultEncoder) Encode(v *uamsg.Message, chunkSize int) ([][]byte, erro leftBodySize := len(dataBytes) headerLength := messageHeaderLength + securityHeaderLength + sequenceHeaderLength + totalSize := 0 for leftBodySize > 0 { tempBuff := bytes.NewBuffer(nil) @@ -120,6 +124,11 @@ func (e *DefaultEncoder) Encode(v *uamsg.Message, chunkSize int) ([][]byte, erro return nil, err } } + + totalSize += tempBuff.Len() + if totalSize > e.maxEncodedSize { + return nil, errors.New("exceed max encoded size") + } chunks = append(chunks, tempBuff.Bytes()) } return chunks, nil @@ -167,6 +176,9 @@ func genericEncoder(v interface{}) ([]byte, error) { if err != nil { return nil, err } + if length == 0 || length == -1 { + return buff.Bytes(), nil + } for i := 0; i < value.Len(); i++ { elemBytes, err := genericEncoder(value.Index(i).Interface()) diff --git a/opcua/enc/encoder_test.go b/opcua/enc/encoder_test.go index 912df91..02e4f75 100644 --- a/opcua/enc/encoder_test.go +++ b/opcua/enc/encoder_test.go @@ -70,7 +70,7 @@ func showByteSlice(s []byte) { func getHelloMsgTestCase() *encodeTestCase { return &encodeTestCase{ name: "encode hello msg", - e: &DefaultEncoder{}, + e: NewDefaultEncoder(64 * 1024), args: encodeArgs{ v: &uamsg.Message{ MessageHeader: &uamsg.MessageHeader{ @@ -98,7 +98,7 @@ func getHelloMsgTestCase() *encodeTestCase { func getAcknowledgeMsgTestCase() *encodeTestCase { return &encodeTestCase{ name: "encode acknowledge msg", - e: &DefaultEncoder{}, + e: NewDefaultEncoder(64 * 1024), args: encodeArgs{ v: &uamsg.Message{ MessageHeader: &uamsg.MessageHeader{ @@ -124,7 +124,7 @@ func getAcknowledgeMsgTestCase() *encodeTestCase { func getOpenSecureChannelRequestMsgTestCase() *encodeTestCase { return &encodeTestCase{ name: "encode open secure channel request msg", - e: &DefaultEncoder{}, + e: NewDefaultEncoder(64 * 1024), args: encodeArgs{ v: &uamsg.Message{ MessageHeader: &uamsg.MessageHeader{ @@ -183,7 +183,7 @@ func getOpenSecureChannelRequestMsgTestCase() *encodeTestCase { func getOpenSecureChannelResponseMsgTestCase() *encodeTestCase { return &encodeTestCase{ name: "encode open secure channel response msg", - e: &DefaultEncoder{}, + e: NewDefaultEncoder(64 * 1024), args: encodeArgs{ v: &uamsg.Message{ MessageHeader: &uamsg.MessageHeader{ @@ -243,7 +243,7 @@ func getOpenSecureChannelResponseMsgTestCase() *encodeTestCase { func getCreateSessionRequestMsgTestCase() *encodeTestCase { return &encodeTestCase{ name: "encode create session request msg", - e: &DefaultEncoder{}, + e: NewDefaultEncoder(64 * 1024), args: encodeArgs{ v: &uamsg.Message{ MessageHeader: &uamsg.MessageHeader{ @@ -314,7 +314,7 @@ func getCreateSessionRequestMsgTestCase() *encodeTestCase { func getActiveSessionRequestMsgTestCase() *encodeTestCase { return &encodeTestCase{ name: "encode active session request msg", - e: &DefaultEncoder{}, + e: NewDefaultEncoder(64 * 1024), args: encodeArgs{ v: &uamsg.Message{ MessageHeader: &uamsg.MessageHeader{ @@ -386,7 +386,7 @@ func getActiveSessionRequestMsgTestCase() *encodeTestCase { func getActiveSessionResponseMsgTestCase() *encodeTestCase { return &encodeTestCase{ name: "encode active session response msg", - e: &DefaultEncoder{}, + e: NewDefaultEncoder(64 * 1024), args: encodeArgs{ v: &uamsg.Message{ MessageHeader: &uamsg.MessageHeader{ diff --git a/opcua/enc/message_acknowledge_decode_test.go b/opcua/enc/message_acknowledge_decode_test.go index dc9c0ab..75cf25c 100644 --- a/opcua/enc/message_acknowledge_decode_test.go +++ b/opcua/enc/message_acknowledge_decode_test.go @@ -2,8 +2,9 @@ package enc import ( "bytes" - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestDecodeAcknowledgeMessage(t *testing.T) { diff --git a/opcua/enc/message_close_session_request_test.go b/opcua/enc/message_close_session_request_test.go new file mode 100644 index 0000000..f93cad3 --- /dev/null +++ b/opcua/enc/message_close_session_request_test.go @@ -0,0 +1,8 @@ +package enc + +import ( + "testing" +) + +func TestDecodeCloseSessionRequestMessage(t *testing.T) { +} diff --git a/opcua/enc/message_hello_decode_test.go b/opcua/enc/message_hello_decode_test.go index a4158ea..cca1530 100644 --- a/opcua/enc/message_hello_decode_test.go +++ b/opcua/enc/message_hello_decode_test.go @@ -2,8 +2,9 @@ package enc import ( "bytes" - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestDecodeHelloMessage(t *testing.T) { diff --git a/opcua/enc/message_open_secure_channel_decode_test.go b/opcua/enc/message_open_secure_channel_decode_test.go index 47e7c37..0707f2a 100644 --- a/opcua/enc/message_open_secure_channel_decode_test.go +++ b/opcua/enc/message_open_secure_channel_decode_test.go @@ -2,8 +2,9 @@ package enc import ( "bytes" - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestDecodeOpenSecureChannelMessage(t *testing.T) { diff --git a/opcua/secure_channel.go b/opcua/secure_channel.go index 2ee9c41..aa90cde 100644 --- a/opcua/secure_channel.go +++ b/opcua/secure_channel.go @@ -53,7 +53,7 @@ func newSecureChannel(conn *Conn, handler: handler, logger: logger, decoder: enc.NewDefaultDecoder(conn, int64(svcConf.ReceiverBufferSize)), - encoder: enc.NewDefaultEncoder(), + encoder: enc.NewDefaultEncoder(svcConf.MaxResponseSize), } } diff --git a/opcua/server.go b/opcua/server.go index e6dba5d..df4bfb1 100644 --- a/opcua/server.go +++ b/opcua/server.go @@ -16,6 +16,7 @@ type ServerConfig struct { Handler ServerHandler ReceiverBufferSize int + MaxResponseSize int ReadRequestNodeLimit int @@ -52,6 +53,9 @@ func NewServer(config *ServerConfig) (*Server, error) { if config.ReceiverBufferSize < 9 { return nil, fmt.Errorf("receiver buffer size must be at least 9 bytes") } + if config.MaxResponseSize <= 0 { + config.MaxResponseSize = 64 * 1024 + } server := &Server{ config: config, channelIdGen: &ChannelIdGen{},