diff --git a/client.go b/client.go index f40d241..ae3e9d0 100644 --- a/client.go +++ b/client.go @@ -368,11 +368,24 @@ type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) t // attempted. If overriding this, be sure to close the body if needed. type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Response, error) +type HTTPClient interface { + // Do performs an HTTP request and returns an HTTP response. + Do(*http.Request) (*http.Response, error) + // Done is called when the client is no longer needed. + Done() +} + +type HTTPClientFactory interface { + // New returns an HTTP client to use for a request, including retries. + New() HTTPClient +} + // Client is used to make HTTP requests. It adds additional functionality // like automatic retries to tolerate minor outages. type Client struct { - HTTPClient *http.Client // Internal HTTP client. - Logger interface{} // Customer logger instance. Can be either Logger or LeveledLogger + HTTPClient *http.Client // Internal HTTP client. This field is used if set, otherwise HTTPClientFactory is used. + HTTPClientFactory HTTPClientFactory + Logger interface{} // Customer logger instance. Can be either Logger or LeveledLogger RetryWaitMin time.Duration // Minimum time to wait RetryWaitMax time.Duration // Maximum time to wait @@ -397,19 +410,18 @@ type Client struct { ErrorHandler ErrorHandler loggerInit sync.Once - clientInit sync.Once } // NewClient creates a new Client with default settings. func NewClient() *Client { return &Client{ - HTTPClient: cleanhttp.DefaultPooledClient(), - Logger: defaultLogger, - RetryWaitMin: defaultRetryWaitMin, - RetryWaitMax: defaultRetryWaitMax, - RetryMax: defaultRetryMax, - CheckRetry: DefaultRetryPolicy, - Backoff: DefaultBackoff, + HTTPClientFactory: &CleanPooledClientFactory{}, + Logger: defaultLogger, + RetryWaitMin: defaultRetryWaitMin, + RetryWaitMax: defaultRetryWaitMax, + RetryMax: defaultRetryMax, + CheckRetry: DefaultRetryPolicy, + Backoff: DefaultBackoff, } } @@ -573,12 +585,6 @@ func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Respo // Do wraps calling an HTTP method with retries. func (c *Client) Do(req *Request) (*http.Response, error) { - c.clientInit.Do(func() { - if c.HTTPClient == nil { - c.HTTPClient = cleanhttp.DefaultPooledClient() - } - }) - logger := c.logger() if logger != nil { @@ -590,6 +596,9 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } } + httpClient := c.getHTTPClient() + defer httpClient.Done() + var resp *http.Response var attempt int var shouldRetry bool @@ -603,7 +612,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) { if req.body != nil { body, err := req.body() if err != nil { - c.HTTPClient.CloseIdleConnections() return resp, err } if c, ok := body.(io.ReadCloser); ok { @@ -625,7 +633,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } // Attempt the request - resp, doErr = c.HTTPClient.Do(req.Request) + + resp, doErr = httpClient.Do(req.Request) // Check if we should continue with retries. shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) @@ -694,7 +703,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) { select { case <-req.Context().Done(): timer.Stop() - c.HTTPClient.CloseIdleConnections() return nil, req.Context().Err() case <-timer.C: } @@ -710,8 +718,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) { return resp, nil } - defer c.HTTPClient.CloseIdleConnections() - var err error if checkErr != nil { err = checkErr @@ -758,6 +764,19 @@ func (c *Client) drainBody(body io.ReadCloser) { } } +func (c *Client) getHTTPClient() HTTPClient { + if c.HTTPClient != nil { + return &idleConnectionsClosingClient{ + httpClient: c.HTTPClient, + } + } + clientFactory := c.HTTPClientFactory + if clientFactory == nil { + clientFactory = &CleanPooledClientFactory{} + } + return clientFactory.New() +} + // Get is a shortcut for doing a GET request without making a new client. func Get(url string) (*http.Response, error) { return defaultClient.Get(url) @@ -820,3 +839,29 @@ func (c *Client) StandardClient() *http.Client { Transport: &RoundTripper{Client: c}, } } + +var ( + _ HTTPClientFactory = &CleanPooledClientFactory{} + _ HTTPClient = &idleConnectionsClosingClient{} +) + +type CleanPooledClientFactory struct { +} + +func (f *CleanPooledClientFactory) New() HTTPClient { + return &idleConnectionsClosingClient{ + httpClient: cleanhttp.DefaultPooledClient(), + } +} + +type idleConnectionsClosingClient struct { + httpClient *http.Client +} + +func (c *idleConnectionsClosingClient) Do(req *http.Request) (*http.Response, error) { + return c.httpClient.Do(req) +} + +func (c *idleConnectionsClosingClient) Done() { + c.httpClient.CloseIdleConnections() +}