Skip to content

Commit

Permalink
Merge pull request #98 from KianYang-Lee/feature/replace-circuit-brea…
Browse files Browse the repository at this point in the history
…ker-with-library

Replace circuit breaker with library
  • Loading branch information
ksysoev committed Jun 25, 2024
2 parents 48d0786 + a6be774 commit 88991aa
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 112 deletions.
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ go 1.22.1
require (
github.com/google/uuid v1.5.0
github.com/ksysoev/ratestor v0.1.0
github.com/sony/gobreaker/v2 v2.0.0
github.com/stretchr/testify v1.9.0
golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e
golang.org/x/sync v0.7.0
nhooyr.io/websocket v1.8.11
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
golang.org/x/sync v0.7.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ github.com/ksysoev/ratestor v0.1.0 h1:zAlHYNXHyfwj78TnjUF6FHyYwkMcZxxxGul2DRhF4/
github.com/ksysoev/ratestor v0.1.0/go.mod h1:ZJ3MX2d9JtBetKh9WMLvn0ESotKPJnl9rEU/qetyObk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sony/gobreaker/v2 v2.0.0 h1:23AaR4JQ65y4rz8JWMzgXw2gKOykZ/qfqYunll4OwJ4=
github.com/sony/gobreaker/v2 v2.0.0/go.mod h1:8JnRUz80DJ1/ne8M8v7nmTs2713i58nIt4s7XcGe/DI=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
Expand Down
92 changes: 24 additions & 68 deletions middleware/request/circuit_breaker.go
Original file line number Diff line number Diff line change
@@ -1,101 +1,57 @@
package request

import (
"errors"
"fmt"
"sync"
"time"

"github.com/ksysoev/wasabi"
"github.com/ksysoev/wasabi/dispatch"
"github.com/sony/gobreaker/v2"
)

type CircuitBreakerState uint8

var (
ErrCircuitBreakerOpen = fmt.Errorf("circuit breaker is open")
)

const (
Closed CircuitBreakerState = iota
Open
)

// NewCircuitBreakerMiddleware creates a new circuit breaker middleware with the specified parameters.
// It returns a function that wraps the provided `wasabi.RequestHandler` and implements the circuit breaker logic.
// The circuit breaker monitors the number of errors and successful requests within a given time period.
// If the number of errors exceeds the threshold, the circuit breaker switches to the "Open" state and rejects subsequent requests.
// After a specified number of successful requests, the circuit breaker switches back to the "Closed" state.
// The circuit breaker uses a lock to ensure thread safety.
// The `treshold` parameter specifies the maximum number of errors allowed within the time period.
// After a set amount of period, the circuit breaker switches to the
// "Semi-open" state.
// If request succeeds in "Semi-open" state, the state will be changed to
// "Closed", else back to "Open".
// The `threshold` parameter specifies the maximum number of errors allowed within the time period.
// The `period` parameter specifies the duration of the time period.
// The `recoverAfter` parameter specifies the number of successful requests required to recover from the "Open" state.
// The returned function can be used as middleware in a Wasabi server.
func NewCircuitBreakerMiddleware(treshold uint, period time.Duration, recoverAfter uint) func(next wasabi.RequestHandler) wasabi.RequestHandler {
var errorCounter, successCounter uint

intervalEnds := time.Now().Add(period)
state := Closed

lock := &sync.RWMutex{}
semiOpenLock := &sync.Mutex{}
func NewCircuitBreakerMiddleware(threshold uint32, period time.Duration) func(next wasabi.RequestHandler) wasabi.RequestHandler {
var st gobreaker.Settings
st.Timeout = period
st.ReadyToTrip = func(counts gobreaker.Counts) bool {
return counts.ConsecutiveFailures >= threshold
}
cb := gobreaker.NewCircuitBreaker[any](st)

return func(next wasabi.RequestHandler) wasabi.RequestHandler {
return dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error {
lock.RLock()
currentState := state
lock.RUnlock()

switch currentState {
case Closed:
err := next.Handle(conn, req)
if err == nil {
return nil
}

lock.Lock()
defer lock.Unlock()

now := time.Now()
if intervalEnds.Before(time.Now()) {
intervalEnds = now.Add(period)
errorCounter = 0
}

errorCounter++
if errorCounter >= treshold {
state = Open
}

return err
case Open:
if !semiOpenLock.TryLock() {
return ErrCircuitBreakerOpen
}

defer semiOpenLock.Unlock()

_, err := cb.Execute(func() (any, error) {
err := next.Handle(conn, req)

lock.Lock()
defer lock.Unlock()

if err != nil {
successCounter = 0
return err
return nil, err
}

successCounter++

if successCounter >= recoverAfter {
state = Closed
errorCounter = 0
successCounter = 0
return struct{}{}, nil
})
if err != nil {
if errors.Is(err, gobreaker.ErrOpenState) {
return ErrCircuitBreakerOpen
}

return nil
default:
panic("Unknown state of circuit breaker")
return err
}

return nil
})
}
}
78 changes: 35 additions & 43 deletions middleware/request/circuit_breaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@ import (
)

func TestNewCircuitBreakerMiddleware_ClosedState(t *testing.T) {
treshold := uint(3)
threshold := uint32(3)
period := time.Second
recoverAfter := uint(1)

// Create a mock request handler
mockHandler := dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error { return nil })
mockRequest := mocks.NewMockRequest(t)
mockConn := mocks.NewMockConnection(t)

// Create the circuit breaker middleware
middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler)
middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler)

// Test the Closed state
for i := uint(0); i < treshold+1; i++ {
for i := uint32(0); i < threshold+1; i++ {
err := middleware.Handle(mockConn, mockRequest)
if err != nil {
t.Errorf("Expected no error, but got %v", err)
Expand All @@ -33,9 +32,8 @@ func TestNewCircuitBreakerMiddleware_ClosedState(t *testing.T) {
}

func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
treshold := uint(1)
threshold := uint32(1)
period := time.Second
recoverAfter := uint(1)

testError := fmt.Errorf("test error")

Expand All @@ -49,7 +47,7 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
mockConn := mocks.NewMockConnection(t)

// Create the circuit breaker middleware
middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler)
middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler)

// Bring the circuit breaker to the Open state
err := middleware.Handle(mockConn, mockRequest)
Expand All @@ -64,6 +62,8 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
go func() {
results <- middleware.Handle(mockConn, mockRequest)
}()
// Wait out circuit break to change state
time.Sleep(period)
}

OpenErrorCount := 0
Expand Down Expand Up @@ -98,9 +98,8 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
}

