Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide a custom retryablehttp.ErrorHandler to prevent losing the godo.ErrorResponse. #628

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 60 additions & 22 deletions godo.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"reflect"
Expand Down Expand Up @@ -178,6 +178,9 @@ type ErrorResponse struct {

// RequestID returned from the API, useful to contact support.
RequestID string `json:"request_id"`

// Attempts is the number of times the request was attempted when retries are enabled.
Attempts int
}

// Rate contains the rate limit for the current client.
Expand Down Expand Up @@ -311,6 +314,8 @@ func New(httpClient *http.Client, opts ...ClientOpt) (*Client, error) {
// By default this is nil and does not log.
retryableClient.Logger = c.RetryConfig.Logger

retryableClient.ErrorHandler = retryableErrorHandler

// if timeout is set, it is maintained before overwriting client with StandardClient()
retryableClient.HTTPClient.Timeout = c.HTTPClient.Timeout

Expand Down Expand Up @@ -474,28 +479,21 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res

resp, err := DoRequestWithClient(ctx, c.HTTPClient, req)
if err != nil {
// If we received a *url.Error from retryablehttp's `Do` method
// that already wraps a *godo.ErrorResponse, unwrap it to
// prevent a double nested error.
if urlErr, ok := err.(*url.Error); ok {
if _, ok := urlErr.Err.(*ErrorResponse); ok {
return nil, errors.Unwrap(err)
}
}
return nil, err
}
if c.onRequestCompleted != nil {
c.onRequestCompleted(req, resp)
}

defer func() {
// Ensure the response body is fully read and closed
// before we reconnect, so that we reuse the same TCPConnection.
// Close the previous response's body. But read at least some of
// the body so if it's small the underlying TCP connection will be
// re-used. No need to check for errors: if it fails, the Transport
// won't reuse it anyway.
const maxBodySlurpSize = 2 << 10
if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize {
io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize)
}

if rerr := resp.Body.Close(); err == nil {
err = rerr
}
}()
defer drainBody(resp)

response := newResponse(resp)
c.ratemtx.Lock()
Expand Down Expand Up @@ -524,6 +522,21 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res
return response, err
}

// Ensure the response body is fully read and closed
// before we reconnect, so that we reuse the same TCPConnection.
// Close the previous response's body. But read at least some of
// the body so if it's small the underlying TCP connection will be
// re-used. No need to check for errors: if it fails, the Transport
// won't reuse it anyway.
func drainBody(resp *http.Response) {
const maxBodySlurpSize = 2 << 10
if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize {
io.CopyN(io.Discard, resp.Body, maxBodySlurpSize)
}

resp.Body.Close()
}

// DoRequest submits an HTTP request.
func DoRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
return DoRequestWithClient(ctx, http.DefaultClient, req)
Expand All @@ -539,12 +552,17 @@ func DoRequestWithClient(
}

func (r *ErrorResponse) Error() string {
var attempted string
if r.Attempts > 0 {
attempted = fmt.Sprintf("; giving up after %d attempt(s)", r.Attempts)
}

if r.RequestID != "" {
return fmt.Sprintf("%v %v: %d (request %q) %v",
r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.RequestID, r.Message)
return fmt.Sprintf("%v %v: %d (request %q) %v%s",
r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.RequestID, r.Message, attempted)
}
return fmt.Sprintf("%v %v: %d %v",
r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.Message)
return fmt.Sprintf("%v %v: %d %v%s",
r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.Message, attempted)
}

// CheckResponse checks the API response for errors, and returns them if present. A response is considered an
Expand All @@ -557,7 +575,7 @@ func CheckResponse(r *http.Response) error {
}

