Skip to content

Commit

Permalink
feat: support config MaxResponseSize (#46)
Browse files Browse the repository at this point in the history
Signed-off-by: ZhangJian He <shoothzj@gmail.com>
  • Loading branch information
shoothzj authored Oct 28, 2024
1 parent cdb2753 commit 3af6e5d
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 13 deletions.
16 changes: 14 additions & 2 deletions opcua/enc/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ type FastEncoder interface {

type DefaultEncoder struct {
sequenceNumberGenerator func() uint32
maxEncodedSize int
}

func NewDefaultEncoder() *DefaultEncoder {
return &DefaultEncoder{}
func NewDefaultEncoder(maxEncodedSize int) *DefaultEncoder {
return &DefaultEncoder{
maxEncodedSize: maxEncodedSize,
}
}

func (e *DefaultEncoder) Encode(v *uamsg.Message, chunkSize int) ([][]byte, error) {
Expand Down Expand Up @@ -67,6 +70,7 @@ func (e *DefaultEncoder) Encode(v *uamsg.Message, chunkSize int) ([][]byte, erro

leftBodySize := len(dataBytes)
headerLength := messageHeaderLength + securityHeaderLength + sequenceHeaderLength
totalSize := 0

for leftBodySize > 0 {
tempBuff := bytes.NewBuffer(nil)
Expand Down Expand Up @@ -120,6 +124,11 @@ func (e *DefaultEncoder) Encode(v *uamsg.Message, chunkSize int) ([][]byte, erro
return nil, err
}
}

totalSize += tempBuff.Len()
if totalSize > e.maxEncodedSize {
return nil, errors.New("exceed max encoded size")
}
chunks = append(chunks, tempBuff.Bytes())
}
return chunks, nil
Expand Down Expand Up @@ -167,6 +176,9 @@ func genericEncoder(v interface{}) ([]byte, error) {
if err != nil {
return nil, err
}
if length == 0 || length == -1 {
return buff.Bytes(), nil
}

for i := 0; i < value.Len(); i++ {
elemBytes, err := genericEncoder(value.Index(i).Interface())
Expand Down
14 changes: 7 additions & 7 deletions opcua/enc/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func showByteSlice(s []byte) {
func getHelloMsgTestCase() *encodeTestCase {
return &encodeTestCase{
name: "encode hello msg",
e: &DefaultEncoder{},
e: NewDefaultEncoder(64 * 1024),
args: encodeArgs{
v: &uamsg.Message{
MessageHeader: &uamsg.MessageHeader{
Expand Down Expand Up @@ -98,7 +98,7 @@ func getHelloMsgTestCase() *encodeTestCase {
func getAcknowledgeMsgTestCase() *encodeTestCase {
return &encodeTestCase{
name: "encode acknowledge msg",
e: &DefaultEncoder{},
e: NewDefaultEncoder(64 * 1024),
args: encodeArgs{
v: &uamsg.Message{
MessageHeader: &uamsg.MessageHeader{
Expand All @@ -124,7 +124,7 @@ func getAcknowledgeMsgTestCase() *encodeTestCase {
func getOpenSecureChannelRequestMsgTestCase() *encodeTestCase {
return &encodeTestCase{
name: "encode open secure channel request msg",
e: &DefaultEncoder{},
e: NewDefaultEncoder(64 * 1024),
args: encodeArgs{
v: &uamsg.Message{
MessageHeader: &uamsg.MessageHeader{
Expand Down Expand Up @@ -183,7 +183,7 @@ func getOpenSecureChannelRequestMsgTestCase() *encodeTestCase {
func getOpenSecureChannelResponseMsgTestCase() *encodeTestCase {
return &encodeTestCase{
name: "encode open secure channel response msg",
e: &DefaultEncoder{},
e: NewDefaultEncoder(64 * 1024),
args: encodeArgs{
v: &uamsg.Message{
MessageHeader: &uamsg.MessageHeader{
Expand Down Expand Up @@ -243,7 +243,7 @@ func getOpenSecureChannelResponseMsgTestCase() *encodeTestCase {
func getCreateSessionRequestMsgTestCase() *encodeTestCase {
return &encodeTestCase{
name: "encode create session request msg",
e: &DefaultEncoder{},
e: NewDefaultEncoder(64 * 1024),
args: encodeArgs{
v: &uamsg.Message{
MessageHeader: &uamsg.MessageHeader{
Expand Down Expand Up @@ -314,7 +314,7 @@ func getCreateSessionRequestMsgTestCase() *encodeTestCase {
func getActiveSessionRequestMsgTestCase() *encodeTestCase {
return &encodeTestCase{
name: "encode active session request msg",
e: &DefaultEncoder{},
e: NewDefaultEncoder(64 * 1024),
args: encodeArgs{
v: &uamsg.Message{
MessageHeader: &uamsg.MessageHeader{
Expand Down Expand Up @@ -386,7 +386,7 @@ func getActiveSessionRequestMsgTestCase() *encodeTestCase {
func getActiveSessionResponseMsgTestCase() *encodeTestCase {
return &encodeTestCase{
name: "encode active session response msg",
e: &DefaultEncoder{},
e: NewDefaultEncoder(64 * 1024),
args: encodeArgs{
v: &uamsg.Message{
MessageHeader: &uamsg.MessageHeader{
Expand Down
3 changes: 2 additions & 1 deletion opcua/enc/message_acknowledge_decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package enc

import (
"bytes"
"github.com/stretchr/testify/require"
"testing"

"github.com/stretchr/testify/require"
)

func TestDecodeAcknowledgeMessage(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions opcua/enc/message_close_session_request_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package enc

import (
"testing"
)

func TestDecodeCloseSessionRequestMessage(t *testing.T) {
}
3 changes: 2 additions & 1 deletion opcua/enc/message_hello_decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package enc

import (
"bytes"
"github.com/stretchr/testify/require"
"testing"

"github.com/stretchr/testify/require"
)

func TestDecodeHelloMessage(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion opcua/enc/message_open_secure_channel_decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package enc

import (
"bytes"
"github.com/stretchr/testify/require"
"testing"

"github.com/stretchr/testify/require"
)

func TestDecodeOpenSecureChannelMessage(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion opcua/secure_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func newSecureChannel(conn *Conn,
handler: handler,
logger: logger,
decoder: enc.NewDefaultDecoder(conn, int64(svcConf.ReceiverBufferSize)),
encoder: enc.NewDefaultEncoder(),
encoder: enc.NewDefaultEncoder(svcConf.MaxResponseSize),
}
}

Expand Down
4 changes: 4 additions & 0 deletions opcua/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type ServerConfig struct {
Handler ServerHandler

ReceiverBufferSize int
MaxResponseSize int

ReadRequestNodeLimit int

Expand Down Expand Up @@ -52,6 +53,9 @@ func NewServer(config *ServerConfig) (*Server, error) {
if config.ReceiverBufferSize < 9 {
return nil, fmt.Errorf("receiver buffer size must be at least 9 bytes")
}
if config.MaxResponseSize <= 0 {
config.MaxResponseSize = 64 * 1024
}
server := &Server{
config: config,
channelIdGen: &ChannelIdGen{},
Expand Down

0 comments on commit 3af6e5d

Please sign in to comment.