diff --git a/go.mod b/go.mod index 6d61ae6..90fe0ad 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/google/uuid v1.5.0 github.com/jellydator/ttlcache/v3 v3.2.0 github.com/ksysoev/ratestor v0.1.0 + github.com/sony/gobreaker/v2 v2.0.0 github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e golang.org/x/sync v0.7.0 diff --git a/go.sum b/go.sum index 01e1670..d15e2d4 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/ksysoev/ratestor v0.1.0 h1:zAlHYNXHyfwj78TnjUF6FHyYwkMcZxxxGul2DRhF4/ github.com/ksysoev/ratestor v0.1.0/go.mod h1:ZJ3MX2d9JtBetKh9WMLvn0ESotKPJnl9rEU/qetyObk= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sony/gobreaker/v2 v2.0.0 h1:23AaR4JQ65y4rz8JWMzgXw2gKOykZ/qfqYunll4OwJ4= +github.com/sony/gobreaker/v2 v2.0.0/go.mod h1:8JnRUz80DJ1/ne8M8v7nmTs2713i58nIt4s7XcGe/DI= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/middleware/request/circuit_breaker.go b/middleware/request/circuit_breaker.go index 2514ccc..53acd6e 100644 --- a/middleware/request/circuit_breaker.go +++ b/middleware/request/circuit_breaker.go @@ -1,101 +1,57 @@ package request import ( + "errors" "fmt" - "sync" "time" "github.com/ksysoev/wasabi" "github.com/ksysoev/wasabi/dispatch" + "github.com/sony/gobreaker/v2" ) -type CircuitBreakerState uint8 - var ( ErrCircuitBreakerOpen = fmt.Errorf("circuit breaker is open") ) -const ( - Closed CircuitBreakerState = iota - Open -) - // NewCircuitBreakerMiddleware creates a new circuit breaker middleware with the specified parameters. // It returns a function that wraps the provided `wasabi.RequestHandler` and implements the circuit breaker logic. // The circuit breaker monitors the number of errors and successful requests within a given time period. // If the number of errors exceeds the threshold, the circuit breaker switches to the "Open" state and rejects subsequent requests. -// After a specified number of successful requests, the circuit breaker switches back to the "Closed" state. -// The circuit breaker uses a lock to ensure thread safety. -// The `treshold` parameter specifies the maximum number of errors allowed within the time period. +// After a set amount of period, the circuit breaker switches to the +// "Semi-open" state. +// If request succeeds in "Semi-open" state, the state will be changed to +// "Closed", else back to "Open". +// The `threshold` parameter specifies the maximum number of errors allowed within the time period. // The `period` parameter specifies the duration of the time period. -// The `recoverAfter` parameter specifies the number of successful requests required to recover from the "Open" state. // The returned function can be used as middleware in a Wasabi server. -func NewCircuitBreakerMiddleware(treshold uint, period time.Duration, recoverAfter uint) func(next wasabi.RequestHandler) wasabi.RequestHandler { - var errorCounter, successCounter uint - - intervalEnds := time.Now().Add(period) - state := Closed - - lock := &sync.RWMutex{} - semiOpenLock := &sync.Mutex{} +func NewCircuitBreakerMiddleware(threshold uint32, period time.Duration) func(next wasabi.RequestHandler) wasabi.RequestHandler { + var st gobreaker.Settings + st.Timeout = period + st.ReadyToTrip = func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures >= threshold + } + cb := gobreaker.NewCircuitBreaker[any](st) return func(next wasabi.RequestHandler) wasabi.RequestHandler { return dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error { - lock.RLock() - currentState := state - lock.RUnlock() - - switch currentState { - case Closed: - err := next.Handle(conn, req) - if err == nil { - return nil - } - - lock.Lock() - defer lock.Unlock() - - now := time.Now() - if intervalEnds.Before(time.Now()) { - intervalEnds = now.Add(period) - errorCounter = 0 - } - - errorCounter++ - if errorCounter >= treshold { - state = Open - } - - return err - case Open: - if !semiOpenLock.TryLock() { - return ErrCircuitBreakerOpen - } - - defer semiOpenLock.Unlock() - + _, err := cb.Execute(func() (any, error) { err := next.Handle(conn, req) - - lock.Lock() - defer lock.Unlock() - if err != nil { - successCounter = 0 - return err + return nil, err } - successCounter++ - - if successCounter >= recoverAfter { - state = Closed - errorCounter = 0 - successCounter = 0 + return struct{}{}, nil + }) + if err != nil { + if errors.Is(err, gobreaker.ErrOpenState) { + return ErrCircuitBreakerOpen } - return nil - default: - panic("Unknown state of circuit breaker") + return err } + + return nil }) } } diff --git a/middleware/request/circuit_breaker_test.go b/middleware/request/circuit_breaker_test.go index 79fcb33..12f9c4d 100644 --- a/middleware/request/circuit_breaker_test.go +++ b/middleware/request/circuit_breaker_test.go @@ -11,9 +11,8 @@ import ( ) func TestNewCircuitBreakerMiddleware_ClosedState(t *testing.T) { - treshold := uint(3) + threshold := uint32(3) period := time.Second - recoverAfter := uint(1) // Create a mock request handler mockHandler := dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error { return nil }) @@ -21,10 +20,10 @@ func TestNewCircuitBreakerMiddleware_ClosedState(t *testing.T) { mockConn := mocks.NewMockConnection(t) // Create the circuit breaker middleware - middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler) + middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler) // Test the Closed state - for i := uint(0); i < treshold+1; i++ { + for i := uint32(0); i < threshold+1; i++ { err := middleware.Handle(mockConn, mockRequest) if err != nil { t.Errorf("Expected no error, but got %v", err) @@ -33,9 +32,8 @@ func TestNewCircuitBreakerMiddleware_ClosedState(t *testing.T) { } func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) { - treshold := uint(1) + threshold := uint32(1) period := time.Second - recoverAfter := uint(1) testError := fmt.Errorf("test error") @@ -49,7 +47,7 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) { mockConn := mocks.NewMockConnection(t) // Create the circuit breaker middleware - middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler) + middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler) // Bring the circuit breaker to the Open state err := middleware.Handle(mockConn, mockRequest) @@ -64,6 +62,8 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) { go func() { results <- middleware.Handle(mockConn, mockRequest) }() + // Wait out circuit break to change state + time.Sleep(period) } OpenErrorCount := 0 @@ -98,9 +98,8 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) { } func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) { - treshold := uint(1) + threshold := uint32(1) period := time.Second - recoverAfter := uint(1) testError := fmt.Errorf("test error") @@ -116,7 +115,7 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) { mockConn := mocks.NewMockConnection(t) // Create the circuit breaker middleware - middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler) + middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler) // Bring the circuit breaker to the Open state err := middleware.Handle(mockConn, mockRequest) @@ -134,6 +133,8 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) { go func() { results <- middleware.Handle(mockConn, mockRequest) }() + // Wait out circuit breaker to change state + time.Sleep(period) } for i := 0; i < 2; i++ { @@ -203,9 +204,8 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) { } func TestNewCircuitBreakerMiddleware_ResetMeasureInterval(t *testing.T) { - treshold := uint(2) + threshold := uint32(2) period := 20 * time.Millisecond - recoverAfter := uint(1) testError := fmt.Errorf("test error") @@ -221,58 +221,50 @@ func TestNewCircuitBreakerMiddleware_ResetMeasureInterval(t *testing.T) { mockConn := mocks.NewMockConnection(t) // Create the circuit breaker middleware - middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler) + middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler) // Bring the circuit breaker to the Open state - if err := middleware.Handle(mockConn, mockRequest); err != testError { - t.Errorf("Expected error %v, but got %v", testError, err) + for i := uint32(0); i < threshold; i++ { + if err := middleware.Handle(mockConn, mockRequest); err != testError { + t.Errorf("Expected error %v, but got %v", testError, err) + } } + // Confirm that the circuit breaker is now in the Semi-open state time.Sleep(period) - if err := middleware.Handle(mockConn, mockRequest); err != testError { - t.Errorf("Expected error %v, but got %v", testError, err) - } - - // Confirm that the circuit breaker is now in the Closed state - errorToReturn = nil results := make(chan error) - for i := 0; i < 2; i++ { - go func() { - results <- middleware.Handle(mockConn, mockRequest) - }() - } + go func() { + results <- middleware.Handle(mockConn, mockRequest) + }() OpenErrorCount := 0 SuccessCount := 0 - for i := 0; i < 2; i++ { - select { - case err := <-results: - if err != ErrCircuitBreakerOpen && err != nil { - t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err) - continue - } - - if err == ErrCircuitBreakerOpen { - OpenErrorCount++ - } else if err == nil { - SuccessCount++ - } + select { + case err := <-results: + if err != ErrCircuitBreakerOpen && err != nil { + t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err) + } - case <-time.After(100 * time.Millisecond): - t.Fatal("Expected error, but got none") + if err == ErrCircuitBreakerOpen { + OpenErrorCount++ + } else if err == nil { + SuccessCount++ } + + case <-time.After(100 * time.Millisecond): + t.Fatal("Expected error, but got none") } if OpenErrorCount != 0 { t.Errorf("Expected 0 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount) } - if SuccessCount != 2 { - t.Errorf("Expected 2 test error, but got %d", SuccessCount) + if SuccessCount != 1 { + t.Errorf("Expected 1 test error, but got %d", SuccessCount) } }