diff --git a/lib/srv/app/server_test.go b/lib/srv/app/server_test.go index e8895bf2c014..9c0ce35740e5 100644 --- a/lib/srv/app/server_test.go +++ b/lib/srv/app/server_test.go @@ -49,6 +49,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" + "github.com/gorilla/websocket" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "go.uber.org/atomic" @@ -108,6 +109,10 @@ type suiteConfig struct { Apps types.Apps // ServerStreamer is the auth server audit events streamer. ServerStreamer events.Streamer + // ValidateRequest is a function that will validate the request received by the application. + ValidateRequest func(*Suite, *http.Request) + // EnableHTTP2 defines if the test server will support HTTP2. + EnableHTTP2 bool // CloudImporter will use the given cloud importer for the app server. CloudImporter labels.Importer // AppLabels are the labels assigned to the application. @@ -193,10 +198,33 @@ func SetUpSuiteWithConfig(t *testing.T, config suiteConfig) *Suite { s.message = uuid.New().String() s.testhttp = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, s.message) + if strings.ToLower(r.Header.Get("upgrade")) == "websocket" { + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + ws, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) + + err = ws.WriteMessage(websocket.TextMessage, []byte(s.message)) + require.NoError(t, err) + } else { + fmt.Fprintln(w, s.message) + } + + if config.ValidateRequest != nil { + config.ValidateRequest(s, r) + } })) s.testhttp.Config.TLSConfig = &tls.Config{Time: s.clock.Now} - s.testhttp.Start() + if config.EnableHTTP2 { + s.testhttp.EnableHTTP2 = true + // Add NextProtos to support both protocols: h2, http/1.1 + s.testhttp.Config.TLSConfig.NextProtos = []string{"h2", "http/1.1"} + s.testhttp.StartTLS() + } else { + s.testhttp.Start() + } // Extract the hostport that the in-memory HTTP server is running on. u, err := url.Parse(s.testhttp.URL) @@ -214,9 +242,10 @@ func SetUpSuiteWithConfig(t *testing.T, config suiteConfig) *Suite { Name: "foo", Labels: appLabels, }, types.AppSpecV3{ - URI: s.testhttp.URL, - PublicAddr: "foo.example.com", - DynamicLabels: types.LabelsToV2(dynamicLabels), + URI: s.testhttp.URL, + PublicAddr: "foo.example.com", + InsecureSkipVerify: true, + DynamicLabels: types.LabelsToV2(dynamicLabels), }) require.NoError(t, err) appAWS, err := types.NewAppV3(types.Metadata{ @@ -349,8 +378,9 @@ func TestStart(t *testing.T) { Name: "foo", Labels: staticLabels, }, types.AppSpecV3{ - URI: s.testhttp.URL, - PublicAddr: "foo.example.com", + URI: s.testhttp.URL, + PublicAddr: "foo.example.com", + InsecureSkipVerify: true, DynamicLabels: map[string]types.CommandLabelV2{ dynamicLabelName: { Period: dynamicLabelPeriod, @@ -557,6 +587,42 @@ func TestHandleConnection(t *testing.T) { }) } +// TestHandleConnectionHTTP2WS given a server that supports HTTP2, make a +// request and then connect to WebSocket, ensuring that both succeed. +// +// This test guarantees the server is capable of handing requests and websockets +// in different HTTP versions. +func TestHandleConnectionHTTP2WS(t *testing.T) { + s := SetUpSuiteWithConfig(t, suiteConfig{ + EnableHTTP2: true, + ValidateRequest: func(s *Suite, r *http.Request) { + // Differentiate WebSocket requests. + if strings.ToLower(r.Header.Get("upgrade")) == "websocket" { + // Expect WS requests to be using http 1. + require.Equal(t, 1, r.ProtoMajor) + return + } + + // Expect http requests to be using h2. + require.Equal(t, 2, r.ProtoMajor) + }, + }) + + // First, make the request. This will be using HTTP2. + s.checkHTTPResponse(t, s.clientCertificate, func(resp *http.Response) { + require.Equal(t, resp.StatusCode, http.StatusOK) + buf, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, strings.TrimSpace(string(buf)), s.message) + }) + + // Second, make the WebSocket connection. This will be using HTTP/1.1 + s.checkWSResponse(t, s.clientCertificate, func(messageType int, message string) { + require.Equal(t, websocket.TextMessage, messageType) + require.Equal(t, s.message, message) + }) +} + // TestAuthorize verifies that only authorized requests are handled. func TestAuthorize(t *testing.T) { tests := []struct { @@ -821,11 +887,67 @@ func (s *Suite) checkHTTPResponse(t *testing.T, clientCert tls.Certificate, chec checkResp(resp) require.NoError(t, resp.Body.Close()) - // Close should not trigger an error. - require.NoError(t, s.appServer.Close()) + // Close should not trigger an error. Closing the connection is enough to + // get out of the HandleConnection routine. + require.NoError(t, pw.Close()) + + // Wait for the request routine to finish. + wg.Wait() +} + +// checkWSResponse checks expected websocket response. +func (s *Suite) checkWSResponse(t *testing.T, clientCert tls.Certificate, checkMessage func(messageType int, message string)) { + pr, pw := net.Pipe() + defer pw.Close() + defer pr.Close() + + dialer := websocket.Dialer{ + NetDial: func(_, _ string) (net.Conn, error) { + return pr, nil + }, + TLSClientConfig: &tls.Config{ + // RootCAs is a pool of host certificates used to verify the identity of + // the server this client is connecting to. + RootCAs: s.hostCertPool, + // Certificates is the user's application specific certificate. + Certificates: []tls.Certificate{clientCert}, + // Time defines the time anchor for certificate validation + Time: s.clock.Now, + }, + } + + var wg sync.WaitGroup + wg.Add(1) + + // Handle the connection in another goroutine. + go func() { + s.appServer.HandleConnection(pw) + wg.Done() + }() + + // Issue request. + ws, resp, err := dialer.Dial("wss://"+constants.APIDomain, http.Header{}) + require.NoError(t, err) + + // Check response. + require.Equal(t, resp.StatusCode, http.StatusSwitchingProtocols) + require.NoError(t, resp.Body.Close()) + + // Read websocket message + messageType, message, err := ws.ReadMessage() + require.NoError(t, err) + + // Check message + checkMessage(messageType, string(message)) + + // This should not trigger an error. + require.NoError(t, ws.Close()) + + // Close should not trigger an error. Closing the connection is enough to + // get out of the HandleConnection routine. + require.NoError(t, pw.Close()) - // Wait for the application server to actually stop serving before - // closing the test. This will make sure the server removes the listeners + // Wait for the request routine to finish. wg.Wait() } diff --git a/lib/srv/app/transport.go b/lib/srv/app/transport.go index 48fad39748fe..29620617c24f 100644 --- a/lib/srv/app/transport.go +++ b/lib/srv/app/transport.go @@ -115,7 +115,7 @@ func newTransport(ctx context.Context, c *transportConfig) (*transport, error) { c: c, uri: uri, tr: tr, - ws: newWebsocketTransport(uri, tr.TLSClientConfig), + ws: newWebsocketTransport(uri, tr.TLSClientConfig.Clone()), }, nil }