Skip to content

Commit

Permalink
feat: add validate hello method (#39)
Browse files Browse the repository at this point in the history
Signed-off-by: ZhangJian He <shoothzj@gmail.com>
  • Loading branch information
shoothzj authored Oct 16, 2024
1 parent c21236d commit 201fd58
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 21 deletions.
30 changes: 30 additions & 0 deletions opcua/conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package opcua

import (
"net"
)

type Conn struct {
net.Conn
context any
}

func (c *Conn) SetContext(value any) {
c.context = value
}

func (c *Conn) Context() (value any) {
return c.context
}

func (c *Conn) NetConn() net.Conn {
return c.Conn
}

func (c *Conn) RemoteAddr() net.Addr {
return c.Conn.RemoteAddr()
}

func (c *Conn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
}
21 changes: 18 additions & 3 deletions opcua/opcua_handler.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
package opcua

import "net"
import "github.com/protocol-laboratory/opcua-go/opcua/uamsg"

type ServerHandler interface {
ConnectionOpened(conn net.Conn)
ConnectionClosed(conn net.Conn)
ConnectionOpened(conn *Conn)
ConnectionClosed(conn *Conn)

ValidateHello(conn *Conn, helloMessage *uamsg.HelloMessageExtras) error
}

type DefaultServerHandler struct {
}

func (d DefaultServerHandler) ConnectionOpened(conn *Conn) {
}

func (d DefaultServerHandler) ConnectionClosed(conn *Conn) {
}

func (d DefaultServerHandler) ValidateHello(conn *Conn, helloMessage *uamsg.HelloMessageExtras) error {
return nil
}
21 changes: 16 additions & 5 deletions opcua/secure_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package opcua

import (
"errors"
"net"
"sync/atomic"

"golang.org/x/exp/slog"
Expand All @@ -13,7 +12,7 @@ import (
)

type SecureChannel struct {
conn net.Conn
conn *Conn
channelId uint32
channelIdGen *ChannelIdGen
logger *slog.Logger
Expand All @@ -24,6 +23,8 @@ type SecureChannel struct {
maxResponseMessageSize uint32
endpointUrl string

handler ServerHandler

// TODO conn set read timeout
decoder enc.Decoder
encoder enc.Encoder
Expand All @@ -37,11 +38,17 @@ const (
ProtocolVersion uint32 = 0
)

func newSecureChannel(conn net.Conn, svcConf *ServerConfig, channelId uint32, channelIdGen *ChannelIdGen, logger *slog.Logger) *SecureChannel {
func newSecureChannel(conn *Conn,
svcConf *ServerConfig,
channelId uint32,
channelIdGen *ChannelIdGen,
handler ServerHandler,
logger *slog.Logger) *SecureChannel {
return &SecureChannel{
conn: conn,
channelId: channelId,
channelIdGen: channelIdGen,
handler: handler,
logger: logger,
decoder: enc.NewDefaultDecoder(conn, int64(svcConf.ReceiverBufferSize)),
encoder: enc.NewDefaultEncoder(),
Expand Down Expand Up @@ -83,9 +90,13 @@ func (secChan *SecureChannel) handleHello() error {
secChan.sendMaxChunkSize = helloBody.ReceiveBufferSize
secChan.maxChunkCount = helloBody.MaxChunkCount
secChan.maxResponseMessageSize = helloBody.MaxMessageSize
// TODO need callback to validate endpoint
secChan.endpointUrl = helloBody.EndpointUrl

err = secChan.handler.ValidateHello(secChan.conn, helloBody)
if err != nil {
return err
}

resp := &uamsg.Message{
MessageHeader: &uamsg.MessageHeader{
MessageType: uamsg.AcknowledgeMessageType,
Expand Down Expand Up @@ -143,7 +154,7 @@ func (secChan *SecureChannel) handleOpenSecureChannel() error {
return errors.New("only support NONE security mode")
}

serverNonce := []byte{}
var serverNonce []byte
channelId := secChan.channelIdGen.next()
tokenId := secChan.getNextTokenId()

Expand Down
24 changes: 11 additions & 13 deletions opcua/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ func NewServer(config *ServerConfig) (*Server, error) {
config: config,
quit: make(chan bool),
channelIdGen: &ChannelIdGen{},
handler: config.handler,
logger: config.Logger,
}
if config.handler == nil {
server.handler = &DefaultServerHandler{}
}
server.logger.Info("server initialized", slog.String("host", config.Host), slog.Int("port", config.Port))
return server, nil
}
Expand Down Expand Up @@ -100,34 +102,30 @@ func (s *Server) listenLoop() {
}
}
go func() {
s.handleConn(netConn)
s.handleConn(&Conn{
Conn: netConn,
})
}()
}
}

func (s *Server) handleConn(conn net.Conn) {
if s.handler != nil {
s.handler.ConnectionOpened(conn)
}
func (s *Server) handleConn(conn *Conn) {
s.handler.ConnectionOpened(conn)
channelId := s.channelIdGen.next()
channelLogger := s.logger.With(LogRemoteAddr, conn.RemoteAddr().String()).With(LogChannelId, channelId)
channelLogger.Info("starting SecureChannel initialization")
secChannel := newSecureChannel(conn, s.config, channelId, s.channelIdGen, channelLogger)
secChannel := newSecureChannel(conn, s.config, channelId, s.channelIdGen, s.handler, channelLogger)
err := secChannel.open()
if err != nil {
_ = conn.Close()
if s.handler != nil {
s.handler.ConnectionClosed(conn)
}
s.handler.ConnectionClosed(conn)
channelLogger.Error("failed to open SecureChannel", slog.Any("err", err.Error()))
return
}
err = secChannel.serve()
if err != nil {
_ = conn.Close()
if s.handler != nil {
s.handler.ConnectionClosed(conn)
}
s.handler.ConnectionClosed(conn)
secChannel.logger.Error("processing request error", slog.Any("err", err))
}
}
Expand Down

0 comments on commit 201fd58

Please sign in to comment.