Skip to content

Commit bbb7c69

Browse files
Fix preparing of POST/PUT requests not taking into account request body
1 parent 40b0cad commit bbb7c69

File tree

2 files changed

+79
-11
lines changed

2 files changed

+79
-11
lines changed

client.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,14 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
687687
}
688688
}
689689

690+
// First attempt was already signed
691+
if attempt > 1 && c.PrepareRetry != nil {
692+
if err := c.PrepareRetry(req.Request); err != nil {
693+
prepareErr = err
694+
break
695+
}
696+
}
697+
690698
if c.RequestLogHook != nil {
691699
switch v := logger.(type) {
692700
case LeveledLogger:
@@ -778,12 +786,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
778786
httpreq := *req.Request
779787
req.Request = &httpreq
780788

781-
if c.PrepareRetry != nil {
782-
if err := c.PrepareRetry(req.Request); err != nil {
783-
prepareErr = err
784-
break
785-
}
786-
}
787789
}
788790

789791
// this is the closest we have to success criteria

client_test.go

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ package retryablehttp
66
import (
77
"bytes"
88
"context"
9+
"crypto/sha256"
10+
"encoding/base64"
911
"errors"
1012
"fmt"
1113
"io"
@@ -372,6 +374,21 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) {
372374
client.PrepareRetry = func(req *http.Request) error {
373375
prepareChecks++
374376
req.Header.Set("foo", strconv.Itoa(prepareChecks))
377+
378+
// if the method is POST or PUT, set a header based on request body content
379+
if req.Method == "POST" || req.Method == "PUT" {
380+
bodyBytes, err := io.ReadAll(req.Body)
381+
if err != nil {
382+
t.Fatalf("could not read request body: %s", err)
383+
}
384+
preparedBody := string(bodyBytes)
385+
386+
if len(preparedBody) > 0 {
387+
sum := sha256.Sum256([]byte(preparedBody))
388+
contentHash := base64.StdEncoding.EncodeToString(sum[:])
389+
req.Header.Set("content_hash", contentHash)
390+
}
391+
}
375392
return nil
376393
}
377394

@@ -384,43 +401,50 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) {
384401
var shouldSucceed bool
385402
tests := []struct {
386403
name string
404+
method string
405+
requestBody string
387406
handler ResponseHandlerFunc
388407
expectedChecks int // often 2x number of attempts since we check twice
389408
expectedPrepareChecks int
390409
err string
391410
}{
392411
{
393412
name: "nil handler",
413+
method: http.MethodGet,
394414
handler: nil,
395415
expectedChecks: 1,
396416
expectedPrepareChecks: 0,
397417
},
398418
{
399-
name: "handler always succeeds",
419+
name: "handler always succeeds",
420+
method: http.MethodGet,
400421
handler: func(*http.Response) error {
401422
return nil
402423
},
403424
expectedChecks: 2,
404425
expectedPrepareChecks: 0,
405426
},
406427
{
407-
name: "handler always fails in a retryable way",
428+
name: "handler always fails in a retryable way",
429+
method: http.MethodGet,
408430
handler: func(*http.Response) error {
409431
return errors.New("retryable failure")
410432
},
411433
expectedChecks: 6,
412434
expectedPrepareChecks: 2,
413435
},
414436
{
415-
name: "handler always fails in a nonretryable way",
437+
name: "handler always fails in a nonretryable way",
438+
method: http.MethodGet,
416439
handler: func(*http.Response) error {
417440
return errors.New("nonretryable failure")
418441
},
419442
expectedChecks: 2,
420443
expectedPrepareChecks: 0,
421444
},
422445
{
423-
name: "handler succeeds on second attempt",
446+
name: "handler succeeds on second attempt",
447+
method: http.MethodGet,
424448
handler: func(*http.Response) error {
425449
if shouldSucceed {
426450
return nil
@@ -431,15 +455,51 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) {
431455
expectedChecks: 4,
432456
expectedPrepareChecks: 1,
433457
},
458+
{
459+
name: "POST - handler succeeds on second attempt, using body for PrepareRetry",
460+
method: http.MethodPost,
461+
requestBody: "dummy data",
462+
handler: func(response *http.Response) error {
463+
if shouldSucceed {
464+
return nil
465+
}
466+
shouldSucceed = true
467+
return errors.New("retryable failure")
468+
},
469+
expectedChecks: 4,
470+
expectedPrepareChecks: 1,
471+
},
472+
{
473+
name: "PUT - handler succeeds on second attempt, using body for PrepareRetry",
474+
method: http.MethodPut,
475+
requestBody: "dummy data",
476+
handler: func(response *http.Response) error {
477+
if shouldSucceed {
478+
return nil
479+
}
480+
shouldSucceed = true
481+
return errors.New("retryable failure")
482+
},
483+
expectedChecks: 4,
484+
expectedPrepareChecks: 1,
485+
},
434486
}
435487

436488
for _, tt := range tests {
437489
t.Run(tt.name, func(t *testing.T) {
438490
checks = 0
439491
prepareChecks = 0
440492
shouldSucceed = false
493+
var req *Request
494+
var err error
495+
441496
// Create the request
442-
req, err := NewRequest("GET", ts.URL, nil)
497+
if tt.requestBody != "" {
498+
req, err = NewRequest(tt.method, ts.URL, strings.NewReader(tt.requestBody))
499+
} else {
500+
req, err = NewRequest(tt.method, ts.URL, nil)
501+
}
502+
443503
if err != nil {
444504
t.Fatalf("err: %v", err)
445505
}
@@ -470,6 +530,12 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) {
470530
t.Fatalf("expected changes in request header 'foo' '%s', but got '%s'", expectedHeader, header)
471531
}
472532

533+
if tt.method == "POST" || tt.method == "PUT" {
534+
headerFromContent := req.Request.Header.Get("content_hash")
535+
if headerFromContent == "" {
536+
t.Fatalf("expected 'content_hash' header to exist, but it does not")
537+
}
538+
}
473539
})
474540
}
475541
}

0 commit comments

Comments
 (0)