diff --git a/opcua/conn.go b/opcua/conn.go new file mode 100644 index 0000000..a47d030 --- /dev/null +++ b/opcua/conn.go @@ -0,0 +1,30 @@ +package opcua + +import ( + "net" +) + +type Conn struct { + net.Conn + context any +} + +func (c *Conn) SetContext(value any) { + c.context = value +} + +func (c *Conn) Context() (value any) { + return c.context +} + +func (c *Conn) NetConn() net.Conn { + return c.Conn +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.Conn.RemoteAddr() +} + +func (c *Conn) LocalAddr() net.Addr { + return c.Conn.LocalAddr() +} diff --git a/opcua/opcua_handler.go b/opcua/opcua_handler.go index 12bbdae..b37e2e3 100644 --- a/opcua/opcua_handler.go +++ b/opcua/opcua_handler.go @@ -1,8 +1,23 @@ package opcua -import "net" +import "github.com/protocol-laboratory/opcua-go/opcua/uamsg" type ServerHandler interface { - ConnectionOpened(conn net.Conn) - ConnectionClosed(conn net.Conn) + ConnectionOpened(conn *Conn) + ConnectionClosed(conn *Conn) + + ValidateHello(conn *Conn, helloMessage *uamsg.HelloMessageExtras) error +} + +type DefaultServerHandler struct { +} + +func (d DefaultServerHandler) ConnectionOpened(conn *Conn) { +} + +func (d DefaultServerHandler) ConnectionClosed(conn *Conn) { +} + +func (d DefaultServerHandler) ValidateHello(conn *Conn, helloMessage *uamsg.HelloMessageExtras) error { + return nil } diff --git a/opcua/secure_channel.go b/opcua/secure_channel.go index ea1bfde..339264f 100644 --- a/opcua/secure_channel.go +++ b/opcua/secure_channel.go @@ -2,7 +2,6 @@ package opcua import ( "errors" - "net" "sync/atomic" "golang.org/x/exp/slog" @@ -13,7 +12,7 @@ import ( ) type SecureChannel struct { - conn net.Conn + conn *Conn channelId uint32 channelIdGen *ChannelIdGen logger *slog.Logger @@ -24,6 +23,8 @@ type SecureChannel struct { maxResponseMessageSize uint32 endpointUrl string + handler ServerHandler + // TODO conn set read timeout decoder enc.Decoder encoder enc.Encoder @@ -37,11 +38,17 @@ const ( ProtocolVersion uint32 = 0 ) -func newSecureChannel(conn net.Conn, svcConf *ServerConfig, channelId uint32, channelIdGen *ChannelIdGen, logger *slog.Logger) *SecureChannel { +func newSecureChannel(conn *Conn, + svcConf *ServerConfig, + channelId uint32, + channelIdGen *ChannelIdGen, + handler ServerHandler, + logger *slog.Logger) *SecureChannel { return &SecureChannel{ conn: conn, channelId: channelId, channelIdGen: channelIdGen, + handler: handler, logger: logger, decoder: enc.NewDefaultDecoder(conn, int64(svcConf.ReceiverBufferSize)), encoder: enc.NewDefaultEncoder(), @@ -83,9 +90,13 @@ func (secChan *SecureChannel) handleHello() error { secChan.sendMaxChunkSize = helloBody.ReceiveBufferSize secChan.maxChunkCount = helloBody.MaxChunkCount secChan.maxResponseMessageSize = helloBody.MaxMessageSize - // TODO need callback to validate endpoint secChan.endpointUrl = helloBody.EndpointUrl + err = secChan.handler.ValidateHello(secChan.conn, helloBody) + if err != nil { + return err + } + resp := &uamsg.Message{ MessageHeader: &uamsg.MessageHeader{ MessageType: uamsg.AcknowledgeMessageType, @@ -143,7 +154,7 @@ func (secChan *SecureChannel) handleOpenSecureChannel() error { return errors.New("only support NONE security mode") } - serverNonce := []byte{} + var serverNonce []byte channelId := secChan.channelIdGen.next() tokenId := secChan.getNextTokenId() diff --git a/opcua/server.go b/opcua/server.go index 6a87212..cd89668 100644 --- a/opcua/server.go +++ b/opcua/server.go @@ -51,9 +51,11 @@ func NewServer(config *ServerConfig) (*Server, error) { config: config, quit: make(chan bool), channelIdGen: &ChannelIdGen{}, - handler: config.handler, logger: config.Logger, } + if config.handler == nil { + server.handler = &DefaultServerHandler{} + } server.logger.Info("server initialized", slog.String("host", config.Host), slog.Int("port", config.Port)) return server, nil } @@ -100,34 +102,30 @@ func (s *Server) listenLoop() { } } go func() { - s.handleConn(netConn) + s.handleConn(&Conn{ + Conn: netConn, + }) }() } } -func (s *Server) handleConn(conn net.Conn) { - if s.handler != nil { - s.handler.ConnectionOpened(conn) - } +func (s *Server) handleConn(conn *Conn) { + s.handler.ConnectionOpened(conn) channelId := s.channelIdGen.next() channelLogger := s.logger.With(LogRemoteAddr, conn.RemoteAddr().String()).With(LogChannelId, channelId) channelLogger.Info("starting SecureChannel initialization") - secChannel := newSecureChannel(conn, s.config, channelId, s.channelIdGen, channelLogger) + secChannel := newSecureChannel(conn, s.config, channelId, s.channelIdGen, s.handler, channelLogger) err := secChannel.open() if err != nil { _ = conn.Close() - if s.handler != nil { - s.handler.ConnectionClosed(conn) - } + s.handler.ConnectionClosed(conn) channelLogger.Error("failed to open SecureChannel", slog.Any("err", err.Error())) return } err = secChannel.serve() if err != nil { _ = conn.Close() - if s.handler != nil { - s.handler.ConnectionClosed(conn) - } + s.handler.ConnectionClosed(conn) secChannel.logger.Error("processing request error", slog.Any("err", err)) } }