Skip to content

Commit

Permalink
fix(app): clone tls configuration for websocket dialer (#19396)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielcorado committed Dec 16, 2022
1 parent 6100659 commit b78c4da
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 12 deletions.
144 changes: 133 additions & 11 deletions lib/srv/app/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
Expand Down Expand Up @@ -109,6 +110,10 @@ type suiteConfig struct {
Apps types.Apps
// ServerStreamer is the auth server audit events streamer.
ServerStreamer events.Streamer
// ValidateRequest is a function that will validate the request received by the application.
ValidateRequest func(*Suite, *http.Request)
// EnableHTTP2 defines if the test server will support HTTP2.
EnableHTTP2 bool
// CloudImporter will use the given cloud importer for the app server.
CloudImporter labels.Importer
// AppLabels are the labels assigned to the application.
Expand Down Expand Up @@ -193,10 +198,33 @@ func SetUpSuiteWithConfig(t *testing.T, config suiteConfig) *Suite {
s.message = uuid.New().String()

s.testhttp = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, s.message)
if strings.ToLower(r.Header.Get("upgrade")) == "websocket" {
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
ws, err := upgrader.Upgrade(w, r, nil)
require.NoError(t, err)

err = ws.WriteMessage(websocket.TextMessage, []byte(s.message))
require.NoError(t, err)
} else {
fmt.Fprintln(w, s.message)
}

if config.ValidateRequest != nil {
config.ValidateRequest(s, r)
}
}))
s.testhttp.Config.TLSConfig = &tls.Config{Time: s.clock.Now}
s.testhttp.Start()
if config.EnableHTTP2 {
s.testhttp.EnableHTTP2 = true
// Add NextProtos to support both protocols: h2, http/1.1
s.testhttp.Config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
s.testhttp.StartTLS()
} else {
s.testhttp.Start()
}

// Extract the hostport that the in-memory HTTP server is running on.
u, err := url.Parse(s.testhttp.URL)
Expand All @@ -214,9 +242,10 @@ func SetUpSuiteWithConfig(t *testing.T, config suiteConfig) *Suite {
Name: "foo",
Labels: appLabels,
}, types.AppSpecV3{
URI: s.testhttp.URL,
PublicAddr: "foo.example.com",
DynamicLabels: types.LabelsToV2(dynamicLabels),
URI: s.testhttp.URL,
PublicAddr: "foo.example.com",
InsecureSkipVerify: true,
DynamicLabels: types.LabelsToV2(dynamicLabels),
})
require.NoError(t, err)
appAWS, err := types.NewAppV3(types.Metadata{
Expand Down Expand Up @@ -349,8 +378,9 @@ func TestStart(t *testing.T) {
Name: "foo",
Labels: staticLabels,
}, types.AppSpecV3{
URI: s.testhttp.URL,
PublicAddr: "foo.example.com",
URI: s.testhttp.URL,
PublicAddr: "foo.example.com",
InsecureSkipVerify: true,
DynamicLabels: map[string]types.CommandLabelV2{
dynamicLabelName: {
Period: dynamicLabelPeriod,
Expand Down Expand Up @@ -557,6 +587,42 @@ func TestHandleConnection(t *testing.T) {
})
}

// TestHandleConnectionHTTP2WS given a server that supports HTTP2, make a
// request and then connect to WebSocket, ensuring that both succeed.
//
// This test guarantees the server is capable of handing requests and websockets
// in different HTTP versions.
func TestHandleConnectionHTTP2WS(t *testing.T) {
s := SetUpSuiteWithConfig(t, suiteConfig{
EnableHTTP2: true,
ValidateRequest: func(s *Suite, r *http.Request) {
// Differentiate WebSocket requests.
if strings.ToLower(r.Header.Get("upgrade")) == "websocket" {
// Expect WS requests to be using http 1.
require.Equal(t, 1, r.ProtoMajor)
return
}

// Expect http requests to be using h2.
require.Equal(t, 2, r.ProtoMajor)
},
})

// First, make the request. This will be using HTTP2.
s.checkHTTPResponse(t, s.clientCertificate, func(resp *http.Response) {
require.Equal(t, resp.StatusCode, http.StatusOK)
buf, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, strings.TrimSpace(string(buf)), s.message)
})

// Second, make the WebSocket connection. This will be using HTTP/1.1
s.checkWSResponse(t, s.clientCertificate, func(messageType int, message string) {
require.Equal(t, websocket.TextMessage, messageType)
require.Equal(t, s.message, message)
})
}

// TestAuthorize verifies that only authorized requests are handled.
func TestAuthorize(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -821,11 +887,67 @@ func (s *Suite) checkHTTPResponse(t *testing.T, clientCert tls.Certificate, chec
checkResp(resp)
require.NoError(t, resp.Body.Close())

// Close should not trigger an error.
require.NoError(t, s.appServer.Close())
// Close should not trigger an error. Closing the connection is enough to
// get out of the HandleConnection routine.
require.NoError(t, pw.Close())

// Wait for the request routine to finish.
wg.Wait()
}

// checkWSResponse checks expected websocket response.
func (s *Suite) checkWSResponse(t *testing.T, clientCert tls.Certificate, checkMessage func(messageType int, message string)) {
pr, pw := net.Pipe()
defer pw.Close()
defer pr.Close()

dialer := websocket.Dialer{
NetDial: func(_, _ string) (net.Conn, error) {
return pr, nil
},
TLSClientConfig: &tls.Config{
// RootCAs is a pool of host certificates used to verify the identity of
// the server this client is connecting to.
RootCAs: s.hostCertPool,
// Certificates is the user's application specific certificate.
Certificates: []tls.Certificate{clientCert},
// Time defines the time anchor for certificate validation
Time: s.clock.Now,
},
}

var wg sync.WaitGroup
wg.Add(1)

// Handle the connection in another goroutine.
go func() {
s.appServer.HandleConnection(pw)
wg.Done()
}()

// Issue request.
ws, resp, err := dialer.Dial("wss://"+constants.APIDomain, http.Header{})
require.NoError(t, err)

// Check response.
require.Equal(t, resp.StatusCode, http.StatusSwitchingProtocols)
require.NoError(t, resp.Body.Close())

// Read websocket message
messageType, message, err := ws.ReadMessage()
require.NoError(t, err)

// Check message
checkMessage(messageType, string(message))

// This should not trigger an error.
require.NoError(t, ws.Close())

// Close should not trigger an error. Closing the connection is enough to
// get out of the HandleConnection routine.
require.NoError(t, pw.Close())

// Wait for the application server to actually stop serving before
// closing the test. This will make sure the server removes the listeners
// Wait for the request routine to finish.
wg.Wait()
}

Expand Down
2 changes: 1 addition & 1 deletion lib/srv/app/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func newTransport(ctx context.Context, c *transportConfig) (*transport, error) {
c: c,
uri: uri,
tr: tr,
ws: newWebsocketTransport(uri, tr.TLSClientConfig, c),
ws: newWebsocketTransport(uri, tr.TLSClientConfig.Clone(), c),
}, nil
}

Expand Down

0 comments on commit b78c4da

Please sign in to comment.