Skip to content

Commit

Permalink
Replace websocket library
Browse files Browse the repository at this point in the history
  • Loading branch information
ksysoev committed Apr 13, 2024
1 parent ec4eec9 commit e0acb79
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 54 deletions.
2 changes: 1 addition & 1 deletion backend/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (b *HTTPBackend) Handle(conn wasabi.Connection, r wasabi.Request) error {
return err
}

return conn.Send(respBody.String())
return conn.Send(respBody.Bytes())
}

func WithDefaultHTTPTimeout(timeout time.Duration) HTTPBackendOption {
Expand Down
2 changes: 1 addition & 1 deletion backend/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestHTTPBackend_Handle(t *testing.T) {

mockReq.EXPECT().Context().Return(context.Background())

mockConn.EXPECT().Send("OK").Return(nil)
mockConn.EXPECT().Send([]byte("OK")).Return(nil)
mockReq.EXPECT().Data().Return([]byte("test request"))

backend := NewBackend(func(req wasabi.Request) (*http.Request, error) {
Expand Down
10 changes: 8 additions & 2 deletions channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"net/http"

"github.com/ksysoev/wasabi"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"
)

// DefaultChannel is default implementation of Channel
Expand Down Expand Up @@ -51,7 +51,13 @@ func (c *DefaultChannel) Handler() http.Handler {
})
}

wsHandler := websocket.Handler(func(ws *websocket.Conn) {
wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := websocket.Accept(w, r, nil)

if err != nil {
return
}

conn := c.connRegistry.AddConnection(ctx, ws, c.disptacher.Dispatch)
conn.HandleRequests()
})
Expand Down
33 changes: 13 additions & 20 deletions channel/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ import (
"errors"
"io"
"log/slog"
"net"
"sync"
"sync/atomic"

"github.com/google/uuid"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"

"github.com/ksysoev/wasabi"
)
Expand Down Expand Up @@ -67,26 +66,20 @@ func (c *Conn) HandleRequests() {
defer c.close()

for c.ctx.Err() == nil {
var data []byte
err := websocket.Message.Receive(c.ws, &data)
_, reader, err := c.ws.Reader(c.ctx)

if err != nil {
switch {
case c.isClosed.Load():
return
case errors.Is(err, io.EOF):
return
case errors.Is(err, websocket.ErrFrameTooLarge):
return
case errors.Is(err, net.ErrClosed):
return
default:
slog.Warn("Error reading message: ", err)
}
slog.Warn("Error reading message: " + err.Error())

return
}

data, err := io.ReadAll(reader)

if err != nil {
slog.Warn("Error reading message: " + err.Error())

continue
return
}

c.reqWG.Add(1)
Expand All @@ -99,12 +92,12 @@ func (c *Conn) HandleRequests() {
}

// Send sends message to connection
func (c *Conn) Send(msg any) error {
func (c *Conn) Send(msg []byte) error {
if c.isClosed.Load() || c.ctx.Err() != nil {
return ErrConnectionClosed
}

return websocket.Message.Send(c.ws, msg)
return c.ws.Write(c.ctx, websocket.MessageText, msg)
}

// close closes the connection.
Expand All @@ -119,6 +112,6 @@ func (c *Conn) close() {
c.onClose <- c.id
c.isClosed.Store(true)

c.ws.Close()
_ = c.ws.CloseNow()
c.reqWG.Wait()
}
2 changes: 1 addition & 1 deletion channel/connection_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"sync"

"github.com/ksysoev/wasabi"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"
)

// DefaultConnectionRegistry is default implementation of ConnectionRegistry
Expand Down
2 changes: 1 addition & 1 deletion channel/connection_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

"github.com/ksysoev/wasabi"
"github.com/ksysoev/wasabi/mocks"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"
)

func TestDefaultConnectionRegistry_AddConnection(t *testing.T) {
Expand Down
83 changes: 66 additions & 17 deletions channel/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,48 @@ package channel
import (
"context"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/ksysoev/wasabi"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"
)

var wsHandlerEcho = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
if err != nil {
return
}
defer c.Close(websocket.StatusNormalClosure, "")

for {
_, wsr, err := c.Reader(r.Context())
if err != nil {
if err == io.EOF {
return
}
return
}

wsw, err := c.Writer(r.Context(), websocket.MessageText)
if err != nil {
return
}

_, err = io.Copy(wsw, wsr)
if err != nil {
return
}

err = wsw.Close()
if err != nil {
return
}
}
})

func TestConn_ID(t *testing.T) {
ws := &websocket.Conn{}
onClose := make(chan string)
Expand All @@ -33,50 +66,62 @@ func TestConn_Context(t *testing.T) {
}

func TestConn_HandleRequests(t *testing.T) {
server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) }))
server := httptest.NewServer(wsHandlerEcho)
defer server.Close()

