diff --git a/handlers/helpers.go b/handlers/helpers.go index 3cf5bba2..4337a181 100644 --- a/handlers/helpers.go +++ b/handlers/helpers.go @@ -54,7 +54,7 @@ func IsWebSocketUpgrade(request *http.Request) bool { func upgradeHeader(request *http.Request) string { // handle multiple Connection field-values, either in a comma-separated string or multiple field-headers for _, v := range request.Header[http.CanonicalHeaderKey("Connection")] { - // upgrade should be case insensitive per RFC6455 4.2.1 + // upgrade should be case-insensitive per RFC6455 4.2.1 if strings.Contains(strings.ToLower(v), "upgrade") { return request.Header.Get("Upgrade") } diff --git a/handlers/hop_by_hop.go b/handlers/hop_by_hop.go new file mode 100644 index 00000000..b9312380 --- /dev/null +++ b/handlers/hop_by_hop.go @@ -0,0 +1,54 @@ +package handlers + +import ( + "net/http" + "strings" + + "code.cloudfoundry.org/gorouter/config" + "code.cloudfoundry.org/gorouter/logger" +) + +type HopByHop struct { + cfg *config.Config + logger logger.Logger +} + +// NewHopByHop creates a new handler that sanitizes hop-by-hop headers based on the HopByHopHeadersToFilter config +func NewHopByHop(cfg *config.Config, logger logger.Logger) *HopByHop { + return &HopByHop{ + logger: logger, + cfg: cfg, + } +} + +func (h *HopByHop) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + h.SanitizeRequestConnection(r) + next(rw, r) +} + +func (h *HopByHop) SanitizeRequestConnection(r *http.Request) { + if len(h.cfg.HopByHopHeadersToFilter) == 0 { + return + } + connections := r.Header.Values("Connection") + for index, connection := range connections { + if connection != "" { + values := strings.Split(connection, ",") + connectionHeader := []string{} + for i := range values { + trimmedValue := strings.TrimSpace(values[i]) + found := false + for _, item := range h.cfg.HopByHopHeadersToFilter { + if strings.EqualFold(item, trimmedValue) { + found = true + break + } + } + if !found { + connectionHeader = append(connectionHeader, trimmedValue) + } + } + r.Header[http.CanonicalHeaderKey("Connection")][index] = strings.Join(connectionHeader, ", ") + } + } +} diff --git a/handlers/hop_by_hop_test.go b/handlers/hop_by_hop_test.go new file mode 100644 index 00000000..cdb2a72e --- /dev/null +++ b/handlers/hop_by_hop_test.go @@ -0,0 +1,159 @@ +package handlers_test + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + + "code.cloudfoundry.org/gorouter/config" + "code.cloudfoundry.org/gorouter/handlers" + logger_fakes "code.cloudfoundry.org/gorouter/logger/fakes" + "code.cloudfoundry.org/gorouter/route" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/urfave/negroni/v3" +) + +var _ = Describe("HopByHop", func() { + var ( + handler *negroni.Negroni + + resp http.ResponseWriter + req *http.Request + rawPath string + header http.Header + result *http.Response + responseBody []byte + requestBody *bytes.Buffer + + cfg *config.Config + fakeLogger *logger_fakes.FakeLogger + hopByHop *handlers.HopByHop + + nextCalled bool + ) + + nextHandler := negroni.HandlerFunc(func(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) { + _, err := io.ReadAll(req.Body) + Expect(err).NotTo(HaveOccurred()) + + rw.WriteHeader(http.StatusTeapot) + for name, values := range req.Header { + for _, value := range values { + rw.Header().Set(name, value) + } + } + + rw.Write([]byte("I'm a little teapot, short and stout.")) + + if next != nil { + next(rw, req) + } + + nextCalled = true + }) + + handleRequest := func() { + var err error + handler.ServeHTTP(resp, req) + + result = resp.(*httptest.ResponseRecorder).Result() + responseBody, err = io.ReadAll(result.Body) + Expect(err).NotTo(HaveOccurred()) + result.Body.Close() + } + + BeforeEach(func() { + cfg = &config.Config{ + HopByHopHeadersToFilter: make([]string, 0), + LoadBalance: config.LOAD_BALANCE_RR, + } + requestBody = bytes.NewBufferString("What are you?") + rawPath = "/" + header = http.Header{} + resp = httptest.NewRecorder() + }) + + JustBeforeEach(func() { + fakeLogger = new(logger_fakes.FakeLogger) + handler = negroni.New() + hopByHop = handlers.NewHopByHop(cfg, fakeLogger) + handler.Use(hopByHop) + handler.Use(nextHandler) + + nextCalled = false + + var err error + req, err = http.NewRequest("GET", "http://example.com"+rawPath, requestBody) + Expect(err).NotTo(HaveOccurred()) + + req.Header = header + reqInfo := &handlers.RequestInfo{ + RoutePool: route.NewPool(&route.PoolOpts{}), + } + reqInfo.RoutePool.Put(route.NewEndpoint(&route.EndpointOpts{ + AppId: "fake-app", + Host: "fake-host", + Port: 1234, + PrivateInstanceId: "fake-instance", + })) + req = req.WithContext(context.WithValue(req.Context(), handlers.RequestInfoCtxKey, reqInfo)) + }) + + Context("when HopByHopHeadersToFilter is empty", func() { + BeforeEach(func() { + header.Add("Connection", "X-Forwarded-Proto") + }) + + It("does not touch headers listed in the Connection header", func() { + handleRequest() + Expect(resp.Header().Get("Connection")).To(ContainSubstring("X-Forwarded-Proto")) + Expect(result.StatusCode).To(Equal(http.StatusTeapot)) + Expect(result.Status).To(Equal("418 I'm a teapot")) + Expect(string(responseBody)).To(Equal("I'm a little teapot, short and stout.")) + + }) + It("calls the next handler", func() { + handleRequest() + Expect(nextCalled).To(BeTrue()) + }) + It("doesn't set the reqInfo's RouteEndpoint", func() { + handleRequest() + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).NotTo(HaveOccurred()) + + Expect(reqInfo.RouteEndpoint).To(BeNil()) + }) + }) + + Context("when HopByHopHeadersToFilter is set", func() { + BeforeEach(func() { + cfg.HopByHopHeadersToFilter = append(cfg.HopByHopHeadersToFilter, "X-Forwarded-Proto") + header.Add("Connection", "X-Forwarded-Proto") + }) + + It("removes the headers listed in the Connection header", func() { + handleRequest() + Expect(resp.Header().Get("Connection")).To(BeEmpty()) + Expect(result.StatusCode).To(Equal(http.StatusTeapot)) + Expect(result.Status).To(Equal("418 I'm a teapot")) + Expect(string(responseBody)).To(Equal("I'm a little teapot, short and stout.")) + + }) + It("calls the next handler", func() { + handleRequest() + Expect(nextCalled).To(BeTrue()) + }) + It("doesn't set the reqInfo's RouteEndpoint", func() { + handleRequest() + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).NotTo(HaveOccurred()) + + Expect(reqInfo.RouteEndpoint).To(BeNil()) + }) + }) + +}) diff --git a/handlers/httpstartstop.go b/handlers/httpstartstop.go index e0909977..a4559e71 100644 --- a/handlers/httpstartstop.go +++ b/handlers/httpstartstop.go @@ -21,7 +21,7 @@ type httpStartStopHandler struct { logger logger.Logger } -// NewHTTPStartStop creates a new handler that handles emitting frontent +// NewHTTPStartStop creates a new handler that handles emitting frontend // HTTP StartStop events func NewHTTPStartStop(emitter dropsonde.EventEmitter, logger logger.Logger) negroni.Handler { return &httpStartStopHandler{ @@ -61,6 +61,7 @@ func (hh *httpStartStopHandler) ServeHTTP(rw http.ResponseWriter, r *http.Reques envelope, err := emitter.Wrap(startStopEvent, hh.emitter.Origin()) if err != nil { logger.Info("failed-to-create-startstop-envelope", zap.Error(err)) + return } endpoint, _ := GetEndpoint(r.Context()) diff --git a/integration/gdpr_test.go b/integration/gdpr_test.go index 0d01fcf0..1b81f57a 100644 --- a/integration/gdpr_test.go +++ b/integration/gdpr_test.go @@ -13,7 +13,6 @@ import ( "code.cloudfoundry.org/gorouter/test_util" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/onsi/gomega/gbytes" ) // Involves scrubbing client IPs, for more info on GDPR: https://www.eugdpr.org/ @@ -61,7 +60,8 @@ var _ = Describe("GDPR", func() { Expect(f).NotTo(ContainSubstring("192.168.0.1")) }) - It("omits x-forwarded-for from stdout", func() { + It("omits x-forwarded-for headers for websockets", func() { + testState.EnableAccessLog() testState.cfg.Status.Pass = "pass" testState.cfg.Status.User = "user" testState.cfg.Status.Routes.Port = 6705 @@ -96,8 +96,12 @@ var _ = Describe("GDPR", func() { x.Close() - Eventually(gbytes.BufferReader(testState.gorouterSession.Out)).Should(gbytes.Say(`"X-Forwarded-For":"-"`)) - Expect(testState.gorouterSession.Out.Contents()).ToNot(ContainSubstring("192.168.0.1")) + Eventually(func() ([]byte, error) { + return os.ReadFile(testState.AccessLogFilePath()) + }).Should(ContainSubstring(`x_forwarded_for:"-"`)) + f, err := os.ReadFile(testState.AccessLogFilePath()) + Expect(err).NotTo(HaveOccurred()) + Expect(f).NotTo(ContainSubstring("192.168.0.1")) }) }) @@ -127,7 +131,8 @@ var _ = Describe("GDPR", func() { }).Should(ContainSubstring(`"foo-agent" "-"`)) }) - It("omits RemoteAddr from stdout", func() { + It("omits RemoteAddr in log for websockets", func() { + testState.EnableAccessLog() testState.cfg.Status.Pass = "pass" testState.cfg.Status.User = "user" testState.cfg.Status.Routes.Port = 6706 @@ -151,6 +156,7 @@ var _ = Describe("GDPR", func() { req := test_util.NewRequest("GET", "ws-app."+test_util.LocalhostDNS, "", nil) req.Header.Set("Upgrade", "websocket") req.Header.Set("Connection", "upgrade") + req.Header.Set("User-Agent", "foo-agent") x.WriteRequest(req) resp, _ := x.ReadResponse() @@ -161,7 +167,9 @@ var _ = Describe("GDPR", func() { x.Close() - Eventually(gbytes.BufferReader(testState.gorouterSession.Out)).Should(gbytes.Say(`"RemoteAddr":"-"`)) + Eventually(func() ([]byte, error) { + return os.ReadFile(testState.AccessLogFilePath()) + }).Should(ContainSubstring(`"foo-agent" "-"`)) }) }) }) diff --git a/integration/tls_to_backends_test.go b/integration/tls_to_backends_test.go index 8793c5cb..a627c234 100644 --- a/integration/tls_to_backends_test.go +++ b/integration/tls_to_backends_test.go @@ -81,8 +81,13 @@ var _ = Describe("TLS to backends", func() { assertWebsocketSuccess(wsApp) }) - It("closes connections with backends that respond with non 101-status code", func() { - wsApp := test.NewHangingWebSocketApp([]route.Uri{"ws-app." + test_util.LocalhostDNS}, testState.cfg.Port, testState.mbusClient, "") + // this test mandates RFC 6455 - https://datatracker.ietf.org/doc/html/rfc6455#section-4 + // where it is stated that: + // "(...) If the status code received from the server is not 101, the + // client handles the response per HTTP [RFC2616] procedures." + // Which means the proxy must treat non-101 responses as regular HTTP [ and not close the connection per se ] + It("does not close connections with backends that respond with non 101-status code", func() { + wsApp := test.NewNotUpgradingWebSocketApp([]route.Uri{"ws-app." + test_util.LocalhostDNS}, testState.cfg.Port, testState.mbusClient, "") wsApp.Register() wsApp.Listen() @@ -104,18 +109,13 @@ var _ = Describe("TLS to backends", func() { resp, err := http.ReadResponse(x.Reader, &http.Request{}) Expect(err).NotTo(HaveOccurred()) - resp.Body.Close() - Expect(resp.StatusCode).To(Equal(404)) - // client-side conn should have been closed - // we verify this by trying to read from it, and checking that - // - the read does not block - // - the read returns no data - // - the read returns an error EOF - n, err := conn.Read(make([]byte, 100)) - Expect(n).To(Equal(0)) - Expect(err).To(Equal(io.EOF)) + data, err := io.ReadAll(resp.Body) + Expect(err).To(Not(HaveOccurred())) + resp.Body.Close() + + Expect(string(data)).To(ContainSubstring("beginning of the response body goes here")) x.Close() }) diff --git a/proxy/handler/forwarder.go b/proxy/handler/forwarder.go deleted file mode 100644 index 75b1298a..00000000 --- a/proxy/handler/forwarder.go +++ /dev/null @@ -1,84 +0,0 @@ -package handler - -import ( - "bufio" - "bytes" - "errors" - "fmt" - "io" - "net/http" - "time" - - "code.cloudfoundry.org/gorouter/logger" - "code.cloudfoundry.org/gorouter/proxy/utils" - "github.com/uber-go/zap" -) - -type Forwarder struct { - BackendReadTimeout time.Duration - Logger logger.Logger -} - -// ForwardIO sets up websocket forwarding with a backend -// -// It returns after one of the connections closes. -// -// If the backend response code is not 101 Switching Protocols, then -// ForwardIO will return immediately, allowing the caller to close the connections. -func (f *Forwarder) ForwardIO(clientConn, backendConn io.ReadWriter) (int, error) { - done := make(chan bool, 2) - - copy := func(dst io.Writer, src io.Reader) { - // don't care about errors here - _, _ = io.Copy(dst, src) - done <- true - } - - headerBytes := &bytes.Buffer{} - teedReader := io.TeeReader(backendConn, headerBytes) - - resp, err := utils.ReadResponseWithTimeout(bufio.NewReader(teedReader), nil, f.BackendReadTimeout) - if err != nil { - f.Logger.Error("websocket-forwardio", zap.Error(err)) - // we have to write our own HTTP header since we didn't get one from the backend - _, writeErr := clientConn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")) - if writeErr != nil { - f.Logger.Error("websocket-client-write", zap.Error(writeErr)) - } - return http.StatusBadGateway, err - } - - // as long as we got a valid response from the backend, - // we always write the header... - _, err = io.Copy(clientConn, headerBytes) - if err != nil { - f.Logger.Error("websocket-client-write", zap.Error(err)) - // we got a status code from the backend, - // - // we don't know for sure that this got back to the client - // but there isn't much we can do about that at this point - // - // return it so we can log it in access logs - return resp.StatusCode, err - } - - if !isValidWebsocketResponse(resp) { - errMsg := fmt.Sprintf("backend responded with non-101 status code: %d", resp.StatusCode) - err = errors.New(errMsg) - f.Logger.Error("websocket-backend", zap.Error(err)) - return resp.StatusCode, err - } - - // only now do we start copying body data - go copy(clientConn, backendConn) - go copy(backendConn, clientConn) - - // Note: this blocks until the entire websocket activity completes - <-done - return http.StatusSwitchingProtocols, nil -} - -func isValidWebsocketResponse(resp *http.Response) bool { - ok := resp.StatusCode == http.StatusSwitchingProtocols - return ok -} diff --git a/proxy/handler/forwarder_test.go b/proxy/handler/forwarder_test.go deleted file mode 100644 index 166d92bd..00000000 --- a/proxy/handler/forwarder_test.go +++ /dev/null @@ -1,178 +0,0 @@ -package handler_test - -import ( - "bytes" - "errors" - "io" - "net/http" - "runtime" - "sync" - "time" - - "github.com/onsi/gomega/gbytes" - - "code.cloudfoundry.org/gorouter/proxy/handler" - "code.cloudfoundry.org/gorouter/proxy/utils" - "code.cloudfoundry.org/gorouter/test_util" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("Forwarder", func() { - var clientConn, backendConn *MockReadWriter - var forwarder *handler.Forwarder - var logger *test_util.TestZapLogger - - buildFakeBackend := func(statusString string, responseBody io.Reader) *MockReadWriter { - fakeBackend := io.MultiReader(bytes.NewBufferString("HTTP/1.1 "+statusString+"\r\n\r\n"), responseBody) - return NewMockConn(fakeBackend) - } - - BeforeEach(func() { - logger = test_util.NewTestZapLogger("test") - forwarder = &handler.Forwarder{ - BackendReadTimeout: time.Second, - Logger: logger, - } - clientConn = NewMockConn(bytes.NewReader([]byte("some client data"))) - }) - - Context("when the backend gives a valid websocket response", func() { - BeforeEach(func() { - fakeResponseBody := io.MultiReader(bytes.NewBufferString("some websocket data"), &test_util.HangingReadCloser{}) - backendConn = buildFakeBackend("101 Switching Protocols", fakeResponseBody) - }) - - It("returns the status code that the backend responded with", func() { - code, err := forwarder.ForwardIO(clientConn, backendConn) - Expect(code).To(Equal(http.StatusSwitchingProtocols)) - Expect(err).To(BeNil()) - }) - - It("always copies the full response header to the client conn, before it returns", func() { - forwarder.ForwardIO(clientConn, backendConn) - Expect(clientConn.GetWrittenBytes()).To(HavePrefix("HTTP/1.1 101 Switching Protocols")) - }) - - It("eventually writes all the response data", func() { - backendConn = buildFakeBackend("101 Switching Protocols", bytes.NewBufferString("some websocket data")) - code, err := forwarder.ForwardIO(clientConn, backendConn) - Expect(code).To(Equal(http.StatusSwitchingProtocols)) - Expect(err).To(BeNil()) - Eventually(clientConn.GetWrittenBytes).Should(ContainSubstring("some websocket data")) - }) - }) - - Context("when the backend response has a non-101 status code", func() { - BeforeEach(func() { - backendConn = buildFakeBackend("200 OK", &test_util.HangingReadCloser{}) - }) - - It("immediately returns the code, without waiting for either connection to close", func() { - code, err := forwarder.ForwardIO(clientConn, backendConn) - Expect(code).To(Equal(http.StatusOK)) - Expect(err).To(MatchError("backend responded with non-101 status code: 200")) - }) - - It("always copies the full response header to the client conn, before it returns", func() { - forwarder.ForwardIO(clientConn, backendConn) - Expect(clientConn.GetWrittenBytes()).To(HavePrefix("HTTP/1.1 200 OK")) - }) - }) - - Context("when the backend response is not a valid HTTP response", func() { - BeforeEach(func() { - backendConn = buildFakeBackend("banana", bytes.NewBufferString("bad data")) - }) - - It("returns code 502 and logs the error", func() { - code, err := forwarder.ForwardIO(clientConn, backendConn) - Expect(err).Should(MatchError("malformed HTTP status code \"banana\"")) - Expect(code).To(Equal(http.StatusBadGateway)) - Expect(logger.Buffer()).To(gbytes.Say(`websocket-forwardio`)) - Expect(clientConn.GetWrittenBytes()).To(HavePrefix("HTTP/1.1 502 Bad Gateway\r\n\r\n")) - }) - - Context("when the bytes cannot be written to the client connection", func() { - BeforeEach(func() { - clientConn.WriteError("banana") - }) - It("returns code 502 and logs the error", func() { - code, err := forwarder.ForwardIO(clientConn, backendConn) - Expect(err).Should(MatchError("malformed HTTP status code \"banana\"")) - Expect(code).To(Equal(http.StatusBadGateway)) - Expect(logger.Buffer()).To(gbytes.Say(`websocket-forwardio`)) - Expect(logger.Buffer()).To(gbytes.Say(`websocket-client-write.*banana`)) - }) - }) - }) - - Context("when the backend hangs indefinitely on reading the header", func() { - BeforeEach(func() { - backendConn = NewMockConn(&test_util.HangingReadCloser{}) - }) - - It("times out after some time and logs the timeout", func() { - code, err := forwarder.ForwardIO(clientConn, backendConn) - Expect(code).To(Equal(http.StatusBadGateway)) - Expect(err).To(MatchError(utils.TimeoutError{})) - Expect(logger.Buffer()).To(gbytes.Say(`timeout waiting for http response from backend`)) - }) - }) - - Context("when the backend responds after BackendReadTimeout", func() { - var ( - sleepDuration time.Duration - ) - - BeforeEach(func() { - forwarder.BackendReadTimeout = 10 * time.Millisecond - sleepDuration = 100 * time.Millisecond - backendConn = NewMockConn(&test_util.SlowReadCloser{SleepDuration: sleepDuration}) - }) - - It("does not leak goroutines", func() { - beforeGoroutineCount := runtime.NumGoroutine() - Expect(forwarder.ForwardIO(clientConn, backendConn)).To(Equal(http.StatusBadGateway)) - - Eventually(func() int { - return runtime.NumGoroutine() - }).Should(BeNumerically("<=", beforeGoroutineCount)) - }) - }) -}) - -func NewMockConn(fakeBackend io.Reader) *MockReadWriter { - return &MockReadWriter{ - buffer: &bytes.Buffer{}, - Reader: fakeBackend, - } -} - -type MockReadWriter struct { - io.Reader - sync.Mutex - buffer *bytes.Buffer - writeError error -} - -func (m *MockReadWriter) WriteError(err string) { - m.writeError = errors.New(err) -} - -func (m *MockReadWriter) Write(buffer []byte) (int, error) { - if m.writeError != nil { - return 0, m.writeError - } - time.Sleep(100 * time.Millisecond) // simulate some network delay - m.Lock() - defer m.Unlock() - return m.buffer.Write(buffer) -} - -func (m *MockReadWriter) GetWrittenBytes() string { - m.Lock() - defer m.Unlock() - return m.buffer.String() -} diff --git a/proxy/handler/init_test.go b/proxy/handler/init_test.go deleted file mode 100644 index eb51a6b9..00000000 --- a/proxy/handler/init_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package handler_test - -import ( - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - - "testing" -) - -func TestHandler(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Handler Suite") -} diff --git a/proxy/handler/request_handler.go b/proxy/handler/request_handler.go deleted file mode 100644 index b0a2e356..00000000 --- a/proxy/handler/request_handler.go +++ /dev/null @@ -1,346 +0,0 @@ -package handler - -import ( - "bufio" - "crypto/tls" - "errors" - "fmt" - "net" - "net/http" - "strconv" - "strings" - "time" - - router_http "code.cloudfoundry.org/gorouter/common/http" - "code.cloudfoundry.org/gorouter/errorwriter" - "code.cloudfoundry.org/gorouter/handlers" - "code.cloudfoundry.org/gorouter/logger" - "code.cloudfoundry.org/gorouter/metrics" - "code.cloudfoundry.org/gorouter/proxy/utils" - "code.cloudfoundry.org/gorouter/route" - "github.com/uber-go/zap" -) - -var NoEndpointsAvailable = errors.New("No endpoints available") - -type RequestHandler struct { - logger logger.Logger - errorWriter errorwriter.ErrorWriter - reporter metrics.ProxyReporter - - request *http.Request - response utils.ProxyResponseWriter - - endpointDialTimeout time.Duration - websocketDialTimeout time.Duration - maxAttempts int - - tlsConfigTemplate *tls.Config - - forwarder *Forwarder - disableXFFLogging bool - disableSourceIPLogging bool - hopByHopHeadersToFilter []string -} - -func NewRequestHandler( - request *http.Request, - response utils.ProxyResponseWriter, - r metrics.ProxyReporter, - logger logger.Logger, - errorWriter errorwriter.ErrorWriter, - endpointDialTimeout time.Duration, - websocketDialTimeout time.Duration, - maxAttempts int, - tlsConfig *tls.Config, - hopByHopHeadersToFilter []string, - opts ...func(*RequestHandler), -) *RequestHandler { - reqHandler := &RequestHandler{ - errorWriter: errorWriter, - reporter: r, - request: request, - response: response, - endpointDialTimeout: endpointDialTimeout, - websocketDialTimeout: websocketDialTimeout, - maxAttempts: maxAttempts, - tlsConfigTemplate: tlsConfig, - hopByHopHeadersToFilter: hopByHopHeadersToFilter, - } - - for _, option := range opts { - option(reqHandler) - } - - requestLogger := setupLogger(reqHandler.disableXFFLogging, reqHandler.disableSourceIPLogging, request, logger) - reqHandler.forwarder = &Forwarder{ - BackendReadTimeout: websocketDialTimeout, - Logger: requestLogger, - } - reqHandler.logger = requestLogger - - return reqHandler -} - -func setupLogger(disableXFFLogging, disableSourceIPLogging bool, request *http.Request, logger logger.Logger) logger.Logger { - fields := []zap.Field{ - zap.String("RemoteAddr", request.RemoteAddr), - zap.String("Host", request.Host), - zap.String("Path", request.URL.Path), - zap.Object("X-Forwarded-For", request.Header["X-Forwarded-For"]), - zap.Object("X-Forwarded-Proto", request.Header["X-Forwarded-Proto"]), - zap.Object("X-Vcap-Request-Id", request.Header["X-Vcap-Request-Id"]), - } - // Specific indexes below is to preserve the schema in the log line - if disableSourceIPLogging { - fields[0] = zap.String("RemoteAddr", "-") - } - - if disableXFFLogging { - fields[3] = zap.Object("X-Forwarded-For", "-") - } - - l := logger.Session("request-handler").With(fields...) - return l -} - -func DisableXFFLogging(t bool) func(*RequestHandler) { - return func(h *RequestHandler) { - h.disableXFFLogging = t - } -} - -func DisableSourceIPLogging(t bool) func(*RequestHandler) { - return func(h *RequestHandler) { - h.disableSourceIPLogging = t - } -} - -func (h *RequestHandler) Logger() logger.Logger { - return h.logger -} - -func (h *RequestHandler) HandleBadGateway(err error, request *http.Request) { - h.reporter.CaptureBadGateway() - - handlers.AddRouterErrorHeader(h.response, "endpoint_failure") - - h.errorWriter.WriteError(h.response, http.StatusBadGateway, "Registered endpoint failed to handle the request.", h.logger) - h.response.Done() -} - -func (h *RequestHandler) HandleTcpRequest(iter route.EndpointIterator) { - h.logger.Info("handling-tcp-request", zap.String("Upgrade", "tcp")) - - onConnectionFailed := func(err error) { h.logger.Error("tcp-connection-failed", zap.Error(err)) } - backendStatusCode, err := h.serveTcp(iter, nil, onConnectionFailed) - if err != nil { - h.logger.Error("tcp-request-failed", zap.Error(err)) - h.errorWriter.WriteError(h.response, http.StatusBadGateway, "TCP forwarding to endpoint failed.", h.logger) - return - } - h.response.SetStatus(backendStatusCode) -} - -func (h *RequestHandler) HandleWebSocketRequest(iter route.EndpointIterator) { - h.logger.Info("handling-websocket-request", zap.String("Upgrade", "websocket")) - - onConnectionSucceeded := func(connection net.Conn, endpoint *route.Endpoint) error { - h.setupRequest(endpoint) - err := h.request.Write(connection) - if err != nil { - return err - } - return nil - } - onConnectionFailed := func(err error) { h.logger.Error("websocket-connection-failed", zap.Error(err)) } - - backendStatusCode, err := h.serveTcp(iter, onConnectionSucceeded, onConnectionFailed) - - if err != nil { - h.logger.Error("websocket-request-failed", zap.Error(err)) - h.errorWriter.WriteError(h.response, http.StatusBadGateway, "WebSocket request to endpoint failed.", h.logger) - h.reporter.CaptureWebSocketFailure() - return - } - - h.response.SetStatus(backendStatusCode) - h.reporter.CaptureWebSocketUpdate() -} - -func (h *RequestHandler) SanitizeRequestConnection() { - if len(h.hopByHopHeadersToFilter) == 0 { - return - } - connections := h.request.Header.Values("Connection") - for index, connection := range connections { - if connection != "" { - values := strings.Split(connection, ",") - connectionHeader := []string{} - for i := range values { - trimmedValue := strings.TrimSpace(values[i]) - found := false - for _, item := range h.hopByHopHeadersToFilter { - if strings.EqualFold(item, trimmedValue) { - found = true - break - } - } - if !found { - connectionHeader = append(connectionHeader, trimmedValue) - } - } - h.request.Header[http.CanonicalHeaderKey("Connection")][index] = strings.Join(connectionHeader, ", ") - } - } -} - -type connSuccessCB func(net.Conn, *route.Endpoint) error -type connFailureCB func(error) - -var nilConnSuccessCB = func(net.Conn, *route.Endpoint) error { return nil } -var nilConnFailureCB = func(error) {} - -func (h *RequestHandler) serveTcp( - iter route.EndpointIterator, - onConnectionSucceeded connSuccessCB, - onConnectionFailed connFailureCB, -) (int, error) { - var err error - var backendConnection net.Conn - var endpoint *route.Endpoint - - if onConnectionSucceeded == nil { - onConnectionSucceeded = nilConnSuccessCB - } - if onConnectionFailed == nil { - onConnectionFailed = nilConnFailureCB - } - - reqInfo, err := handlers.ContextRequestInfo(h.request) - if err != nil { - return 0, err - } - // httptrace.ClientTrace only works for Transports, so we have to do the tracing manually - var dialStartedAt, dialFinishedAt, tlsHandshakeStartedAt, tlsHandshakeFinishedAt time.Time - - retry := 0 - for { - endpoint = iter.Next(retry) - if endpoint == nil { - err = NoEndpointsAvailable - h.HandleBadGateway(err, h.request) - return 0, err - } - - iter.PreRequest(endpoint) - - dialStartedAt = time.Now() - backendConnection, err = net.DialTimeout("tcp", endpoint.CanonicalAddr(), h.endpointDialTimeout) - dialFinishedAt = time.Now() - if endpoint.IsTLS() { - tlsConfigLocal := utils.TLSConfigWithServerName(endpoint.ServerCertDomainSAN, h.tlsConfigTemplate, false) - tlsBackendConnection := tls.Client(backendConnection, tlsConfigLocal) - tlsHandshakeStartedAt = time.Now() - err = tlsBackendConnection.Handshake() - tlsHandshakeFinishedAt = time.Now() - backendConnection = tlsBackendConnection - } - - if err == nil { - defer iter.PostRequest(endpoint) - break - } else { - iter.PostRequest(endpoint) - } - - reqInfo.FailedAttempts++ - reqInfo.LastFailedAttemptFinishedAt = time.Now() - - iter.EndpointFailed(err) - onConnectionFailed(err) - - retry++ - if retry == h.maxAttempts { - return 0, err - } - } - if backendConnection == nil { - return 0, nil - } - defer backendConnection.Close() - - err = onConnectionSucceeded(backendConnection, endpoint) - if err != nil { - return 0, err - } - - client, _, err := h.hijack() - if err != nil { - return 0, err - } - defer client.Close() - - // Round trip was successful at this point - reqInfo.RoundTripSuccessful = true - - // Record the times from the last attempt, but only if it succeeded. - reqInfo.DialStartedAt = dialStartedAt - reqInfo.DialFinishedAt = dialFinishedAt - reqInfo.TlsHandshakeStartedAt = tlsHandshakeStartedAt - reqInfo.TlsHandshakeFinishedAt = tlsHandshakeFinishedAt - - // Any status code has already been sent to the client, - // but this is the value that gets written to the access logs - backendStatusCode, err := h.forwarder.ForwardIO(client, backendConnection) - - // add X-Cf-RouterError header to improve traceability in access log - if err != nil { - errMsg := fmt.Sprintf("endpoint_failure (%s)", err.Error()) - handlers.AddRouterErrorHeader(h.response, errMsg) - } - - return backendStatusCode, nil -} - -func (h *RequestHandler) setupRequest(endpoint *route.Endpoint) { - h.setRequestURL(endpoint.CanonicalAddr()) - h.setRequestXForwardedFor() - SetRequestXRequestStart(h.request) -} - -func (h *RequestHandler) setRequestURL(addr string) { - h.request.URL.Scheme = "http" - h.request.URL.Host = addr -} - -func (h *RequestHandler) setRequestXForwardedFor() { - if clientIP, _, err := net.SplitHostPort(h.request.RemoteAddr); err == nil { - // If we aren't the first proxy retain prior - // X-Forwarded-For information as a comma+space - // separated list and fold multiple headers into one. - if prior, ok := h.request.Header["X-Forwarded-For"]; ok { - clientIP = strings.Join(prior, ", ") + ", " + clientIP - } - h.request.Header.Set("X-Forwarded-For", clientIP) - } -} - -func SetRequestXRequestStart(request *http.Request) { - if _, ok := request.Header[http.CanonicalHeaderKey("X-Request-Start")]; !ok { - request.Header.Set("X-Request-Start", strconv.FormatInt(time.Now().UnixNano()/1e6, 10)) - } -} - -func SetRequestXCfInstanceId(request *http.Request, endpoint *route.Endpoint) { - value := endpoint.PrivateInstanceId - if value == "" { - value = endpoint.CanonicalAddr() - } - - request.Header.Set(router_http.CfInstanceIdHeader, value) -} - -func (h *RequestHandler) hijack() (client net.Conn, io *bufio.ReadWriter, err error) { - return h.response.Hijack() -} diff --git a/proxy/handler/request_handler_test.go b/proxy/handler/request_handler_test.go deleted file mode 100644 index 1d4b20b0..00000000 --- a/proxy/handler/request_handler_test.go +++ /dev/null @@ -1,240 +0,0 @@ -package handler_test - -import ( - "crypto/tls" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "time" - - "code.cloudfoundry.org/gorouter/errorwriter" - metric "code.cloudfoundry.org/gorouter/metrics/fakes" - "code.cloudfoundry.org/gorouter/proxy/handler" - "code.cloudfoundry.org/gorouter/proxy/utils" - iter "code.cloudfoundry.org/gorouter/route/fakes" - "code.cloudfoundry.org/gorouter/test_util" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - "github.com/onsi/gomega/gbytes" -) - -var _ = Describe("RequestHandler", func() { - var ( - rh *handler.RequestHandler - logger *test_util.TestZapLogger - ew = errorwriter.NewPlaintextErrorWriter() - req *http.Request - pr utils.ProxyResponseWriter - ) - - BeforeEach(func() { - logger = test_util.NewTestZapLogger("test") - pr = utils.NewProxyResponseWriter(httptest.NewRecorder()) - }) - - Context("when disableLogForwardedFor is set to true", func() { - BeforeEach(func() { - req = &http.Request{ - RemoteAddr: "downtown-nino-brown", - Host: "gersh", - URL: &url.URL{ - Path: "/foo", - }, - Header: http.Header{ - "X-Forwarded-For": []string{"1.1.1.1"}, - }, - } - rh = handler.NewRequestHandler( - req, pr, - &metric.FakeProxyReporter{}, logger, ew, - time.Second*2, time.Second*2, 3, &tls.Config{}, - nil, - handler.DisableXFFLogging(true), - ) - }) - Describe("HandleBadGateway", func() { - It("does not include the X-Forwarded-For header in log output", func() { - rh.HandleBadGateway(nil, req) - Eventually(logger.Buffer()).Should(gbytes.Say(`"X-Forwarded-For":"-"`)) - }) - }) - - Describe("HandleTCPRequest", func() { - It("does not include the X-Forwarded-For header in log output", func() { - rh.HandleTcpRequest(&iter.FakeEndpointIterator{}) - Eventually(logger.Buffer()).Should(gbytes.Say(`"X-Forwarded-For":"-"`)) - }) - - Context("when serveTcp returns an error", func() { - It("does not include X-Forwarded-For in log output", func() { - i := &iter.FakeEndpointIterator{} - i.NextReturns(nil) - rh.HandleTcpRequest(i) - Eventually(logger.Buffer()).Should(gbytes.Say("tcp-request-failed")) - Eventually(logger.Buffer()).Should(gbytes.Say(`"X-Forwarded-For":"-"`)) - }) - }) - }) - - Describe("HandleTCPRequest", func() { - It("does not include the X-Forwarded-For header in log output", func() { - rh.HandleWebSocketRequest(&iter.FakeEndpointIterator{}) - Eventually(logger.Buffer()).Should(gbytes.Say(`"X-Forwarded-For":"-"`)) - }) - }) - }) - - Context("when disableLogSourceIP is set to true", func() { - BeforeEach(func() { - req = &http.Request{ - RemoteAddr: "downtown-nino-brown", - Host: "gersh", - URL: &url.URL{ - Path: "/foo", - }, - } - rh = handler.NewRequestHandler( - req, pr, - &metric.FakeProxyReporter{}, logger, ew, - time.Second*2, time.Second*2, 3, &tls.Config{}, - nil, - handler.DisableSourceIPLogging(true), - ) - }) - Describe("HandleBadGateway", func() { - It("does not include the RemoteAddr header in log output", func() { - rh.HandleBadGateway(nil, req) - Eventually(logger.Buffer()).Should(gbytes.Say(`"RemoteAddr":"-"`)) - }) - }) - - Describe("HandleTCPRequest", func() { - It("does not include the RemoteAddr header in log output", func() { - rh.HandleTcpRequest(&iter.FakeEndpointIterator{}) - Eventually(logger.Buffer()).Should(gbytes.Say(`"RemoteAddr":"-"`)) - }) - - Context("when serveTcp returns an error", func() { - It("does not include RemoteAddr in log output", func() { - i := &iter.FakeEndpointIterator{} - i.NextReturns(nil) - rh.HandleTcpRequest(i) - Eventually(logger.Buffer()).Should(gbytes.Say("tcp-request-failed")) - Eventually(logger.Buffer()).Should(gbytes.Say(`"RemoteAddr":"-"`)) - }) - }) - }) - - Describe("HandleTCPRequest", func() { - It("does not include the RemoteAddr header in log output", func() { - rh.HandleWebSocketRequest(&iter.FakeEndpointIterator{}) - Eventually(logger.Buffer()).Should(gbytes.Say(`"RemoteAddr":"-"`)) - }) - }) - }) - - Context("when connection header has forbidden values", func() { - var hopByHopHeadersToFilter []string - BeforeEach(func() { - hopByHopHeadersToFilter = []string{ - "X-Forwarded-For", - "X-Forwarded-Proto", - "B3", - "X-B3", - "X-B3-SpanID", - "X-B3-TraceID", - "X-Request-Start", - "X-Forwarded-Client-Cert", - } - }) - Context("For a single Connection header", func() { - BeforeEach(func() { - req = &http.Request{ - RemoteAddr: "downtown-nino-brown", - Host: "gersh", - URL: &url.URL{ - Path: "/foo", - }, - Header: http.Header{}, - } - values := []string{ - "Content-Type", - "User-Agent", - "X-Forwarded-Proto", - "Accept", - "X-B3-Spanid", - "X-B3-Traceid", - "B3", - "X-Request-Start", - "Cookie", - "X-Cf-Applicationid", - "X-Cf-Instanceid", - "X-Cf-Instanceindex", - "X-Vcap-Request-Id", - } - req.Header.Add("Connection", strings.Join(values, ", ")) - rh = handler.NewRequestHandler( - req, pr, - &metric.FakeProxyReporter{}, logger, ew, - time.Second*2, time.Second*2, 3, &tls.Config{}, - hopByHopHeadersToFilter, - handler.DisableSourceIPLogging(true), - ) - }) - Describe("SanitizeRequestConnection", func() { - It("Filters hop-by-hop headers", func() { - rh.SanitizeRequestConnection() - Expect(req.Header.Get("Connection")).To(Equal("Content-Type, User-Agent, Accept, Cookie, X-Cf-Applicationid, X-Cf-Instanceid, X-Cf-Instanceindex, X-Vcap-Request-Id")) - }) - }) - }) - Context("For multiple Connection headers", func() { - BeforeEach(func() { - req = &http.Request{ - RemoteAddr: "downtown-nino-brown", - Host: "gersh", - URL: &url.URL{ - Path: "/foo", - }, - Header: http.Header{}, - } - req.Header.Add("Connection", strings.Join([]string{ - "Content-Type", - "X-B3-Spanid", - "X-B3-Traceid", - "X-Request-Start", - "Cookie", - "X-Cf-Instanceid", - "X-Vcap-Request-Id", - }, ", ")) - req.Header.Add("Connection", strings.Join([]string{ - "Content-Type", - "User-Agent", - "X-Forwarded-Proto", - "Accept", - "X-B3-Spanid", - "X-Cf-Applicationid", - "X-Cf-Instanceindex", - }, ", ")) - rh = handler.NewRequestHandler( - req, pr, - &metric.FakeProxyReporter{}, logger, ew, - time.Second*2, time.Second*2, 3, &tls.Config{}, - hopByHopHeadersToFilter, - handler.DisableSourceIPLogging(true), - ) - }) - Describe("SanitizeRequestConnection", func() { - It("Filters hop-by-hop headers", func() { - rh.SanitizeRequestConnection() - headers := req.Header.Values("Connection") - Expect(len(headers)).To(Equal(2)) - Expect(headers[0]).To(Equal("Content-Type, Cookie, X-Cf-Instanceid, X-Vcap-Request-Id")) - Expect(headers[1]).To(Equal("Content-Type, User-Agent, Accept, X-Cf-Applicationid, X-Cf-Instanceindex")) - }) - }) - }) - }) -}) diff --git a/proxy/proxy.go b/proxy/proxy.go index 3cd8b887..81da8634 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httputil" "net/url" + "strconv" "strings" "time" @@ -24,11 +25,9 @@ import ( "code.cloudfoundry.org/gorouter/logger" "code.cloudfoundry.org/gorouter/metrics" "code.cloudfoundry.org/gorouter/proxy/fails" - "code.cloudfoundry.org/gorouter/proxy/handler" "code.cloudfoundry.org/gorouter/proxy/round_tripper" "code.cloudfoundry.org/gorouter/proxy/utils" "code.cloudfoundry.org/gorouter/registry" - "code.cloudfoundry.org/gorouter/route" "code.cloudfoundry.org/gorouter/routeservice" ) @@ -171,6 +170,7 @@ func NewProxy( logger, errorWriter, )) + n.Use(handlers.NewHopByHop(cfg, logger)) n.Use(&handlers.XForwardedProto{ SkipSanitization: SkipSanitizeXFP(routeServiceHandler.(*handlers.RouteService)), ForceForwardedProtoHttps: p.config.ForceForwardedProtoHttps, @@ -227,49 +227,11 @@ func (p *proxy) ServeHTTP(responseWriter http.ResponseWriter, request *http.Requ if err != nil { logger.Panic("request-info-err", zap.Error(err)) } - handler := handler.NewRequestHandler( - request, - proxyWriter, - p.reporter, - p.logger, - p.errorWriter, - p.config.EndpointDialTimeout, - p.config.WebsocketDialTimeout, - p.config.Backends.MaxAttempts, - p.backendTLSConfig, - p.config.HopByHopHeadersToFilter, - handler.DisableXFFLogging(p.config.Logging.DisableLogForwardedFor), - handler.DisableSourceIPLogging(p.config.Logging.DisableLogSourceIP), - ) if reqInfo.RoutePool == nil { logger.Panic("request-info-err", zap.Error(errors.New("failed-to-access-RoutePool"))) } - nestedIterator, err := handlers.EndpointIteratorForRequest(logger, request, p.config.LoadBalance, p.config.StickySessionCookieNames, p.config.StickySessionsForAuthNegotiate, p.config.LoadBalanceAZPreference, p.config.Zone) - if err != nil { - logger.Panic("request-info-err", zap.Error(err)) - } - - endpointIterator := &wrappedIterator{ - nested: nestedIterator, - - afterNext: func(endpoint *route.Endpoint) { - if endpoint != nil { - reqInfo.RouteEndpoint = endpoint - p.reporter.CaptureRoutingRequest(endpoint) - } - }, - } - - handler.SanitizeRequestConnection() - if handlers.IsWebSocketUpgrade(request) { - reqInfo.AppRequestStartedAt = time.Now() - handler.HandleWebSocketRequest(endpointIterator) - reqInfo.AppRequestFinishedAt = time.Now() - return - } - reqInfo.AppRequestStartedAt = time.Now() next(responseWriter, request) reqInfo.AppRequestFinishedAt = time.Now() @@ -298,10 +260,16 @@ func (p *proxy) setupProxyRequest(target *http.Request) { } target.URL.RawQuery = "" - handler.SetRequestXRequestStart(target) + setRequestXRequestStart(target) target.Header.Del(router_http.CfAppInstance) } +func setRequestXRequestStart(request *http.Request) { + if _, ok := request.Header[http.CanonicalHeaderKey("X-Request-Start")]; !ok { + request.Header.Set("X-Request-Start", strconv.FormatInt(time.Now().UnixNano()/1e6, 10)) + } +} + func escapePathAndPreserveSlashes(unescaped string) string { parts := strings.Split(unescaped, "/") escapedPath := "" @@ -313,26 +281,3 @@ func escapePathAndPreserveSlashes(unescaped string) string { return escapedPath } - -type wrappedIterator struct { - nested route.EndpointIterator - afterNext func(*route.Endpoint) -} - -func (i *wrappedIterator) Next(attempt int) *route.Endpoint { - e := i.nested.Next(attempt) - if i.afterNext != nil { - i.afterNext(e) - } - return e -} - -func (i *wrappedIterator) EndpointFailed(err error) { - i.nested.EndpointFailed(err) -} -func (i *wrappedIterator) PreRequest(e *route.Endpoint) { - i.nested.PreRequest(e) -} -func (i *wrappedIterator) PostRequest(e *route.Endpoint) { - i.nested.PostRequest(e) -} diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index e4e56fab..8adefcbd 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -2444,12 +2444,12 @@ var _ = Describe("Proxy", func() { req, err := http.ReadRequest(conn.Reader) Expect(err).NotTo(HaveOccurred()) - done <- req.Header.Get("Upgrade") == "WebsockeT" && - req.Header.Get("Connection") == "UpgradE" + done <- req.Header.Get("Upgrade") == "Websocket" && + req.Header.Get("Connection") == "Upgrade" resp := test_util.NewResponse(http.StatusSwitchingProtocols) - resp.Header.Set("Upgrade", "WebsockeT") - resp.Header.Set("Connection", "UpgradE") + resp.Header.Set("Upgrade", "Websocket") + resp.Header.Set("Connection", "Upgrade") conn.WriteResponse(resp) @@ -2462,8 +2462,8 @@ var _ = Describe("Proxy", func() { conn := dialProxy(proxyServer) req := test_util.NewRequest("GET", "ws", "/chat", nil) - req.Header.Set("Upgrade", "WebsockeT") - req.Header.Set("Connection", "UpgradE") + req.Header.Set("Upgrade", "Websocket") + req.Header.Set("Connection", "Upgrade") conn.WriteRequest(req) @@ -2473,8 +2473,8 @@ var _ = Describe("Proxy", func() { resp, _ := conn.ReadResponse() Expect(resp.StatusCode).To(Equal(http.StatusSwitchingProtocols)) - Expect(resp.Header.Get("Upgrade")).To(Equal("WebsockeT")) - Expect(resp.Header.Get("Connection")).To(Equal("UpgradE")) + Expect(resp.Header.Get("Upgrade")).To(Equal("Websocket")) + Expect(resp.Header.Get("Connection")).To(Equal("Upgrade")) conn.WriteLine("hello from client") conn.CheckLine("hello from server") @@ -2490,12 +2490,12 @@ var _ = Describe("Proxy", func() { req, err := http.ReadRequest(conn.Reader) Expect(err).NotTo(HaveOccurred()) - done <- req.Header.Get("Upgrade") == "WebsockeT" && - req.Header.Get("Connection") == "UpgradE" + done <- req.Header.Get("Upgrade") == "Websocket" && + req.Header.Get("Connection") == "Upgrade" resp := test_util.NewResponse(http.StatusSwitchingProtocols) - resp.Header.Set("Upgrade", "WebsockeT") - resp.Header.Set("Connection", "UpgradE") + resp.Header.Set("Upgrade", "Websocket") + resp.Header.Set("Connection", "Upgrade") conn.WriteResponse(resp) @@ -2508,8 +2508,8 @@ var _ = Describe("Proxy", func() { conn := dialProxy(proxyServer) req := test_util.NewRequest("GET", "ws", "/chat", nil) - req.Header.Set("Upgrade", "WebsockeT") - req.Header.Set("Connection", "UpgradE") + req.Header.Set("Upgrade", "Websocket") + req.Header.Set("Connection", "Upgrade") conn.WriteRequest(req) @@ -2519,8 +2519,8 @@ var _ = Describe("Proxy", func() { resp, _ := conn.ReadResponse() Expect(resp.StatusCode).To(Equal(http.StatusSwitchingProtocols)) - Expect(resp.Header.Get("Upgrade")).To(Equal("WebsockeT")) - Expect(resp.Header.Get("Connection")).To(Equal("UpgradE")) + Expect(resp.Header.Get("Upgrade")).To(Equal("Websocket")) + Expect(resp.Header.Get("Connection")).To(Equal("Upgrade")) conn.WriteLine("hello from client") conn.CheckLine("hello from server") @@ -2535,8 +2535,10 @@ var _ = Describe("Proxy", func() { req, err := http.ReadRequest(conn.Reader) Expect(err).NotTo(HaveOccurred()) + // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. + // Only "Upgrade" will be added again by httputil.ReverseProxy done <- req.Header.Get("Upgrade") == "Websocket" && - req.Header.Get("Connection") == "keep-alive, Upgrade" + req.Header.Get("Connection") == "Upgrade" resp := test_util.NewResponse(http.StatusSwitchingProtocols) resp.Header.Set("Upgrade", "Websocket") @@ -2581,9 +2583,10 @@ var _ = Describe("Proxy", func() { req, err := http.ReadRequest(conn.Reader) Expect(err).NotTo(HaveOccurred()) + // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. + // Only "Upgrade" will be added again by httputil.ReverseProxy done <- req.Header.Get("Upgrade") == "Websocket" && - req.Header[http.CanonicalHeaderKey("Connection")][0] == "keep-alive" && - req.Header[http.CanonicalHeaderKey("Connection")][1] == "Upgrade" + req.Header.Get("Connection") == "Upgrade" resp := test_util.NewResponse(http.StatusSwitchingProtocols) resp.Header.Set("Upgrade", "Websocket") diff --git a/proxy/round_tripper/proxy_round_tripper.go b/proxy/round_tripper/proxy_round_tripper.go index 56f468bb..2280e496 100644 --- a/proxy/round_tripper/proxy_round_tripper.go +++ b/proxy/round_tripper/proxy_round_tripper.go @@ -1,6 +1,7 @@ package round_tripper import ( + router_http "code.cloudfoundry.org/gorouter/common/http" "context" "errors" "fmt" @@ -21,7 +22,6 @@ import ( "code.cloudfoundry.org/gorouter/logger" "code.cloudfoundry.org/gorouter/metrics" "code.cloudfoundry.org/gorouter/proxy/fails" - "code.cloudfoundry.org/gorouter/proxy/handler" "code.cloudfoundry.org/gorouter/proxy/utils" "code.cloudfoundry.org/gorouter/route" "code.cloudfoundry.org/gorouter/routeservice" @@ -40,6 +40,8 @@ const ( AuthNegotiateHeaderCookieMaxAgeInSeconds = 60 ) +var NoEndpointsAvailable = errors.New("No endpoints available") + //go:generate counterfeiter -o fakes/fake_proxy_round_tripper.go . ProxyRoundTripper type ProxyRoundTripper interface { http.RoundTripper @@ -302,12 +304,25 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response responseWriterMu.Lock() defer responseWriterMu.Unlock() rt.errorHandler.HandleError(reqInfo.ProxyResponseWriter, err) + if handlers.IsWebSocketUpgrade(request) { + rt.combinedReporter.CaptureWebSocketFailure() + } return nil, err } // Round trip was successful at this point reqInfo.RoundTripSuccessful = true + // Set status code for access log + if res != nil { + reqInfo.ProxyResponseWriter.SetStatus(res.StatusCode) + } + + // Write metric for ws upgrades + if handlers.IsWebSocketUpgrade(request) { + rt.combinedReporter.CaptureWebSocketUpdate() + } + // Record the times from the last attempt, but only if it succeeded. reqInfo.DnsStartedAt = trace.DnsStart() reqInfo.DnsFinishedAt = trace.DnsDone() @@ -341,7 +356,7 @@ func (rt *roundTripper) backendRoundTrip(request *http.Request, endpoint *route. request.URL.Host = endpoint.CanonicalAddr() request.Header.Set("X-CF-ApplicationID", endpoint.ApplicationId) request.Header.Set("X-CF-InstanceIndex", endpoint.PrivateInstanceIndex) - handler.SetRequestXCfInstanceId(request, endpoint) + setRequestXCfInstanceId(request, endpoint) // increment connection stats iter.PreRequest(endpoint) @@ -356,7 +371,7 @@ func (rt *roundTripper) backendRoundTrip(request *http.Request, endpoint *route. } func (rt *roundTripper) timedRoundTrip(tr http.RoundTripper, request *http.Request, logger logger.Logger) (*http.Response, error) { - if rt.config.EndpointTimeout <= 0 { + if rt.config.EndpointTimeout <= 0 || handlers.IsWebSocketUpgrade(request) { return tr.RoundTrip(request) } @@ -368,7 +383,7 @@ func (rt *roundTripper) timedRoundTrip(tr http.RoundTripper, request *http.Reque vrid := request.Header.Get(handlers.VcapRequestIdHeader) go func() { <-reqCtx.Done() - if reqCtx.Err() == context.DeadlineExceeded { + if errors.Is(reqCtx.Err(), context.DeadlineExceeded) { logger.Error("backend-request-timeout", zap.Error(reqCtx.Err()), zap.String("vcap_request_id", vrid)) } cancel() @@ -386,12 +401,21 @@ func (rt *roundTripper) timedRoundTrip(tr http.RoundTripper, request *http.Reque func (rt *roundTripper) selectEndpoint(iter route.EndpointIterator, request *http.Request, attempt int) (*route.Endpoint, error) { endpoint := iter.Next(attempt) if endpoint == nil { - return nil, handler.NoEndpointsAvailable + return nil, NoEndpointsAvailable } return endpoint, nil } +func setRequestXCfInstanceId(request *http.Request, endpoint *route.Endpoint) { + value := endpoint.PrivateInstanceId + if value == "" { + value = endpoint.CanonicalAddr() + } + + request.Header.Set(router_http.CfInstanceIdHeader, value) +} + func setupStickySession( response *http.Response, endpoint *route.Endpoint, diff --git a/proxy/round_tripper/proxy_round_tripper_test.go b/proxy/round_tripper/proxy_round_tripper_test.go index 15376399..113e060f 100644 --- a/proxy/round_tripper/proxy_round_tripper_test.go +++ b/proxy/round_tripper/proxy_round_tripper_test.go @@ -26,7 +26,6 @@ import ( "code.cloudfoundry.org/gorouter/handlers" "code.cloudfoundry.org/gorouter/metrics/fakes" "code.cloudfoundry.org/gorouter/proxy/fails" - "code.cloudfoundry.org/gorouter/proxy/handler" "code.cloudfoundry.org/gorouter/proxy/round_tripper" "code.cloudfoundry.org/gorouter/proxy/utils" "code.cloudfoundry.org/gorouter/route" @@ -663,7 +662,7 @@ var _ = Describe("ProxyRoundTripper", func() { It("returns a 502 Bad Gateway response", func() { backendRes, err := proxyRoundTripper.RoundTrip(req) Expect(backendRes).To(BeNil()) - Expect(err).To(Equal(handler.NoEndpointsAvailable)) + Expect(err).To(Equal(round_tripper.NoEndpointsAvailable)) Expect(reqInfo.RouteEndpoint).To(BeNil()) Expect(reqInfo.RoundTripSuccessful).To(BeFalse()) @@ -673,7 +672,7 @@ var _ = Describe("ProxyRoundTripper", func() { proxyRoundTripper.RoundTrip(req) Expect(errorHandler.HandleErrorCallCount()).To(Equal(1)) _, err := errorHandler.HandleErrorArgsForCall(0) - Expect(err).To(Equal(handler.NoEndpointsAvailable)) + Expect(err).To(Equal(round_tripper.NoEndpointsAvailable)) }) It("logs a message with `select-endpoint-failed`", func() { @@ -685,21 +684,21 @@ var _ = Describe("ProxyRoundTripper", func() { It("does not capture any routing requests to the backend", func() { _, err := proxyRoundTripper.RoundTrip(req) - Expect(err).To(Equal(handler.NoEndpointsAvailable)) + Expect(err).To(Equal(round_tripper.NoEndpointsAvailable)) Expect(combinedReporter.CaptureRoutingRequestCallCount()).To(Equal(0)) }) It("does not log anything about route services", func() { _, err := proxyRoundTripper.RoundTrip(req) - Expect(err).To(Equal(handler.NoEndpointsAvailable)) + Expect(err).To(Equal(round_tripper.NoEndpointsAvailable)) Expect(logger.Buffer()).ToNot(gbytes.Say(`route-service`)) }) It("does not report the endpoint failure", func() { _, err := proxyRoundTripper.RoundTrip(req) - Expect(err).To(MatchError(handler.NoEndpointsAvailable)) + Expect(err).To(MatchError(round_tripper.NoEndpointsAvailable)) Expect(logger.Buffer()).ToNot(gbytes.Say(`backend-endpoint-failed`)) }) diff --git a/router/router_test.go b/router/router_test.go index bad719ff..5d905cf3 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -1137,7 +1137,7 @@ var _ = Describe("Router", func() { }) }) - It("websockets do not terminate", func() { + It("websocket connections are not affected by EndpointTimeout", func() { app := test.NewWebSocketApp( []route.Uri{"ws-app." + test_util.LocalhostDNS}, config.Port, diff --git a/test/websocket_app.go b/test/websocket_app.go index c386fb81..77f0c243 100644 --- a/test/websocket_app.go +++ b/test/websocket_app.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "net/http" + "strings" "time" nats "github.com/nats-io/nats.go" @@ -22,7 +23,7 @@ func NewWebSocketApp(urls []route.Uri, rPort uint16, mbusClient *nats.Conn, dela defer ginkgo.GinkgoRecover() Expect(r.Header.Get("Upgrade")).To(Equal("websocket")) - Expect(r.Header.Get("Connection")).To(Equal("upgrade")) + Expect(strings.ToLower(r.Header.Get("Connection"))).To(Equal("upgrade")) conn, _, err := w.(http.Hijacker).Hijack() Expect(err).ToNot(HaveOccurred()) @@ -49,7 +50,7 @@ func NewFailingWebSocketApp(urls []route.Uri, rPort uint16, mbusClient *nats.Con defer ginkgo.GinkgoRecover() Expect(r.Header.Get("Upgrade")).To(Equal("websocket")) - Expect(r.Header.Get("Connection")).To(Equal("upgrade")) + Expect(strings.ToLower(r.Header.Get("Connection"))).To(Equal("upgrade")) conn, _, err := w.(http.Hijacker).Hijack() Expect(err).ToNot(HaveOccurred()) @@ -60,13 +61,13 @@ func NewFailingWebSocketApp(urls []route.Uri, rPort uint16, mbusClient *nats.Con return app } -func NewHangingWebSocketApp(urls []route.Uri, rPort uint16, mbusClient *nats.Conn, routeServiceUrl string) *common.TestApp { +func NewNotUpgradingWebSocketApp(urls []route.Uri, rPort uint16, mbusClient *nats.Conn, routeServiceUrl string) *common.TestApp { app := common.NewTestApp(urls, rPort, mbusClient, nil, routeServiceUrl) app.AddHandler("/", func(w http.ResponseWriter, r *http.Request) { defer ginkgo.GinkgoRecover() Expect(r.Header.Get("Upgrade")).To(Equal("websocket")) - Expect(r.Header.Get("Connection")).To(Equal("upgrade")) + Expect(strings.ToLower(r.Header.Get("Connection"))).To(Equal("upgrade")) conn, _, err := w.(http.Hijacker).Hijack() Expect(err).ToNot(HaveOccurred()) @@ -81,10 +82,9 @@ func NewHangingWebSocketApp(urls []route.Uri, rPort uint16, mbusClient *nats.Con bytes.NewBufferString("\r\nbeginning of the response body goes here\r\n\r\n"), bytes.NewBuffer(make([]byte, 10024)), // bigger than the internal buffer of the http stdlib bytes.NewBufferString("\r\nmore response here, probably won't be seen by client\r\n"), - &test_util.HangingReadCloser{}), + ), ) x.WriteResponse(resp) - panic("you won't get here in a test") }) return app