Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace websocket library #1

Merged
merged 3 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (b *HTTPBackend) Handle(conn wasabi.Connection, r wasabi.Request) error {
return err
}

return conn.Send(respBody.String())
return conn.Send(wasabi.MsgTypeText, respBody.Bytes())
}

func WithDefaultHTTPTimeout(timeout time.Duration) HTTPBackendOption {
Expand Down
2 changes: 1 addition & 1 deletion backend/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestHTTPBackend_Handle(t *testing.T) {

mockReq.EXPECT().Context().Return(context.Background())

mockConn.EXPECT().Send("OK").Return(nil)
mockConn.EXPECT().Send(wasabi.MsgTypeText, []byte("OK")).Return(nil)
mockReq.EXPECT().Data().Return([]byte("test request"))

backend := NewBackend(func(req wasabi.Request) (*http.Request, error) {
Expand Down
10 changes: 8 additions & 2 deletions channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"net/http"

"github.com/ksysoev/wasabi"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"
)

// DefaultChannel is default implementation of Channel
Expand Down Expand Up @@ -51,7 +51,13 @@ func (c *DefaultChannel) Handler() http.Handler {
})
}

wsHandler := websocket.Handler(func(ws *websocket.Conn) {
wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := websocket.Accept(w, r, nil)

if err != nil {
return
}

conn := c.connRegistry.AddConnection(ctx, ws, c.disptacher.Dispatch)
conn.HandleRequests()
})
Expand Down
29 changes: 13 additions & 16 deletions channel/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ import (
"errors"
"io"
"log/slog"
"net"
"sync"
"sync/atomic"

"github.com/google/uuid"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"

"github.com/ksysoev/wasabi"
)
Expand Down Expand Up @@ -67,44 +66,42 @@ func (c *Conn) HandleRequests() {
defer c.close()

for c.ctx.Err() == nil {
var data []byte
err := websocket.Message.Receive(c.ws, &data)
msgType, reader, err := c.ws.Reader(c.ctx)

if err != nil {
return
}

data, err := io.ReadAll(reader)
if err != nil {
switch {
case c.isClosed.Load():
return
case errors.Is(err, io.EOF):
return
case errors.Is(err, websocket.ErrFrameTooLarge):
return
case errors.Is(err, net.ErrClosed):
case errors.Is(err, context.Canceled):
return
default:
slog.Warn("Error reading message: ", err)
}

slog.Warn("Error reading message: " + err.Error())

continue
return
}

c.reqWG.Add(1)

go func(wg *sync.WaitGroup) {
defer wg.Done()
c.onMessageCB(c, data)
c.onMessageCB(c, msgType, data)
}(c.reqWG)
}
}

// Send sends message to connection
func (c *Conn) Send(msg any) error {
func (c *Conn) Send(msgType wasabi.MessageType, msg []byte) error {
if c.isClosed.Load() || c.ctx.Err() != nil {
return ErrConnectionClosed
}

return websocket.Message.Send(c.ws, msg)
return c.ws.Write(c.ctx, msgType, msg)
}

// close closes the connection.
Expand All @@ -119,6 +116,6 @@ func (c *Conn) close() {
c.onClose <- c.id
c.isClosed.Store(true)

c.ws.Close()
_ = c.ws.CloseNow()
c.reqWG.Wait()
}
2 changes: 1 addition & 1 deletion channel/connection_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"sync"

"github.com/ksysoev/wasabi"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"
)

// DefaultConnectionRegistry is default implementation of ConnectionRegistry
Expand Down
4 changes: 2 additions & 2 deletions channel/connection_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import (

"github.com/ksysoev/wasabi"
"github.com/ksysoev/wasabi/mocks"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"
)

func TestDefaultConnectionRegistry_AddConnection(t *testing.T) {
ctx := context.Background()
ws := &websocket.Conn{}
cb := func(wasabi.Connection, []byte) {}
cb := func(wasabi.Connection, wasabi.MessageType, []byte) {}

registry := NewDefaultConnectionRegistry()

Expand Down
85 changes: 67 additions & 18 deletions channel/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,48 @@ package channel
import (
"context"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/ksysoev/wasabi"
"golang.org/x/net/websocket"
"nhooyr.io/websocket"
)

var wsHandlerEcho = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
if err != nil {
return
}
defer c.Close(websocket.StatusNormalClosure, "")

for {
_, wsr, err := c.Reader(r.Context())
if err != nil {
if err == io.EOF {
return
}
return
}

wsw, err := c.Writer(r.Context(), websocket.MessageText)
if err != nil {
return
}

_, err = io.Copy(wsw, wsr)
if err != nil {
return
}

err = wsw.Close()
if err != nil {
return
}
}
})

func TestConn_ID(t *testing.T) {
ws := &websocket.Conn{}
onClose := make(chan string)
Expand All @@ -33,71 +66,87 @@ func TestConn_Context(t *testing.T) {
}

func TestConn_HandleRequests(t *testing.T) {
server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) }))
server := httptest.NewServer(wsHandlerEcho)
defer server.Close()

url := "ws://" + server.Listener.Addr().String()

ws, err := websocket.Dial(url, "", "http://localhost/")
ws, resp, err := websocket.Dial(context.Background(), url, nil)

