diff --git a/go.mod b/go.mod index 24fe4a9..0fc1eb3 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,10 @@ go 1.22.1 require ( github.com/google/uuid v1.5.0 github.com/ksysoev/ratestor v0.1.0 + github.com/sony/gobreaker/v2 v2.0.0 github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e + golang.org/x/sync v0.7.0 nhooyr.io/websocket v1.8.11 ) @@ -14,6 +16,5 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect - golang.org/x/sync v0.7.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 67a5853..46f1ab2 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/ksysoev/ratestor v0.1.0 h1:zAlHYNXHyfwj78TnjUF6FHyYwkMcZxxxGul2DRhF4/ github.com/ksysoev/ratestor v0.1.0/go.mod h1:ZJ3MX2d9JtBetKh9WMLvn0ESotKPJnl9rEU/qetyObk= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sony/gobreaker/v2 v2.0.0 h1:23AaR4JQ65y4rz8JWMzgXw2gKOykZ/qfqYunll4OwJ4= +github.com/sony/gobreaker/v2 v2.0.0/go.mod h1:8JnRUz80DJ1/ne8M8v7nmTs2713i58nIt4s7XcGe/DI= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/middleware/request/circuit_breaker.go b/middleware/request/circuit_breaker.go index 2514ccc..8008eb1 100644 --- a/middleware/request/circuit_breaker.go +++ b/middleware/request/circuit_breaker.go @@ -2,11 +2,11 @@ package request import ( "fmt" - "sync" "time" "github.com/ksysoev/wasabi" "github.com/ksysoev/wasabi/dispatch" + "github.com/sony/gobreaker/v2" ) type CircuitBreakerState uint8 @@ -24,78 +24,35 @@ const ( // It returns a function that wraps the provided `wasabi.RequestHandler` and implements the circuit breaker logic. // The circuit breaker monitors the number of errors and successful requests within a given time period. // If the number of errors exceeds the threshold, the circuit breaker switches to the "Open" state and rejects subsequent requests. -// After a specified number of successful requests, the circuit breaker switches back to the "Closed" state. -// The circuit breaker uses a lock to ensure thread safety. -// The `treshold` parameter specifies the maximum number of errors allowed within the time period. +// After a set amount of period, the circuit breaker switches to the +// "Semi-open" state. +// If request succeeds in "Semi-open" state, the state will be changed to +// "Closed", else back to "Open". +// The `threshold` parameter specifies the maximum number of errors allowed within the time period. // The `period` parameter specifies the duration of the time period. -// The `recoverAfter` parameter specifies the number of successful requests required to recover from the "Open" state. // The returned function can be used as middleware in a Wasabi server. -func NewCircuitBreakerMiddleware(treshold uint, period time.Duration, recoverAfter uint) func(next wasabi.RequestHandler) wasabi.RequestHandler { - var errorCounter, successCounter uint - - intervalEnds := time.Now().Add(period) - state := Closed - - lock := &sync.RWMutex{} - semiOpenLock := &sync.Mutex{} +func NewCircuitBreakerMiddleware(threshold uint, period time.Duration) func(next wasabi.RequestHandler) wasabi.RequestHandler { + var st gobreaker.Settings + st.Timeout = period + st.ReadyToTrip = func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures >= uint32(threshold) + } + cb := gobreaker.NewCircuitBreaker[any](st) return func(next wasabi.RequestHandler) wasabi.RequestHandler { return dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error { - lock.RLock() - currentState := state - lock.RUnlock() - - switch currentState { - case Closed: - err := next.Handle(conn, req) - if err == nil { - return nil - } - - lock.Lock() - defer lock.Unlock() - - now := time.Now() - if intervalEnds.Before(time.Now()) { - intervalEnds = now.Add(period) - errorCounter = 0 - } - - errorCounter++ - if errorCounter >= treshold { - state = Open - } - - return err - case Open: - if !semiOpenLock.TryLock() { - return ErrCircuitBreakerOpen - } - - defer semiOpenLock.Unlock() - + _, err := cb.Execute(func() (any, error) { err := next.Handle(conn, req) - - lock.Lock() - defer lock.Unlock() - if err != nil { - successCounter = 0 - return err + return nil, err } - - successCounter++ - - if successCounter >= recoverAfter { - state = Closed - errorCounter = 0 - successCounter = 0 - } - - return nil - default: - panic("Unknown state of circuit breaker") + return struct{}{}, nil + }) + if err != nil { + return err } + return nil }) } + } diff --git a/middleware/request/circuit_breaker_test.go b/middleware/request/circuit_breaker_test.go index 79fcb33..07b5dce 100644 --- a/middleware/request/circuit_breaker_test.go +++ b/middleware/request/circuit_breaker_test.go @@ -8,12 +8,12 @@ import ( "github.com/ksysoev/wasabi" "github.com/ksysoev/wasabi/dispatch" "github.com/ksysoev/wasabi/mocks" + "github.com/sony/gobreaker/v2" ) func TestNewCircuitBreakerMiddleware_ClosedState(t *testing.T) { - treshold := uint(3) + threshold := uint(3) period := time.Second - recoverAfter := uint(1) // Create a mock request handler mockHandler := dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error { return nil }) @@ -21,10 +21,10 @@ func TestNewCircuitBreakerMiddleware_ClosedState(t *testing.T) { mockConn := mocks.NewMockConnection(t) // Create the circuit breaker middleware - middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler) + middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler) // Test the Closed state - for i := uint(0); i < treshold+1; i++ { + for i := uint(0); i < threshold+1; i++ { err := middleware.Handle(mockConn, mockRequest) if err != nil { t.Errorf("Expected no error, but got %v", err) @@ -33,9 +33,8 @@ func TestNewCircuitBreakerMiddleware_ClosedState(t *testing.T) { } func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) { - treshold := uint(1) + threshold := uint(1) period := time.Second - recoverAfter := uint(1) testError := fmt.Errorf("test error") @@ -49,7 +48,7 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) { mockConn := mocks.NewMockConnection(t) // Create the circuit breaker middleware - middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler) + middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler) // Bring the circuit breaker to the Open state err := middleware.Handle(mockConn, mockRequest) @@ -64,6 +63,8 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) { go func() { results <- middleware.Handle(mockConn, mockRequest) }() + // Wait out circuit break to change state + time.Sleep(period) } OpenErrorCount := 0 @@ -72,12 +73,12 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) { for i := 0; i < 2; i++ { select { case err := <-results: - if err != ErrCircuitBreakerOpen && err != testError { - t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err) + if err != gobreaker.ErrOpenState && err != testError { + t.Errorf("Expected error %v, but got %v", gobreaker.ErrOpenState, err) continue } - if err == ErrCircuitBreakerOpen { + if err == gobreaker.ErrOpenState { OpenErrorCount++ } else if err == testError { TestErrorCount++ @@ -89,7 +90,7 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) { } if OpenErrorCount != 1 { - t.Errorf("Expected 1 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount) + t.Errorf("Expected 1 gobreaker.ErrOpenState error, but got %d", OpenErrorCount) } if TestErrorCount != 1 { @@ -98,9 +99,8 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) { } func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) { - treshold := uint(1) + threshold := uint(1) period := time.Second - recoverAfter := uint(1) testError := fmt.Errorf("test error") @@ -116,7 +116,7 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) { mockConn := mocks.NewMockConnection(t) // Create the circuit breaker middleware - middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler) + middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler) // Bring the circuit breaker to the Open state err := middleware.Handle(mockConn, mockRequest) @@ -134,17 +134,19 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) { go func() { results <- middleware.Handle(mockConn, mockRequest) }() + // Wait out circuit breaker to change state + time.Sleep(period) } for i := 0; i < 2; i++ { select { case err := <-results: - if err != ErrCircuitBreakerOpen && err != nil { - t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err) + if err != gobreaker.ErrOpenState && err != nil { + t.Errorf("Expected error %v, but got %v", gobreaker.ErrOpenState, err) continue } - if err == ErrCircuitBreakerOpen { + if err == gobreaker.ErrOpenState { OpenErrorCount++ } else if err == nil { SuccessCount++ @@ -156,7 +158,7 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) { } if OpenErrorCount != 1 { - t.Errorf("Expected 1 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount) + t.Errorf("Expected 1 gobreaker.ErrOpenState error, but got %d", OpenErrorCount) } if SuccessCount != 1 { @@ -177,12 +179,12 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) { for i := 0; i < 2; i++ { select { case err := <-results: - if err != ErrCircuitBreakerOpen && err != nil { - t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err) + if err != gobreaker.ErrOpenState && err != nil { + t.Errorf("Expected error %v, but got %v", gobreaker.ErrOpenState, err) continue } - if err == ErrCircuitBreakerOpen { + if err == gobreaker.ErrOpenState { OpenErrorCount++ } else if err == nil { SuccessCount++ @@ -194,7 +196,7 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) { } if OpenErrorCount != 0 { - t.Errorf("Expected 0 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount) + t.Errorf("Expected 0 gobreaker.ErrOpenState error, but got %d", OpenErrorCount) } if SuccessCount != 2 { @@ -203,9 +205,8 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) { } func TestNewCircuitBreakerMiddleware_ResetMeasureInterval(t *testing.T) { - treshold := uint(2) + threshold := uint(2) period := 20 * time.Millisecond - recoverAfter := uint(1) testError := fmt.Errorf("test error") @@ -221,58 +222,51 @@ func TestNewCircuitBreakerMiddleware_ResetMeasureInterval(t *testing.T) { mockConn := mocks.NewMockConnection(t) // Create the circuit breaker middleware - middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler) + middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler) // Bring the circuit breaker to the Open state - if err := middleware.Handle(mockConn, mockRequest); err != testError { - t.Errorf("Expected error %v, but got %v", testError, err) + for i := uint(0); i < threshold; i++ { + if err := middleware.Handle(mockConn, mockRequest); err != testError { + t.Errorf("Expected error %v, but got %v", testError, err) + } } + // Confirm that the circuit breaker is now in the Semi-open state time.Sleep(period) - if err := middleware.Handle(mockConn, mockRequest); err != testError { - t.Errorf("Expected error %v, but got %v", testError, err) - } - - // Confirm that the circuit breaker is now in the Closed state - errorToReturn = nil results := make(chan error) - for i := 0; i < 2; i++ { - go func() { - results <- middleware.Handle(mockConn, mockRequest) - }() - } + go func() { + results <- middleware.Handle(mockConn, mockRequest) + }() OpenErrorCount := 0 SuccessCount := 0 - for i := 0; i < 2; i++ { - select { - case err := <-results: - if err != ErrCircuitBreakerOpen && err != nil { - t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err) - continue - } - - if err == ErrCircuitBreakerOpen { - OpenErrorCount++ - } else if err == nil { - SuccessCount++ - } + select { + case err := <-results: + fmt.Println(err) + if err != gobreaker.ErrOpenState && err != nil { + t.Errorf("Expected error %v, but got %v", gobreaker.ErrOpenState, err) + } - case <-time.After(100 * time.Millisecond): - t.Fatal("Expected error, but got none") + if err == gobreaker.ErrOpenState { + OpenErrorCount++ + } else if err == nil { + SuccessCount++ } + + case <-time.After(100 * time.Millisecond): + t.Fatal("Expected error, but got none") } if OpenErrorCount != 0 { - t.Errorf("Expected 0 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount) + t.Errorf("Expected 0 gobreaker.ErrOpenState error, but got %d", OpenErrorCount) } - if SuccessCount != 2 { - t.Errorf("Expected 2 test error, but got %d", SuccessCount) + if SuccessCount != 1 { + t.Errorf("Expected 1 test error, but got %d", SuccessCount) } }