From 66869a4d51824b277460ea15a25952e90877c0cc Mon Sep 17 00:00:00 2001 From: Malte Mindedal Date: Wed, 24 Sep 2025 21:48:19 +0200 Subject: [PATCH 1/2] feat: implement JSON message handling and broadcasting in WebSocket server --- internal/server/server.go | 66 ++++++++++++++++++---- test/integration/websocket_test.go | 89 ++++++++++++++++++++++++------ test/unit/hub_test.go | 4 +- 3 files changed, 129 insertions(+), 30 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index 03de56b..4695ce3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -20,6 +20,7 @@ package server import ( + "encoding/json" "fmt" "io" "log" @@ -53,12 +54,24 @@ type Client struct { closed bool } +// Message represents the V1 JSON message format exchanged between clients. +type Message struct { + Content string `json:"content"` +} + +// BroadcastMessage encapsulates a message being broadcast by the hub, +// including the originating client so it can be excluded from delivery. +type BroadcastMessage struct { + Sender *Client + Payload []byte +} + // Hub manages all WebSocket client connections and handles message broadcasting. // It maintains client registration/unregistration and ensures thread-safe operations // through mutex protection. type Hub struct { clients map[*Client]bool - broadcast chan []byte + broadcast chan BroadcastMessage register chan *Client unregister chan *Client mutex sync.RWMutex @@ -69,7 +82,7 @@ type Hub struct { func NewHub() *Hub { return &Hub{ clients: make(map[*Client]bool), - broadcast: make(chan []byte), + broadcast: make(chan BroadcastMessage), register: make(chan *Client), unregister: make(chan *Client), } @@ -102,7 +115,7 @@ func (h *Hub) GetUnregisterChan() chan<- *Client { // GetBroadcastChan returns the channel used for broadcasting messages to all clients. // This channel is write-only from the caller's perspective. -func (h *Hub) GetBroadcastChan() chan<- []byte { +func (h *Hub) GetBroadcastChan() chan<- BroadcastMessage { return h.broadcast } @@ -145,12 +158,21 @@ func (h *Hub) Run() { for { select { case client := <-h.register: + if client == nil { + log.Printf("Received nil client registration; skipping") + continue + } + h.mutex.Lock() + client.closed = false h.clients[client] = true clientCount := len(h.clients) h.mutex.Unlock() log.Printf("Client registered from %s. Total clients: %d", client.addr, clientCount) + go client.writePump() + go client.readPump() + case client := <-h.unregister: h.mutex.Lock() if _, ok := h.clients[client]; ok { @@ -165,7 +187,7 @@ func (h *Hub) Run() { h.mutex.Unlock() } - case message := <-h.broadcast: + case broadcastMsg := <-h.broadcast: h.mutex.RLock() clientCount := len(h.clients) clients := make([]*Client, 0, clientCount) @@ -174,12 +196,23 @@ func (h *Hub) Run() { } h.mutex.RUnlock() - log.Printf("Broadcasting message to %d clients", clientCount) + targetCount := clientCount + if broadcastMsg.Sender != nil { + targetCount-- + } + if targetCount < 0 { + targetCount = 0 + } + + log.Printf("Broadcasting message to %d clients", targetCount) var clientsToRemove []*Client for _, client := range clients { - if !h.safeSend(client, message) { + if broadcastMsg.Sender != nil && client == broadcastMsg.Sender { + continue + } + if !h.safeSend(client, broadcastMsg.Payload) { clientsToRemove = append(clientsToRemove, client) } } @@ -226,7 +259,7 @@ func (c *Client) readPump() { }) for { - _, message, err := c.conn.ReadMessage() + _, rawMessage, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("WebSocket error from %s: %v", c.addr, err) @@ -234,8 +267,20 @@ func (c *Client) readPump() { break } - log.Printf("Received message from %s: %s", c.addr, string(message)) - c.hub.broadcast <- message + var msg Message + if err := json.Unmarshal(rawMessage, &msg); err != nil { + log.Printf("Invalid message from %s: %v", c.addr, err) + continue + } + + normalizedMessage, err := json.Marshal(msg) + if err != nil { + log.Printf("Error normalizing message from %s: %v", c.addr, err) + continue + } + + log.Printf("Received message from %s: %s", c.addr, string(normalizedMessage)) + c.hub.broadcast <- BroadcastMessage{Sender: c, Payload: normalizedMessage} } } @@ -387,9 +432,8 @@ func WebSocketHandler(w http.ResponseWriter, r *http.Request) { client := NewClient(conn, hub, r.RemoteAddr) + // Register the client with the hub; the hub will launch the pump goroutines. client.hub.register <- client - go client.writePump() - client.readPump() } // HealthHandler provides a simple health check endpoint that returns server status. diff --git a/test/integration/websocket_test.go b/test/integration/websocket_test.go index 62d2689..879c74f 100644 --- a/test/integration/websocket_test.go +++ b/test/integration/websocket_test.go @@ -8,6 +8,8 @@ package integration import ( "context" + "encoding/json" + "net" "net/http" "net/http/httptest" "net/url" @@ -19,6 +21,39 @@ import ( "github.com/gorilla/websocket" ) +func mustMarshalMessage(t *testing.T, content string) []byte { + t.Helper() + if t == nil { + panic("testing.T is required") + } + payload, err := json.Marshal(server.Message{Content: content}) + if err != nil { + t.Fatalf("Failed to marshal message: %v", err) + } + return payload +} + +func expectNoMessage(t *testing.T, conn *websocket.Conn, timeout time.Duration) { + t.Helper() + if conn == nil { + t.Fatalf("nil connection provided to expectNoMessage") + } + if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { + t.Fatalf("Failed to set read deadline: %v", err) + } + _, _, err := conn.ReadMessage() + if err == nil { + t.Fatalf("Expected no message, but received one") + } + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return + } + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + return + } + t.Fatalf("Unexpected error while waiting for absence of message: %v", err) +} + // TestWebSocketEndpointIntegration tests the WebSocket endpoint with full server integration. // It verifies that WebSocket connections can be established, messages can be sent and received, // and the complete WebSocket functionality works in a real server environment. @@ -49,7 +84,7 @@ func TestWebSocketEndpointIntegration(t *testing.T) { } testMessage := "Hello, WebSocket!" - err = conn.WriteMessage(websocket.TextMessage, []byte(testMessage)) + err = conn.WriteMessage(websocket.TextMessage, mustMarshalMessage(t, testMessage)) if err != nil { t.Errorf("Failed to send message: %v", err) } @@ -118,22 +153,18 @@ func TestWebSocketMessageBroadcasting(t *testing.T) { time.Sleep(50 * time.Millisecond) // Send a message from the first client - testMessage := "Hello from client 0!" - err = connections[0].WriteMessage(websocket.TextMessage, []byte(testMessage)) - if err != nil { + messageContent := "Hello from client 0!" + if err := connections[0].WriteMessage(websocket.TextMessage, mustMarshalMessage(t, messageContent)); err != nil { t.Fatalf("Failed to send message from client 0: %v", err) } // Check that all other clients receive the message for i := 1; i < numClients; i++ { - // Set a read deadline - err = connections[i].SetReadDeadline(time.Now().Add(2 * time.Second)) - if err != nil { + if err := connections[i].SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { t.Errorf("Failed to set read deadline for client %d: %v", i, err) continue } - // Read the broadcasted message messageType, message, err := connections[i].ReadMessage() if err != nil { t.Errorf("Client %d failed to receive broadcasted message: %v", i, err) @@ -144,11 +175,32 @@ func TestWebSocketMessageBroadcasting(t *testing.T) { t.Errorf("Client %d: Expected text message, got type %d", i, messageType) } - if string(message) != testMessage { - t.Errorf("Client %d: Expected message %q, got %q", i, testMessage, string(message)) + var received server.Message + if err := json.Unmarshal(message, &received); err != nil { + t.Errorf("Client %d: Failed to unmarshal message: %v", i, err) + continue + } + + if received.Content != messageContent { + t.Errorf("Client %d: Expected content %q, got %q", i, messageContent, received.Content) } } + // Ensure the sender does not receive its own message + expectNoMessage(t, connections[0], 200*time.Millisecond) + + // Send malformed JSON from another client and ensure it is ignored + if err := connections[1].WriteMessage(websocket.TextMessage, []byte("not valid json")); err != nil { + t.Fatalf("Failed to send malformed message: %v", err) + } + + for i := 0; i < numClients; i++ { + if i == 1 { + continue + } + expectNoMessage(t, connections[i], 150*time.Millisecond) + } + // Close all connections gracefully for i, conn := range connections { err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) @@ -205,8 +257,7 @@ func TestWebSocketConnectionLifecycle(t *testing.T) { // Send a test message testMsg := "Test message " + string(rune('A'+i)) - err = conn.WriteMessage(websocket.TextMessage, []byte(testMsg)) - if err != nil { + if err := conn.WriteMessage(websocket.TextMessage, mustMarshalMessage(t, testMsg)); err != nil { t.Errorf("Failed to send message on iteration %d: %v", i, err) } @@ -245,7 +296,13 @@ func TestWebSocketConcurrentConnections(t *testing.T) { // Start multiple clients concurrently for i := 0; i < numConcurrentClients; i++ { - go func(clientID int) { + message := "Message from client " + string(rune('0'+i)) + payload, err := json.Marshal(server.Message{Content: message}) + if err != nil { + t.Fatalf("Failed to marshal message for client %d: %v", i, err) + } + + go func(clientID int, msgPayload []byte) { defer func() { if r := recover(); r != nil { done <- err @@ -262,9 +319,7 @@ func TestWebSocketConcurrentConnections(t *testing.T) { defer func() { _ = resp.Body.Close() }() // Send a message - message := "Message from client " + string(rune('0'+clientID)) - err = conn.WriteMessage(websocket.TextMessage, []byte(message)) - if err != nil { + if err := conn.WriteMessage(websocket.TextMessage, msgPayload); err != nil { done <- err return } @@ -291,7 +346,7 @@ func TestWebSocketConcurrentConnections(t *testing.T) { <-ctx.Done() done <- nil - }(i) + }(i, payload) } // Wait for all clients to complete diff --git a/test/unit/hub_test.go b/test/unit/hub_test.go index ba928c5..64b7054 100644 --- a/test/unit/hub_test.go +++ b/test/unit/hub_test.go @@ -86,7 +86,7 @@ func TestHubBroadcastChannel(t *testing.T) { testMessage := []byte("test broadcast") select { - case hub.GetBroadcastChan() <- testMessage: + case hub.GetBroadcastChan() <- server.BroadcastMessage{Payload: testMessage}: case <-time.After(100 * time.Millisecond): t.Error("Failed to send message to broadcast channel") } @@ -149,7 +149,7 @@ func TestConcurrentHubOperations(t *testing.T) { message := []byte("concurrent message") select { - case hub.GetBroadcastChan() <- message: + case hub.GetBroadcastChan() <- server.BroadcastMessage{Payload: message}: case <-time.After(100 * time.Millisecond): } }(i) From e065c6e57f39123cca8601eee01c8c01d280a14b Mon Sep 17 00:00:00 2001 From: Malte Mindedal Date: Wed, 24 Sep 2025 21:55:36 +0200 Subject: [PATCH 2/2] fix: linting errors --- test/integration/websocket_test.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/integration/websocket_test.go b/test/integration/websocket_test.go index 879c74f..e950605 100644 --- a/test/integration/websocket_test.go +++ b/test/integration/websocket_test.go @@ -9,6 +9,7 @@ package integration import ( "context" "encoding/json" + "fmt" "net" "net/http" "net/http/httptest" @@ -22,10 +23,10 @@ import ( ) func mustMarshalMessage(t *testing.T, content string) []byte { - t.Helper() if t == nil { panic("testing.T is required") } + t.Helper() payload, err := json.Marshal(server.Message{Content: content}) if err != nil { t.Fatalf("Failed to marshal message: %v", err) @@ -305,14 +306,14 @@ func TestWebSocketConcurrentConnections(t *testing.T) { go func(clientID int, msgPayload []byte) { defer func() { if r := recover(); r != nil { - done <- err + done <- fmt.Errorf("client %d panic: %v", clientID, r) } }() // Connect conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { - done <- err + done <- fmt.Errorf("client %d dial: %w", clientID, err) return } defer func() { _ = conn.Close() }() @@ -320,7 +321,7 @@ func TestWebSocketConcurrentConnections(t *testing.T) { // Send a message if err := conn.WriteMessage(websocket.TextMessage, msgPayload); err != nil { - done <- err + done <- fmt.Errorf("client %d write: %w", clientID, err) return }