Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v9] Clone TLS configuration for WebSocket dialer #19425

Merged
merged 2 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 133 additions & 11 deletions lib/srv/app/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
Expand Down Expand Up @@ -108,6 +109,10 @@ type suiteConfig struct {
Apps types.Apps
// ServerStreamer is the auth server audit events streamer.
ServerStreamer events.Streamer
// ValidateRequest is a function that will validate the request received by the application.
ValidateRequest func(*Suite, *http.Request)
// EnableHTTP2 defines if the test server will support HTTP2.
EnableHTTP2 bool
// CloudImporter will use the given cloud importer for the app server.
CloudImporter labels.Importer
// AppLabels are the labels assigned to the application.
Expand Down Expand Up @@ -193,10 +198,33 @@ func SetUpSuiteWithConfig(t *testing.T, config suiteConfig) *Suite {
s.message = uuid.New().String()

s.testhttp = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, s.message)
if strings.ToLower(r.Header.Get("upgrade")) == "websocket" {
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
ws, err := upgrader.Upgrade(w, r, nil)
require.NoError(t, err)

err = ws.WriteMessage(websocket.TextMessage, []byte(s.message))
require.NoError(t, err)
} else {
fmt.Fprintln(w, s.message)
}

if config.ValidateRequest != nil {
config.ValidateRequest(s, r)
}
}))
s.testhttp.Config.TLSConfig = &tls.Config{Time: s.clock.Now}
s.testhttp.Start()
if config.EnableHTTP2 {
s.testhttp.EnableHTTP2 = true
// Add NextProtos to support both protocols: h2, http/1.1
s.testhttp.Config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
s.testhttp.StartTLS()
} else {
s.testhttp.Start()
}

// Extract the hostport that the in-memory HTTP server is running on.
u, err := url.Parse(s.testhttp.URL)
Expand All @@ -214,9 +242,10 @@ func SetUpSuiteWithConfig(t *testing.T, config suiteConfig) *Suite {
Name: "foo",
Labels: appLabels,
}, types.AppSpecV3{
URI: s.testhttp.URL,
PublicAddr: "foo.example.com",
DynamicLabels: types.LabelsToV2(dynamicLabels),
URI: s.testhttp.URL,
PublicAddr: "foo.example.com",
InsecureSkipVerify: true,
DynamicLabels: types.LabelsToV2(dynamicLabels),
})
require.NoError(t, err)
appAWS, err := types.NewAppV3(types.Metadata{
Expand Down Expand Up @@ -349,8 +378,9 @@ func TestStart(t *testing.T) {
Name: "foo",
Labels: staticLabels,
}, types.AppSpecV3{
URI: s.testhttp.URL,
PublicAddr: "foo.example.com",
URI: s.testhttp.URL,
PublicAddr: "foo.example.com",
InsecureSkipVerify: true,
DynamicLabels: map[string]types.CommandLabelV2{
dynamicLabelName: {
Period: dynamicLabelPeriod,
Expand Down Expand Up @@ -557,6 +587,42 @@ func TestHandleConnection(t *testing.T) {
})
}

// TestHandleConnectionHTTP2WS given a server that supports HTTP2, make a
// request and then connect to WebSocket, ensuring that both succeed.
//
// This test guarantees the server is capable of handing requests and websockets
// in different HTTP versions.
func TestHandleConnectionHTTP2WS(t *testing.T) {
s := SetUpSuiteWithConfig(t, suiteConfig{
EnableHTTP2: true,
ValidateRequest: func(s *Suite, r *http.Request) {
// Differentiate WebSocket requests.
if strings.ToLower(r.Header.Get("upgrade")) == "websocket" {
// Expect WS requests to be using http 1.
require.Equal(t, 1, r.ProtoMajor)
return
}

// Expect http requests to be using h2.
require.Equal(t, 2, r.ProtoMajor)
},
})

// First, make the request. This will be using HTTP2.
s.checkHTTPResponse(t, s.clientCertificate, func(resp *http.Response) {
require.Equal(t, resp.StatusCode, http.StatusOK)
buf, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, strings.TrimSpace(string(buf)), s.message)
})

// Second, make the WebSocket connection. This will be using HTTP/1.1
s.checkWSResponse(t, s.clientCertificate, func(messageType int, message string) {
require.Equal(t, websocket.TextMessage, messageType)
require.Equal(t, s.message, message)
})
}

// TestAuthorize verifies that only authorized requests are handled.
func TestAuthorize(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -821,11 +887,67 @@ func (s *Suite) checkHTTPResponse(t *testing.T, clientCert tls.Certificate, chec
checkResp(resp)
require.NoError(t, resp.Body.Close())

// Close should not trigger an error.
require.NoError(t, s.appServer.Close())
// Close should not trigger an error. Closing the connection is enough to
// get out of the HandleConnection routine.
require.NoError(t, pw.Close())

// Wait for the request routine to finish.
wg.Wait()
}

// checkWSResponse checks expected websocket response.
func (s *Suite) checkWSResponse(t *testing.T, clientCert tls.Certificate, checkMessage func(messageType int, message string)) {
pr, pw := net.Pipe()
defer pw.Close()
defer pr.Close()

dialer := websocket.Dialer{
NetDial: func(_, _ string) (net.Conn, error) {
return pr, nil
},
TLSClientConfig: &tls.Config{
// RootCAs is a pool of host certificates used to verify the identity of
// the server this client is connecting to.
RootCAs: s.hostCertPool,
// Certificates is the user's application specific certificate.
Certificates: []tls.Certificate{clientCert},
// Time defines the time anchor for certificate validation
Time: s.clock.Now,
},
}

var wg sync.WaitGroup
wg.Add(1)

// Handle the connection in another goroutine.
go func() {
s.appServer.HandleConnection(pw)
wg.Done()
}()

// Issue request.
ws, resp, err := dialer.Dial("wss://"+constants.APIDomain, http.Header{})
require.NoError(t, err)

// Check response.
require.Equal(t, resp.StatusCode, http.StatusSwitchingProtocols)
require.NoError(t, resp.Body.Close())

// Read websocket message
messageType, message, err := ws.ReadMessage()
require.NoError(t, err)

// Check message
checkMessage(messageType, string(message))

// This should not trigger an error.
require.NoError(t, ws.Close())

// Close should not trigger an error. Closing the connection is enough to
// get out of the HandleConnection routine.
require.NoError(t, pw.Close())

// Wait for the application server to actually stop serving before
// closing the test. This will make sure the server removes the listeners
// Wait for the request routine to finish.
wg.Wait()
}

Expand Down
2 changes: 1 addition & 1 deletion lib/srv/app/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func newTransport(ctx context.Context, c *transportConfig) (*transport, error) {
c: c,
uri: uri,
tr: tr,
ws: newWebsocketTransport(uri, tr.TLSClientConfig),
ws: newWebsocketTransport(uri, tr.TLSClientConfig.Clone()),
}, nil
}

Expand Down