diff --git a/handlers/helpers.go b/handlers/helpers.go index 3cf5bba29..4337a1811 100644 --- a/handlers/helpers.go +++ b/handlers/helpers.go @@ -54,7 +54,7 @@ func IsWebSocketUpgrade(request *http.Request) bool { func upgradeHeader(request *http.Request) string { // handle multiple Connection field-values, either in a comma-separated string or multiple field-headers for _, v := range request.Header[http.CanonicalHeaderKey("Connection")] { - // upgrade should be case insensitive per RFC6455 4.2.1 + // upgrade should be case-insensitive per RFC6455 4.2.1 if strings.Contains(strings.ToLower(v), "upgrade") { return request.Header.Get("Upgrade") } diff --git a/handlers/hop_by_hop.go b/handlers/hop_by_hop.go new file mode 100644 index 000000000..15d2a665c --- /dev/null +++ b/handlers/hop_by_hop.go @@ -0,0 +1,54 @@ +package handlers + +import ( + "net/http" + "strings" + + "code.cloudfoundry.org/gorouter/config" + "code.cloudfoundry.org/gorouter/logger" +) + +type HopByHop struct { + cfg *config.Config + logger logger.Logger +} + +// NewHopByHop creates a new handler that sanitizes hop-by-hop headers based on the HopByHopHeadersToFilter config +func NewHopByHop(cfg *config.Config, logger logger.Logger) *HopByHop { + return &HopByHop{ + logger: logger, + cfg: cfg, + } +} + +func (h *HopByHop) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + h.SanitizeRequestConnection(r) + next(rw, r) +} + +func (h *HopByHop) SanitizeRequestConnection(r *http.Request) { + if len(h.cfg.HopByHopHeadersToFilter) == 0 { + return + } + connections := r.Header.Values("Connection") + for index, connection := range connections { + if connection != "" { + values := strings.Split(connection, ",") + connectionHeader := []string{} + for i := range values { + trimmedValue := strings.TrimSpace(values[i]) + found := false + for _, item := range h.cfg.HopByHopHeadersToFilter { + if strings.ToLower(item) == strings.ToLower(trimmedValue) { + found = true + break + } + } + if !found { + connectionHeader = append(connectionHeader, trimmedValue) + } + } + r.Header[http.CanonicalHeaderKey("Connection")][index] = strings.Join(connectionHeader, ", ") + } + } +} diff --git a/handlers/hop_by_hop_test.go b/handlers/hop_by_hop_test.go new file mode 100644 index 000000000..cdb2a72ef --- /dev/null +++ b/handlers/hop_by_hop_test.go @@ -0,0 +1,159 @@ +package handlers_test + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + + "code.cloudfoundry.org/gorouter/config" + "code.cloudfoundry.org/gorouter/handlers" + logger_fakes "code.cloudfoundry.org/gorouter/logger/fakes" + "code.cloudfoundry.org/gorouter/route" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/urfave/negroni/v3" +) + +var _ = Describe("HopByHop", func() { + var ( + handler *negroni.Negroni + + resp http.ResponseWriter + req *http.Request + rawPath string + header http.Header + result *http.Response + responseBody []byte + requestBody *bytes.Buffer + + cfg *config.Config + fakeLogger *logger_fakes.FakeLogger + hopByHop *handlers.HopByHop + + nextCalled bool + ) + + nextHandler := negroni.HandlerFunc(func(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) { + _, err := io.ReadAll(req.Body) + Expect(err).NotTo(HaveOccurred()) + + rw.WriteHeader(http.StatusTeapot) + for name, values := range req.Header { + for _, value := range values { + rw.Header().Set(name, value) + } + } + + rw.Write([]byte("I'm a little teapot, short and stout.")) + + if next != nil { + next(rw, req) + } + + nextCalled = true + }) + + handleRequest := func() { + var err error + handler.ServeHTTP(resp, req) + + result = resp.(*httptest.ResponseRecorder).Result() + responseBody, err = io.ReadAll(result.Body) + Expect(err).NotTo(HaveOccurred()) + result.Body.Close() + } + + BeforeEach(func() { + cfg = &config.Config{ + HopByHopHeadersToFilter: make([]string, 0), + LoadBalance: config.LOAD_BALANCE_RR, + } + requestBody = bytes.NewBufferString("What are you?") + rawPath = "/" + header = http.Header{} + resp = httptest.NewRecorder() + }) + + JustBeforeEach(func() { + fakeLogger = new(logger_fakes.FakeLogger) + handler = negroni.New() + hopByHop = handlers.NewHopByHop(cfg, fakeLogger) + handler.Use(hopByHop) + handler.Use(nextHandler) + + nextCalled = false + + var err error + req, err = http.NewRequest("GET", "http://example.com"+rawPath, requestBody) + Expect(err).NotTo(HaveOccurred()) + + req.Header = header + reqInfo := &handlers.RequestInfo{ + RoutePool: route.NewPool(&route.PoolOpts{}), + } + reqInfo.RoutePool.Put(route.NewEndpoint(&route.EndpointOpts{ + AppId: "fake-app", + Host: "fake-host", + Port: 1234, + PrivateInstanceId: "fake-instance", + })) + req = req.WithContext(context.WithValue(req.Context(), handlers.RequestInfoCtxKey, reqInfo)) + }) + + Context("when HopByHopHeadersToFilter is empty", func() { + BeforeEach(func() { + header.Add("Connection", "X-Forwarded-Proto") + }) + + It("does not touch headers listed in the Connection header", func() { + handleRequest() + Expect(resp.Header().Get("Connection")).To(ContainSubstring("X-Forwarded-Proto")) + Expect(result.StatusCode).To(Equal(http.StatusTeapot)) + Expect(result.Status).To(Equal("418 I'm a teapot")) + Expect(string(responseBody)).To(Equal("I'm a little teapot, short and stout.")) + + }) + It("calls the next handler", func() { + handleRequest() + Expect(nextCalled).To(BeTrue()) + }) + It("doesn't set the reqInfo's RouteEndpoint", func() { + handleRequest() + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).NotTo(HaveOccurred()) + + Expect(reqInfo.RouteEndpoint).To(BeNil()) + }) + }) + + Context("when HopByHopHeadersToFilter is set", func() { + BeforeEach(func() { + cfg.HopByHopHeadersToFilter = append(cfg.HopByHopHeadersToFilter, "X-Forwarded-Proto") + header.Add("Connection", "X-Forwarded-Proto") + }) + + It("removes the headers listed in the Connection header", func() { + handleRequest() + Expect(resp.Header().Get("Connection")).To(BeEmpty()) + Expect(result.StatusCode).To(Equal(http.StatusTeapot)) + Expect(result.Status).To(Equal("418 I'm a teapot")) + Expect(string(responseBody)).To(Equal("I'm a little teapot, short and stout.")) + + }) + It("calls the next handler", func() { + handleRequest() + Expect(nextCalled).To(BeTrue()) + }) + It("doesn't set the reqInfo's RouteEndpoint", func() { + handleRequest() + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).NotTo(HaveOccurred()) + + Expect(reqInfo.RouteEndpoint).To(BeNil()) + }) + }) + +}) diff --git a/handlers/httpstartstop.go b/handlers/httpstartstop.go index 278289460..00cf8d529 100644 --- a/handlers/httpstartstop.go +++ b/handlers/httpstartstop.go @@ -22,7 +22,7 @@ type httpStartStopHandler struct { logger logger.Logger } -// NewHTTPStartStop creates a new handler that handles emitting frontent +// NewHTTPStartStop creates a new handler that handles emitting frontend // HTTP StartStop events func NewHTTPStartStop(emitter dropsonde.EventEmitter, logger logger.Logger) negroni.Handler { return &httpStartStopHandler{ @@ -62,6 +62,7 @@ func (hh *httpStartStopHandler) ServeHTTP(rw http.ResponseWriter, r *http.Reques envelope, err := emitter.Wrap(startStopEvent, hh.emitter.Origin()) if err != nil { logger.Info("failed-to-create-startstop-envelope", zap.Error(err)) + return } info, err := ContextRequestInfo(r) diff --git a/integration/gdpr_test.go b/integration/gdpr_test.go index 0d01fcf02..1b81f57a9 100644 --- a/integration/gdpr_test.go +++ b/integration/gdpr_test.go @@ -13,7 +13,6 @@ import ( "code.cloudfoundry.org/gorouter/test_util" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/onsi/gomega/gbytes" ) // Involves scrubbing client IPs, for more info on GDPR: https://www.eugdpr.org/ @@ -61,7 +60,8 @@ var _ = Describe("GDPR", func() { Expect(f).NotTo(ContainSubstring("192.168.0.1")) }) - It("omits x-forwarded-for from stdout", func() { + It("omits x-forwarded-for headers for websockets", func() { + testState.EnableAccessLog() testState.cfg.Status.Pass = "pass" testState.cfg.Status.User = "user" testState.cfg.Status.Routes.Port = 6705 @@ -96,8 +96,12 @@ var _ = Describe("GDPR", func() { x.Close() - Eventually(gbytes.BufferReader(testState.gorouterSession.Out)).Should(gbytes.Say(`"X-Forwarded-For":"-"`)) - Expect(testState.gorouterSession.Out.Contents()).ToNot(ContainSubstring("192.168.0.1")) + Eventually(func() ([]byte, error) { + return os.ReadFile(testState.AccessLogFilePath()) + }).Should(ContainSubstring(`x_forwarded_for:"-"`)) + f, err := os.ReadFile(testState.AccessLogFilePath()) + Expect(err).NotTo(HaveOccurred()) + Expect(f).NotTo(ContainSubstring("192.168.0.1")) }) }) @@ -127,7 +131,8 @@ var _ = Describe("GDPR", func() { }).Should(ContainSubstring(`"foo-agent" "-"`)) }) - It("omits RemoteAddr from stdout", func() { + It("omits RemoteAddr in log for websockets", func() { + testState.EnableAccessLog() testState.cfg.Status.Pass = "pass" testState.cfg.Status.User = "user" testState.cfg.Status.Routes.Port = 6706 @@ -151,6 +156,7 @@ var _ = Describe("GDPR", func() { req := test_util.NewRequest("GET", "ws-app."+test_util.LocalhostDNS, "", nil) req.Header.Set("Upgrade", "websocket") req.Header.Set("Connection", "upgrade") + req.Header.Set("User-Agent", "foo-agent") x.WriteRequest(req) resp, _ := x.ReadResponse() @@ -161,7 +167,9 @@ var _ = Describe("GDPR", func() { x.Close() - Eventually(gbytes.BufferReader(testState.gorouterSession.Out)).Should(gbytes.Say(`"RemoteAddr":"-"`)) + Eventually(func() ([]byte, error) { + return os.ReadFile(testState.AccessLogFilePath()) + }).Should(ContainSubstring(`"foo-agent" "-"`)) }) }) }) diff --git a/integration/tls_to_backends_test.go b/integration/tls_to_backends_test.go index 8793c5cb3..a39bebad8 100644 --- a/integration/tls_to_backends_test.go +++ b/integration/tls_to_backends_test.go @@ -81,8 +81,13 @@ var _ = Describe("TLS to backends", func() { assertWebsocketSuccess(wsApp) }) - It("closes connections with backends that respond with non 101-status code", func() { - wsApp := test.NewHangingWebSocketApp([]route.Uri{"ws-app." + test_util.LocalhostDNS}, testState.cfg.Port, testState.mbusClient, "") + // this test mandates RFC 6455 - https://datatracker.ietf.org/doc/html/rfc6455#section-4 + // where it is stated that: + // "(...) If the status code received from the server is not 101, the + // client handles the response per HTTP [RFC2616] procedures." + // Which means the proxy must treat non-101 responses as regular HTTP [ and not close the connection per se ] + It("does not close connections with backends that respond with non 101-status code", func() { + wsApp := test.NewNotUpgradingWebSocketApp([]route.Uri{"ws-app." + test_util.LocalhostDNS}, testState.cfg.Port, testState.mbusClient, "") wsApp.Register() wsApp.Listen() @@ -104,18 +109,12 @@ var _ = Describe("TLS to backends", func() { resp, err := http.ReadResponse(x.Reader, &http.Request{}) Expect(err).NotTo(HaveOccurred()) - resp.Body.Close() - Expect(resp.StatusCode).To(Equal(404)) - // client-side conn should have been closed - // we verify this by trying to read from it, and checking that - // - the read does not block - // - the read returns no data - // - the read returns an error EOF - n, err := conn.Read(make([]byte, 100)) - Expect(n).To(Equal(0)) - Expect(err).To(Equal(io.EOF)) + data, err := io.ReadAll(resp.Body) + resp.Body.Close() + + Expect(string(data)).To(ContainSubstring("beginning of the response body goes here")) x.Close() }) diff --git a/proxy/proxy.go b/proxy/proxy.go index e09b77667..81da8634b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -170,6 +170,7 @@ func NewProxy( logger, errorWriter, )) + n.Use(handlers.NewHopByHop(cfg, logger)) n.Use(&handlers.XForwardedProto{ SkipSanitization: SkipSanitizeXFP(routeServiceHandler.(*handlers.RouteService)), ForceForwardedProtoHttps: p.config.ForceForwardedProtoHttps, diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index e4e56faba..8adefcbd8 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -2444,12 +2444,12 @@ var _ = Describe("Proxy", func() { req, err := http.ReadRequest(conn.Reader) Expect(err).NotTo(HaveOccurred()) - done <- req.Header.Get("Upgrade") == "WebsockeT" && - req.Header.Get("Connection") == "UpgradE" + done <- req.Header.Get("Upgrade") == "Websocket" && + req.Header.Get("Connection") == "Upgrade" resp := test_util.NewResponse(http.StatusSwitchingProtocols) - resp.Header.Set("Upgrade", "WebsockeT") - resp.Header.Set("Connection", "UpgradE") + resp.Header.Set("Upgrade", "Websocket") + resp.Header.Set("Connection", "Upgrade") conn.WriteResponse(resp) @@ -2462,8 +2462,8 @@ var _ = Describe("Proxy", func() { conn := dialProxy(proxyServer) req := test_util.NewRequest("GET", "ws", "/chat", nil) - req.Header.Set("Upgrade", "WebsockeT") - req.Header.Set("Connection", "UpgradE") + req.Header.Set("Upgrade", "Websocket") + req.Header.Set("Connection", "Upgrade") conn.WriteRequest(req) @@ -2473,8 +2473,8 @@ var _ = Describe("Proxy", func() { resp, _ := conn.ReadResponse() Expect(resp.StatusCode).To(Equal(http.StatusSwitchingProtocols)) - Expect(resp.Header.Get("Upgrade")).To(Equal("WebsockeT")) - Expect(resp.Header.Get("Connection")).To(Equal("UpgradE")) + Expect(resp.Header.Get("Upgrade")).To(Equal("Websocket")) + Expect(resp.Header.Get("Connection")).To(Equal("Upgrade")) conn.WriteLine("hello from client") conn.CheckLine("hello from server") @@ -2490,12 +2490,12 @@ var _ = Describe("Proxy", func() { req, err := http.ReadRequest(conn.Reader) Expect(err).NotTo(HaveOccurred()) - done <- req.Header.Get("Upgrade") == "WebsockeT" && - req.Header.Get("Connection") == "UpgradE" + done <- req.Header.Get("Upgrade") == "Websocket" && + req.Header.Get("Connection") == "Upgrade" resp := test_util.NewResponse(http.StatusSwitchingProtocols) - resp.Header.Set("Upgrade", "WebsockeT") - resp.Header.Set("Connection", "UpgradE") + resp.Header.Set("Upgrade", "Websocket") + resp.Header.Set("Connection", "Upgrade") conn.WriteResponse(resp) @@ -2508,8 +2508,8 @@ var _ = Describe("Proxy", func() { conn := dialProxy(proxyServer) req := test_util.NewRequest("GET", "ws", "/chat", nil) - req.Header.Set("Upgrade", "WebsockeT") - req.Header.Set("Connection", "UpgradE") + req.Header.Set("Upgrade", "Websocket") + req.Header.Set("Connection", "Upgrade") conn.WriteRequest(req) @@ -2519,8 +2519,8 @@ var _ = Describe("Proxy", func() { resp, _ := conn.ReadResponse() Expect(resp.StatusCode).To(Equal(http.StatusSwitchingProtocols)) - Expect(resp.Header.Get("Upgrade")).To(Equal("WebsockeT")) - Expect(resp.Header.Get("Connection")).To(Equal("UpgradE")) + Expect(resp.Header.Get("Upgrade")).To(Equal("Websocket")) + Expect(resp.Header.Get("Connection")).To(Equal("Upgrade")) conn.WriteLine("hello from client") conn.CheckLine("hello from server") @@ -2535,8 +2535,10 @@ var _ = Describe("Proxy", func() { req, err := http.ReadRequest(conn.Reader) Expect(err).NotTo(HaveOccurred()) + // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. + // Only "Upgrade" will be added again by httputil.ReverseProxy done <- req.Header.Get("Upgrade") == "Websocket" && - req.Header.Get("Connection") == "keep-alive, Upgrade" + req.Header.Get("Connection") == "Upgrade" resp := test_util.NewResponse(http.StatusSwitchingProtocols) resp.Header.Set("Upgrade", "Websocket") @@ -2581,9 +2583,10 @@ var _ = Describe("Proxy", func() { req, err := http.ReadRequest(conn.Reader) Expect(err).NotTo(HaveOccurred()) + // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. + // Only "Upgrade" will be added again by httputil.ReverseProxy done <- req.Header.Get("Upgrade") == "Websocket" && - req.Header[http.CanonicalHeaderKey("Connection")][0] == "keep-alive" && - req.Header[http.CanonicalHeaderKey("Connection")][1] == "Upgrade" + req.Header.Get("Connection") == "Upgrade" resp := test_util.NewResponse(http.StatusSwitchingProtocols) resp.Header.Set("Upgrade", "Websocket") diff --git a/proxy/round_tripper/proxy_round_tripper.go b/proxy/round_tripper/proxy_round_tripper.go index fc9f97468..2280e496d 100644 --- a/proxy/round_tripper/proxy_round_tripper.go +++ b/proxy/round_tripper/proxy_round_tripper.go @@ -304,12 +304,25 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response responseWriterMu.Lock() defer responseWriterMu.Unlock() rt.errorHandler.HandleError(reqInfo.ProxyResponseWriter, err) + if handlers.IsWebSocketUpgrade(request) { + rt.combinedReporter.CaptureWebSocketFailure() + } return nil, err } // Round trip was successful at this point reqInfo.RoundTripSuccessful = true + // Set status code for access log + if res != nil { + reqInfo.ProxyResponseWriter.SetStatus(res.StatusCode) + } + + // Write metric for ws upgrades + if handlers.IsWebSocketUpgrade(request) { + rt.combinedReporter.CaptureWebSocketUpdate() + } + // Record the times from the last attempt, but only if it succeeded. reqInfo.DnsStartedAt = trace.DnsStart() reqInfo.DnsFinishedAt = trace.DnsDone() @@ -358,7 +371,7 @@ func (rt *roundTripper) backendRoundTrip(request *http.Request, endpoint *route. } func (rt *roundTripper) timedRoundTrip(tr http.RoundTripper, request *http.Request, logger logger.Logger) (*http.Response, error) { - if rt.config.EndpointTimeout <= 0 { + if rt.config.EndpointTimeout <= 0 || handlers.IsWebSocketUpgrade(request) { return tr.RoundTrip(request) } diff --git a/proxy/round_tripper/proxy_round_tripper_test.go b/proxy/round_tripper/proxy_round_tripper_test.go index d66d7f1cd..16d514f79 100644 --- a/proxy/round_tripper/proxy_round_tripper_test.go +++ b/proxy/round_tripper/proxy_round_tripper_test.go @@ -26,7 +26,6 @@ import ( "code.cloudfoundry.org/gorouter/handlers" "code.cloudfoundry.org/gorouter/metrics/fakes" "code.cloudfoundry.org/gorouter/proxy/fails" - "code.cloudfoundry.org/gorouter/proxy/handler" "code.cloudfoundry.org/gorouter/proxy/round_tripper" "code.cloudfoundry.org/gorouter/proxy/utils" "code.cloudfoundry.org/gorouter/route" @@ -663,7 +662,7 @@ var _ = Describe("ProxyRoundTripper", func() { It("returns a 502 Bad Gateway response", func() { backendRes, err := proxyRoundTripper.RoundTrip(req) Expect(backendRes).To(BeNil()) - Expect(err).To(Equal(handler.NoEndpointsAvailable)) + Expect(err).To(Equal(round_tripper.NoEndpointsAvailable)) Expect(reqInfo.RouteEndpoint).To(BeNil()) Expect(reqInfo.RoundTripSuccessful).To(BeFalse()) @@ -673,7 +672,7 @@ var _ = Describe("ProxyRoundTripper", func() { proxyRoundTripper.RoundTrip(req) Expect(errorHandler.HandleErrorCallCount()).To(Equal(1)) _, err := errorHandler.HandleErrorArgsForCall(0) - Expect(err).To(Equal(handler.NoEndpointsAvailable)) + Expect(err).To(Equal(round_tripper.NoEndpointsAvailable)) }) It("logs a message with `select-endpoint-failed`", func() { @@ -685,21 +684,21 @@ var _ = Describe("ProxyRoundTripper", func() { It("does not capture any routing requests to the backend", func() { _, err := proxyRoundTripper.RoundTrip(req) - Expect(err).To(Equal(handler.NoEndpointsAvailable)) + Expect(err).To(Equal(round_tripper.NoEndpointsAvailable)) Expect(combinedReporter.CaptureRoutingRequestCallCount()).To(Equal(0)) }) It("does not log anything about route services", func() { _, err := proxyRoundTripper.RoundTrip(req) - Expect(err).To(Equal(handler.NoEndpointsAvailable)) + Expect(err).To(Equal(round_tripper.NoEndpointsAvailable)) Expect(logger.Buffer()).ToNot(gbytes.Say(`route-service`)) }) It("does not report the endpoint failure", func() { _, err := proxyRoundTripper.RoundTrip(req) - Expect(err).To(MatchError(handler.NoEndpointsAvailable)) + Expect(err).To(MatchError(round_tripper.NoEndpointsAvailable)) Expect(logger.Buffer()).ToNot(gbytes.Say(`backend-endpoint-failed`)) }) diff --git a/router/router_test.go b/router/router_test.go index 08ebdc7d3..fcfac810a 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -1137,7 +1137,7 @@ var _ = Describe("Router", func() { }) }) - It("websockets do not terminate", func() { + It("websocket connections are not affected by EndpointTimeout", func() { app := test.NewWebSocketApp( []route.Uri{"ws-app." + test_util.LocalhostDNS}, config.Port, diff --git a/test/websocket_app.go b/test/websocket_app.go index c386fb812..77f0c2436 100644 --- a/test/websocket_app.go +++ b/test/websocket_app.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "net/http" + "strings" "time" nats "github.com/nats-io/nats.go" @@ -22,7 +23,7 @@ func NewWebSocketApp(urls []route.Uri, rPort uint16, mbusClient *nats.Conn, dela defer ginkgo.GinkgoRecover() Expect(r.Header.Get("Upgrade")).To(Equal("websocket")) - Expect(r.Header.Get("Connection")).To(Equal("upgrade")) + Expect(strings.ToLower(r.Header.Get("Connection"))).To(Equal("upgrade")) conn, _, err := w.(http.Hijacker).Hijack() Expect(err).ToNot(HaveOccurred()) @@ -49,7 +50,7 @@ func NewFailingWebSocketApp(urls []route.Uri, rPort uint16, mbusClient *nats.Con defer ginkgo.GinkgoRecover() Expect(r.Header.Get("Upgrade")).To(Equal("websocket")) - Expect(r.Header.Get("Connection")).To(Equal("upgrade")) + Expect(strings.ToLower(r.Header.Get("Connection"))).To(Equal("upgrade")) conn, _, err := w.(http.Hijacker).Hijack() Expect(err).ToNot(HaveOccurred()) @@ -60,13 +61,13 @@ func NewFailingWebSocketApp(urls []route.Uri, rPort uint16, mbusClient *nats.Con return app } -func NewHangingWebSocketApp(urls []route.Uri, rPort uint16, mbusClient *nats.Conn, routeServiceUrl string) *common.TestApp { +func NewNotUpgradingWebSocketApp(urls []route.Uri, rPort uint16, mbusClient *nats.Conn, routeServiceUrl string) *common.TestApp { app := common.NewTestApp(urls, rPort, mbusClient, nil, routeServiceUrl) app.AddHandler("/", func(w http.ResponseWriter, r *http.Request) { defer ginkgo.GinkgoRecover() Expect(r.Header.Get("Upgrade")).To(Equal("websocket")) - Expect(r.Header.Get("Connection")).To(Equal("upgrade")) + Expect(strings.ToLower(r.Header.Get("Connection"))).To(Equal("upgrade")) conn, _, err := w.(http.Hijacker).Hijack() Expect(err).ToNot(HaveOccurred()) @@ -81,10 +82,9 @@ func NewHangingWebSocketApp(urls []route.Uri, rPort uint16, mbusClient *nats.Con bytes.NewBufferString("\r\nbeginning of the response body goes here\r\n\r\n"), bytes.NewBuffer(make([]byte, 10024)), // bigger than the internal buffer of the http stdlib bytes.NewBufferString("\r\nmore response here, probably won't be seen by client\r\n"), - &test_util.HangingReadCloser{}), + ), ) x.WriteResponse(resp) - panic("you won't get here in a test") }) return app