diff --git a/ws/network_test.go b/ws/network_test.go index 887c6082..b1f44a45 100644 --- a/ws/network_test.go +++ b/ws/network_test.go @@ -44,8 +44,8 @@ func (s *NetworkTestSuite) TearDownSuite() { } func (s *NetworkTestSuite) SetupTest() { - s.server = NewWebsocketServer(s.T(), nil) - s.client = NewWebsocketClient(s.T(), nil) + s.server = newWebsocketServer(s.T(), nil) + s.client = newWebsocketClient(s.T(), nil) } func (s *NetworkTestSuite) TearDownTest() { @@ -55,7 +55,7 @@ func (s *NetworkTestSuite) TearDownTest() { func (s *NetworkTestSuite) TestClientConnectionFailed() { t := s.T() - s.server = NewWebsocketServer(t, nil) + s.server = newWebsocketServer(t, nil) s.server.SetNewClientHandler(func(ws Channel) { assert.Fail(t, "should not accept new clients") }) @@ -88,7 +88,7 @@ func (s *NetworkTestSuite) TestClientConnectionFailedTimeout() { // Set timeouts for test s.client.timeoutConfig.HandshakeTimeout = 2 * time.Second // Setup - s.server = NewWebsocketServer(t, nil) + s.server = newWebsocketServer(t, nil) s.server.SetNewClientHandler(func(ws Channel) { assert.Fail(t, "should not accept new clients") }) @@ -125,7 +125,7 @@ func (s *NetworkTestSuite) TestClientAutoReconnect() { serverOnDisconnected := make(chan bool, 1) clientOnDisconnected := make(chan bool, 1) reconnected := make(chan bool, 1) - s.server = NewWebsocketServer(t, nil) + s.server = newWebsocketServer(t, nil) s.server.SetNewClientHandler(func(ws Channel) { assert.NotNil(t, ws) conn := s.server.connections[ws.GetID()] diff --git a/ws/websocket_test.go b/ws/websocket_test.go index 729a4fa1..00b2f37e 100644 --- a/ws/websocket_test.go +++ b/ws/websocket_test.go @@ -31,7 +31,7 @@ const ( testPath = "/ws/testws" ) -func NewWebsocketServer(t *testing.T, onMessage func(data []byte) ([]byte, error)) *Server { +func newWebsocketServer(t *testing.T, onMessage func(data []byte) ([]byte, error)) *Server { wsServer := NewServer() wsServer.SetMessageHandler(func(ws Channel, data []byte) error { assert.NotNil(t, ws) @@ -49,7 +49,7 @@ func NewWebsocketServer(t *testing.T, onMessage func(data []byte) ([]byte, error return wsServer } -func NewWebsocketClient(t *testing.T, onMessage func(data []byte) ([]byte, error)) *Client { +func newWebsocketClient(t *testing.T, onMessage func(data []byte) ([]byte, error)) *Client { wsClient := NewClient() wsClient.SetMessageHandler(func(data []byte) error { assert.NotNil(t, data) @@ -67,7 +67,7 @@ func NewWebsocketClient(t *testing.T, onMessage func(data []byte) ([]byte, error } func TestWebsocketSetConnected(t *testing.T) { - wsClient := NewWebsocketClient(t, func(data []byte) ([]byte, error) { + wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { return nil, nil }) assert.False(t, wsClient.IsConnected()) @@ -80,7 +80,7 @@ func TestWebsocketSetConnected(t *testing.T) { func TestWebsocketEcho(t *testing.T) { message := []byte("Hello WebSocket!") var wsServer *Server - wsServer = NewWebsocketServer(t, func(data []byte) ([]byte, error) { + wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) { assert.True(t, bytes.Equal(message, data)) return data, nil }) @@ -88,7 +88,7 @@ func TestWebsocketEcho(t *testing.T) { time.Sleep(1 * time.Second) // Test message - wsClient := NewWebsocketClient(t, func(data []byte) ([]byte, error) { + wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { assert.True(t, bytes.Equal(message, data)) return nil, nil }) @@ -122,7 +122,7 @@ func TestTLSWebsocketEcho(t *testing.T) { message := []byte("Hello Secure WebSocket!") var wsServer *Server // Use NewTLSServer() when in different package - wsServer = NewWebsocketServer(t, func(data []byte) ([]byte, error) { + wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) { assert.True(t, bytes.Equal(message, data)) return data, nil }) @@ -141,11 +141,11 @@ func TestTLSWebsocketEcho(t *testing.T) { time.Sleep(1 * time.Second) // Create TLS client - wsClient := NewWebsocketClient(t, func(data []byte) ([]byte, error) { + wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { assert.True(t, bytes.Equal(message, data)) return nil, nil }) - wsClient.dialOptions = append(wsClient.dialOptions, func(dialer *websocket.Dialer) { + wsClient.AddOption(func(dialer *websocket.Dialer) { certPool := x509.NewCertPool() data, err := ioutil.ReadFile(certFilename) assert.Nil(t, err) @@ -185,7 +185,7 @@ func TestWebsocketClientConnectionBreak(t *testing.T) { newClient := make(chan bool) disconnected := make(chan bool) var wsServer *Server - wsServer = NewWebsocketServer(t, nil) + wsServer = newWebsocketServer(t, nil) wsServer.SetNewClientHandler(func(ws Channel) { newClient <- true }) @@ -196,7 +196,7 @@ func TestWebsocketClientConnectionBreak(t *testing.T) { time.Sleep(1 * time.Second) // Test - wsClient := NewWebsocketClient(t, nil) + wsClient := newWebsocketClient(t, nil) host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} // Wait for connection to be established, then break the connection @@ -219,7 +219,7 @@ func TestWebsocketClientConnectionBreak(t *testing.T) { func TestWebsocketServerConnectionBreak(t *testing.T) { var wsServer *Server disconnected := make(chan bool) - wsServer = NewWebsocketServer(t, nil) + wsServer = newWebsocketServer(t, nil) wsServer.SetNewClientHandler(func(ws Channel) { assert.NotNil(t, ws) conn := wsServer.connections[ws.GetID()] @@ -235,7 +235,7 @@ func TestWebsocketServerConnectionBreak(t *testing.T) { time.Sleep(1 * time.Second) // Test - wsClient := NewWebsocketClient(t, nil) + wsClient := newWebsocketClient(t, nil) host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} err := wsClient.Start(u.String()) @@ -334,11 +334,14 @@ func TestInvalidBasicAuth(t *testing.T) { wsClient := NewTLSClient(&tls.Config{ RootCAs: certPool, }) + // Test connection without bssic auth -> error expected + host := fmt.Sprintf("localhost:%v", serverPort) + u := url.URL{Scheme: "wss", Host: host, Path: testPath} + err = wsClient.Start(u.String()) + assert.Error(t, err) // Add basic auth wsClient.SetBasicAuth(authUsername, "invalidPassword") // Test connection - host := fmt.Sprintf("localhost:%v", serverPort) - u := url.URL{Scheme: "wss", Host: host, Path: testPath} err = wsClient.Start(u.String()) assert.NotNil(t, err) httpError, ok := err.(HttpConnectionError) @@ -351,7 +354,7 @@ func TestInvalidBasicAuth(t *testing.T) { func TestInvalidOriginHeader(t *testing.T) { var wsServer *Server - wsServer = NewWebsocketServer(t, func(data []byte) ([]byte, error) { + wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) { assert.Fail(t, "no message should be received from client!") return nil, nil }) @@ -362,7 +365,7 @@ func TestInvalidOriginHeader(t *testing.T) { time.Sleep(500 * time.Millisecond) // Test message - wsClient := NewWebsocketClient(t, func(data []byte) ([]byte, error) { + wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { assert.Fail(t, "no message should be received from server!") return nil, nil }) @@ -386,7 +389,7 @@ func TestCustomOriginHeaderHandler(t *testing.T) { var wsServer *Server origin := "example.org" connected := make(chan bool) - wsServer = NewWebsocketServer(t, func(data []byte) ([]byte, error) { + wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) { assert.Fail(t, "no message should be received from client!") return nil, nil }) @@ -400,7 +403,7 @@ func TestCustomOriginHeaderHandler(t *testing.T) { time.Sleep(500 * time.Millisecond) // Test message - wsClient := NewWebsocketClient(t, func(data []byte) ([]byte, error) { + wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { assert.Fail(t, "no message should be received from server!") return nil, nil }) @@ -428,7 +431,6 @@ func TestCustomOriginHeaderHandler(t *testing.T) { } func TestValidClientTLSCertificate(t *testing.T) { - var wsServer *Server // Create self-signed TLS certificate clientCertFilename := "/tmp/client.pem" clientKeyFilename := "/tmp/client_key.pem" @@ -449,7 +451,7 @@ func TestValidClientTLSCertificate(t *testing.T) { require.Nil(t, err) ok := certPool.AppendCertsFromPEM(data) require.True(t, ok) - wsServer = NewTLSServer(serverCertFilename, serverKeyFilename, &tls.Config{ + wsServer := NewTLSServer(serverCertFilename, serverKeyFilename, &tls.Config{ ClientCAs: certPool, ClientAuth: tls.RequireAndVerifyClientCert, }) @@ -486,7 +488,6 @@ func TestValidClientTLSCertificate(t *testing.T) { } func TestInvalidClientTLSCertificate(t *testing.T) { - var wsServer *Server // Create self-signed TLS certificate clientCertFilename := "/tmp/client.pem" clientKeyFilename := "/tmp/client_key.pem" @@ -507,7 +508,7 @@ func TestInvalidClientTLSCertificate(t *testing.T) { require.Nil(t, err) ok := certPool.AppendCertsFromPEM(data) require.True(t, ok) - wsServer = NewTLSServer(serverCertFilename, serverKeyFilename, &tls.Config{ + wsServer := NewTLSServer(serverCertFilename, serverKeyFilename, &tls.Config{ ClientCAs: certPool, // Contains server certificate as allowed client CA ClientAuth: tls.RequireAndVerifyClientCert, // Requires client certificate signed by allowed CA (server) }) @@ -547,7 +548,7 @@ func TestInvalidClientTLSCertificate(t *testing.T) { func TestUnsupportedSubprotocol(t *testing.T) { var wsServer *Server disconnected := make(chan bool) - wsServer = NewWebsocketServer(t, nil) + wsServer = newWebsocketServer(t, nil) wsServer.SetNewClientHandler(func(ws Channel) { assert.Fail(t, "invalid subprotocol expected, but hit client handler instead") t.Fail() @@ -556,12 +557,17 @@ func TestUnsupportedSubprotocol(t *testing.T) { disconnected <- true }) wsServer.AddSupportedSubprotocol(defaultSubProtocol) + assert.Len(t, wsServer.upgrader.Subprotocols, 1) + // Test duplicate subprotocol + wsServer.AddSupportedSubprotocol(defaultSubProtocol) + assert.Len(t, wsServer.upgrader.Subprotocols, 1) + // Start server go wsServer.Start(serverPort, serverPath) time.Sleep(1 * time.Second) - wsClient := NewWebsocketClient(t, nil) + wsClient := newWebsocketClient(t, nil) // Set invalid subprotocol - wsClient.dialOptions = append(wsClient.dialOptions, func(dialer *websocket.Dialer) { + wsClient.AddOption(func(dialer *websocket.Dialer) { dialer.Subprotocols = []string{"unsupportedSubProto"} }) // Test @@ -574,9 +580,8 @@ func TestUnsupportedSubprotocol(t *testing.T) { } func TestSetServerTimeoutConfig(t *testing.T) { - var wsServer *Server disconnected := make(chan bool) - wsServer = NewWebsocketServer(t, nil) + wsServer := newWebsocketServer(t, nil) wsServer.SetNewClientHandler(func(ws Channel) { }) wsServer.SetDisconnectedClientHandler(func(ws Channel) { @@ -596,7 +601,7 @@ func TestSetServerTimeoutConfig(t *testing.T) { assert.Equal(t, wsServer.timeoutConfig.PingWait, pingWait) assert.Equal(t, wsServer.timeoutConfig.WriteWait, writeWait) // Run test - wsClient := NewWebsocketClient(t, nil) + wsClient := newWebsocketClient(t, nil) host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} err := wsClient.Start(u.String()) @@ -609,9 +614,8 @@ func TestSetServerTimeoutConfig(t *testing.T) { } func TestSetClientTimeoutConfig(t *testing.T) { - var wsServer *Server disconnected := make(chan bool) - wsServer = NewWebsocketServer(t, nil) + wsServer := newWebsocketServer(t, nil) wsServer.SetNewClientHandler(func(ws Channel) { }) wsServer.SetDisconnectedClientHandler(func(ws Channel) { @@ -620,9 +624,9 @@ func TestSetClientTimeoutConfig(t *testing.T) { }) // Start server go wsServer.Start(serverPort, serverPath) - time.Sleep(500 * time.Millisecond) + time.Sleep(200 * time.Millisecond) // Run test - wsClient := NewWebsocketClient(t, nil) + wsClient := newWebsocketClient(t, nil) host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} // Set client timeout @@ -658,6 +662,124 @@ func TestSetClientTimeoutConfig(t *testing.T) { wsServer.Stop() } +func TestServerErrors(t *testing.T) { + triggerC := make(chan bool, 1) + finishC := make(chan bool, 1) + wsServer := newWebsocketServer(t, nil) + wsServer.SetNewClientHandler(func(ws Channel) { + triggerC <- true + }) + // Intercept errors asynchronously + assert.Nil(t, wsServer.errC) + go func() { + for { + select { + case err, ok := <-wsServer.Errors(): + triggerC <- true + if ok { + assert.Error(t, err) + } + case _, _ = <-finishC: + return + } + } + }() + wsServer.SetMessageHandler(func(ws Channel, data []byte) error { + return fmt.Errorf("this is a dummy error") + }) + // Will trigger an out-of-bound error + time.Sleep(50 * time.Millisecond) + wsServer.Stop() + r, _ := <-triggerC + assert.True(t, r) + // Start server for real + wsServer.httpServer = &http.Server{} + go wsServer.Start(serverPort, serverPath) + time.Sleep(200 * time.Millisecond) + // Create and connect client + wsClient := newWebsocketClient(t, nil) + host := fmt.Sprintf("localhost:%v", serverPort) + u := url.URL{Scheme: "ws", Host: host, Path: testPath} + err := wsClient.Start(u.String()) + require.NoError(t, err) + // Wait for new client callback + r, _ = <-triggerC + require.True(t, r) + // Send a dummy message and expect error on server side + err = wsClient.Write([]byte("dummy message")) + require.NoError(t, err) + r, _ = <-triggerC + assert.True(t, r) + // Send unexpected close message and wait for error to be thrown + err = wsClient.webSocket.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, "")) + assert.NoError(t, err) + r, _ = <-triggerC + // Stop and wait for errors channel cleanup + wsServer.Stop() + r, _ = <-triggerC + assert.True(t, r) + close(finishC) +} + +func TestClientErrors(t *testing.T) { + triggerC := make(chan bool, 1) + finishC := make(chan bool, 1) + wsServer := newWebsocketServer(t, nil) + wsServer.SetNewClientHandler(func(ws Channel) { + triggerC <- true + }) + wsClient := newWebsocketClient(t, nil) + wsClient.SetMessageHandler(func(data []byte) error { + return fmt.Errorf("this is a dummy error") + }) + // Intercept errors asynchronously + assert.Nil(t, wsClient.errC) + go func() { + for { + select { + case err, ok := <-wsClient.Errors(): + triggerC <- true + if ok { + assert.Error(t, err) + } + case _, _ = <-finishC: + return + } + } + }() + go wsServer.Start(serverPort, serverPath) + time.Sleep(200 * time.Millisecond) + // Attempt to write a message without being connected + err := wsClient.Write([]byte("dummy message")) + require.Error(t, err) + // Connect client + host := fmt.Sprintf("localhost:%v", serverPort) + u := url.URL{Scheme: "ws", Host: host, Path: testPath} + err = wsClient.Start(u.String()) + require.NoError(t, err) + // Wait for new client callback + r, _ := <-triggerC + require.True(t, r) + // Send a dummy message and expect error on client side + err = wsServer.Write(testPath, []byte("dummy message")) + require.NotNil(t, t, err) + r, _ = <-triggerC + assert.True(t, r) + // Send unexpected close message and wait for error to be thrown + conn := wsServer.connections[testPath] + require.NotNil(t, conn) + err = conn.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, "")) + assert.NoError(t, err) + r, _ = <-triggerC + require.True(t, r) + // Stop server and client and wait for errors channel cleanup + wsServer.Stop() + wsClient.Stop() + r, _ = <-triggerC + require.True(t, r) + close(finishC) +} + // Utility functions func createCACertificate(certificateFilename string, keyFilename string) (*x509.Certificate, *ecdsa.PrivateKey, error) {