Skip to content

Commit

Permalink
feat: add request retry (#17)
Browse files Browse the repository at this point in the history
* feat: implement retry module

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

* feat: return body so it can be used

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

* fix: remove withContext functional creator

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

* feat: add retry strategy to embedding request

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

* fix: return http response instead of the body

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

* fix: use http response instead of the body

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

* feat: use retry strategy in the text and text stream requests

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

* fix: avoid null pointer exception

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

* fix: error check when create new request in embeddings

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

* feat: add http client wrapper with retry and include DoWithRetry method in Doer interface

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

* feat: use the http wrapper in the client constructor method

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

* feat: make requests with DoWithRetry

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

* fix: set 1 second default value for maxJitter

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>

---------

Signed-off-by: PabloSanchi <pablo.sanchi.herrera@gmail.com>
  • Loading branch information
PabloSanchi authored Sep 3, 2024
1 parent 36dd5c2 commit ab5e97d
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 39 deletions.
110 changes: 110 additions & 0 deletions pkg/internal/tests/models/retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package test

import (
"encoding/json"
"log"
"net/http"
"net/http/httptest"
"testing"
"time"

wx "github.com/IBM/watsonx-go/pkg/models"
)

// TestRetryWithSuccessOnFirstRequest tests the retry mechanism with a server that always returns a 200 status code.
func TestRetryWithSuccessOnFirstRequest(t *testing.T) {
type ResponseType struct {
Content string `json:"content"`
Status int `json:"status"`
}

expectedResponse := ResponseType{Content: "success"}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"content":"success"}`))
}))
defer server.Close()

var retryCount uint = 0
var expectedRetries uint = 0

sendRequest := func() (*http.Response, error) {
return http.Get(server.URL + "/success")
}

resp, err := wx.Retry(
sendRequest,
wx.WithOnRetry(func(n uint, err error) {
retryCount = n
log.Printf("Retrying request after error: %v", err)
}),
)

if err != nil {
t.Errorf("Expected nil, got error: %v", err)
}

if retryCount != expectedRetries {
t.Errorf("Expected 0 retries, but got %d", retryCount)
}

defer resp.Body.Close()
var response ResponseType
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
t.Errorf("Failed to unmarshal response body: %v", err)
}

if response != expectedResponse {
t.Errorf("Expected response %v, but got %v", expectedResponse, response)
}
}

// TestRetryWithNoSuccessStatusOnAnyRequest tests the retry mechanism with a server that always returns a 429 status code.
func TestRetryWithNoSuccessStatusOnAnyRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
}))
defer server.Close()

var backoffTime = 2 * time.Second
var retryCount uint = 0
var expectedRetries uint = 3

sendRequest := func() (*http.Response, error) {
return http.Get(server.URL + "/notfound")
}

startTime := time.Now()

resp, err := wx.Retry(
sendRequest,
wx.WithBackoff(backoffTime),
wx.WithOnRetry(func(n uint, err error) {
retryCount = n
log.Printf("Retrying request after error: %v", err)
}),
)

endTime := time.Now()

elapsedTime := endTime.Sub(startTime)
expectedMinimumTime := backoffTime * time.Duration(expectedRetries)

if err == nil {
t.Errorf("Expected error, got nil")
}

if resp != nil {
defer resp.Body.Close()
t.Errorf("Expected nil response, got %v", resp.Body)
}

if retryCount != expectedRetries {
t.Errorf("Expected 3 retries, but got %d", retryCount)
}

if elapsedTime < expectedMinimumTime {
t.Errorf("Expected minimum time of %v, but got %v", expectedMinimumTime, elapsedTime)
}
}
3 changes: 1 addition & 2 deletions pkg/models/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package models
import (
"errors"
"fmt"
"net/http"
"net/url"
"os"
)
Expand Down Expand Up @@ -62,7 +61,7 @@ func NewClient(options ...ClientOption) (*Client, error) {
apiKey: opts.apiKey,
projectID: opts.projectID,

httpClient: &http.Client{},
httpClient: NewHttpClient(),
}

err := m.RefreshToken()
Expand Down
16 changes: 4 additions & 12 deletions pkg/models/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
)
Expand Down Expand Up @@ -85,23 +83,17 @@ func (m *Client) generateEmbeddingRequest(payload EmbeddingPayload) (embeddingRe
}

req, err := http.NewRequest(http.MethodPost, embeddingUrl, bytes.NewBuffer(payloadJSON))
if err != nil {
return embeddingResponse{}, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+m.token.value)

res, err := m.httpClient.Do(req)
res, err := m.httpClient.DoWithRetry(req)
if err != nil {
return embeddingResponse{}, err
}

statusCode := res.StatusCode
if statusCode != http.StatusOK {
body, err := io.ReadAll(res.Body)
if err != nil {
return embeddingResponse{}, fmt.Errorf("request failed with status code %d", statusCode)
}
return embeddingResponse{}, fmt.Errorf("request failed with status code %d and error %s", statusCode, body)
}
defer res.Body.Close()

var embeddingRes embeddingResponse
Expand Down
28 changes: 3 additions & 25 deletions pkg/models/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
Expand Down Expand Up @@ -105,20 +103,10 @@ func (m *Client) generateTextRequest(payload GenerateTextPayload) (generateTextR
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+m.token.value)

res, err := m.httpClient.Do(req)
res, err := m.httpClient.DoWithRetry(req)
if err != nil {
return generateTextResponse{}, err
}

statusCode := res.StatusCode

if statusCode < 200 || statusCode >= 300 {
body, err := io.ReadAll(res.Body)
if err != nil {
return generateTextResponse{}, fmt.Errorf("request failed with status code %d", statusCode)
}
return generateTextResponse{}, fmt.Errorf("request failed with status code %d and error %s", statusCode, body)
}
defer res.Body.Close()

var generateRes generateTextResponse
Expand Down Expand Up @@ -197,23 +185,13 @@ func (m *Client) generateTextStreamRequest(payload GenerateTextPayload) (<-chan
req.Header.Set("Authorization", "Bearer "+m.token.value)
req.Header.Set("Accept", "text/event-stream")

res, err := m.httpClient.Do(req)
res, err := m.httpClient.DoWithRetry(req)
if err != nil {
log.Println("error making request: ", err)
return
}
defer res.Body.Close()

if res.StatusCode != http.StatusOK {
body, err := io.ReadAll(res.Body)
if err != nil {
log.Printf("request failed with status code %d", res.StatusCode)
} else {
log.Printf("request failed with status code %d and error %s", res.StatusCode, body)
}
return
}

defer res.Body.Close()
scanner := bufio.NewScanner(res.Body)
for scanner.Scan() {
line := scanner.Text()
Expand Down
165 changes: 165 additions & 0 deletions pkg/models/retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package models

import (
"context"
"errors"
"math/rand"
"net/http"
"time"
)

// OnRetryFunc is a function type that is called on each retry attempt.
type OnRetryFunc func(attempt uint, err error)

// Timer interface to abstract time-based operations for retries.
type Timer interface {
After(time.Duration) <-chan time.Time
}

// RetryIfFunc determines whether a retry should be attempted based on the error.
type RetryIfFunc func(error) bool

// RetryConfig contains configuration options for the retry mechanism.
type RetryConfig struct {
retries uint
backoff time.Duration
maxJitter time.Duration
onRetry OnRetryFunc
retryIf RetryIfFunc
timer Timer
context context.Context
}

// RetryOption is a function type for modifying RetryConfig options.
type RetryOption func(*RetryConfig)

// timerImpl implements the Timer interface using time.After.
type timerImpl struct{}

func (t timerImpl) After(d time.Duration) <-chan time.Time {
return time.After(d)
}

// newDefaultRetryConfig creates a default RetryConfig with sensible defaults.
func newDefaultRetryConfig() *RetryConfig {
return &RetryConfig{
retries: 3,
backoff: 1 * time.Second,
maxJitter: 1 * time.Second,
onRetry: func(n uint, err error) {}, // no-op onRetry by default
retryIf: func(err error) bool { return err != nil }, // retry on any error by default
timer: &timerImpl{},
context: context.Background(),
}
}

// RetryableFuncWithResponse represents a function that returns an HTTP response or an error.
type RetryableFuncWithResponse func() (*http.Response, error)

// Retry retries the provided retryableFunc according to the retry configuration options.
func Retry(retryableFunc RetryableFuncWithResponse, options ...RetryOption) (*http.Response, error) {
opts := newDefaultRetryConfig()

for _, opt := range options {
if opt != nil {
opt(opts)
}
}

var lastErr error
for n := uint(0); n < opts.retries; n++ {
if err := opts.context.Err(); err != nil {
return nil, err
}

resp, err := retryableFunc()
if err == nil && resp != nil && resp.StatusCode == http.StatusOK {
return resp, nil
}

if err == nil && resp != nil {
err = errors.New(resp.Status)
}

if !opts.retryIf(err) {
return nil, err
}

lastErr = err
opts.onRetry(n+1, err)

backoffDuration := opts.backoff
if opts.maxJitter > 0 {
jitter := time.Duration(rand.Int63n(int64(opts.maxJitter)))
backoffDuration += jitter
}

select {
case <-opts.timer.After(backoffDuration):
case <-opts.context.Done():
return nil, opts.context.Err()
}
}

return nil, lastErr
}

// WithRetries sets the number of retries for the retry configuration.
func WithRetries(retries uint) RetryOption {
return func(cfg *RetryConfig) {
cfg.retries = retries
}
}

// WithBackoff sets the backoff duration between retries.
func WithBackoff(backoff time.Duration) RetryOption {
return func(cfg *RetryConfig) {
cfg.backoff = backoff
}
}

// WithMaxJitter sets the maximum jitter duration to add to the backoff.
func WithMaxJitter(maxJitter time.Duration) RetryOption {
return func(cfg *RetryConfig) {
cfg.maxJitter = maxJitter
}
}

// WithOnRetry sets the callback function to execute on each retry.
func WithOnRetry(onRetry OnRetryFunc) RetryOption {
return func(cfg *RetryConfig) {
cfg.onRetry = onRetry
}
}

// WithRetryIf sets the condition to determine whether to retry based on the error.
func WithRetryIf(retryIf RetryIfFunc) RetryOption {
return func(cfg *RetryConfig) {
cfg.retryIf = retryIf
}
}

// Custom wrapper for http.Client that implements the Doer interface.
// - Do
// - DoWithRetry
type HttpClient struct {
httpClient *http.Client
}

func NewHttpClient() *HttpClient {
return &HttpClient{
httpClient: &http.Client{},
}
}

func (c *HttpClient) Do(req *http.Request) (*http.Response, error) {
return c.httpClient.Do(req)
}

func (c *HttpClient) DoWithRetry(req *http.Request) (*http.Response, error) {
return Retry(
func() (*http.Response, error) {
return c.httpClient.Do(req)
},
)
}
1 change: 1 addition & 0 deletions pkg/models/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ const (

type Doer interface {
Do(req *http.Request) (*http.Response, error)
DoWithRetry(req *http.Request) (*http.Response, error)
}

0 comments on commit ab5e97d

Please sign in to comment.