Skip to content

Commit

Permalink
Merge pull request #49 from ksysoev/fix_duplicate_conn_openning
Browse files Browse the repository at this point in the history
Refactor WSBackend for connection management
  • Loading branch information
ksysoev authored May 18, 2024
2 parents 522ec97 + f1e2e9b commit dc190b1
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
42 changes: 30 additions & 12 deletions backend/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ package backend
import (
"bytes"
"errors"
"fmt"
"io"
"sync"

"github.com/ksysoev/wasabi"
"golang.org/x/sync/singleflight"
"nhooyr.io/websocket"
)

type WSBackend struct {
group *singleflight.Group
connections map[string]*websocket.Conn
lock *sync.RWMutex
factory WSRequestFactory
Expand All @@ -22,6 +25,7 @@ type WSRequestFactory func(r wasabi.Request) (websocket.MessageType, []byte, err
// NewWSBackend creates a new instance of WSBackend with the specified URL.
func NewWSBackend(url string, factory WSRequestFactory) *WSBackend {
return &WSBackend{
group: &singleflight.Group{},
connections: make(map[string]*websocket.Conn),
lock: &sync.RWMutex{},
factory: factory,
Expand Down Expand Up @@ -53,29 +57,43 @@ func (b *WSBackend) Handle(conn wasabi.Connection, r wasabi.Request) error {
// Otherwise, it establishes a new connection and returns it.
func (b *WSBackend) getConnection(conn wasabi.Connection) (*websocket.Conn, error) {
b.lock.RLock()
c, ok := b.connections[conn.ID()]
ws, ok := b.connections[conn.ID()]
b.lock.RUnlock()

if ok {
return c, nil
return ws, nil
}

c, resp, err := websocket.Dial(conn.Context(), b.URL, nil)
uws, err, _ := b.group.Do(conn.ID(), func() (interface{}, error) {
fmt.Println("Connecting to", b.URL, "for connection", conn.ID())
c, resp, err := websocket.Dial(conn.Context(), b.URL, nil)
if err != nil {
return nil, err
}

if resp.Body != nil {
defer resp.Body.Close()
}

go b.responseHandler(c, conn)

b.lock.Lock()
b.connections[conn.ID()] = c
b.lock.Unlock()

return c, nil
})

if err != nil {
return nil, err
}

if resp.Body != nil {
defer resp.Body.Close()
ws, ok = uws.(*websocket.Conn)
if !ok {
panic("unexpected type")
}

b.lock.Lock()
b.connections[conn.ID()] = c
b.lock.Unlock()

go b.responseHandler(c, conn)

return c, nil
return ws, nil
}

// responseHandler handles the response from the server to the client.
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
golang.org/x/sync v0.7.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e h1:723BNChdd0c2Wk6WOE320qGBiPtYx0F0Bbm1kriShfE=
golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down

0 comments on commit dc190b1

Please sign in to comment.