Skip to content

Commit

Permalink
Merge pull request #1 from ksysoev/replace_websocket_library
Browse files Browse the repository at this point in the history
Replace websocket library
  • Loading branch information
ksysoev committed Apr 13, 2024
2 parents ec4eec9 + 5a76f13 commit 3ab02ec
Show file tree
Hide file tree
Showing 20 changed files with 182 additions and 97 deletions.
2 changes: 1 addition & 1 deletion backend/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (b *HTTPBackend) Handle(conn wasabi.Connection, r wasabi.Request) error {
return err
}

return conn.Send(respBody.String())
return conn.Send(wasabi.MsgTypeText, respBody.Bytes())
}

func WithDefaultHTTPTimeout(timeout time.Duration) HTTPBackendOption {
Expand Down
2 changes: 1 addition & 1 deletion backend/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestHTTPBackend_Handle(t *testing.T) {

mockReq.EXPECT().Context().Return(context.Background())

mockConn.EXPECT().Send("OK").Return(nil)
mockConn.EXPECT().Send(wasabi.MsgTypeText, []byte("OK")).Return(nil)
mockReq.EXPECT().Data().Return([]byte("test request"))

backend := NewBackend(func(req wasabi.Request) (*http.Request, error) {
Expand Down
10 changes: 8 additions & 2 deletions channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"net/http"

"github.com/ksysoev/wasabi"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"
)

// DefaultChannel is default implementation of Channel
Expand Down Expand Up @@ -51,7 +51,13 @@ func (c *DefaultChannel) Handler() http.Handler {
})
}

wsHandler := websocket.Handler(func(ws *websocket.Conn) {
wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := websocket.Accept(w, r, nil)

if err != nil {
return
}

conn := c.connRegistry.AddConnection(ctx, ws, c.disptacher.Dispatch)
conn.HandleRequests()
})
Expand Down
29 changes: 13 additions & 16 deletions channel/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ import (
"errors"
"io"
"log/slog"
"net"
"sync"
"sync/atomic"

"github.com/google/uuid"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"

"github.com/ksysoev/wasabi"
)
Expand Down Expand Up @@ -67,44 +66,42 @@ func (c *Conn) HandleRequests() {
defer c.close()

for c.ctx.Err() == nil {
var data []byte
err := websocket.Message.Receive(c.ws, &data)
msgType, reader, err := c.ws.Reader(c.ctx)

if err != nil {
return
}

data, err := io.ReadAll(reader)
if err != nil {
switch {
case c.isClosed.Load():
return
case errors.Is(err, io.EOF):
return
case errors.Is(err, websocket.ErrFrameTooLarge):
return
case errors.Is(err, net.ErrClosed):
case errors.Is(err, context.Canceled):
return
default:
slog.Warn("Error reading message: ", err)
}

slog.Warn("Error reading message: " + err.Error())

continue
return
}

c.reqWG.Add(1)

go func(wg *sync.WaitGroup) {
defer wg.Done()
c.onMessageCB(c, data)
c.onMessageCB(c, msgType, data)
}(c.reqWG)
}
}

// Send sends message to connection
func (c *Conn) Send(msg any) error {
func (c *Conn) Send(msgType wasabi.MessageType, msg []byte) error {
if c.isClosed.Load() || c.ctx.Err() != nil {
return ErrConnectionClosed
}

return websocket.Message.Send(c.ws, msg)
return c.ws.Write(c.ctx, msgType, msg)
}

// close closes the connection.
Expand All @@ -119,6 +116,6 @@ func (c *Conn) close() {
c.onClose <- c.id
c.isClosed.Store(true)

c.ws.Close()
_ = c.ws.CloseNow()
c.reqWG.Wait()
}
2 changes: 1 addition & 1 deletion channel/connection_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"sync"

"github.com/ksysoev/wasabi"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"
)

// DefaultConnectionRegistry is default implementation of ConnectionRegistry
Expand Down
4 changes: 2 additions & 2 deletions channel/connection_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import (

"github.com/ksysoev/wasabi"
"github.com/ksysoev/wasabi/mocks"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"
)

func TestDefaultConnectionRegistry_AddConnection(t *testing.T) {
ctx := context.Background()
ws := &websocket.Conn{}
cb := func(wasabi.Connection, []byte) {}
cb := func(wasabi.Connection, wasabi.MessageType, []byte) {}

registry := NewDefaultConnectionRegistry()

Expand Down
85 changes: 67 additions & 18 deletions channel/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,48 @@ package channel
import (
"context"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/ksysoev/wasabi"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"
)

var wsHandlerEcho = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
if err != nil {
return
}
defer c.Close(websocket.StatusNormalClosure, "")

for {
_, wsr, err := c.Reader(r.Context())
if err != nil {
if err == io.EOF {
return
}
return
}

wsw, err := c.Writer(r.Context(), websocket.MessageText)
if err != nil {
return
}

_, err = io.Copy(wsw, wsr)
if err != nil {
return
}

err = wsw.Close()
if err != nil {
return
}
}
})

func TestConn_ID(t *testing.T) {
ws := &websocket.Conn{}
onClose := make(chan string)
Expand All @@ -33,71 +66,87 @@ func TestConn_Context(t *testing.T) {
}

func TestConn_HandleRequests(t *testing.T) {
server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) }))
server := httptest.NewServer(wsHandlerEcho)
defer server.Close()

