Skip to content

Commit

Permalink
Add connections mutex to websocket server
Browse files Browse the repository at this point in the history
Signed-off-by: Lorenzo Donini <lorenzo.donini90@gmail.com>
  • Loading branch information
lorenzodonini committed May 2, 2021
1 parent 3c4ad35 commit 7a1f9e8
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions ws/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,16 @@ type Server struct {
timeoutConfig ServerTimeoutConfig
upgrader websocket.Upgrader
errC chan error
connMutex sync.RWMutex
}

// Creates a new simple websocket server (the websockets are not secured).
func NewServer() *Server {
return &Server{timeoutConfig: NewServerTimeoutConfig(), upgrader: websocket.Upgrader{Subprotocols: []string{}}}
return &Server{
httpServer: &http.Server{},
timeoutConfig: NewServerTimeoutConfig(),
upgrader: websocket.Upgrader{Subprotocols: []string{}},
}
}

// Creates a new secure websocket server. All created websocket channels will use TLS.
Expand Down Expand Up @@ -338,6 +343,8 @@ func (server *Server) Stop() {
}

func (server *Server) Write(webSocketId string, data []byte) error {
server.connMutex.Lock()
defer server.connMutex.Unlock()
ws, ok := server.connections[webSocketId]
if !ok {
return fmt.Errorf("couldn't write to websocket. No socket with id %v is open", webSocketId)
Expand Down Expand Up @@ -394,12 +401,14 @@ func (server *Server) wsHandler(w http.ResponseWriter, r *http.Request) {
ws := WebSocket{
connection: conn,
id: url.Path,
outQueue: make(chan []byte),
outQueue: make(chan []byte, 1),
closeSignal: make(chan error, 1),
pingMessage: make(chan []byte, 1),
tlsConnectionState: r.TLS,
}
server.connections[url.Path] = &ws
server.connMutex.Lock()
defer server.connMutex.Unlock()
server.connections[ws.id] = &ws
// Read and write routines are started in separate goroutines and function will return immediately
go server.writePump(&ws)
go server.readPump(&ws)
Expand Down Expand Up @@ -449,6 +458,9 @@ func (server *Server) writePump(ws *WebSocket) {
conn := ws.connection
defer func() {
_ = conn.Close()
server.connMutex.Lock()
defer server.connMutex.Unlock()
delete(server.connections, ws.id)
}()

for {
Expand Down Expand Up @@ -782,6 +794,7 @@ func (client *Client) Start(url string) error {
for _, option := range client.dialOptions {
option(&dialer)
}
// Connect
ws, resp, err := dialer.Dial(url, client.header)
if err != nil {
if resp != nil {
Expand All @@ -794,7 +807,6 @@ func (client *Client) Start(url string) error {
}
err = httpError
}
client.error(fmt.Errorf("connect failed: %w", err))
return err
}

Expand Down

0 comments on commit 7a1f9e8

Please sign in to comment.