From e0acb79bc3d2e7c7b9f57f6f7327416e7caf1c0c Mon Sep 17 00:00:00 2001 From: Kirill Sysoev Date: Sat, 13 Apr 2024 18:17:16 +0800 Subject: [PATCH 1/3] Replace websocket library --- backend/http.go | 2 +- backend/http_test.go | 2 +- channel/channel.go | 10 +++- channel/connection.go | 33 +++++------- channel/connection_registry.go | 2 +- channel/connection_registry_test.go | 2 +- channel/connection_test.go | 83 +++++++++++++++++++++++------ go.mod | 4 +- go.sum | 4 +- interfaces.go | 2 +- mocks/mock_Connection.go | 12 ++--- 11 files changed, 102 insertions(+), 54 deletions(-) diff --git a/backend/http.go b/backend/http.go index 325479d..e8a30e6 100644 --- a/backend/http.go +++ b/backend/http.go @@ -62,7 +62,7 @@ func (b *HTTPBackend) Handle(conn wasabi.Connection, r wasabi.Request) error { return err } - return conn.Send(respBody.String()) + return conn.Send(respBody.Bytes()) } func WithDefaultHTTPTimeout(timeout time.Duration) HTTPBackendOption { diff --git a/backend/http_test.go b/backend/http_test.go index 12c6c36..faf6420 100644 --- a/backend/http_test.go +++ b/backend/http_test.go @@ -41,7 +41,7 @@ func TestHTTPBackend_Handle(t *testing.T) { mockReq.EXPECT().Context().Return(context.Background()) - mockConn.EXPECT().Send("OK").Return(nil) + mockConn.EXPECT().Send([]byte("OK")).Return(nil) mockReq.EXPECT().Data().Return([]byte("test request")) backend := NewBackend(func(req wasabi.Request) (*http.Request, error) { diff --git a/channel/channel.go b/channel/channel.go index 623e75f..3ac426c 100644 --- a/channel/channel.go +++ b/channel/channel.go @@ -5,7 +5,7 @@ import ( "net/http" "github.com/ksysoev/wasabi" - "golang.org/x/net/websocket" + "nhooyr.io/websocket" ) // DefaultChannel is default implementation of Channel @@ -51,7 +51,13 @@ func (c *DefaultChannel) Handler() http.Handler { }) } - wsHandler := websocket.Handler(func(ws *websocket.Conn) { + wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := websocket.Accept(w, r, nil) + + if err != nil { + return + } + conn := c.connRegistry.AddConnection(ctx, ws, c.disptacher.Dispatch) conn.HandleRequests() }) diff --git a/channel/connection.go b/channel/connection.go index a311104..a2f5ec9 100644 --- a/channel/connection.go +++ b/channel/connection.go @@ -5,12 +5,11 @@ import ( "errors" "io" "log/slog" - "net" "sync" "sync/atomic" "github.com/google/uuid" - "golang.org/x/net/websocket" + "nhooyr.io/websocket" "github.com/ksysoev/wasabi" ) @@ -67,26 +66,20 @@ func (c *Conn) HandleRequests() { defer c.close() for c.ctx.Err() == nil { - var data []byte - err := websocket.Message.Receive(c.ws, &data) + _, reader, err := c.ws.Reader(c.ctx) if err != nil { - switch { - case c.isClosed.Load(): - return - case errors.Is(err, io.EOF): - return - case errors.Is(err, websocket.ErrFrameTooLarge): - return - case errors.Is(err, net.ErrClosed): - return - default: - slog.Warn("Error reading message: ", err) - } + slog.Warn("Error reading message: " + err.Error()) + + return + } + data, err := io.ReadAll(reader) + + if err != nil { slog.Warn("Error reading message: " + err.Error()) - continue + return } c.reqWG.Add(1) @@ -99,12 +92,12 @@ func (c *Conn) HandleRequests() { } // Send sends message to connection -func (c *Conn) Send(msg any) error { +func (c *Conn) Send(msg []byte) error { if c.isClosed.Load() || c.ctx.Err() != nil { return ErrConnectionClosed } - return websocket.Message.Send(c.ws, msg) + return c.ws.Write(c.ctx, websocket.MessageText, msg) } // close closes the connection. @@ -119,6 +112,6 @@ func (c *Conn) close() { c.onClose <- c.id c.isClosed.Store(true) - c.ws.Close() + _ = c.ws.CloseNow() c.reqWG.Wait() } diff --git a/channel/connection_registry.go b/channel/connection_registry.go index f695b5d..7e8f043 100644 --- a/channel/connection_registry.go +++ b/channel/connection_registry.go @@ -5,7 +5,7 @@ import ( "sync" "github.com/ksysoev/wasabi" - "golang.org/x/net/websocket" + "nhooyr.io/websocket" ) // DefaultConnectionRegistry is default implementation of ConnectionRegistry diff --git a/channel/connection_registry_test.go b/channel/connection_registry_test.go index 26371aa..b49bbfa 100644 --- a/channel/connection_registry_test.go +++ b/channel/connection_registry_test.go @@ -7,7 +7,7 @@ import ( "github.com/ksysoev/wasabi" "github.com/ksysoev/wasabi/mocks" - "golang.org/x/net/websocket" + "nhooyr.io/websocket" ) func TestDefaultConnectionRegistry_AddConnection(t *testing.T) { diff --git a/channel/connection_test.go b/channel/connection_test.go index eeeaaf7..4a090e7 100644 --- a/channel/connection_test.go +++ b/channel/connection_test.go @@ -3,15 +3,48 @@ package channel import ( "context" "io" + "net/http" "net/http/httptest" - "sync" "testing" "time" "github.com/ksysoev/wasabi" - "golang.org/x/net/websocket" + "nhooyr.io/websocket" ) +var wsHandlerEcho = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, nil) + if err != nil { + return + } + defer c.Close(websocket.StatusNormalClosure, "") + + for { + _, wsr, err := c.Reader(r.Context()) + if err != nil { + if err == io.EOF { + return + } + return + } + + wsw, err := c.Writer(r.Context(), websocket.MessageText) + if err != nil { + return + } + + _, err = io.Copy(wsw, wsr) + if err != nil { + return + } + + err = wsw.Close() + if err != nil { + return + } + } +}) + func TestConn_ID(t *testing.T) { ws := &websocket.Conn{} onClose := make(chan string) @@ -33,50 +66,62 @@ func TestConn_Context(t *testing.T) { } func TestConn_HandleRequests(t *testing.T) { - server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) })) + server := httptest.NewServer(wsHandlerEcho) defer server.Close() url := "ws://" + server.Listener.Addr().String() - ws, err := websocket.Dial(url, "", "http://localhost/") + ws, resp, err := websocket.Dial(context.Background(), url, nil) + if err != nil { t.Errorf("Unexpected error dialing websocket: %v", err) } - defer ws.Close() + if resp.Body != nil { + resp.Body.Close() + } + + defer func() { _ = ws.CloseNow() }() onClose := make(chan string) conn := NewConnection(context.Background(), ws, nil, onClose) // Mock OnMessage callback - var wg sync.WaitGroup + received := make(chan struct{}) - wg.Add(1) - - conn.onMessageCB = func(c wasabi.Connection, data []byte) { wg.Done() } + conn.onMessageCB = func(c wasabi.Connection, data []byte) { received <- struct{}{} } go conn.HandleRequests() // Send message to trigger OnMessage callback - err = websocket.Message.Send(ws, []byte("test message")) + err = ws.Write(context.Background(), websocket.MessageText, []byte("test message")) if err != nil { t.Errorf("Unexpected error sending message: %v", err) } - wg.Wait() + select { + case <-received: + // Expected + case <-time.After(50 * time.Millisecond): + t.Error("Expected OnMessage callback to be called") + } } func TestConn_Send(t *testing.T) { - server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) })) + server := httptest.NewServer(wsHandlerEcho) defer server.Close() url := "ws://" + server.Listener.Addr().String() - ws, err := websocket.Dial(url, "", "http://localhost/") + ws, resp, err := websocket.Dial(context.Background(), url, nil) if err != nil { t.Errorf("Unexpected error dialing websocket: %v", err) } - defer ws.Close() + if resp.Body != nil { + resp.Body.Close() + } + + defer func() { _ = ws.CloseNow() }() onClose := make(chan string) conn := NewConnection(context.Background(), ws, nil, onClose) @@ -88,16 +133,20 @@ func TestConn_Send(t *testing.T) { } func TestConn_close(t *testing.T) { - server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) })) + server := httptest.NewServer(wsHandlerEcho) defer server.Close() url := "ws://" + server.Listener.Addr().String() - ws, err := websocket.Dial(url, "", "http://localhost/") + ws, resp, err := websocket.Dial(context.Background(), url, nil) if err != nil { t.Errorf("Unexpected error dialing websocket: %v", err) } - defer ws.Close() + if resp.Body != nil { + resp.Body.Close() + } + + defer func() { _ = ws.CloseNow() }() onClose := make(chan string) conn := NewConnection(context.Background(), ws, nil, onClose) diff --git a/go.mod b/go.mod index 6867533..36c2c9b 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,14 @@ go 1.22.1 require ( github.com/google/uuid v1.5.0 + github.com/ksysoev/ratestor v0.1.0 github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e - golang.org/x/net v0.19.0 + nhooyr.io/websocket v1.8.11 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/ksysoev/ratestor v0.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 10e1689..ff8e672 100644 --- a/go.sum +++ b/go.sum @@ -12,9 +12,9 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e h1:723BNChdd0c2Wk6WOE320qGBiPtYx0F0Bbm1kriShfE= golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0= +nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= diff --git a/interfaces.go b/interfaces.go index 45b9218..7b2fbb5 100644 --- a/interfaces.go +++ b/interfaces.go @@ -26,7 +26,7 @@ type OnMessage func(conn Connection, data []byte) // Connection is interface for connections type Connection interface { - Send(msg any) error + Send(msg []byte) error Context() context.Context ID() string HandleRequests() diff --git a/mocks/mock_Connection.go b/mocks/mock_Connection.go index 6a85bb8..04d5fb1 100644 --- a/mocks/mock_Connection.go +++ b/mocks/mock_Connection.go @@ -148,7 +148,7 @@ func (_c *MockConnection_ID_Call) RunAndReturn(run func() string) *MockConnectio } // Send provides a mock function with given fields: msg -func (_m *MockConnection) Send(msg interface{}) error { +func (_m *MockConnection) Send(msg []byte) error { ret := _m.Called(msg) if len(ret) == 0 { @@ -156,7 +156,7 @@ func (_m *MockConnection) Send(msg interface{}) error { } var r0 error - if rf, ok := ret.Get(0).(func(interface{}) error); ok { + if rf, ok := ret.Get(0).(func([]byte) error); ok { r0 = rf(msg) } else { r0 = ret.Error(0) @@ -171,14 +171,14 @@ type MockConnection_Send_Call struct { } // Send is a helper method to define mock.On call -// - msg interface{} +// - msg []byte func (_e *MockConnection_Expecter) Send(msg interface{}) *MockConnection_Send_Call { return &MockConnection_Send_Call{Call: _e.mock.On("Send", msg)} } -func (_c *MockConnection_Send_Call) Run(run func(msg interface{})) *MockConnection_Send_Call { +func (_c *MockConnection_Send_Call) Run(run func(msg []byte)) *MockConnection_Send_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(interface{})) + run(args[0].([]byte)) }) return _c } @@ -188,7 +188,7 @@ func (_c *MockConnection_Send_Call) Return(_a0 error) *MockConnection_Send_Call return _c } -func (_c *MockConnection_Send_Call) RunAndReturn(run func(interface{}) error) *MockConnection_Send_Call { +func (_c *MockConnection_Send_Call) RunAndReturn(run func([]byte) error) *MockConnection_Send_Call { _c.Call.Return(run) return _c } From f0a4bd6b2c21bf68c39f5d3586996264539bf2cb Mon Sep 17 00:00:00 2001 From: Kirill Sysoev Date: Sat, 13 Apr 2024 19:28:51 +0800 Subject: [PATCH 2/3] Expose message type to Wasabi components --- backend/http.go | 2 +- backend/http_test.go | 2 +- channel/connection.go | 8 ++++---- channel/connection_registry_test.go | 2 +- channel/connection_test.go | 4 ++-- dispatch/common.go | 2 +- dispatch/pipe_dispatcher.go | 4 ++-- dispatch/pipe_dispatcher_test.go | 4 ++-- dispatch/request.go | 18 +++++++++++----- dispatch/request_test.go | 12 ++++++----- dispatch/router_dipatcher.go | 4 ++-- dispatch/router_dipatcher_test.go | 32 ++++++++++++++++++----------- examples/http_backend/main.go | 2 +- interfaces.go | 15 +++++++++++--- mocks/mock_Connection.go | 23 ++++++++++++--------- mocks/mock_Dispatcher.go | 19 +++++++++-------- 16 files changed, 93 insertions(+), 60 deletions(-) diff --git a/backend/http.go b/backend/http.go index e8a30e6..eb0192c 100644 --- a/backend/http.go +++ b/backend/http.go @@ -62,7 +62,7 @@ func (b *HTTPBackend) Handle(conn wasabi.Connection, r wasabi.Request) error { return err } - return conn.Send(respBody.Bytes()) + return conn.Send(wasabi.MsgTypeText, respBody.Bytes()) } func WithDefaultHTTPTimeout(timeout time.Duration) HTTPBackendOption { diff --git a/backend/http_test.go b/backend/http_test.go index faf6420..c5647b9 100644 --- a/backend/http_test.go +++ b/backend/http_test.go @@ -41,7 +41,7 @@ func TestHTTPBackend_Handle(t *testing.T) { mockReq.EXPECT().Context().Return(context.Background()) - mockConn.EXPECT().Send([]byte("OK")).Return(nil) + mockConn.EXPECT().Send(wasabi.MsgTypeText, []byte("OK")).Return(nil) mockReq.EXPECT().Data().Return([]byte("test request")) backend := NewBackend(func(req wasabi.Request) (*http.Request, error) { diff --git a/channel/connection.go b/channel/connection.go index a2f5ec9..b874857 100644 --- a/channel/connection.go +++ b/channel/connection.go @@ -66,7 +66,7 @@ func (c *Conn) HandleRequests() { defer c.close() for c.ctx.Err() == nil { - _, reader, err := c.ws.Reader(c.ctx) + msgType, reader, err := c.ws.Reader(c.ctx) if err != nil { slog.Warn("Error reading message: " + err.Error()) @@ -86,18 +86,18 @@ func (c *Conn) HandleRequests() { go func(wg *sync.WaitGroup) { defer wg.Done() - c.onMessageCB(c, data) + c.onMessageCB(c, msgType, data) }(c.reqWG) } } // Send sends message to connection -func (c *Conn) Send(msg []byte) error { +func (c *Conn) Send(msgType wasabi.MessageType, msg []byte) error { if c.isClosed.Load() || c.ctx.Err() != nil { return ErrConnectionClosed } - return c.ws.Write(c.ctx, websocket.MessageText, msg) + return c.ws.Write(c.ctx, msgType, msg) } // close closes the connection. diff --git a/channel/connection_registry_test.go b/channel/connection_registry_test.go index b49bbfa..73c9791 100644 --- a/channel/connection_registry_test.go +++ b/channel/connection_registry_test.go @@ -13,7 +13,7 @@ import ( func TestDefaultConnectionRegistry_AddConnection(t *testing.T) { ctx := context.Background() ws := &websocket.Conn{} - cb := func(wasabi.Connection, []byte) {} + cb := func(wasabi.Connection, wasabi.MessageType, []byte) {} registry := NewDefaultConnectionRegistry() diff --git a/channel/connection_test.go b/channel/connection_test.go index 4a090e7..a48e676 100644 --- a/channel/connection_test.go +++ b/channel/connection_test.go @@ -89,7 +89,7 @@ func TestConn_HandleRequests(t *testing.T) { // Mock OnMessage callback received := make(chan struct{}) - conn.onMessageCB = func(c wasabi.Connection, data []byte) { received <- struct{}{} } + conn.onMessageCB = func(c wasabi.Connection, msgType wasabi.MessageType, data []byte) { received <- struct{}{} } go conn.HandleRequests() @@ -126,7 +126,7 @@ func TestConn_Send(t *testing.T) { onClose := make(chan string) conn := NewConnection(context.Background(), ws, nil, onClose) - err = conn.Send([]byte("test message")) + err = conn.Send(wasabi.MsgTypeText, []byte("test message")) if err != nil { t.Errorf("Unexpected error sending message: %v", err) } diff --git a/dispatch/common.go b/dispatch/common.go index 312b5b5..b4b8d30 100644 --- a/dispatch/common.go +++ b/dispatch/common.go @@ -13,4 +13,4 @@ func (f RequestHandlerFunc) Handle(conn wasabi.Connection, req wasabi.Request) e return f(conn, req) } -type RequestParser func(conn wasabi.Connection, data []byte) wasabi.Request +type RequestParser func(conn wasabi.Connection, msgType wasabi.MessageType, data []byte) wasabi.Request diff --git a/dispatch/pipe_dispatcher.go b/dispatch/pipe_dispatcher.go index 4557a1c..75f0be0 100644 --- a/dispatch/pipe_dispatcher.go +++ b/dispatch/pipe_dispatcher.go @@ -19,8 +19,8 @@ func NewPipeDispatcher(backend wasabi.Backend) *PipeDispatcher { } // Dispatch dispatches request to backend -func (d *PipeDispatcher) Dispatch(conn wasabi.Connection, data []byte) { - req := NewRawRequest(conn.Context(), data) +func (d *PipeDispatcher) Dispatch(conn wasabi.Connection, msgType wasabi.MessageType, data []byte) { + req := NewRawRequest(conn.Context(), msgType, data) err := d.useMiddleware(d.backend).Handle(conn, req) if err != nil { diff --git a/dispatch/pipe_dispatcher_test.go b/dispatch/pipe_dispatcher_test.go index 7da2ad3..e4335c1 100644 --- a/dispatch/pipe_dispatcher_test.go +++ b/dispatch/pipe_dispatcher_test.go @@ -31,9 +31,9 @@ func TestPipeDispatcher_Dispatch(t *testing.T) { testError := fmt.Errorf("test error") conn.On("Context").Return(context.Background()) - backend.EXPECT().Handle(conn, NewRawRequest(conn.Context(), data)).Return(testError) + backend.EXPECT().Handle(conn, NewRawRequest(conn.Context(), wasabi.MsgTypeText, data)).Return(testError) - dispatcher.Dispatch(conn, data) + dispatcher.Dispatch(conn, wasabi.MsgTypeText, data) } func TestPipeDispatcher_Use(t *testing.T) { diff --git a/dispatch/request.go b/dispatch/request.go index bb9fcf6..5158e1e 100644 --- a/dispatch/request.go +++ b/dispatch/request.go @@ -7,16 +7,17 @@ import ( ) type RawRequest struct { - ctx context.Context - data []byte + ctx context.Context + data []byte + msgType wasabi.MessageType } -func NewRawRequest(ctx context.Context, data []byte) *RawRequest { +func NewRawRequest(ctx context.Context, msgType wasabi.MessageType, data []byte) *RawRequest { if ctx == nil { panic("nil context") } - return &RawRequest{ctx: ctx, data: data} + return &RawRequest{ctx: ctx, data: data, msgType: msgType} } func (r *RawRequest) Data() []byte { @@ -24,7 +25,14 @@ func (r *RawRequest) Data() []byte { } func (r *RawRequest) RoutingKey() string { - return "" + switch r.msgType { + case wasabi.MsgTypeText: + return "text" + case wasabi.MsgTypeBinary: + return "binary" + default: + panic("unknown message type " + r.msgType.String()) + } } func (r *RawRequest) Context() context.Context { diff --git a/dispatch/request_test.go b/dispatch/request_test.go index 260be4b..6d2a3ab 100644 --- a/dispatch/request_test.go +++ b/dispatch/request_test.go @@ -4,11 +4,13 @@ import ( "bytes" "context" "testing" + + "github.com/ksysoev/wasabi" ) func TestRawRequest_Data(t *testing.T) { data := []byte("test data") - req := NewRawRequest(context.Background(), data) + req := NewRawRequest(context.Background(), wasabi.MsgTypeText, data) if !bytes.Equal(req.Data(), data) { t.Errorf("Expected data to be '%s', but got '%s'", data, req.Data()) @@ -16,16 +18,16 @@ func TestRawRequest_Data(t *testing.T) { } func TestRawRequest_RoutingKey(t *testing.T) { - req := NewRawRequest(context.Background(), []byte{}) + req := NewRawRequest(context.Background(), wasabi.MsgTypeText, []byte{}) - if req.RoutingKey() != "" { + if req.RoutingKey() != "text" { t.Errorf("Expected routing key to be empty, but got %v", req.RoutingKey()) } } func TestRawRequest_Context(t *testing.T) { ctx := context.Background() - req := NewRawRequest(ctx, []byte{}) + req := NewRawRequest(ctx, wasabi.MsgTypeText, []byte{}) if req.Context() != ctx { t.Errorf("Expected context to be %v, but got %v", ctx, req.Context()) @@ -34,7 +36,7 @@ func TestRawRequest_Context(t *testing.T) { func TestRawRequest_WithContext(t *testing.T) { ctx := context.Background() - req := NewRawRequest(context.Background(), []byte{}) + req := NewRawRequest(context.Background(), wasabi.MsgTypeText, []byte{}) newReq := req.WithContext(ctx) diff --git a/dispatch/router_dipatcher.go b/dispatch/router_dipatcher.go index adf61cb..4da386d 100644 --- a/dispatch/router_dipatcher.go +++ b/dispatch/router_dipatcher.go @@ -43,8 +43,8 @@ func (d *RouterDispatcher) AddBackend(backend wasabi.Backend, routingKeys []stri // Dispatch handles the incoming connection and data by parsing the request, // determining the appropriate backend, and handling the request using middleware. // If an error occurs during handling, it is logged. -func (d *RouterDispatcher) Dispatch(conn wasabi.Connection, data []byte) { - req := d.parser(conn, data) +func (d *RouterDispatcher) Dispatch(conn wasabi.Connection, msgType wasabi.MessageType, data []byte) { + req := d.parser(conn, msgType, data) if req == nil { return diff --git a/dispatch/router_dipatcher_test.go b/dispatch/router_dipatcher_test.go index a332273..6c92666 100644 --- a/dispatch/router_dipatcher_test.go +++ b/dispatch/router_dipatcher_test.go @@ -11,7 +11,9 @@ import ( func TestNewRouterDispatcher(t *testing.T) { defaultBackend := mocks.NewMockBackend(t) - parser := func(_ wasabi.Connection, _ []byte) wasabi.Request { return mocks.NewMockRequest(t) } + parser := func(_ wasabi.Connection, _ wasabi.MessageType, _ []byte) wasabi.Request { + return mocks.NewMockRequest(t) + } dispatcher := NewRouterDispatcher(defaultBackend, parser) @@ -25,7 +27,9 @@ func TestNewRouterDispatcher(t *testing.T) { } func TestRouterDispatcher_AddBackend(t *testing.T) { defaultBackend := mocks.NewMockBackend(t) - parser := func(_ wasabi.Connection, _ []byte) wasabi.Request { return mocks.NewMockRequest(t) } + parser := func(_ wasabi.Connection, _ wasabi.MessageType, _ []byte) wasabi.Request { + return mocks.NewMockRequest(t) + } dispatcher := NewRouterDispatcher(defaultBackend, parser) backend := mocks.NewMockBackend(t) @@ -58,7 +62,7 @@ func TestRouterDispatcher_DispatchDefault(t *testing.T) { defaultBackend := mocks.NewMockBackend(t) req := mocks.NewMockRequest(t) - parser := func(_ wasabi.Connection, _ []byte) wasabi.Request { return req } + parser := func(_ wasabi.Connection, _ wasabi.MessageType, _ []byte) wasabi.Request { return req } dispatcher := NewRouterDispatcher(defaultBackend, parser) conn := mocks.NewMockConnection(t) @@ -70,13 +74,13 @@ func TestRouterDispatcher_DispatchDefault(t *testing.T) { defaultBackend.EXPECT().Handle(conn, req).Return(nil) - dispatcher.Dispatch(conn, data) + dispatcher.Dispatch(conn, wasabi.MsgTypeText, data) } func TestRouterDispatcher_DispatchByRoutingKey(t *testing.T) { defaultBackend := mocks.NewMockBackend(t) req := mocks.NewMockRequest(t) - parser := func(_ wasabi.Connection, _ []byte) wasabi.Request { return req } + parser := func(_ wasabi.Connection, _ wasabi.MessageType, _ []byte) wasabi.Request { return req } dispatcher := NewRouterDispatcher(defaultBackend, parser) conn := mocks.NewMockConnection(t) @@ -90,24 +94,24 @@ func TestRouterDispatcher_DispatchByRoutingKey(t *testing.T) { mockBackend.EXPECT().Handle(conn, req).Return(nil) dispatcher.backendMap[routingKey] = mockBackend - dispatcher.Dispatch(conn, data) + dispatcher.Dispatch(conn, wasabi.MsgTypeText, data) } func TestRouterDispatcher_DispatchWrongRequest(t *testing.T) { defaultBackend := mocks.NewMockBackend(t) - parser := func(_ wasabi.Connection, _ []byte) wasabi.Request { return nil } + parser := func(_ wasabi.Connection, _ wasabi.MessageType, _ []byte) wasabi.Request { return nil } dispatcher := NewRouterDispatcher(defaultBackend, parser) conn := mocks.NewMockConnection(t) data := []byte("test data") - dispatcher.Dispatch(conn, data) + dispatcher.Dispatch(conn, wasabi.MsgTypeText, data) } func TestRouterDispatcher_DispatchErrorHandlingRequest(t *testing.T) { defaultBackend := mocks.NewMockBackend(t) req := mocks.NewMockRequest(t) - parser := func(_ wasabi.Connection, _ []byte) wasabi.Request { return req } + parser := func(_ wasabi.Connection, _ wasabi.MessageType, _ []byte) wasabi.Request { return req } dispatcher := NewRouterDispatcher(defaultBackend, parser) conn := mocks.NewMockConnection(t) @@ -120,11 +124,13 @@ func TestRouterDispatcher_DispatchErrorHandlingRequest(t *testing.T) { mockBackend.EXPECT().Handle(conn, req).Return(fmt.Errorf("test error")) dispatcher.backendMap[routingKey] = mockBackend - dispatcher.Dispatch(conn, data) + dispatcher.Dispatch(conn, wasabi.MsgTypeText, data) } func TestRouterDispatcher_Use(t *testing.T) { defaultBackend := mocks.NewMockBackend(t) - parser := func(_ wasabi.Connection, _ []byte) wasabi.Request { return mocks.NewMockRequest(t) } + parser := func(_ wasabi.Connection, _ wasabi.MessageType, _ []byte) wasabi.Request { + return mocks.NewMockRequest(t) + } dispatcher := NewRouterDispatcher(defaultBackend, parser) middleware := RequestMiddlewere(func(next wasabi.RequestHandler) wasabi.RequestHandler { return next }) @@ -143,7 +149,9 @@ func TestRouterDispatcher_UseMiddleware(t *testing.T) { defaultBackend := mocks.NewMockBackend(t) defaultBackend.EXPECT().Handle(mockConn, mockReq).Return(testError) - parser := func(_ wasabi.Connection, _ []byte) wasabi.Request { return mocks.NewMockRequest(t) } + parser := func(_ wasabi.Connection, _ wasabi.MessageType, _ []byte) wasabi.Request { + return mocks.NewMockRequest(t) + } dispatcher := NewRouterDispatcher(defaultBackend, parser) middleware1 := RequestMiddlewere(func(next wasabi.RequestHandler) wasabi.RequestHandler { return next }) diff --git a/examples/http_backend/main.go b/examples/http_backend/main.go index 8330540..f4f5bd9 100644 --- a/examples/http_backend/main.go +++ b/examples/http_backend/main.go @@ -32,7 +32,7 @@ func main() { }) ErrHandler := request.NewErrorHandlingMiddleware(func(conn wasabi.Connection, req wasabi.Request, err error) error { - conn.Send([]byte("Failed to process request: " + err.Error())) + conn.Send(wasabi.MsgTypeText, []byte("Failed to process request: "+err.Error())) return nil }) diff --git a/interfaces.go b/interfaces.go index 7b2fbb5..4b76fff 100644 --- a/interfaces.go +++ b/interfaces.go @@ -3,6 +3,15 @@ package wasabi import ( "context" "net/http" + + "nhooyr.io/websocket" +) + +type MessageType = websocket.MessageType + +const ( + MsgTypeText MessageType = websocket.MessageText + MsgTypeBinary MessageType = websocket.MessageBinary ) type Request interface { @@ -18,15 +27,15 @@ type Backend interface { // Dispatcher is interface for dispatchers type Dispatcher interface { - Dispatch(conn Connection, data []byte) + Dispatch(conn Connection, msgType MessageType, data []byte) } // OnMessage is type for OnMessage callback -type OnMessage func(conn Connection, data []byte) +type OnMessage func(conn Connection, msgType MessageType, data []byte) // Connection is interface for connections type Connection interface { - Send(msg []byte) error + Send(msgType MessageType, msg []byte) error Context() context.Context ID() string HandleRequests() diff --git a/mocks/mock_Connection.go b/mocks/mock_Connection.go index 04d5fb1..4409b58 100644 --- a/mocks/mock_Connection.go +++ b/mocks/mock_Connection.go @@ -8,6 +8,8 @@ import ( context "context" mock "github.com/stretchr/testify/mock" + + websocket "nhooyr.io/websocket" ) // MockConnection is an autogenerated mock type for the Connection type @@ -147,17 +149,17 @@ func (_c *MockConnection_ID_Call) RunAndReturn(run func() string) *MockConnectio return _c } -// Send provides a mock function with given fields: msg -func (_m *MockConnection) Send(msg []byte) error { - ret := _m.Called(msg) +// Send provides a mock function with given fields: msgType, msg +func (_m *MockConnection) Send(msgType websocket.MessageType, msg []byte) error { + ret := _m.Called(msgType, msg) if len(ret) == 0 { panic("no return value specified for Send") } var r0 error - if rf, ok := ret.Get(0).(func([]byte) error); ok { - r0 = rf(msg) + if rf, ok := ret.Get(0).(func(websocket.MessageType, []byte) error); ok { + r0 = rf(msgType, msg) } else { r0 = ret.Error(0) } @@ -171,14 +173,15 @@ type MockConnection_Send_Call struct { } // Send is a helper method to define mock.On call +// - msgType websocket.MessageType // - msg []byte -func (_e *MockConnection_Expecter) Send(msg interface{}) *MockConnection_Send_Call { - return &MockConnection_Send_Call{Call: _e.mock.On("Send", msg)} +func (_e *MockConnection_Expecter) Send(msgType interface{}, msg interface{}) *MockConnection_Send_Call { + return &MockConnection_Send_Call{Call: _e.mock.On("Send", msgType, msg)} } -func (_c *MockConnection_Send_Call) Run(run func(msg []byte)) *MockConnection_Send_Call { +func (_c *MockConnection_Send_Call) Run(run func(msgType websocket.MessageType, msg []byte)) *MockConnection_Send_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]byte)) + run(args[0].(websocket.MessageType), args[1].([]byte)) }) return _c } @@ -188,7 +191,7 @@ func (_c *MockConnection_Send_Call) Return(_a0 error) *MockConnection_Send_Call return _c } -func (_c *MockConnection_Send_Call) RunAndReturn(run func([]byte) error) *MockConnection_Send_Call { +func (_c *MockConnection_Send_Call) RunAndReturn(run func(websocket.MessageType, []byte) error) *MockConnection_Send_Call { _c.Call.Return(run) return _c } diff --git a/mocks/mock_Dispatcher.go b/mocks/mock_Dispatcher.go index ba287cb..0ee4d86 100644 --- a/mocks/mock_Dispatcher.go +++ b/mocks/mock_Dispatcher.go @@ -7,6 +7,8 @@ package mocks import ( wasabi "github.com/ksysoev/wasabi" mock "github.com/stretchr/testify/mock" + + websocket "nhooyr.io/websocket" ) // MockDispatcher is an autogenerated mock type for the Dispatcher type @@ -22,9 +24,9 @@ func (_m *MockDispatcher) EXPECT() *MockDispatcher_Expecter { return &MockDispatcher_Expecter{mock: &_m.Mock} } -// Dispatch provides a mock function with given fields: conn, data -func (_m *MockDispatcher) Dispatch(conn wasabi.Connection, data []byte) { - _m.Called(conn, data) +// Dispatch provides a mock function with given fields: conn, msgType, data +func (_m *MockDispatcher) Dispatch(conn wasabi.Connection, msgType websocket.MessageType, data []byte) { + _m.Called(conn, msgType, data) } // MockDispatcher_Dispatch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Dispatch' @@ -34,14 +36,15 @@ type MockDispatcher_Dispatch_Call struct { // Dispatch is a helper method to define mock.On call // - conn wasabi.Connection +// - msgType websocket.MessageType // - data []byte -func (_e *MockDispatcher_Expecter) Dispatch(conn interface{}, data interface{}) *MockDispatcher_Dispatch_Call { - return &MockDispatcher_Dispatch_Call{Call: _e.mock.On("Dispatch", conn, data)} +func (_e *MockDispatcher_Expecter) Dispatch(conn interface{}, msgType interface{}, data interface{}) *MockDispatcher_Dispatch_Call { + return &MockDispatcher_Dispatch_Call{Call: _e.mock.On("Dispatch", conn, msgType, data)} } -func (_c *MockDispatcher_Dispatch_Call) Run(run func(conn wasabi.Connection, data []byte)) *MockDispatcher_Dispatch_Call { +func (_c *MockDispatcher_Dispatch_Call) Run(run func(conn wasabi.Connection, msgType websocket.MessageType, data []byte)) *MockDispatcher_Dispatch_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(wasabi.Connection), args[1].([]byte)) + run(args[0].(wasabi.Connection), args[1].(websocket.MessageType), args[2].([]byte)) }) return _c } @@ -51,7 +54,7 @@ func (_c *MockDispatcher_Dispatch_Call) Return() *MockDispatcher_Dispatch_Call { return _c } -func (_c *MockDispatcher_Dispatch_Call) RunAndReturn(run func(wasabi.Connection, []byte)) *MockDispatcher_Dispatch_Call { +func (_c *MockDispatcher_Dispatch_Call) RunAndReturn(run func(wasabi.Connection, websocket.MessageType, []byte)) *MockDispatcher_Dispatch_Call { _c.Call.Return(run) return _c } From 5a76f1312b81b2bf6ad164aca73810791037995b Mon Sep 17 00:00:00 2001 From: Kirill Sysoev Date: Sat, 13 Apr 2024 19:44:39 +0800 Subject: [PATCH 3/3] Improves error handling on reading from WS connections --- channel/connection.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/channel/connection.go b/channel/connection.go index b874857..b8fceca 100644 --- a/channel/connection.go +++ b/channel/connection.go @@ -69,14 +69,18 @@ func (c *Conn) HandleRequests() { msgType, reader, err := c.ws.Reader(c.ctx) if err != nil { - slog.Warn("Error reading message: " + err.Error()) - return } data, err := io.ReadAll(reader) - if err != nil { + switch { + case errors.Is(err, io.EOF): + return + case errors.Is(err, context.Canceled): + return + } + slog.Warn("Error reading message: " + err.Error()) return