func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
treshold := uint(1)
threshold := uint32(1)
period := time.Second
recoverAfter := uint(1)

testError := fmt.Errorf("test error")

Expand All @@ -116,7 +115,7 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
mockConn := mocks.NewMockConnection(t)

// Create the circuit breaker middleware
middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler)
middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler)

// Bring the circuit breaker to the Open state
err := middleware.Handle(mockConn, mockRequest)
Expand All @@ -134,6 +133,8 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
go func() {
results <- middleware.Handle(mockConn, mockRequest)
}()
// Wait out circuit breaker to change state
time.Sleep(period)
}

for i := 0; i < 2; i++ {
Expand Down Expand Up @@ -203,9 +204,8 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
}

func TestNewCircuitBreakerMiddleware_ResetMeasureInterval(t *testing.T) {
treshold := uint(2)
threshold := uint32(2)
period := 20 * time.Millisecond
recoverAfter := uint(1)

testError := fmt.Errorf("test error")

Expand All @@ -221,58 +221,50 @@ func TestNewCircuitBreakerMiddleware_ResetMeasureInterval(t *testing.T) {
mockConn := mocks.NewMockConnection(t)

// Create the circuit breaker middleware
middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler)
middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler)

// Bring the circuit breaker to the Open state

if err := middleware.Handle(mockConn, mockRequest); err != testError {
t.Errorf("Expected error %v, but got %v", testError, err)
for i := uint32(0); i < threshold; i++ {
if err := middleware.Handle(mockConn, mockRequest); err != testError {
t.Errorf("Expected error %v, but got %v", testError, err)
}
}

// Confirm that the circuit breaker is now in the Semi-open state
time.Sleep(period)

if err := middleware.Handle(mockConn, mockRequest); err != testError {
t.Errorf("Expected error %v, but got %v", testError, err)
}

// Confirm that the circuit breaker is now in the Closed state

errorToReturn = nil
results := make(chan error)

for i := 0; i < 2; i++ {
go func() {
results <- middleware.Handle(mockConn, mockRequest)
}()
}
go func() {
results <- middleware.Handle(mockConn, mockRequest)
}()

OpenErrorCount := 0
SuccessCount := 0

for i := 0; i < 2; i++ {
select {
case err := <-results:
if err != ErrCircuitBreakerOpen && err != nil {
t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err)
continue
}

if err == ErrCircuitBreakerOpen {
OpenErrorCount++
} else if err == nil {
SuccessCount++
}
select {
case err := <-results:
if err != ErrCircuitBreakerOpen && err != nil {
t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err)
}

case <-time.After(100 * time.Millisecond):
t.Fatal("Expected error, but got none")
if err == ErrCircuitBreakerOpen {
OpenErrorCount++
} else if err == nil {
SuccessCount++
}

case <-time.After(100 * time.Millisecond):
t.Fatal("Expected error, but got none")
}

if OpenErrorCount != 0 {
t.Errorf("Expected 0 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount)
}

if SuccessCount != 2 {
t.Errorf("Expected 2 test error, but got %d", SuccessCount)
if SuccessCount != 1 {
t.Errorf("Expected 1 test error, but got %d", SuccessCount)
}
}

0 comments on commit 88991aa

Please sign in to comment.