@@ -6,6 +6,8 @@ package retryablehttp
6
6
import (
7
7
"bytes"
8
8
"context"
9
+ "crypto/sha256"
10
+ "encoding/base64"
9
11
"errors"
10
12
"fmt"
11
13
"io"
@@ -372,6 +374,21 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) {
372
374
client .PrepareRetry = func (req * http.Request ) error {
373
375
prepareChecks ++
374
376
req .Header .Set ("foo" , strconv .Itoa (prepareChecks ))
377
+
378
+ // if the method is POST or PUT, set a header based on request body content
379
+ if req .Method == "POST" || req .Method == "PUT" {
380
+ bodyBytes , err := io .ReadAll (req .Body )
381
+ if err != nil {
382
+ t .Fatalf ("could not read request body: %s" , err )
383
+ }
384
+ preparedBody := string (bodyBytes )
385
+
386
+ if len (preparedBody ) > 0 {
387
+ sum := sha256 .Sum256 ([]byte (preparedBody ))
388
+ contentHash := base64 .StdEncoding .EncodeToString (sum [:])
389
+ req .Header .Set ("content_hash" , contentHash )
390
+ }
391
+ }
375
392
return nil
376
393
}
377
394
@@ -384,43 +401,50 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) {
384
401
var shouldSucceed bool
385
402
tests := []struct {
386
403
name string
404
+ method string
405
+ requestBody string
387
406
handler ResponseHandlerFunc
388
407
expectedChecks int // often 2x number of attempts since we check twice
389
408
expectedPrepareChecks int
390
409
err string
391
410
}{
392
411
{
393
412
name : "nil handler" ,
413
+ method : http .MethodGet ,
394
414
handler : nil ,
395
415
expectedChecks : 1 ,
396
416
expectedPrepareChecks : 0 ,
397
417
},
398
418
{
399
- name : "handler always succeeds" ,
419
+ name : "handler always succeeds" ,
420
+ method : http .MethodGet ,
400
421
handler : func (* http.Response ) error {
401
422
return nil
402
423
},
403
424
expectedChecks : 2 ,
404
425
expectedPrepareChecks : 0 ,
405
426
},
406
427
{
407
- name : "handler always fails in a retryable way" ,
428
+ name : "handler always fails in a retryable way" ,
429
+ method : http .MethodGet ,
408
430
handler : func (* http.Response ) error {
409
431
return errors .New ("retryable failure" )
410
432
},
411
433
expectedChecks : 6 ,
412
434
expectedPrepareChecks : 2 ,
413
435
},
414
436
{
415
- name : "handler always fails in a nonretryable way" ,
437
+ name : "handler always fails in a nonretryable way" ,
438
+ method : http .MethodGet ,
416
439
handler : func (* http.Response ) error {
417
440
return errors .New ("nonretryable failure" )
418
441
},
419
442
expectedChecks : 2 ,
420
443
expectedPrepareChecks : 0 ,
421
444
},
422
445
{
423
- name : "handler succeeds on second attempt" ,
446
+ name : "handler succeeds on second attempt" ,
447
+ method : http .MethodGet ,
424
448
handler : func (* http.Response ) error {
425
449
if shouldSucceed {
426
450
return nil
@@ -431,15 +455,51 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) {
431
455
expectedChecks : 4 ,
432
456
expectedPrepareChecks : 1 ,
433
457
},
458
+ {
459
+ name : "POST - handler succeeds on second attempt, using body for PrepareRetry" ,
460
+ method : http .MethodPost ,
461
+ requestBody : "dummy data" ,
462
+ handler : func (response * http.Response ) error {
463
+ if shouldSucceed {
464
+ return nil
465
+ }
466
+ shouldSucceed = true
467
+ return errors .New ("retryable failure" )
468
+ },
469
+ expectedChecks : 4 ,
470
+ expectedPrepareChecks : 1 ,
471
+ },
472
+ {
473
+ name : "PUT - handler succeeds on second attempt, using body for PrepareRetry" ,
474
+ method : http .MethodPut ,
475
+ requestBody : "dummy data" ,
476
+ handler : func (response * http.Response ) error {
477
+ if shouldSucceed {
478
+ return nil
479
+ }
480
+ shouldSucceed = true
481
+ return errors .New ("retryable failure" )
482
+ },
483
+ expectedChecks : 4 ,
484
+ expectedPrepareChecks : 1 ,
485
+ },
434
486
}
435
487
436
488
for _ , tt := range tests {
437
489
t .Run (tt .name , func (t * testing.T ) {
438
490
checks = 0
439
491
prepareChecks = 0
440
492
shouldSucceed = false
493
+ var req * Request
494
+ var err error
495
+
441
496
// Create the request
442
- req , err := NewRequest ("GET" , ts .URL , nil )
497
+ if tt .requestBody != "" {
498
+ req , err = NewRequest (tt .method , ts .URL , strings .NewReader (tt .requestBody ))
499
+ } else {
500
+ req , err = NewRequest (tt .method , ts .URL , nil )
501
+ }
502
+
443
503
if err != nil {
444
504
t .Fatalf ("err: %v" , err )
445
505
}
@@ -470,6 +530,12 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) {
470
530
t .Fatalf ("expected changes in request header 'foo' '%s', but got '%s'" , expectedHeader , header )
471
531
}
472
532
533
+ if tt .method == "POST" || tt .method == "PUT" {
534
+ headerFromContent := req .Request .Header .Get ("content_hash" )
535
+ if headerFromContent == "" {
536
+ t .Fatalf ("expected 'content_hash' header to exist, but it does not" )
537
+ }
538
+ }
473
539
})
474
540
}
475
541
}
0 commit comments