-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: implement retry module Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> * feat: return body so it can be used Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> * fix: remove withContext functional creator Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> * feat: add retry strategy to embedding request Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> * fix: return http response instead of the body Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> * fix: use http response instead of the body Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> * feat: use retry strategy in the text and text stream requests Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> * fix: avoid null pointer exception Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> * fix: error check when create new request in embeddings Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> * feat: add http client wrapper with retry and include DoWithRetry method in Doer interface Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> * feat: use the http wrapper in the client constructor method Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> * feat: make requests with DoWithRetry Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> * fix: set 1 second default value for maxJitter Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com> --------- Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>
- Loading branch information
1 parent
36dd5c2
commit ab5e97d
Showing
6 changed files
with
284 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
package test | ||
|
||
import ( | ||
"encoding/json" | ||
"log" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
"time" | ||
|
||
wx "github.com/IBM/watsonx-go/pkg/models" | ||
) | ||
|
||
// TestRetryWithSuccessOnFirstRequest tests the retry mechanism with a server that always returns a 200 status code. | ||
func TestRetryWithSuccessOnFirstRequest(t *testing.T) { | ||
type ResponseType struct { | ||
Content string `json:"content"` | ||
Status int `json:"status"` | ||
} | ||
|
||
expectedResponse := ResponseType{Content: "success"} | ||
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
w.Write([]byte(`{"content":"success"}`)) | ||
})) | ||
defer server.Close() | ||
|
||
var retryCount uint = 0 | ||
var expectedRetries uint = 0 | ||
|
||
sendRequest := func() (*http.Response, error) { | ||
return http.Get(server.URL + "/success") | ||
} | ||
|
||
resp, err := wx.Retry( | ||
sendRequest, | ||
wx.WithOnRetry(func(n uint, err error) { | ||
retryCount = n | ||
log.Printf("Retrying request after error: %v", err) | ||
}), | ||
) | ||
|
||
if err != nil { | ||
t.Errorf("Expected nil, got error: %v", err) | ||
} | ||
|
||
if retryCount != expectedRetries { | ||
t.Errorf("Expected 0 retries, but got %d", retryCount) | ||
} | ||
|
||
defer resp.Body.Close() | ||
var response ResponseType | ||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { | ||
t.Errorf("Failed to unmarshal response body: %v", err) | ||
} | ||
|
||
if response != expectedResponse { | ||
t.Errorf("Expected response %v, but got %v", expectedResponse, response) | ||
} | ||
} | ||
|
||
// TestRetryWithNoSuccessStatusOnAnyRequest tests the retry mechanism with a server that always returns a 429 status code. | ||
func TestRetryWithNoSuccessStatusOnAnyRequest(t *testing.T) { | ||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusTooManyRequests) | ||
})) | ||
defer server.Close() | ||
|
||
var backoffTime = 2 * time.Second | ||
var retryCount uint = 0 | ||
var expectedRetries uint = 3 | ||
|
||
sendRequest := func() (*http.Response, error) { | ||
return http.Get(server.URL + "/notfound") | ||
} | ||
|
||
startTime := time.Now() | ||
|
||
resp, err := wx.Retry( | ||
sendRequest, | ||
wx.WithBackoff(backoffTime), | ||
wx.WithOnRetry(func(n uint, err error) { | ||
retryCount = n | ||
log.Printf("Retrying request after error: %v", err) | ||
}), | ||
) | ||
|
||
endTime := time.Now() | ||
|
||
elapsedTime := endTime.Sub(startTime) | ||
expectedMinimumTime := backoffTime * time.Duration(expectedRetries) | ||
|
||
if err == nil { | ||
t.Errorf("Expected error, got nil") | ||
} | ||
|
||
if resp != nil { | ||
defer resp.Body.Close() | ||
t.Errorf("Expected nil response, got %v", resp.Body) | ||
} | ||
|
||
if retryCount != expectedRetries { | ||
t.Errorf("Expected 3 retries, but got %d", retryCount) | ||
} | ||
|
||
if elapsedTime < expectedMinimumTime { | ||
t.Errorf("Expected minimum time of %v, but got %v", expectedMinimumTime, elapsedTime) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
package models | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"math/rand" | ||
"net/http" | ||
"time" | ||
) | ||
|
||
// OnRetryFunc is a function type that is called on each retry attempt. | ||
type OnRetryFunc func(attempt uint, err error) | ||
|
||
// Timer interface to abstract time-based operations for retries. | ||
type Timer interface { | ||
After(time.Duration) <-chan time.Time | ||
} | ||
|
||
// RetryIfFunc determines whether a retry should be attempted based on the error. | ||
type RetryIfFunc func(error) bool | ||
|
||
// RetryConfig contains configuration options for the retry mechanism. | ||
type RetryConfig struct { | ||
retries uint | ||
backoff time.Duration | ||
maxJitter time.Duration | ||
onRetry OnRetryFunc | ||
retryIf RetryIfFunc | ||
timer Timer | ||
context context.Context | ||
} | ||
|
||
// RetryOption is a function type for modifying RetryConfig options. | ||
type RetryOption func(*RetryConfig) | ||
|
||
// timerImpl implements the Timer interface using time.After. | ||
type timerImpl struct{} | ||
|
||
func (t timerImpl) After(d time.Duration) <-chan time.Time { | ||
return time.After(d) | ||
} | ||
|
||
// newDefaultRetryConfig creates a default RetryConfig with sensible defaults. | ||
func newDefaultRetryConfig() *RetryConfig { | ||
return &RetryConfig{ | ||
retries: 3, | ||
backoff: 1 * time.Second, | ||
maxJitter: 1 * time.Second, | ||
onRetry: func(n uint, err error) {}, // no-op onRetry by default | ||
retryIf: func(err error) bool { return err != nil }, // retry on any error by default | ||
timer: &timerImpl{}, | ||
context: context.Background(), | ||
} | ||
} | ||
|
||
// RetryableFuncWithResponse represents a function that returns an HTTP response or an error. | ||
type RetryableFuncWithResponse func() (*http.Response, error) | ||
|
||
// Retry retries the provided retryableFunc according to the retry configuration options. | ||
func Retry(retryableFunc RetryableFuncWithResponse, options ...RetryOption) (*http.Response, error) { | ||
opts := newDefaultRetryConfig() | ||
|
||
for _, opt := range options { | ||
if opt != nil { | ||
opt(opts) | ||
} | ||
} | ||
|
||
var lastErr error | ||
for n := uint(0); n < opts.retries; n++ { | ||
if err := opts.context.Err(); err != nil { | ||
return nil, err | ||
} | ||
|
||
resp, err := retryableFunc() | ||
if err == nil && resp != nil && resp.StatusCode == http.StatusOK { | ||
return resp, nil | ||
} | ||
|
||
if err == nil && resp != nil { | ||
err = errors.New(resp.Status) | ||
} | ||
|
||
if !opts.retryIf(err) { | ||
return nil, err | ||
} | ||
|
||
lastErr = err | ||
opts.onRetry(n+1, err) | ||
|
||
backoffDuration := opts.backoff | ||
if opts.maxJitter > 0 { | ||
jitter := time.Duration(rand.Int63n(int64(opts.maxJitter))) | ||
backoffDuration += jitter | ||
} | ||
|
||
select { | ||
case <-opts.timer.After(backoffDuration): | ||
case <-opts.context.Done(): | ||
return nil, opts.context.Err() | ||
} | ||
} | ||
|
||
return nil, lastErr | ||
} | ||
|
||
// WithRetries sets the number of retries for the retry configuration. | ||
func WithRetries(retries uint) RetryOption { | ||
return func(cfg *RetryConfig) { | ||
cfg.retries = retries | ||
} | ||
} | ||
|
||
// WithBackoff sets the backoff duration between retries. | ||
func WithBackoff(backoff time.Duration) RetryOption { | ||
return func(cfg *RetryConfig) { | ||
cfg.backoff = backoff | ||
} | ||
} | ||
|
||
// WithMaxJitter sets the maximum jitter duration to add to the backoff. | ||
func WithMaxJitter(maxJitter time.Duration) RetryOption { | ||
return func(cfg *RetryConfig) { | ||
cfg.maxJitter = maxJitter | ||
} | ||
} | ||
|
||
// WithOnRetry sets the callback function to execute on each retry. | ||
func WithOnRetry(onRetry OnRetryFunc) RetryOption { | ||
return func(cfg *RetryConfig) { | ||
cfg.onRetry = onRetry | ||
} | ||
} | ||
|
||
// WithRetryIf sets the condition to determine whether to retry based on the error. | ||
func WithRetryIf(retryIf RetryIfFunc) RetryOption { | ||
return func(cfg *RetryConfig) { | ||
cfg.retryIf = retryIf | ||
} | ||
} | ||
|
||
// Custom wrapper for http.Client that implements the Doer interface. | ||
// - Do | ||
// - DoWithRetry | ||
type HttpClient struct { | ||
httpClient *http.Client | ||
} | ||
|
||
func NewHttpClient() *HttpClient { | ||
return &HttpClient{ | ||
httpClient: &http.Client{}, | ||
} | ||
} | ||
|
||
func (c *HttpClient) Do(req *http.Request) (*http.Response, error) { | ||
return c.httpClient.Do(req) | ||
} | ||
|
||
func (c *HttpClient) DoWithRetry(req *http.Request) (*http.Response, error) { | ||
return Retry( | ||
func() (*http.Response, error) { | ||
return c.httpClient.Do(req) | ||
}, | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters