Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 55 additions & 11 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package server

import (
"encoding/json"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -53,12 +54,24 @@ type Client struct {
closed bool
}

// Message represents the V1 JSON message format exchanged between clients.
type Message struct {
Content string `json:"content"`
}

// BroadcastMessage encapsulates a message being broadcast by the hub,
// including the originating client so it can be excluded from delivery.
type BroadcastMessage struct {
Sender *Client
Payload []byte
}

// Hub manages all WebSocket client connections and handles message broadcasting.
// It maintains client registration/unregistration and ensures thread-safe operations
// through mutex protection.
type Hub struct {
clients map[*Client]bool
broadcast chan []byte
broadcast chan BroadcastMessage
register chan *Client
unregister chan *Client
mutex sync.RWMutex
Expand All @@ -69,7 +82,7 @@ type Hub struct {
func NewHub() *Hub {
return &Hub{
clients: make(map[*Client]bool),
broadcast: make(chan []byte),
broadcast: make(chan BroadcastMessage),
register: make(chan *Client),
unregister: make(chan *Client),
}
Expand Down Expand Up @@ -102,7 +115,7 @@ func (h *Hub) GetUnregisterChan() chan<- *Client {

// GetBroadcastChan returns the channel used for broadcasting messages to all clients.
// This channel is write-only from the caller's perspective.
func (h *Hub) GetBroadcastChan() chan<- []byte {
func (h *Hub) GetBroadcastChan() chan<- BroadcastMessage {
return h.broadcast
}

Expand Down Expand Up @@ -145,12 +158,21 @@ func (h *Hub) Run() {
for {
select {
case client := <-h.register:
if client == nil {
log.Printf("Received nil client registration; skipping")
continue
}

h.mutex.Lock()
client.closed = false
h.clients[client] = true
clientCount := len(h.clients)
h.mutex.Unlock()
log.Printf("Client registered from %s. Total clients: %d", client.addr, clientCount)

go client.writePump()
go client.readPump()

case client := <-h.unregister:
h.mutex.Lock()
if _, ok := h.clients[client]; ok {
Expand All @@ -165,7 +187,7 @@ func (h *Hub) Run() {
h.mutex.Unlock()
}

case message := <-h.broadcast:
case broadcastMsg := <-h.broadcast:
h.mutex.RLock()
clientCount := len(h.clients)
clients := make([]*Client, 0, clientCount)
Expand All @@ -174,12 +196,23 @@ func (h *Hub) Run() {
}
h.mutex.RUnlock()

log.Printf("Broadcasting message to %d clients", clientCount)
targetCount := clientCount
if broadcastMsg.Sender != nil {
targetCount--
}
if targetCount < 0 {
targetCount = 0
}

log.Printf("Broadcasting message to %d clients", targetCount)

var clientsToRemove []*Client

for _, client := range clients {
if !h.safeSend(client, message) {
if broadcastMsg.Sender != nil && client == broadcastMsg.Sender {
continue
}
if !h.safeSend(client, broadcastMsg.Payload) {
clientsToRemove = append(clientsToRemove, client)
}
}
Expand Down Expand Up @@ -226,16 +259,28 @@ func (c *Client) readPump() {
})

for {
_, message, err := c.conn.ReadMessage()
_, rawMessage, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error from %s: %v", c.addr, err)
}
break
}

log.Printf("Received message from %s: %s", c.addr, string(message))
c.hub.broadcast <- message
var msg Message
if err := json.Unmarshal(rawMessage, &msg); err != nil {
log.Printf("Invalid message from %s: %v", c.addr, err)
continue
}

normalizedMessage, err := json.Marshal(msg)
if err != nil {
log.Printf("Error normalizing message from %s: %v", c.addr, err)
continue
}

log.Printf("Received message from %s: %s", c.addr, string(normalizedMessage))
c.hub.broadcast <- BroadcastMessage{Sender: c, Payload: normalizedMessage}
}
}

Expand Down Expand Up @@ -387,9 +432,8 @@ func WebSocketHandler(w http.ResponseWriter, r *http.Request) {

client := NewClient(conn, hub, r.RemoteAddr)

// Register the client with the hub; the hub will launch the pump goroutines.
client.hub.register <- client
go client.writePump()
client.readPump()
}

// HealthHandler provides a simple health check endpoint that returns server status.
Expand Down
96 changes: 76 additions & 20 deletions test/integration/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ package integration

import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -19,6 +22,39 @@ import (
"github.com/gorilla/websocket"
)

func mustMarshalMessage(t *testing.T, content string) []byte {
if t == nil {
panic("testing.T is required")
}
t.Helper()
payload, err := json.Marshal(server.Message{Content: content})
if err != nil {
t.Fatalf("Failed to marshal message: %v", err)
}
return payload
}

func expectNoMessage(t *testing.T, conn *websocket.Conn, timeout time.Duration) {
t.Helper()
if conn == nil {
t.Fatalf("nil connection provided to expectNoMessage")
}
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
t.Fatalf("Failed to set read deadline: %v", err)
}
_, _, err := conn.ReadMessage()
if err == nil {
t.Fatalf("Expected no message, but received one")
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return
}
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
return
}
t.Fatalf("Unexpected error while waiting for absence of message: %v", err)
}

// TestWebSocketEndpointIntegration tests the WebSocket endpoint with full server integration.
// It verifies that WebSocket connections can be established, messages can be sent and received,
// and the complete WebSocket functionality works in a real server environment.
Expand Down Expand Up @@ -49,7 +85,7 @@ func TestWebSocketEndpointIntegration(t *testing.T) {
}

testMessage := "Hello, WebSocket!"
err = conn.WriteMessage(websocket.TextMessage, []byte(testMessage))
err = conn.WriteMessage(websocket.TextMessage, mustMarshalMessage(t, testMessage))
if err != nil {
t.Errorf("Failed to send message: %v", err)
}
Expand Down Expand Up @@ -118,22 +154,18 @@ func TestWebSocketMessageBroadcasting(t *testing.T) {
time.Sleep(50 * time.Millisecond)

// Send a message from the first client
testMessage := "Hello from client 0!"
err = connections[0].WriteMessage(websocket.TextMessage, []byte(testMessage))
if err != nil {
messageContent := "Hello from client 0!"
if err := connections[0].WriteMessage(websocket.TextMessage, mustMarshalMessage(t, messageContent)); err != nil {
t.Fatalf("Failed to send message from client 0: %v", err)
}

// Check that all other clients receive the message
for i := 1; i < numClients; i++ {
// Set a read deadline
err = connections[i].SetReadDeadline(time.Now().Add(2 * time.Second))
if err != nil {
if err := connections[i].SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Errorf("Failed to set read deadline for client %d: %v", i, err)
continue
}

// Read the broadcasted message
messageType, message, err := connections[i].ReadMessage()
if err != nil {
t.Errorf("Client %d failed to receive broadcasted message: %v", i, err)
Expand All @@ -144,11 +176,32 @@ func TestWebSocketMessageBroadcasting(t *testing.T) {
t.Errorf("Client %d: Expected text message, got type %d", i, messageType)
}

if string(message) != testMessage {
t.Errorf("Client %d: Expected message %q, got %q", i, testMessage, string(message))
var received server.Message
if err := json.Unmarshal(message, &received); err != nil {
t.Errorf("Client %d: Failed to unmarshal message: %v", i, err)
continue
}

if received.Content != messageContent {
t.Errorf("Client %d: Expected content %q, got %q", i, messageContent, received.Content)
}
}

// Ensure the sender does not receive its own message
expectNoMessage(t, connections[0], 200*time.Millisecond)

// Send malformed JSON from another client and ensure it is ignored
if err := connections[1].WriteMessage(websocket.TextMessage, []byte("not valid json")); err != nil {
t.Fatalf("Failed to send malformed message: %v", err)
}

for i := 0; i < numClients; i++ {
if i == 1 {
continue
}
expectNoMessage(t, connections[i], 150*time.Millisecond)
}

// Close all connections gracefully
for i, conn := range connections {
err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
Expand Down Expand Up @@ -205,8 +258,7 @@ func TestWebSocketConnectionLifecycle(t *testing.T) {

// Send a test message
testMsg := "Test message " + string(rune('A'+i))
err = conn.WriteMessage(websocket.TextMessage, []byte(testMsg))
if err != nil {
if err := conn.WriteMessage(websocket.TextMessage, mustMarshalMessage(t, testMsg)); err != nil {
t.Errorf("Failed to send message on iteration %d: %v", i, err)
}

Expand Down Expand Up @@ -245,27 +297,31 @@ func TestWebSocketConcurrentConnections(t *testing.T) {

// Start multiple clients concurrently
for i := 0; i < numConcurrentClients; i++ {
go func(clientID int) {
message := "Message from client " + string(rune('0'+i))
payload, err := json.Marshal(server.Message{Content: message})
if err != nil {
t.Fatalf("Failed to marshal message for client %d: %v", i, err)
}

go func(clientID int, msgPayload []byte) {
defer func() {
if r := recover(); r != nil {
done <- err
done <- fmt.Errorf("client %d panic: %v", clientID, r)
}
}()

// Connect
conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
done <- err
done <- fmt.Errorf("client %d dial: %w", clientID, err)
return
}
defer func() { _ = conn.Close() }()
defer func() { _ = resp.Body.Close() }()

// Send a message
message := "Message from client " + string(rune('0'+clientID))
err = conn.WriteMessage(websocket.TextMessage, []byte(message))
if err != nil {
done <- err
if err := conn.WriteMessage(websocket.TextMessage, msgPayload); err != nil {
done <- fmt.Errorf("client %d write: %w", clientID, err)
return
}

Expand All @@ -291,7 +347,7 @@ func TestWebSocketConcurrentConnections(t *testing.T) {

<-ctx.Done()
done <- nil
}(i)
}(i, payload)
}

// Wait for all clients to complete
Expand Down
4 changes: 2 additions & 2 deletions test/unit/hub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func TestHubBroadcastChannel(t *testing.T) {
testMessage := []byte("test broadcast")

select {
case hub.GetBroadcastChan() <- testMessage:
case hub.GetBroadcastChan() <- server.BroadcastMessage{Payload: testMessage}:
case <-time.After(100 * time.Millisecond):
t.Error("Failed to send message to broadcast channel")
}
Expand Down Expand Up @@ -149,7 +149,7 @@ func TestConcurrentHubOperations(t *testing.T) {

message := []byte("concurrent message")
select {
case hub.GetBroadcastChan() <- message:
case hub.GetBroadcastChan() <- server.BroadcastMessage{Payload: message}:
case <-time.After(100 * time.Millisecond):
}
}(i)
Expand Down
Loading