if err != nil {
t.Errorf("Unexpected error dialing websocket: %v", err)
}

defer ws.Close()
if resp.Body != nil {
resp.Body.Close()
}

defer func() { _ = ws.CloseNow() }()

onClose := make(chan string)
conn := NewConnection(context.Background(), ws, nil, onClose)

// Mock OnMessage callback
var wg sync.WaitGroup
received := make(chan struct{})

wg.Add(1)

conn.onMessageCB = func(c wasabi.Connection, data []byte) { wg.Done() }
conn.onMessageCB = func(c wasabi.Connection, msgType wasabi.MessageType, data []byte) { received <- struct{}{} }

go conn.HandleRequests()

// Send message to trigger OnMessage callback
err = websocket.Message.Send(ws, []byte("test message"))
err = ws.Write(context.Background(), websocket.MessageText, []byte("test message"))
if err != nil {
t.Errorf("Unexpected error sending message: %v", err)
}

wg.Wait()
select {
case <-received:
// Expected
case <-time.After(50 * time.Millisecond):
t.Error("Expected OnMessage callback to be called")
}
}

func TestConn_Send(t *testing.T) {
server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) }))
server := httptest.NewServer(wsHandlerEcho)
defer server.Close()
url := "ws://" + server.Listener.Addr().String()

ws, err := websocket.Dial(url, "", "http://localhost/")
ws, resp, err := websocket.Dial(context.Background(), url, nil)
if err != nil {
t.Errorf("Unexpected error dialing websocket: %v", err)
}

defer ws.Close()
if resp.Body != nil {
resp.Body.Close()
}

defer func() { _ = ws.CloseNow() }()

onClose := make(chan string)
conn := NewConnection(context.Background(), ws, nil, onClose)

err = conn.Send([]byte("test message"))
err = conn.Send(wasabi.MsgTypeText, []byte("test message"))
if err != nil {
t.Errorf("Unexpected error sending message: %v", err)
}
}

func TestConn_close(t *testing.T) {
server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { _, _ = io.Copy(ws, ws) }))
server := httptest.NewServer(wsHandlerEcho)
defer server.Close()
url := "ws://" + server.Listener.Addr().String()

ws, err := websocket.Dial(url, "", "http://localhost/")
ws, resp, err := websocket.Dial(context.Background(), url, nil)
if err != nil {
t.Errorf("Unexpected error dialing websocket: %v", err)
}

defer ws.Close()
if resp.Body != nil {
resp.Body.Close()
}

defer func() { _ = ws.CloseNow() }()

onClose := make(chan string)
conn := NewConnection(context.Background(), ws, nil, onClose)
Expand Down
2 changes: 1 addition & 1 deletion dispatch/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ func (f RequestHandlerFunc) Handle(conn wasabi.Connection, req wasabi.Request) e
return f(conn, req)
}

type RequestParser func(conn wasabi.Connection, data []byte) wasabi.Request
type RequestParser func(conn wasabi.Connection, msgType wasabi.MessageType, data []byte) wasabi.Request
4 changes: 2 additions & 2 deletions dispatch/pipe_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ func NewPipeDispatcher(backend wasabi.Backend) *PipeDispatcher {
}

// Dispatch dispatches request to backend
func (d *PipeDispatcher) Dispatch(conn wasabi.Connection, data []byte) {
req := NewRawRequest(conn.Context(), data)
func (d *PipeDispatcher) Dispatch(conn wasabi.Connection, msgType wasabi.MessageType, data []byte) {
req := NewRawRequest(conn.Context(), msgType, data)

err := d.useMiddleware(d.backend).Handle(conn, req)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions dispatch/pipe_dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ func TestPipeDispatcher_Dispatch(t *testing.T) {
testError := fmt.Errorf("test error")

conn.On("Context").Return(context.Background())
backend.EXPECT().Handle(conn, NewRawRequest(conn.Context(), data)).Return(testError)
backend.EXPECT().Handle(conn, NewRawRequest(conn.Context(), wasabi.MsgTypeText, data)).Return(testError)

dispatcher.Dispatch(conn, data)
dispatcher.Dispatch(conn, wasabi.MsgTypeText, data)
}

func TestPipeDispatcher_Use(t *testing.T) {
Expand Down
18 changes: 13 additions & 5 deletions dispatch/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,32 @@ import (
)

type RawRequest struct {
ctx context.Context
data []byte
ctx context.Context
data []byte
msgType wasabi.MessageType
}

func NewRawRequest(ctx context.Context, data []byte) *RawRequest {
func NewRawRequest(ctx context.Context, msgType wasabi.MessageType, data []byte) *RawRequest {
if ctx == nil {
panic("nil context")
}

return &RawRequest{ctx: ctx, data: data}
return &RawRequest{ctx: ctx, data: data, msgType: msgType}
}

func (r *RawRequest) Data() []byte {
return r.data
}

func (r *RawRequest) RoutingKey() string {
return ""
switch r.msgType {
case wasabi.MsgTypeText:
return "text"
case wasabi.MsgTypeBinary:
return "binary"
default:
panic("unknown message type " + r.msgType.String())
}
}

func (r *RawRequest) Context() context.Context {
Expand Down
Loading
Loading