url := "ws://" + server.Listener.Addr().String()

ws, err := websocket.Dial(url, "", "http://localhost/")
ws, resp, err := websocket.Dial(context.Background(), url, nil)

if err != nil {
t.Errorf("Unexpected error dialing websocket: %v", err)
}

defer ws.Close()
if resp.Body != nil {
resp.Body.Close()
}

defer func() { _ = ws.CloseNow() }()

onClose := make(chan string)
conn := NewConnection(context.Background(), ws, nil, onClose)

// Mock OnMessage callback
var wg sync.WaitGroup
received := make(chan struct{})

wg.Add(1)

conn.onMessageCB = func(c wasabi.Connection, data []byte) { wg.Done() }
conn.onMessageCB = func(c wasabi.Connection, msgType wasabi.MessageType, data []byte) { received <- struct{}{} }

go conn.HandleRequests()

// Send message to trigger OnMessage callback
err = websocket.Message.Send(ws, []byte("test message"))
err = ws.Write(context.Background(), websocket.MessageText, []byte("test message"))
if err != nil {
t.Errorf("Unexpected error sending message: %v", err)
}

wg.Wait()
select {
case <-received:
// Expected
case <-time.After(50 * time.Millisecond):
t.Error("Expected OnMessage callback to be called")
}
}

func TestConn_Send(t *testing.T) {
server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) }))
server := httptest.NewServer(wsHandlerEcho)
defer server.Close()
url := "ws://" + server.Listener.Addr().String()

ws, err := websocket.Dial(url, "", "http://localhost/")
ws, resp, err := websocket.Dial(context.Background(), url, nil)
if err != nil {
t.Errorf("Unexpected error dialing websocket: %v", err)
}

defer ws.Close()
if resp.Body != nil {
resp.Body.Close()
}

defer func() { _ = ws.CloseNow() }()

onClose := make(chan string)
conn := NewConnection(context.Background(), ws, nil, onClose)

err = conn.Send([]byte("test message"))
err = conn.Send(wasabi.MsgTypeText, []byte("test message"))
if err != nil {
t.Errorf("Unexpected error sending message: %v", err)
}
}

func TestConn_close(t *testing.T) {
server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) }))
server := httptest.NewServer(wsHandlerEcho)
defer server.Close()
url := "ws://" + server.Listener.Addr().String()

ws, err := websocket.Dial(url, "", "http://localhost/")
ws, resp, err := websocket.Dial(context.Background(), url, nil)
if err != nil {
t.Errorf("Unexpected error dialing websocket: %v", err)
}

defer ws.Close()
if resp.Body != nil {
resp.Body.Close()
}

defer func() { _ = ws.CloseNow() }()

onClose := make(chan string)
conn := NewConnection(context.Background(), ws, nil, onClose)
Expand Down
2 changes: 1 addition & 1 deletion dispatch/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ func (f RequestHandlerFunc) Handle(conn wasabi.Connection, req wasabi.Request) e
return f(conn, req)
}

type RequestParser func(conn wasabi.Connection, data []byte) wasabi.Request
type RequestParser func(conn wasabi.Connection, msgType wasabi.MessageType, data []byte) wasabi.Request
4 changes: 2 additions & 2 deletions dispatch/pipe_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ func NewPipeDispatcher(backend wasabi.Backend) *PipeDispatcher {
}

// Dispatch dispatches request to backend
func (d *PipeDispatcher) Dispatch(conn wasabi.Connection, data []byte) {
req := NewRawRequest(conn.Context(), data)
func (d *PipeDispatcher) Dispatch(conn wasabi.Connection, msgType wasabi.MessageType, data []byte) {
req := NewRawRequest(conn.Context(), msgType, data)

err := d.useMiddleware(d.backend).Handle(conn, req)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions dispatch/pipe_dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ func TestPipeDispatcher_Dispatch(t *testing.T) {
testError := fmt.Errorf("test error")

conn.On("Context").Return(context.Background())
backend.EXPECT().Handle(conn, NewRawRequest(conn.Context(), data)).Return(testError)
backend.EXPECT().Handle(conn, NewRawRequest(conn.Context(), wasabi.MsgTypeText, data)).Return(testError)

dispatcher.Dispatch(conn, data)
dispatcher.Dispatch(conn, wasabi.MsgTypeText, data)
}

func TestPipeDispatcher_Use(t *testing.T) {
Expand Down
18 changes: 13 additions & 5 deletions dispatch/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,32 @@ import (
)

type RawRequest struct {
ctx context.Context
data []byte
ctx context.Context
data []byte
msgType wasabi.MessageType
}

func NewRawRequest(ctx context.Context, data []byte) *RawRequest {
func NewRawRequest(ctx context.Context, msgType wasabi.MessageType, data []byte) *RawRequest {
if ctx == nil {
panic("nil context")
}

return &RawRequest{ctx: ctx, data: data}
return &RawRequest{ctx: ctx, data: data, msgType: msgType}
}

func (r *RawRequest) Data() []byte {
return r.data
}

func (r *RawRequest) RoutingKey() string {
return ""
switch r.msgType {
case wasabi.MsgTypeText:
return "text"
case wasabi.MsgTypeBinary:
return "binary"
default:
panic("unknown message type " + r.msgType.String())
}
}

func (r *RawRequest) Context() context.Context {
Expand Down
Loading

0 comments on commit 3ab02ec

Please sign in to comment.