errorResponse := &ErrorResponse{Response: r}
data, err := ioutil.ReadAll(r.Body)
data, err := io.ReadAll(r.Body)
if err == nil && len(data) > 0 {
err := json.Unmarshal(data, errorResponse)
if err != nil {
Expand All @@ -572,6 +590,26 @@ func CheckResponse(r *http.Response) error {
return errorResponse
}

// retryableErrorHandler implements a retryablehttp.ErrorHandler to provide
// errors that are consistent with a godo.ErrorResponse.
func retryableErrorHandler(resp *http.Response, err error, numTries int) (*http.Response, error) {
// When a custom retryablehttp.ErrorHandler is provided, it is the responsibility
// of the handler to close the response body.
defer drainBody(resp)

if err == nil {
err = CheckResponse(resp)
if _, ok := err.(*ErrorResponse); ok {
err.(*ErrorResponse).Attempts = numTries
}

return nil, err
}

return nil, fmt.Errorf("%s %s giving up after %d attempt(s): %w",
resp.Request.Method, resp.Request.URL, numTries, err)
}

func (r Rate) String() string {
return Stringify(r)
}
Expand Down
132 changes: 123 additions & 9 deletions godo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
Expand Down Expand Up @@ -387,6 +388,7 @@ func TestDo_redirectLoop(t *testing.T) {
if err == nil {
t.Error("Expected error to be returned.")
}

if err, ok := err.(*url.Error); !ok {
t.Errorf("Expected a URL error; got %#v.", err)
}
Expand All @@ -406,7 +408,7 @@ func TestCheckResponse(t *testing.T) {
input: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: ioutil.NopCloser(strings.NewReader(`{"message":"m",
Body: io.NopCloser(strings.NewReader(`{"message":"m",
"errors": [{"resource": "r", "field": "f", "code": "c"}]}`)),
},
expected: &ErrorResponse{
Expand All @@ -418,7 +420,7 @@ func TestCheckResponse(t *testing.T) {
input: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: ioutil.NopCloser(strings.NewReader(`{"message":"m", "request_id": "dead-beef",
Body: io.NopCloser(strings.NewReader(`{"message":"m", "request_id": "dead-beef",
"errors": [{"resource": "r", "field": "f", "code": "c"}]}`)),
},
expected: &ErrorResponse{
Expand All @@ -432,7 +434,7 @@ func TestCheckResponse(t *testing.T) {
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Header: testHeaders,
Body: ioutil.NopCloser(strings.NewReader(`{"message":"m",
Body: io.NopCloser(strings.NewReader(`{"message":"m",
"errors": [{"resource": "r", "field": "f", "code": "c"}]}`)),
},
expected: &ErrorResponse{
Expand All @@ -448,7 +450,7 @@ func TestCheckResponse(t *testing.T) {
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Header: testHeaders,
Body: ioutil.NopCloser(strings.NewReader(`{"message":"m", "request_id": "dead-beef-body",
Body: io.NopCloser(strings.NewReader(`{"message":"m", "request_id": "dead-beef-body",
"errors": [{"resource": "r", "field": "f", "code": "c"}]}`)),
},
expected: &ErrorResponse{
Expand All @@ -463,7 +465,7 @@ func TestCheckResponse(t *testing.T) {
input: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: ioutil.NopCloser(strings.NewReader("")),
Body: io.NopCloser(strings.NewReader("")),
},
expected: &ErrorResponse{},
},
Expand All @@ -484,6 +486,118 @@ func TestCheckResponse(t *testing.T) {
}
}

func TestRetryableErrorHandler(t *testing.T) {
testHeaders := make(http.Header, 1)
testHeaders.Set("x-request-id", "dead-beef")

tests := []struct {
title string
input *http.Response
count int
expected *ErrorResponse
}{
{
title: "default (no request_id)",
input: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(
strings.NewReader(`{"id": "bad_request", "message":"broken"}`)),
},
expected: &ErrorResponse{
Message: "broken",
},
},
{
title: "request_id in body",
input: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(
strings.NewReader(`{"id": "bad_request", "message":"broken", "request_id": "dead-beef"}`)),
},
expected: &ErrorResponse{
Message: "broken",
RequestID: "dead-beef",
},
},
{
title: "request_id in header",
input: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Header: testHeaders,
Body: io.NopCloser(
strings.NewReader(`{"id": "bad_request", "message":"broken"}`)),
},
expected: &ErrorResponse{
Message: "broken",
RequestID: "dead-beef",
},
},
// This tests that the ID in the body takes precedence to ensure we maintain the current
// behavior. In practice, the IDs in the header and body should always be the same.
{
title: "request_id in both",
input: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Header: testHeaders,
Body: io.NopCloser(
strings.NewReader(`{"id": "bad_request", "message":"broken", "request_id": "dead-beef-body"}`)),
},
expected: &ErrorResponse{
Message: "broken",
RequestID: "dead-beef-body",
},
},
// ensure that we properly handle API errors that do not contain a
// response body
{
title: "no body",
input: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(strings.NewReader("")),
},
expected: &ErrorResponse{},
},
{
title: "with retries",
input: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(
strings.NewReader(`{"id": "bad_request", "message":"broken", "request_id": "dead-beef"}`)),
},
count: 5,
expected: &ErrorResponse{
Message: "broken",
RequestID: "dead-beef",
Attempts: 5,
},
},
}

for _, tt := range tests {
t.Run(tt.title, func(t *testing.T) {
_, err := retryableErrorHandler(tt.input, nil, tt.count)
if err == nil {
t.Fatalf("Expected error response.")
}

if _, ok := err.(*ErrorResponse); !ok {
t.Fatalf("Expected a godo.ErrorResponse error response; go: %s", reflect.TypeOf(err))
}

tt.expected.Response = tt.input
if !reflect.DeepEqual(err, tt.expected) {
t.Errorf("Error = %#v, expected %#v", err, tt.expected)
}
})
}
}

func TestErrorResponse_Error(t *testing.T) {
res := &http.Response{Request: &http.Request{}}
err := ErrorResponse{Message: "m", Response: res}
Expand Down Expand Up @@ -615,6 +729,7 @@ func TestWithRetryAndBackoffs(t *testing.T) {
url, _ := url.Parse(server.URL)
mux.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
w.Write([]byte(`{"id": "bad_request", "message": "broken"}`))
})

tokenSrc := oauth2.StaticTokenSource(&oauth2.Token{
Expand Down Expand Up @@ -645,11 +760,10 @@ func TestWithRetryAndBackoffs(t *testing.T) {
t.Fatalf("err: %v", err)
}

expectingErr := "giving up after 4 attempt(s)"
// Send the request.
expectingErr := fmt.Sprintf("GET %s/foo: 500 broken; giving up after 4 attempt(s)", url)
_, err = client.Do(context.Background(), req, nil)
if err == nil || !strings.HasSuffix(err.Error(), expectingErr) {
t.Fatalf("expected giving up error, got: %#v", err)
if err == nil || (err.Error() != expectingErr) {
t.Fatalf("expected giving up error, got: %s", err.Error())
}

}
Expand Down
Loading