Skip to content

Commit

Permalink
Merge pull request #93 from ksysoev/89-retry-middleware
Browse files Browse the repository at this point in the history
Add retry middleware for handling request retries
  • Loading branch information
ksysoev authored Jun 22, 2024
2 parents 16225d8 + c2217ad commit 7973326
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
37 changes: 37 additions & 0 deletions middleware/request/retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package request

import (
"time"

"github.com/ksysoev/wasabi"
"github.com/ksysoev/wasabi/dispatch"
)

// NewRetryMiddleware returns a new retry middleware that wraps the provided `next` request handler.
// The middleware retries the request a maximum of `maxRetries` times with a delay of `interval` between each retry.
// If the request succeeds at any retry, the middleware returns `nil`. If all retries fail, it returns the last error encountered.
func NewRetryMiddleware(maxRetries int, interval time.Duration) func(next wasabi.RequestHandler) wasabi.RequestHandler {
return func(next wasabi.RequestHandler) wasabi.RequestHandler {
return dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error {
var err error
ticker := time.NewTicker(interval)
defer ticker.Stop()
for i := 0; i < maxRetries; i++ {
err = next.Handle(conn, req)
if err == nil {
return nil
}

ticker.Reset(interval)

select {
case <-req.Context().Done():
return req.Context().Err()
case <-ticker.C:
}
}

return err
})
}
}
71 changes: 71 additions & 0 deletions middleware/request/retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package request

import (
"context"
"fmt"
"testing"
"time"

"github.com/ksysoev/wasabi"
"github.com/ksysoev/wasabi/dispatch"
"github.com/ksysoev/wasabi/mocks"
)

func TestNewRetryMiddleware(t *testing.T) {
maxRetries := 3
interval := time.Microsecond
middleware := NewRetryMiddleware(maxRetries, interval)

// Create a mock request handler
mockHandler := dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error {
return fmt.Errorf("mock error")
})

ctx := context.Background()

// Create a mock connection and request
mockConn := mocks.NewMockConnection(t)
mockReq := mocks.NewMockRequest(t)

mockReq.EXPECT().Context().Return(ctx)

// Test with successful request
mockHandlerSuccess := dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error {
return nil
})

if err := middleware(mockHandlerSuccess).Handle(mockConn, mockReq); err != nil {
t.Errorf("Expected no error, but got %v", err)
}

// Test with failed request
if err := middleware(mockHandler).Handle(mockConn, mockReq); err == nil {
t.Error("Expected error, but got nil")
}
}

func TestNewRetryMiddleware_CancelledContext(t *testing.T) {
maxRetries := 3
interval := time.Microsecond
middleware := NewRetryMiddleware(maxRetries, interval)

// Create a mock request handler
mockHandler := dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error {
return fmt.Errorf("mock error")
})

ctx, cancel := context.WithCancel(context.Background())
cancel()

// Create a mock connection and request
mockConn := mocks.NewMockConnection(t)
mockReq := mocks.NewMockRequest(t)

mockReq.EXPECT().Context().Return(ctx)

// Test with failed request
err := middleware(mockHandler).Handle(mockConn, mockReq)
if err != context.Canceled {
t.Errorf("Expected error to be context.Canceled, but got %v", err)
}
}

0 comments on commit 7973326

Please sign in to comment.