url := "ws://" + server.Listener.Addr().String()

ws, err := websocket.Dial(url, "", "http://localhost/")
ws, resp, err := websocket.Dial(context.Background(), url, nil)

if err != nil {
t.Errorf("Unexpected error dialing websocket: %v", err)
}

defer ws.Close()
if resp.Body != nil {
resp.Body.Close()
}

defer func() { _ = ws.CloseNow() }()

onClose := make(chan string)
conn := NewConnection(context.Background(), ws, nil, onClose)

// Mock OnMessage callback
var wg sync.WaitGroup
received := make(chan struct{})

wg.Add(1)

conn.onMessageCB = func(c wasabi.Connection, data []byte) { wg.Done() }
conn.onMessageCB = func(c wasabi.Connection, data []byte) { received <- struct{}{} }

go conn.HandleRequests()

// Send message to trigger OnMessage callback
err = websocket.Message.Send(ws, []byte("test message"))
err = ws.Write(context.Background(), websocket.MessageText, []byte("test message"))
if err != nil {
t.Errorf("Unexpected error sending message: %v", err)
}

wg.Wait()
select {
case <-received:
// Expected
case <-time.After(50 * time.Millisecond):
t.Error("Expected OnMessage callback to be called")
}
}

func TestConn_Send(t *testing.T) {
server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) }))
server := httptest.NewServer(wsHandlerEcho)
defer server.Close()
url := "ws://" + server.Listener.Addr().String()

ws, err := websocket.Dial(url, "", "http://localhost/")
ws, resp, err := websocket.Dial(context.Background(), url, nil)
if err != nil {
t.Errorf("Unexpected error dialing websocket: %v", err)
}

defer ws.Close()
if resp.Body != nil {
resp.Body.Close()
}

defer func() { _ = ws.CloseNow() }()

onClose := make(chan string)
conn := NewConnection(context.Background(), ws, nil, onClose)
Expand All @@ -88,16 +133,20 @@ func TestConn_Send(t *testing.T) {
}

func TestConn_close(t *testing.T) {
server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) }))
server := httptest.NewServer(wsHandlerEcho)
defer server.Close()
url := "ws://" + server.Listener.Addr().String()

ws, err := websocket.Dial(url, "", "http://localhost/")
ws, resp, err := websocket.Dial(context.Background(), url, nil)
if err != nil {
t.Errorf("Unexpected error dialing websocket: %v", err)
}

defer ws.Close()
if resp.Body != nil {
resp.Body.Close()
}

defer func() { _ = ws.CloseNow() }()

onClose := make(chan string)
conn := NewConnection(context.Background(), ws, nil, onClose)
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ go 1.22.1

require (
github.com/google/uuid v1.5.0
github.com/ksysoev/ratestor v0.1.0
github.com/stretchr/testify v1.9.0
golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e
golang.org/x/net v0.19.0
nhooyr.io/websocket v1.8.11
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/ksysoev/ratestor v0.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e h1:723BNChdd0c2Wk6WOE320qGBiPtYx0F0Bbm1kriShfE=
golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0=
nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
2 changes: 1 addition & 1 deletion interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type OnMessage func(conn Connection, data []byte)

// Connection is interface for connections
type Connection interface {
Send(msg any) error
Send(msg []byte) error
Context() context.Context
ID() string
HandleRequests()
Expand Down
12 changes: 6 additions & 6 deletions mocks/mock_Connection.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e0acb79

Please sign in to comment.