diff --git a/backend/http.go b/backend/http.go index 325479d..e8a30e6 100644 --- a/backend/http.go +++ b/backend/http.go @@ -62,7 +62,7 @@ func (b *HTTPBackend) Handle(conn wasabi.Connection, r wasabi.Request) error { return err } - return conn.Send(respBody.String()) + return conn.Send(respBody.Bytes()) } func WithDefaultHTTPTimeout(timeout time.Duration) HTTPBackendOption { diff --git a/backend/http_test.go b/backend/http_test.go index 12c6c36..faf6420 100644 --- a/backend/http_test.go +++ b/backend/http_test.go @@ -41,7 +41,7 @@ func TestHTTPBackend_Handle(t *testing.T) { mockReq.EXPECT().Context().Return(context.Background()) - mockConn.EXPECT().Send("OK").Return(nil) + mockConn.EXPECT().Send([]byte("OK")).Return(nil) mockReq.EXPECT().Data().Return([]byte("test request")) backend := NewBackend(func(req wasabi.Request) (*http.Request, error) { diff --git a/channel/channel.go b/channel/channel.go index 623e75f..3ac426c 100644 --- a/channel/channel.go +++ b/channel/channel.go @@ -5,7 +5,7 @@ import ( "net/http" "github.com/ksysoev/wasabi" - "golang.org/x/net/websocket" + "nhooyr.io/websocket" ) // DefaultChannel is default implementation of Channel @@ -51,7 +51,13 @@ func (c *DefaultChannel) Handler() http.Handler { }) } - wsHandler := websocket.Handler(func(ws *websocket.Conn) { + wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := websocket.Accept(w, r, nil) + + if err != nil { + return + } + conn := c.connRegistry.AddConnection(ctx, ws, c.disptacher.Dispatch) conn.HandleRequests() }) diff --git a/channel/connection.go b/channel/connection.go index a311104..a2f5ec9 100644 --- a/channel/connection.go +++ b/channel/connection.go @@ -5,12 +5,11 @@ import ( "errors" "io" "log/slog" - "net" "sync" "sync/atomic" "github.com/google/uuid" - "golang.org/x/net/websocket" + "nhooyr.io/websocket" "github.com/ksysoev/wasabi" ) @@ -67,26 +66,20 @@ func (c *Conn) HandleRequests() { defer c.close() for c.ctx.Err() == nil { - var data []byte - err := websocket.Message.Receive(c.ws, &data) + _, reader, err := c.ws.Reader(c.ctx) if err != nil { - switch { - case c.isClosed.Load(): - return - case errors.Is(err, io.EOF): - return - case errors.Is(err, websocket.ErrFrameTooLarge): - return - case errors.Is(err, net.ErrClosed): - return - default: - slog.Warn("Error reading message: ", err) - } + slog.Warn("Error reading message: " + err.Error()) + + return + } + data, err := io.ReadAll(reader) + + if err != nil { slog.Warn("Error reading message: " + err.Error()) - continue + return } c.reqWG.Add(1) @@ -99,12 +92,12 @@ func (c *Conn) HandleRequests() { } // Send sends message to connection -func (c *Conn) Send(msg any) error { +func (c *Conn) Send(msg []byte) error { if c.isClosed.Load() || c.ctx.Err() != nil { return ErrConnectionClosed } - return websocket.Message.Send(c.ws, msg) + return c.ws.Write(c.ctx, websocket.MessageText, msg) } // close closes the connection. @@ -119,6 +112,6 @@ func (c *Conn) close() { c.onClose <- c.id c.isClosed.Store(true) - c.ws.Close() + _ = c.ws.CloseNow() c.reqWG.Wait() } diff --git a/channel/connection_registry.go b/channel/connection_registry.go index f695b5d..7e8f043 100644 --- a/channel/connection_registry.go +++ b/channel/connection_registry.go @@ -5,7 +5,7 @@ import ( "sync" "github.com/ksysoev/wasabi" - "golang.org/x/net/websocket" + "nhooyr.io/websocket" ) // DefaultConnectionRegistry is default implementation of ConnectionRegistry diff --git a/channel/connection_registry_test.go b/channel/connection_registry_test.go index 26371aa..b49bbfa 100644 --- a/channel/connection_registry_test.go +++ b/channel/connection_registry_test.go @@ -7,7 +7,7 @@ import ( "github.com/ksysoev/wasabi" "github.com/ksysoev/wasabi/mocks" - "golang.org/x/net/websocket" + "nhooyr.io/websocket" ) func TestDefaultConnectionRegistry_AddConnection(t *testing.T) { diff --git a/channel/connection_test.go b/channel/connection_test.go index eeeaaf7..4a090e7 100644 --- a/channel/connection_test.go +++ b/channel/connection_test.go @@ -3,15 +3,48 @@ package channel import ( "context" "io" + "net/http" "net/http/httptest" - "sync" "testing" "time" "github.com/ksysoev/wasabi" - "golang.org/x/net/websocket" + "nhooyr.io/websocket" ) +var wsHandlerEcho = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, nil) + if err != nil { + return + } + defer c.Close(websocket.StatusNormalClosure, "") + + for { + _, wsr, err := c.Reader(r.Context()) + if err != nil { + if err == io.EOF { + return + } + return + } + + wsw, err := c.Writer(r.Context(), websocket.MessageText) + if err != nil { + return + } + + _, err = io.Copy(wsw, wsr) + if err != nil { + return + } + + err = wsw.Close() + if err != nil { + return + } + } +}) + func TestConn_ID(t *testing.T) { ws := &websocket.Conn{} onClose := make(chan string) @@ -33,50 +66,62 @@ func TestConn_Context(t *testing.T) { } func TestConn_HandleRequests(t *testing.T) { - server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) })) + server := httptest.NewServer(wsHandlerEcho) defer server.Close() url := "ws://" + server.Listener.Addr().String() - ws, err := websocket.Dial(url, "", "http://localhost/") + ws, resp, err := websocket.Dial(context.Background(), url, nil) + if err != nil { t.Errorf("Unexpected error dialing websocket: %v", err) } - defer ws.Close() + if resp.Body != nil { + resp.Body.Close() + } + + defer func() { _ = ws.CloseNow() }() onClose := make(chan string) conn := NewConnection(context.Background(), ws, nil, onClose) // Mock OnMessage callback - var wg sync.WaitGroup + received := make(chan struct{}) - wg.Add(1) - - conn.onMessageCB = func(c wasabi.Connection, data []byte) { wg.Done() } + conn.onMessageCB = func(c wasabi.Connection, data []byte) { received <- struct{}{} } go conn.HandleRequests() // Send message to trigger OnMessage callback - err = websocket.Message.Send(ws, []byte("test message")) + err = ws.Write(context.Background(), websocket.MessageText, []byte("test message")) if err != nil { t.Errorf("Unexpected error sending message: %v", err) } - wg.Wait() + select { + case <-received: + // Expected + case <-time.After(50 * time.Millisecond): + t.Error("Expected OnMessage callback to be called") + } } func TestConn_Send(t *testing.T) { - server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) })) + server := httptest.NewServer(wsHandlerEcho) defer server.Close() url := "ws://" + server.Listener.Addr().String() - ws, err := websocket.Dial(url, "", "http://localhost/") + ws, resp, err := websocket.Dial(context.Background(), url, nil) if err != nil { t.Errorf("Unexpected error dialing websocket: %v", err) } - defer ws.Close() + if resp.Body != nil { + resp.Body.Close() + } + + defer func() { _ = ws.CloseNow() }() onClose := make(chan string) conn := NewConnection(context.Background(), ws, nil, onClose) @@ -88,16 +133,20 @@ func TestConn_Send(t *testing.T) { } func TestConn_close(t *testing.T) { - server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) })) + server := httptest.NewServer(wsHandlerEcho) defer server.Close() url := "ws://" + server.Listener.Addr().String() - ws, err := websocket.Dial(url, "", "http://localhost/") + ws, resp, err := websocket.Dial(context.Background(), url, nil) if err != nil { t.Errorf("Unexpected error dialing websocket: %v", err) } - defer ws.Close() + if resp.Body != nil { + resp.Body.Close() + } + + defer func() { _ = ws.CloseNow() }() onClose := make(chan string) conn := NewConnection(context.Background(), ws, nil, onClose) diff --git a/go.mod b/go.mod index 6867533..36c2c9b 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,14 @@ go 1.22.1 require ( github.com/google/uuid v1.5.0 + github.com/ksysoev/ratestor v0.1.0 github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e - golang.org/x/net v0.19.0 + nhooyr.io/websocket v1.8.11 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/ksysoev/ratestor v0.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 10e1689..ff8e672 100644 --- a/go.sum +++ b/go.sum @@ -12,9 +12,9 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e h1:723BNChdd0c2Wk6WOE320qGBiPtYx0F0Bbm1kriShfE= golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0= +nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= diff --git a/interfaces.go b/interfaces.go index 45b9218..7b2fbb5 100644 --- a/interfaces.go +++ b/interfaces.go @@ -26,7 +26,7 @@ type OnMessage func(conn Connection, data []byte) // Connection is interface for connections type Connection interface { - Send(msg any) error + Send(msg []byte) error Context() context.Context ID() string HandleRequests() diff --git a/mocks/mock_Connection.go b/mocks/mock_Connection.go index 6a85bb8..04d5fb1 100644 --- a/mocks/mock_Connection.go +++ b/mocks/mock_Connection.go @@ -148,7 +148,7 @@ func (_c *MockConnection_ID_Call) RunAndReturn(run func() string) *MockConnectio } // Send provides a mock function with given fields: msg -func (_m *MockConnection) Send(msg interface{}) error { +func (_m *MockConnection) Send(msg []byte) error { ret := _m.Called(msg) if len(ret) == 0 { @@ -156,7 +156,7 @@ func (_m *MockConnection) Send(msg interface{}) error { } var r0 error - if rf, ok := ret.Get(0).(func(interface{}) error); ok { + if rf, ok := ret.Get(0).(func([]byte) error); ok { r0 = rf(msg) } else { r0 = ret.Error(0) @@ -171,14 +171,14 @@ type MockConnection_Send_Call struct { } // Send is a helper method to define mock.On call -// - msg interface{} +// - msg []byte func (_e *MockConnection_Expecter) Send(msg interface{}) *MockConnection_Send_Call { return &MockConnection_Send_Call{Call: _e.mock.On("Send", msg)} } -func (_c *MockConnection_Send_Call) Run(run func(msg interface{})) *MockConnection_Send_Call { +func (_c *MockConnection_Send_Call) Run(run func(msg []byte)) *MockConnection_Send_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(interface{})) + run(args[0].([]byte)) }) return _c } @@ -188,7 +188,7 @@ func (_c *MockConnection_Send_Call) Return(_a0 error) *MockConnection_Send_Call return _c } -func (_c *MockConnection_Send_Call) RunAndReturn(run func(interface{}) error) *MockConnection_Send_Call { +func (_c *MockConnection_Send_Call) RunAndReturn(run func([]byte) error) *MockConnection_Send_Call { _c.Call.Return(run) return _c }