diff --git a/client/httpclient_test.go b/client/httpclient_test.go index a3845c45..fc670411 100644 --- a/client/httpclient_test.go +++ b/client/httpclient_test.go @@ -3,13 +3,17 @@ package client import ( "compress/gzip" "context" + "errors" "io" "net/http" + "net/http/httptest" + "net/url" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "google.golang.org/protobuf/proto" "github.com/open-telemetry/opamp-go/client/internal" @@ -223,3 +227,87 @@ func TestHTTPClientStartWithZeroHeartbeatInterval(t *testing.T) { // Shutdown the Server. srv.Close() } + +func mockRedirectHTTP(t testing.TB, viaLen int, err error) *checkRedirectMock { + m := &checkRedirectMock{ + t: t, + viaLen: viaLen, + http: true, + } + m.On("CheckRedirect", mock.Anything, mock.Anything, mock.Anything).Return(err) + return m +} + +func TestRedirectHTTP(t *testing.T) { + redirectee := internal.StartMockServer(t) + tests := []struct { + Name string + Redirector *httptest.Server + ExpError bool + MockRedirect *checkRedirectMock + }{ + { + Name: "simple redirect", + Redirector: redirectServer("http://"+redirectee.Endpoint, 302), + }, + { + Name: "check redirect", + Redirector: redirectServer("http://"+redirectee.Endpoint, 302), + MockRedirect: mockRedirectHTTP(t, 1, nil), + }, + { + Name: "check redirect returns error", + Redirector: redirectServer("http://"+redirectee.Endpoint, 302), + MockRedirect: mockRedirectHTTP(t, 1, errors.New("hello")), + ExpError: true, + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + var connectErr atomic.Value + var connected atomic.Value + + settings := &types.StartSettings{ + Callbacks: types.Callbacks{ + OnConnect: func(ctx context.Context) { + connected.Store(1) + }, + OnConnectFailed: func(ctx context.Context, err error) { + connectErr.Store(err) + }, + }, + } + if test.MockRedirect != nil { + settings.Callbacks = types.Callbacks{ + OnConnect: func(ctx context.Context) { + connected.Store(1) + }, + OnConnectFailed: func(ctx context.Context, err error) { + connectErr.Store(err) + }, + CheckRedirect: test.MockRedirect.CheckRedirect, + } + } + reURL, _ := url.Parse(test.Redirector.URL) // err can't be non-nil + settings.OpAMPServerURL = reURL.String() + client := NewHTTP(nil) + prepareClient(t, settings, client) + + err := client.Start(context.Background(), *settings) + if err != nil { + t.Fatal(err) + } + defer client.Stop(context.Background()) + // Wait for connection to be established. + eventually(t, func() bool { + return connected.Load() != nil || connectErr.Load() != nil + }) + if test.ExpError && connectErr.Load() == nil { + t.Error("expected non-nil error") + } else if err := connectErr.Load(); !test.ExpError && err != nil { + t.Fatal(err) + } + }) + } +} diff --git a/client/internal/httpsender.go b/client/internal/httpsender.go index 502bf7e4..a97e1311 100644 --- a/client/internal/httpsender.go +++ b/client/internal/httpsender.go @@ -98,6 +98,14 @@ func (h *HTTPSender) Run( h.callbacks = callbacks h.receiveProcessor = newReceivedProcessor(h.logger, callbacks, h, clientSyncedState, packagesStateProvider, capabilities, packageSyncMutex) + // we need to detect if the redirect was ever set, if not, we want default behaviour + if callbacks.CheckRedirect != nil { + h.client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + // viaResp only non-nil for ws client + return callbacks.CheckRedirect(req, via, nil) + } + } + for { pollingTimer := time.NewTimer(time.Millisecond * time.Duration(atomic.LoadInt64(&h.pollingIntervalMs))) select { diff --git a/client/types/callbacks.go b/client/types/callbacks.go index a5dc02ce..48d5f832 100644 --- a/client/types/callbacks.go +++ b/client/types/callbacks.go @@ -2,6 +2,7 @@ package types import ( "context" + "net/http" "github.com/open-telemetry/opamp-go/protobufs" ) @@ -116,6 +117,19 @@ type Callbacks struct { // OnCommand is called when the Server requests that the connected Agent perform a command. OnCommand func(ctx context.Context, command *protobufs.ServerToAgentCommand) error + + // CheckRedirect is called before following a redirect, allowing the client + // the opportunity to observe the redirect chain, and optionally terminate + // following redirects early. + // + // CheckRedirect is intended to be similar, although not exactly equivalent, + // to net/http.Client's CheckRedirect feature. Unlike in net/http, the via + // parameter is a slice of HTTP responses, instead of requests. This gives + // an opportunity to users to know what the exact response headers and + // status were. The request itself can be obtained from the response. + // + // The responses in the via parameter are passed with their bodies closed. + CheckRedirect func(req *http.Request, viaReq []*http.Request, via []*http.Response) error } func (c *Callbacks) SetDefaults() { diff --git a/client/wsclient.go b/client/wsclient.go index 6219a28f..f19d8ab4 100644 --- a/client/wsclient.go +++ b/client/wsclient.go @@ -48,6 +48,12 @@ type wsClient struct { // Network connection timeout used for the WebSocket closing handshake. // This field is currently only modified during testing. connShutdownTimeout time.Duration + + // responseChain is used for the "via" argument in CheckRedirect. + // It is appended to with every redirect followed, and zeroed on a succesful + // connection. responseChain should only be referred to by the goroutine that + // runs tryConnectOnce and its synchronous callees. + responseChain []*http.Response } // NewWebSocket creates a new OpAMP Client that uses WebSocket transport. @@ -151,11 +157,77 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS return c.common.SendCustomMessage(message) } +func viaReq(resps []*http.Response) []*http.Request { + reqs := make([]*http.Request, 0, len(resps)) + for _, resp := range resps { + reqs = append(reqs, resp.Request) + } + return reqs +} + +// handleRedirect checks a failed websocket upgrade response for a 3xx response +// and a Location header. If found, it sets the URL to the location found in the +// header so that it is tried on the next retry, instead of the current URL. +func (c *wsClient) handleRedirect(ctx context.Context, resp *http.Response) error { + // append to the responseChain so that subsequent redirects will have access + c.responseChain = append(c.responseChain, resp) + + // very liberal handling of 3xx that largely ignores HTTP semantics + redirect, err := resp.Location() + if err != nil { + c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err) + return err + } + + // It's slightly tricky to make CheckRedirect work. The WS HTTP request is + // formed within the websocket library. To work around that, copy the + // previous request, available in the response, and set the URL to the new + // location. It should then result in the same URL that the websocket + // library will form. + nextRequest := resp.Request.Clone(ctx) + nextRequest.URL = redirect + + // if CheckRedirect results in an error, it gets returned, terminating + // redirection. As with stdlib, the error is wrapped in url.Error. + if c.common.Callbacks.CheckRedirect != nil { + if err := c.common.Callbacks.CheckRedirect(nextRequest, viaReq(c.responseChain), c.responseChain); err != nil { + return &url.Error{ + Op: "Get", + URL: nextRequest.URL.String(), + Err: err, + } + } + } + + // rewrite the scheme for the sake of tolerance + if redirect.Scheme == "http" { + redirect.Scheme = "ws" + } else if redirect.Scheme == "https" { + redirect.Scheme = "wss" + } + c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect) + + // Set the URL to the redirect, so that it connects to it on the + // next cycle. + c.url = redirect + + return nil +} + // Try to connect once. Returns an error if connection fails and optional retryAfter // duration to indicate to the caller to retry after the specified time as instructed // by the Server. func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinternal.OptionalDuration, err error) { var resp *http.Response + var redirecting bool + defer func() { + if err != nil && !redirecting { + c.responseChain = nil + if !c.common.IsStopping() { + c.common.Callbacks.OnConnectFailed(ctx, err) + } + } + }() conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.getHeader()) if err != nil { if !c.common.IsStopping() { @@ -164,22 +236,10 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinterna if resp != nil { duration := sharedinternal.ExtractRetryAfterHeader(resp) if resp.StatusCode >= 300 && resp.StatusCode < 400 { - // very liberal handling of 3xx that largely ignores HTTP semantics - redirect, err := resp.Location() - if err != nil { - c.common.Logger.Errorf(ctx, "%d redirect, but no valid location: %s", resp.StatusCode, err) + redirecting = true + if err := c.handleRedirect(ctx, resp); err != nil { return duration, err } - // rewrite the scheme for the sake of tolerance - if redirect.Scheme == "http" { - redirect.Scheme = "ws" - } else if redirect.Scheme == "https" { - redirect.Scheme = "wss" - } - c.common.Logger.Debugf(ctx, "%d redirect to %s", resp.StatusCode, redirect) - // Set the URL to the redirect, so that it connects to it on the - // next cycle. - c.url = redirect } else { c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status) } diff --git a/client/wsclient_test.go b/client/wsclient_test.go index cc9fd87d..436ceb55 100644 --- a/client/wsclient_test.go +++ b/client/wsclient_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "errors" "fmt" "net/http" "net/http/httptest" @@ -13,6 +14,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" @@ -322,12 +324,54 @@ func errServer() *httptest.Server { })) } +type checkRedirectMock struct { + mock.Mock + t testing.TB + viaLen int + http bool +} + +func (c *checkRedirectMock) CheckRedirect(req *http.Request, viaReq []*http.Request, via []*http.Response) error { + if req == nil { + c.t.Error("nil request in CheckRedirect") + return errors.New("nil request in CheckRedirect") + } + if len(viaReq) > c.viaLen { + c.t.Error("viaReq should be shorter than viaLen") + } + if !c.http { + // websocket transport + if len(via) > c.viaLen { + c.t.Error("via should be shorter than viaLen") + } + } + if !c.http && len(via) > 0 { + location, err := via[len(via)-1].Location() + if err != nil { + c.t.Error(err) + } + // the URL of the request should match the location header of the last response + assert.Equal(c.t, req.URL, location, "request URL should equal the location in the response") + } + return c.Called(req, via).Error(0) +} + +func mockRedirect(t testing.TB, viaLen int, err error) *checkRedirectMock { + m := &checkRedirectMock{ + t: t, + viaLen: viaLen, + } + m.On("CheckRedirect", mock.Anything, mock.Anything, mock.Anything).Return(err) + return m +} + func TestRedirectWS(t *testing.T) { redirectee := internal.StartMockServer(t) tests := []struct { - Name string - Redirector *httptest.Server - ExpError bool + Name string + Redirector *httptest.Server + ExpError bool + MockRedirect *checkRedirectMock }{ { Name: "redirect ws scheme", @@ -342,6 +386,17 @@ func TestRedirectWS(t *testing.T) { Redirector: errServer(), ExpError: true, }, + { + Name: "check redirect", + Redirector: redirectServer("ws://"+redirectee.Endpoint, 302), + MockRedirect: mockRedirect(t, 1, nil), + }, + { + Name: "check redirect returns error", + Redirector: redirectServer("ws://"+redirectee.Endpoint, 302), + MockRedirect: mockRedirect(t, 1, errors.New("hello")), + ExpError: true, + }, } for _, test := range tests { @@ -366,6 +421,9 @@ func TestRedirectWS(t *testing.T) { }, }, } + if test.MockRedirect != nil { + settings.Callbacks.CheckRedirect = test.MockRedirect.CheckRedirect + } reURL, err := url.Parse(test.Redirector.URL) assert.NoError(t, err) reURL.Scheme = "ws" @@ -388,10 +446,69 @@ func TestRedirectWS(t *testing.T) { // Stop the client. err = client.Stop(context.Background()) assert.NoError(t, err) + + if test.MockRedirect != nil { + test.MockRedirect.AssertCalled(t, "CheckRedirect", mock.Anything, mock.Anything) + } }) } } +func TestRedirectWSFollowChain(t *testing.T) { + // test that redirect following is recursive + redirectee := internal.StartMockServer(t) + middle := redirectServer("http://"+redirectee.Endpoint, 302) + middleURL, err := url.Parse(middle.URL) + if err != nil { + // unlikely + t.Fatal(err) + } + redirector := redirectServer("http://"+middleURL.Host, 302) + + var conn atomic.Value + redirectee.OnWSConnect = func(c *websocket.Conn) { + conn.Store(c) + } + + // Start an OpAMP/WebSocket client. + var connected int64 + var connectErr atomic.Value + mr := mockRedirect(t, 2, nil) + settings := types.StartSettings{ + Callbacks: types.Callbacks{ + OnConnect: func(ctx context.Context) { + atomic.StoreInt64(&connected, 1) + }, + OnConnectFailed: func(ctx context.Context, err error) { + if err != websocket.ErrBadHandshake { + connectErr.Store(err) + } + }, + CheckRedirect: mr.CheckRedirect, + }, + } + reURL, err := url.Parse(redirector.URL) + if err != nil { + // unlikely + t.Fatal(err) + } + reURL.Scheme = "ws" + settings.OpAMPServerURL = reURL.String() + client := NewWebSocket(nil) + startClient(t, settings, client) + + // Wait for connection to be established. + eventually(t, func() bool { + return conn.Load() != nil || connectErr.Load() != nil || client.lastInternalErr.Load() != nil + }) + + assert.True(t, connectErr.Load() == nil) + + // Stop the client. + err = client.Stop(context.Background()) + assert.NoError(t, err) +} + func TestHandlesStopBeforeStart(t *testing.T) { client := NewWebSocket(nil) require.Error(t, client.Stop(context.Background())) diff --git a/go.mod b/go.mod index 2742c8a9..4b9746d1 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22 require ( github.com/cenkalti/backoff/v4 v4.3.0 + github.com/google/go-cmp v0.5.6 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/stretchr/testify v1.10.0 @@ -12,8 +13,8 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/google/go-cmp v0.5.6 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3390120c..ea122d10 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=