Skip to content

Commit

Permalink
refactor: update tests and code for websockets via stdlib
Browse files Browse the repository at this point in the history
  • Loading branch information
domdom82 authored and geofffranks committed Apr 29, 2024
1 parent 437047d commit fea57f9
Show file tree
Hide file tree
Showing 12 changed files with 290 additions and 53 deletions.
2 changes: 1 addition & 1 deletion handlers/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func IsWebSocketUpgrade(request *http.Request) bool {
func upgradeHeader(request *http.Request) string {
// handle multiple Connection field-values, either in a comma-separated string or multiple field-headers
for _, v := range request.Header[http.CanonicalHeaderKey("Connection")] {
// upgrade should be case insensitive per RFC6455 4.2.1
// upgrade should be case-insensitive per RFC6455 4.2.1
if strings.Contains(strings.ToLower(v), "upgrade") {
return request.Header.Get("Upgrade")
}
Expand Down
54 changes: 54 additions & 0 deletions handlers/hop_by_hop.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package handlers

import (
"net/http"
"strings"

"code.cloudfoundry.org/gorouter/config"
"code.cloudfoundry.org/gorouter/logger"
)

type HopByHop struct {
cfg *config.Config
logger logger.Logger
}

// NewHopByHop creates a new handler that sanitizes hop-by-hop headers based on the HopByHopHeadersToFilter config
func NewHopByHop(cfg *config.Config, logger logger.Logger) *HopByHop {
return &HopByHop{
logger: logger,
cfg: cfg,
}
}

func (h *HopByHop) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
h.SanitizeRequestConnection(r)
next(rw, r)
}

func (h *HopByHop) SanitizeRequestConnection(r *http.Request) {
if len(h.cfg.HopByHopHeadersToFilter) == 0 {
return
}
connections := r.Header.Values("Connection")
for index, connection := range connections {
if connection != "" {
values := strings.Split(connection, ",")
connectionHeader := []string{}
for i := range values {
trimmedValue := strings.TrimSpace(values[i])
found := false
for _, item := range h.cfg.HopByHopHeadersToFilter {
if strings.ToLower(item) == strings.ToLower(trimmedValue) {
found = true
break
}
}
if !found {
connectionHeader = append(connectionHeader, trimmedValue)
}
}
r.Header[http.CanonicalHeaderKey("Connection")][index] = strings.Join(connectionHeader, ", ")
}
}
}
159 changes: 159 additions & 0 deletions handlers/hop_by_hop_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package handlers_test

import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"

"code.cloudfoundry.org/gorouter/config"
"code.cloudfoundry.org/gorouter/handlers"
logger_fakes "code.cloudfoundry.org/gorouter/logger/fakes"
"code.cloudfoundry.org/gorouter/route"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/urfave/negroni/v3"
)

var _ = Describe("HopByHop", func() {
var (
handler *negroni.Negroni

resp http.ResponseWriter
req *http.Request
rawPath string
header http.Header
result *http.Response
responseBody []byte
requestBody *bytes.Buffer

cfg *config.Config
fakeLogger *logger_fakes.FakeLogger
hopByHop *handlers.HopByHop

nextCalled bool
)

nextHandler := negroni.HandlerFunc(func(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
_, err := io.ReadAll(req.Body)
Expect(err).NotTo(HaveOccurred())

rw.WriteHeader(http.StatusTeapot)
for name, values := range req.Header {
for _, value := range values {
rw.Header().Set(name, value)
}
}

rw.Write([]byte("I'm a little teapot, short and stout."))

if next != nil {
next(rw, req)
}

nextCalled = true
})

handleRequest := func() {
var err error
handler.ServeHTTP(resp, req)

result = resp.(*httptest.ResponseRecorder).Result()
responseBody, err = io.ReadAll(result.Body)
Expect(err).NotTo(HaveOccurred())
result.Body.Close()
}

BeforeEach(func() {
cfg = &config.Config{
HopByHopHeadersToFilter: make([]string, 0),
LoadBalance: config.LOAD_BALANCE_RR,
}
requestBody = bytes.NewBufferString("What are you?")
rawPath = "/"
header = http.Header{}
resp = httptest.NewRecorder()
})

JustBeforeEach(func() {
fakeLogger = new(logger_fakes.FakeLogger)
handler = negroni.New()
hopByHop = handlers.NewHopByHop(cfg, fakeLogger)
handler.Use(hopByHop)
handler.Use(nextHandler)

nextCalled = false

var err error
req, err = http.NewRequest("GET", "http://example.com"+rawPath, requestBody)
Expect(err).NotTo(HaveOccurred())

req.Header = header
reqInfo := &handlers.RequestInfo{
RoutePool: route.NewPool(&route.PoolOpts{}),
}
reqInfo.RoutePool.Put(route.NewEndpoint(&route.EndpointOpts{
AppId: "fake-app",
Host: "fake-host",
Port: 1234,
PrivateInstanceId: "fake-instance",
}))
req = req.WithContext(context.WithValue(req.Context(), handlers.RequestInfoCtxKey, reqInfo))
})

Context("when HopByHopHeadersToFilter is empty", func() {
BeforeEach(func() {
header.Add("Connection", "X-Forwarded-Proto")
})

It("does not touch headers listed in the Connection header", func() {
handleRequest()
Expect(resp.Header().Get("Connection")).To(ContainSubstring("X-Forwarded-Proto"))
Expect(result.StatusCode).To(Equal(http.StatusTeapot))
Expect(result.Status).To(Equal("418 I'm a teapot"))
Expect(string(responseBody)).To(Equal("I'm a little teapot, short and stout."))

})
It("calls the next handler", func() {
handleRequest()
Expect(nextCalled).To(BeTrue())
})
It("doesn't set the reqInfo's RouteEndpoint", func() {
handleRequest()
reqInfo, err := handlers.ContextRequestInfo(req)
Expect(err).NotTo(HaveOccurred())

Expect(reqInfo.RouteEndpoint).To(BeNil())
})
})

Context("when HopByHopHeadersToFilter is set", func() {
BeforeEach(func() {
cfg.HopByHopHeadersToFilter = append(cfg.HopByHopHeadersToFilter, "X-Forwarded-Proto")
header.Add("Connection", "X-Forwarded-Proto")
})

It("removes the headers listed in the Connection header", func() {
handleRequest()
Expect(resp.Header().Get("Connection")).To(BeEmpty())
Expect(result.StatusCode).To(Equal(http.StatusTeapot))
Expect(result.Status).To(Equal("418 I'm a teapot"))
Expect(string(responseBody)).To(Equal("I'm a little teapot, short and stout."))

})
It("calls the next handler", func() {
handleRequest()
Expect(nextCalled).To(BeTrue())
})
It("doesn't set the reqInfo's RouteEndpoint", func() {
handleRequest()
reqInfo, err := handlers.ContextRequestInfo(req)
Expect(err).NotTo(HaveOccurred())

Expect(reqInfo.RouteEndpoint).To(BeNil())
})
})

})
3 changes: 2 additions & 1 deletion handlers/httpstartstop.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type httpStartStopHandler struct {
logger logger.Logger
}

// NewHTTPStartStop creates a new handler that handles emitting frontent
// NewHTTPStartStop creates a new handler that handles emitting frontend
// HTTP StartStop events
func NewHTTPStartStop(emitter dropsonde.EventEmitter, logger logger.Logger) negroni.Handler {
return &httpStartStopHandler{
Expand Down Expand Up @@ -62,6 +62,7 @@ func (hh *httpStartStopHandler) ServeHTTP(rw http.ResponseWriter, r *http.Reques
envelope, err := emitter.Wrap(startStopEvent, hh.emitter.Origin())
if err != nil {
logger.Info("failed-to-create-startstop-envelope", zap.Error(err))
return
}

info, err := ContextRequestInfo(r)
Expand Down
20 changes: 14 additions & 6 deletions integration/gdpr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"code.cloudfoundry.org/gorouter/test_util"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
)

// Involves scrubbing client IPs, for more info on GDPR: https://www.eugdpr.org/
Expand Down Expand Up @@ -61,7 +60,8 @@ var _ = Describe("GDPR", func() {
Expect(f).NotTo(ContainSubstring("192.168.0.1"))
})

It("omits x-forwarded-for from stdout", func() {
It("omits x-forwarded-for headers for websockets", func() {
testState.EnableAccessLog()
testState.cfg.Status.Pass = "pass"
testState.cfg.Status.User = "user"
testState.cfg.Status.Routes.Port = 6705
Expand Down Expand Up @@ -96,8 +96,12 @@ var _ = Describe("GDPR", func() {

x.Close()

Eventually(gbytes.BufferReader(testState.gorouterSession.Out)).Should(gbytes.Say(`"X-Forwarded-For":"-"`))
Expect(testState.gorouterSession.Out.Contents()).ToNot(ContainSubstring("192.168.0.1"))
Eventually(func() ([]byte, error) {
return os.ReadFile(testState.AccessLogFilePath())
}).Should(ContainSubstring(`x_forwarded_for:"-"`))
f, err := os.ReadFile(testState.AccessLogFilePath())
Expect(err).NotTo(HaveOccurred())
Expect(f).NotTo(ContainSubstring("192.168.0.1"))
})
})

Expand Down Expand Up @@ -127,7 +131,8 @@ var _ = Describe("GDPR", func() {
}).Should(ContainSubstring(`"foo-agent" "-"`))
})

It("omits RemoteAddr from stdout", func() {
It("omits RemoteAddr in log for websockets", func() {
testState.EnableAccessLog()
testState.cfg.Status.Pass = "pass"
testState.cfg.Status.User = "user"
testState.cfg.Status.Routes.Port = 6706
Expand All @@ -151,6 +156,7 @@ var _ = Describe("GDPR", func() {
req := test_util.NewRequest("GET", "ws-app."+test_util.LocalhostDNS, "", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "upgrade")
req.Header.Set("User-Agent", "foo-agent")
x.WriteRequest(req)

resp, _ := x.ReadResponse()
Expand All @@ -161,7 +167,9 @@ var _ = Describe("GDPR", func() {

x.Close()

Eventually(gbytes.BufferReader(testState.gorouterSession.Out)).Should(gbytes.Say(`"RemoteAddr":"-"`))
Eventually(func() ([]byte, error) {
return os.ReadFile(testState.AccessLogFilePath())
}).Should(ContainSubstring(`"foo-agent" "-"`))
})
})
})
23 changes: 11 additions & 12 deletions integration/tls_to_backends_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,13 @@ var _ = Describe("TLS to backends", func() {
assertWebsocketSuccess(wsApp)
})

It("closes connections with backends that respond with non 101-status code", func() {
wsApp := test.NewHangingWebSocketApp([]route.Uri{"ws-app." + test_util.LocalhostDNS}, testState.cfg.Port, testState.mbusClient, "")
// this test mandates RFC 6455 - https://datatracker.ietf.org/doc/html/rfc6455#section-4
// where it is stated that:
// "(...) If the status code received from the server is not 101, the
// client handles the response per HTTP [RFC2616] procedures."
// Which means the proxy must treat non-101 responses as regular HTTP [ and not close the connection per se ]
It("does not close connections with backends that respond with non 101-status code", func() {
wsApp := test.NewNotUpgradingWebSocketApp([]route.Uri{"ws-app." + test_util.LocalhostDNS}, testState.cfg.Port, testState.mbusClient, "")
wsApp.Register()
wsApp.Listen()

Expand All @@ -104,18 +109,12 @@ var _ = Describe("TLS to backends", func() {

resp, err := http.ReadResponse(x.Reader, &http.Request{})
Expect(err).NotTo(HaveOccurred())
resp.Body.Close()

Expect(resp.StatusCode).To(Equal(404))

// client-side conn should have been closed
// we verify this by trying to read from it, and checking that
// - the read does not block
// - the read returns no data
// - the read returns an error EOF
n, err := conn.Read(make([]byte, 100))
Expect(n).To(Equal(0))
Expect(err).To(Equal(io.EOF))
data, err := io.ReadAll(resp.Body)
resp.Body.Close()

Expect(string(data)).To(ContainSubstring("beginning of the response body goes here"))

x.Close()
})
Expand Down
1 change: 1 addition & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ func NewProxy(
logger,
errorWriter,
))
n.Use(handlers.NewHopByHop(cfg, logger))
n.Use(&handlers.XForwardedProto{
SkipSanitization: SkipSanitizeXFP(routeServiceHandler.(*handlers.RouteService)),
ForceForwardedProtoHttps: p.config.ForceForwardedProtoHttps,
Expand Down
Loading

0 comments on commit fea57f9

Please sign in to comment.