Skip to content

Commit 1643719

Browse files
authored
Merge pull request #210 from tomclegg/noretry-header-cert
Fix default retry policy for certificate verification errors and bad request headers
2 parents 4fb315e + eb08cce commit 1643719

File tree

4 files changed

+96
-5
lines changed

4 files changed

+96
-5
lines changed

cert_error_go119.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
//go:build !go1.20
5+
// +build !go1.20
6+
7+
package retryablehttp
8+
9+
import "crypto/x509"
10+
11+
func isCertError(err error) bool {
12+
_, ok := err.(x509.UnknownAuthorityError)
13+
return ok
14+
}

cert_error_go120.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
//go:build go1.20
5+
// +build go1.20
6+
7+
package retryablehttp
8+
9+
import "crypto/tls"
10+
11+
func isCertError(err error) bool {
12+
_, ok := err.(*tls.CertificateVerificationError)
13+
return ok
14+
}

client.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ package retryablehttp
2727
import (
2828
"bytes"
2929
"context"
30-
"crypto/x509"
3130
"fmt"
3231
"io"
3332
"log"
@@ -62,6 +61,10 @@ var (
6261
// limit the size we consume to respReadLimit.
6362
respReadLimit = int64(4096)
6463

64+
// timeNow sets the function that returns the current time.
65+
// This defaults to time.Now. Changes to this should only be done in tests.
66+
timeNow = time.Now
67+
6568
// A regular expression to match the error returned by net/http when the
6669
// configured number of redirects is exhausted. This error isn't typed
6770
// specifically so we resort to matching on the error string.
@@ -72,9 +75,10 @@ var (
7275
// specifically so we resort to matching on the error string.
7376
schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`)
7477

75-
// timeNow sets the function that returns the current time.
76-
// This defaults to time.Now. Changes to this should only be done in tests.
77-
timeNow = time.Now
78+
// A regular expression to match the error returned by net/http when a
79+
// request header or value is invalid. This error isn't typed
80+
// specifically so we resort to matching on the error string.
81+
invalidHeaderErrorRe = regexp.MustCompile(`invalid header`)
7882

7983
// A regular expression to match the error returned by net/http when the
8084
// TLS certificate is not trusted. This error isn't typed
@@ -501,11 +505,16 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) {
501505
return false, v
502506
}
503507

508+
// Don't retry if the error was due to an invalid header.
509+
if invalidHeaderErrorRe.MatchString(v.Error()) {
510+
return false, v
511+
}
512+
504513
// Don't retry if the error was due to TLS cert verification failure.
505514
if notTrustedErrorRe.MatchString(v.Error()) {
506515
return false, v
507516
}
508-
if _, ok := v.Err.(x509.UnknownAuthorityError); ok {
517+
if isCertError(v.Err) {
509518
return false, v
510519
}
511520
}

client_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,60 @@ func TestClient_DefaultRetryPolicy_invalidscheme(t *testing.T) {
935935
}
936936
}
937937

938+
func TestClient_DefaultRetryPolicy_invalidheadername(t *testing.T) {
939+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
940+
w.WriteHeader(200)
941+
}))
942+
defer ts.Close()
943+
944+
attempts := 0
945+
client := NewClient()
946+
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
947+
attempts++
948+
return DefaultRetryPolicy(context.TODO(), resp, err)
949+
}
950+
951+
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
952+
if err != nil {
953+
t.Fatalf("err: %v", err)
954+
}
955+
req.Header.Set("Header-Name-\033", "header value")
956+
_, err = client.StandardClient().Do(req)
957+
if err == nil {
958+
t.Fatalf("expected header error, got nil")
959+
}
960+
if attempts != 1 {
961+
t.Fatalf("expected 1 attempt, got %d", attempts)
962+
}
963+
}
964+
965+
func TestClient_DefaultRetryPolicy_invalidheadervalue(t *testing.T) {
966+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
967+
w.WriteHeader(200)
968+
}))
969+
defer ts.Close()
970+
971+
attempts := 0
972+
client := NewClient()
973+
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
974+
attempts++
975+
return DefaultRetryPolicy(context.TODO(), resp, err)
976+
}
977+
978+
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
979+
if err != nil {
980+
t.Fatalf("err: %v", err)
981+
}
982+
req.Header.Set("Header-Name", "bad header value \033")
983+
_, err = client.StandardClient().Do(req)
984+
if err == nil {
985+
t.Fatalf("expected header value error, got nil")
986+
}
987+
if attempts != 1 {
988+
t.Fatalf("expected 1 attempt, got %d", attempts)
989+
}
990+
}
991+
938992
func TestClient_CheckRetryStop(t *testing.T) {
939993
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
940994
http.Error(w, "test_500_body", http.StatusInternalServerError)

0 commit comments

Comments
 (0)