diff --git a/pkg/internal/tests/models/retry_test.go b/pkg/internal/tests/models/retry_test.go new file mode 100644 index 0000000..6c13893 --- /dev/null +++ b/pkg/internal/tests/models/retry_test.go @@ -0,0 +1,110 @@ +package test + +import ( + "encoding/json" + "log" + "net/http" + "net/http/httptest" + "testing" + "time" + + wx "github.com/IBM/watsonx-go/pkg/models" +) + +// TestRetryWithSuccessOnFirstRequest tests the retry mechanism with a server that always returns a 200 status code. +func TestRetryWithSuccessOnFirstRequest(t *testing.T) { + type ResponseType struct { + Content string `json:"content"` + Status int `json:"status"` + } + + expectedResponse := ResponseType{Content: "success"} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"content":"success"}`)) + })) + defer server.Close() + + var retryCount uint = 0 + var expectedRetries uint = 0 + + sendRequest := func() (*http.Response, error) { + return http.Get(server.URL + "/success") + } + + resp, err := wx.Retry( + sendRequest, + wx.WithOnRetry(func(n uint, err error) { + retryCount = n + log.Printf("Retrying request after error: %v", err) + }), + ) + + if err != nil { + t.Errorf("Expected nil, got error: %v", err) + } + + if retryCount != expectedRetries { + t.Errorf("Expected 0 retries, but got %d", retryCount) + } + + defer resp.Body.Close() + var response ResponseType + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + t.Errorf("Failed to unmarshal response body: %v", err) + } + + if response != expectedResponse { + t.Errorf("Expected response %v, but got %v", expectedResponse, response) + } +} + +// TestRetryWithNoSuccessStatusOnAnyRequest tests the retry mechanism with a server that always returns a 429 status code. +func TestRetryWithNoSuccessStatusOnAnyRequest(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + })) + defer server.Close() + + var backoffTime = 2 * time.Second + var retryCount uint = 0 + var expectedRetries uint = 3 + + sendRequest := func() (*http.Response, error) { + return http.Get(server.URL + "/notfound") + } + + startTime := time.Now() + + resp, err := wx.Retry( + sendRequest, + wx.WithBackoff(backoffTime), + wx.WithOnRetry(func(n uint, err error) { + retryCount = n + log.Printf("Retrying request after error: %v", err) + }), + ) + + endTime := time.Now() + + elapsedTime := endTime.Sub(startTime) + expectedMinimumTime := backoffTime * time.Duration(expectedRetries) + + if err == nil { + t.Errorf("Expected error, got nil") + } + + if resp != nil { + defer resp.Body.Close() + t.Errorf("Expected nil response, got %v", resp.Body) + } + + if retryCount != expectedRetries { + t.Errorf("Expected 3 retries, but got %d", retryCount) + } + + if elapsedTime < expectedMinimumTime { + t.Errorf("Expected minimum time of %v, but got %v", expectedMinimumTime, elapsedTime) + } +} diff --git a/pkg/models/client.go b/pkg/models/client.go index 8dfa047..7104a04 100644 --- a/pkg/models/client.go +++ b/pkg/models/client.go @@ -3,7 +3,6 @@ package models import ( "errors" "fmt" - "net/http" "net/url" "os" ) @@ -62,7 +61,7 @@ func NewClient(options ...ClientOption) (*Client, error) { apiKey: opts.apiKey, projectID: opts.projectID, - httpClient: &http.Client{}, + httpClient: NewHttpClient(), } err := m.RefreshToken() diff --git a/pkg/models/embedding.go b/pkg/models/embedding.go index d1694f6..5436875 100644 --- a/pkg/models/embedding.go +++ b/pkg/models/embedding.go @@ -4,8 +4,6 @@ import ( "bytes" "encoding/json" "errors" - "fmt" - "io" "net/http" "time" ) @@ -85,23 +83,17 @@ func (m *Client) generateEmbeddingRequest(payload EmbeddingPayload) (embeddingRe } req, err := http.NewRequest(http.MethodPost, embeddingUrl, bytes.NewBuffer(payloadJSON)) + if err != nil { + return embeddingResponse{}, err + } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+m.token.value) - res, err := m.httpClient.Do(req) + res, err := m.httpClient.DoWithRetry(req) if err != nil { return embeddingResponse{}, err } - - statusCode := res.StatusCode - if statusCode != http.StatusOK { - body, err := io.ReadAll(res.Body) - if err != nil { - return embeddingResponse{}, fmt.Errorf("request failed with status code %d", statusCode) - } - return embeddingResponse{}, fmt.Errorf("request failed with status code %d and error %s", statusCode, body) - } defer res.Body.Close() var embeddingRes embeddingResponse diff --git a/pkg/models/generate.go b/pkg/models/generate.go index 55168e0..17c09f8 100644 --- a/pkg/models/generate.go +++ b/pkg/models/generate.go @@ -5,8 +5,6 @@ import ( "bytes" "encoding/json" "errors" - "fmt" - "io" "log" "net/http" "strings" @@ -105,20 +103,10 @@ func (m *Client) generateTextRequest(payload GenerateTextPayload) (generateTextR req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+m.token.value) - res, err := m.httpClient.Do(req) + res, err := m.httpClient.DoWithRetry(req) if err != nil { return generateTextResponse{}, err } - - statusCode := res.StatusCode - - if statusCode < 200 || statusCode >= 300 { - body, err := io.ReadAll(res.Body) - if err != nil { - return generateTextResponse{}, fmt.Errorf("request failed with status code %d", statusCode) - } - return generateTextResponse{}, fmt.Errorf("request failed with status code %d and error %s", statusCode, body) - } defer res.Body.Close() var generateRes generateTextResponse @@ -197,23 +185,13 @@ func (m *Client) generateTextStreamRequest(payload GenerateTextPayload) (<-chan req.Header.Set("Authorization", "Bearer "+m.token.value) req.Header.Set("Accept", "text/event-stream") - res, err := m.httpClient.Do(req) + res, err := m.httpClient.DoWithRetry(req) if err != nil { log.Println("error making request: ", err) return } - defer res.Body.Close() - - if res.StatusCode != http.StatusOK { - body, err := io.ReadAll(res.Body) - if err != nil { - log.Printf("request failed with status code %d", res.StatusCode) - } else { - log.Printf("request failed with status code %d and error %s", res.StatusCode, body) - } - return - } + defer res.Body.Close() scanner := bufio.NewScanner(res.Body) for scanner.Scan() { line := scanner.Text() diff --git a/pkg/models/retry.go b/pkg/models/retry.go new file mode 100644 index 0000000..7c1e276 --- /dev/null +++ b/pkg/models/retry.go @@ -0,0 +1,165 @@ +package models + +import ( + "context" + "errors" + "math/rand" + "net/http" + "time" +) + +// OnRetryFunc is a function type that is called on each retry attempt. +type OnRetryFunc func(attempt uint, err error) + +// Timer interface to abstract time-based operations for retries. +type Timer interface { + After(time.Duration) <-chan time.Time +} + +// RetryIfFunc determines whether a retry should be attempted based on the error. +type RetryIfFunc func(error) bool + +// RetryConfig contains configuration options for the retry mechanism. +type RetryConfig struct { + retries uint + backoff time.Duration + maxJitter time.Duration + onRetry OnRetryFunc + retryIf RetryIfFunc + timer Timer + context context.Context +} + +// RetryOption is a function type for modifying RetryConfig options. +type RetryOption func(*RetryConfig) + +// timerImpl implements the Timer interface using time.After. +type timerImpl struct{} + +func (t timerImpl) After(d time.Duration) <-chan time.Time { + return time.After(d) +} + +// newDefaultRetryConfig creates a default RetryConfig with sensible defaults. +func newDefaultRetryConfig() *RetryConfig { + return &RetryConfig{ + retries: 3, + backoff: 1 * time.Second, + maxJitter: 1 * time.Second, + onRetry: func(n uint, err error) {}, // no-op onRetry by default + retryIf: func(err error) bool { return err != nil }, // retry on any error by default + timer: &timerImpl{}, + context: context.Background(), + } +} + +// RetryableFuncWithResponse represents a function that returns an HTTP response or an error. +type RetryableFuncWithResponse func() (*http.Response, error) + +// Retry retries the provided retryableFunc according to the retry configuration options. +func Retry(retryableFunc RetryableFuncWithResponse, options ...RetryOption) (*http.Response, error) { + opts := newDefaultRetryConfig() + + for _, opt := range options { + if opt != nil { + opt(opts) + } + } + + var lastErr error + for n := uint(0); n < opts.retries; n++ { + if err := opts.context.Err(); err != nil { + return nil, err + } + + resp, err := retryableFunc() + if err == nil && resp != nil && resp.StatusCode == http.StatusOK { + return resp, nil + } + + if err == nil && resp != nil { + err = errors.New(resp.Status) + } + + if !opts.retryIf(err) { + return nil, err + } + + lastErr = err + opts.onRetry(n+1, err) + + backoffDuration := opts.backoff + if opts.maxJitter > 0 { + jitter := time.Duration(rand.Int63n(int64(opts.maxJitter))) + backoffDuration += jitter + } + + select { + case <-opts.timer.After(backoffDuration): + case <-opts.context.Done(): + return nil, opts.context.Err() + } + } + + return nil, lastErr +} + +// WithRetries sets the number of retries for the retry configuration. +func WithRetries(retries uint) RetryOption { + return func(cfg *RetryConfig) { + cfg.retries = retries + } +} + +// WithBackoff sets the backoff duration between retries. +func WithBackoff(backoff time.Duration) RetryOption { + return func(cfg *RetryConfig) { + cfg.backoff = backoff + } +} + +// WithMaxJitter sets the maximum jitter duration to add to the backoff. +func WithMaxJitter(maxJitter time.Duration) RetryOption { + return func(cfg *RetryConfig) { + cfg.maxJitter = maxJitter + } +} + +// WithOnRetry sets the callback function to execute on each retry. +func WithOnRetry(onRetry OnRetryFunc) RetryOption { + return func(cfg *RetryConfig) { + cfg.onRetry = onRetry + } +} + +// WithRetryIf sets the condition to determine whether to retry based on the error. +func WithRetryIf(retryIf RetryIfFunc) RetryOption { + return func(cfg *RetryConfig) { + cfg.retryIf = retryIf + } +} + +// Custom wrapper for http.Client that implements the Doer interface. +// - Do +// - DoWithRetry +type HttpClient struct { + httpClient *http.Client +} + +func NewHttpClient() *HttpClient { + return &HttpClient{ + httpClient: &http.Client{}, + } +} + +func (c *HttpClient) Do(req *http.Request) (*http.Response, error) { + return c.httpClient.Do(req) +} + +func (c *HttpClient) DoWithRetry(req *http.Request) (*http.Response, error) { + return Retry( + func() (*http.Response, error) { + return c.httpClient.Do(req) + }, + ) +} diff --git a/pkg/models/types.go b/pkg/models/types.go index c29e141..f0436ab 100644 --- a/pkg/models/types.go +++ b/pkg/models/types.go @@ -32,4 +32,5 @@ const ( type Doer interface { Do(req *http.Request) (*http.Response, error) + DoWithRetry(req *http.Request) (*http.Response, error) }