diff --git a/backend/ws.go b/backend/ws.go index c99f9dd..b3b8b68 100644 --- a/backend/ws.go +++ b/backend/ws.go @@ -3,14 +3,17 @@ package backend import ( "bytes" "errors" + "fmt" "io" "sync" "github.com/ksysoev/wasabi" + "golang.org/x/sync/singleflight" "nhooyr.io/websocket" ) type WSBackend struct { + group *singleflight.Group connections map[string]*websocket.Conn lock *sync.RWMutex factory WSRequestFactory @@ -22,6 +25,7 @@ type WSRequestFactory func(r wasabi.Request) (websocket.MessageType, []byte, err // NewWSBackend creates a new instance of WSBackend with the specified URL. func NewWSBackend(url string, factory WSRequestFactory) *WSBackend { return &WSBackend{ + group: &singleflight.Group{}, connections: make(map[string]*websocket.Conn), lock: &sync.RWMutex{}, factory: factory, @@ -53,29 +57,43 @@ func (b *WSBackend) Handle(conn wasabi.Connection, r wasabi.Request) error { // Otherwise, it establishes a new connection and returns it. func (b *WSBackend) getConnection(conn wasabi.Connection) (*websocket.Conn, error) { b.lock.RLock() - c, ok := b.connections[conn.ID()] + ws, ok := b.connections[conn.ID()] b.lock.RUnlock() if ok { - return c, nil + return ws, nil } - c, resp, err := websocket.Dial(conn.Context(), b.URL, nil) + uws, err, _ := b.group.Do(conn.ID(), func() (interface{}, error) { + fmt.Println("Connecting to", b.URL, "for connection", conn.ID()) + c, resp, err := websocket.Dial(conn.Context(), b.URL, nil) + if err != nil { + return nil, err + } + + if resp.Body != nil { + defer resp.Body.Close() + } + + go b.responseHandler(c, conn) + + b.lock.Lock() + b.connections[conn.ID()] = c + b.lock.Unlock() + + return c, nil + }) + if err != nil { return nil, err } - if resp.Body != nil { - defer resp.Body.Close() + ws, ok = uws.(*websocket.Conn) + if !ok { + panic("unexpected type") } - b.lock.Lock() - b.connections[conn.ID()] = c - b.lock.Unlock() - - go b.responseHandler(c, conn) - - return c, nil + return ws, nil } // responseHandler handles the response from the server to the client. diff --git a/go.mod b/go.mod index 36c2c9b..24fe4a9 100644 --- a/go.mod +++ b/go.mod @@ -14,5 +14,6 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect + golang.org/x/sync v0.7.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ff8e672..67a5853 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e h1:723BNChdd0c2Wk6WOE320qGBiPtYx0F0Bbm1kriShfE= golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=