Skip to content

Commit

Permalink
feat: replace circuit breaker with library
Browse files Browse the repository at this point in the history
Previous implementation of circuit breaker was simplistic. It is replaced
with a third-party library "gobreaker/v2". The function signature
"recoverAfter" was removed to fit the implementation of the library.
Test cases for circuit breaker were modified accordingly.
  • Loading branch information
KianYang-Lee committed Jun 24, 2024
1 parent 7973326 commit c60d235
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 121 deletions.
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ go 1.22.1
require (
github.com/google/uuid v1.5.0
github.com/ksysoev/ratestor v0.1.0
github.com/sony/gobreaker/v2 v2.0.0
github.com/stretchr/testify v1.9.0
golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e
golang.org/x/sync v0.7.0
nhooyr.io/websocket v1.8.11
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
golang.org/x/sync v0.7.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ github.com/ksysoev/ratestor v0.1.0 h1:zAlHYNXHyfwj78TnjUF6FHyYwkMcZxxxGul2DRhF4/
github.com/ksysoev/ratestor v0.1.0/go.mod h1:ZJ3MX2d9JtBetKh9WMLvn0ESotKPJnl9rEU/qetyObk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sony/gobreaker/v2 v2.0.0 h1:23AaR4JQ65y4rz8JWMzgXw2gKOykZ/qfqYunll4OwJ4=
github.com/sony/gobreaker/v2 v2.0.0/go.mod h1:8JnRUz80DJ1/ne8M8v7nmTs2713i58nIt4s7XcGe/DI=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
Expand Down
85 changes: 21 additions & 64 deletions middleware/request/circuit_breaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package request

import (
"fmt"
"sync"
"time"

"github.com/ksysoev/wasabi"
"github.com/ksysoev/wasabi/dispatch"
"github.com/sony/gobreaker/v2"
)

type CircuitBreakerState uint8
Expand All @@ -24,78 +24,35 @@ const (
// It returns a function that wraps the provided `wasabi.RequestHandler` and implements the circuit breaker logic.
// The circuit breaker monitors the number of errors and successful requests within a given time period.
// If the number of errors exceeds the threshold, the circuit breaker switches to the "Open" state and rejects subsequent requests.
// After a specified number of successful requests, the circuit breaker switches back to the "Closed" state.
// The circuit breaker uses a lock to ensure thread safety.
// The `treshold` parameter specifies the maximum number of errors allowed within the time period.
// After a set amount of period, the circuit breaker switches to the
// "Semi-open" state.
// If request succeeds in "Semi-open" state, the state will be changed to
// "Closed", else back to "Open".
// The `threshold` parameter specifies the maximum number of errors allowed within the time period.
// The `period` parameter specifies the duration of the time period.
// The `recoverAfter` parameter specifies the number of successful requests required to recover from the "Open" state.
// The returned function can be used as middleware in a Wasabi server.
func NewCircuitBreakerMiddleware(treshold uint, period time.Duration, recoverAfter uint) func(next wasabi.RequestHandler) wasabi.RequestHandler {
var errorCounter, successCounter uint

intervalEnds := time.Now().Add(period)
state := Closed

lock := &sync.RWMutex{}
semiOpenLock := &sync.Mutex{}
func NewCircuitBreakerMiddleware(threshold uint, period time.Duration) func(next wasabi.RequestHandler) wasabi.RequestHandler {
var st gobreaker.Settings
st.Timeout = period
st.ReadyToTrip = func(counts gobreaker.Counts) bool {
return counts.ConsecutiveFailures >= uint32(threshold)
}
cb := gobreaker.NewCircuitBreaker[any](st)

return func(next wasabi.RequestHandler) wasabi.RequestHandler {
return dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error {
lock.RLock()
currentState := state
lock.RUnlock()

switch currentState {
case Closed:
err := next.Handle(conn, req)
if err == nil {
return nil
}

lock.Lock()
defer lock.Unlock()

now := time.Now()
if intervalEnds.Before(time.Now()) {
intervalEnds = now.Add(period)
errorCounter = 0
}

errorCounter++
if errorCounter >= treshold {
state = Open
}

return err
case Open:
if !semiOpenLock.TryLock() {
return ErrCircuitBreakerOpen
}

defer semiOpenLock.Unlock()

_, err := cb.Execute(func() (any, error) {
err := next.Handle(conn, req)

lock.Lock()
defer lock.Unlock()

if err != nil {
successCounter = 0
return err
return nil, err
}

successCounter++

if successCounter >= recoverAfter {
state = Closed
errorCounter = 0
successCounter = 0
}

return nil
default:
panic("Unknown state of circuit breaker")
return struct{}{}, nil
})
if err != nil {
return err
}
return nil
})
}

Check failure on line 57 in middleware/request/circuit_breaker.go

View workflow job for this annotation

GitHub Actions / tests (1.22.x)

unnecessary trailing newline (whitespace)
}

Check failure on line 58 in middleware/request/circuit_breaker.go

View workflow job for this annotation

GitHub Actions / tests (1.22.x)

block should not end with a whitespace (or comment) (wsl)
106 changes: 50 additions & 56 deletions middleware/request/circuit_breaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,23 @@ import (
"github.com/ksysoev/wasabi"
"github.com/ksysoev/wasabi/dispatch"
"github.com/ksysoev/wasabi/mocks"
"github.com/sony/gobreaker/v2"
)

func TestNewCircuitBreakerMiddleware_ClosedState(t *testing.T) {
treshold := uint(3)
threshold := uint(3)
period := time.Second
recoverAfter := uint(1)

// Create a mock request handler
mockHandler := dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error { return nil })
mockRequest := mocks.NewMockRequest(t)
mockConn := mocks.NewMockConnection(t)

// Create the circuit breaker middleware
middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler)
middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler)

// Test the Closed state
for i := uint(0); i < treshold+1; i++ {
for i := uint(0); i < threshold+1; i++ {
err := middleware.Handle(mockConn, mockRequest)
if err != nil {
t.Errorf("Expected no error, but got %v", err)
Expand All @@ -33,9 +33,8 @@ func TestNewCircuitBreakerMiddleware_ClosedState(t *testing.T) {
}

func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
treshold := uint(1)
threshold := uint(1)
period := time.Second
recoverAfter := uint(1)

testError := fmt.Errorf("test error")

Expand All @@ -49,7 +48,7 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
mockConn := mocks.NewMockConnection(t)

// Create the circuit breaker middleware
middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler)
middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler)

// Bring the circuit breaker to the Open state
err := middleware.Handle(mockConn, mockRequest)
Expand All @@ -64,6 +63,8 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
go func() {
results <- middleware.Handle(mockConn, mockRequest)
}()
// Wait out circuit break to change state
time.Sleep(period)
}

OpenErrorCount := 0
Expand All @@ -72,12 +73,12 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
for i := 0; i < 2; i++ {
select {
case err := <-results:
if err != ErrCircuitBreakerOpen && err != testError {
t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err)
if err != gobreaker.ErrOpenState && err != testError {
t.Errorf("Expected error %v, but got %v", gobreaker.ErrOpenState, err)
continue
}

if err == ErrCircuitBreakerOpen {
if err == gobreaker.ErrOpenState {
OpenErrorCount++
} else if err == testError {
TestErrorCount++
Expand All @@ -89,7 +90,7 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
}

if OpenErrorCount != 1 {
t.Errorf("Expected 1 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount)
t.Errorf("Expected 1 gobreaker.ErrOpenState error, but got %d", OpenErrorCount)
}

if TestErrorCount != 1 {
Expand All @@ -98,9 +99,8 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
}

func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
treshold := uint(1)
threshold := uint(1)
period := time.Second
recoverAfter := uint(1)

testError := fmt.Errorf("test error")

Expand All @@ -116,7 +116,7 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
mockConn := mocks.NewMockConnection(t)

// Create the circuit breaker middleware
middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler)
middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler)

// Bring the circuit breaker to the Open state
err := middleware.Handle(mockConn, mockRequest)
Expand All @@ -134,17 +134,19 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
go func() {
results <- middleware.Handle(mockConn, mockRequest)
}()
// Wait out circuit breaker to change state
time.Sleep(period)
}

for i := 0; i < 2; i++ {
select {
case err := <-results:
if err != ErrCircuitBreakerOpen && err != nil {
t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err)
if err != gobreaker.ErrOpenState && err != nil {
t.Errorf("Expected error %v, but got %v", gobreaker.ErrOpenState, err)
continue
}

if err == ErrCircuitBreakerOpen {
if err == gobreaker.ErrOpenState {
OpenErrorCount++
} else if err == nil {
SuccessCount++
Expand All @@ -156,7 +158,7 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
}

if OpenErrorCount != 1 {
t.Errorf("Expected 1 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount)
t.Errorf("Expected 1 gobreaker.ErrOpenState error, but got %d", OpenErrorCount)
}

if SuccessCount != 1 {
Expand All @@ -177,12 +179,12 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
for i := 0; i < 2; i++ {
select {
case err := <-results:
if err != ErrCircuitBreakerOpen && err != nil {
t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err)
if err != gobreaker.ErrOpenState && err != nil {
t.Errorf("Expected error %v, but got %v", gobreaker.ErrOpenState, err)
continue
}

if err == ErrCircuitBreakerOpen {
if err == gobreaker.ErrOpenState {
OpenErrorCount++
} else if err == nil {
SuccessCount++
Expand All @@ -194,7 +196,7 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
}

if OpenErrorCount != 0 {
t.Errorf("Expected 0 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount)
t.Errorf("Expected 0 gobreaker.ErrOpenState error, but got %d", OpenErrorCount)
}

if SuccessCount != 2 {
Expand All @@ -203,9 +205,8 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
}

func TestNewCircuitBreakerMiddleware_ResetMeasureInterval(t *testing.T) {
treshold := uint(2)
threshold := uint(2)
period := 20 * time.Millisecond
recoverAfter := uint(1)

testError := fmt.Errorf("test error")

Expand All @@ -221,58 +222,51 @@ func TestNewCircuitBreakerMiddleware_ResetMeasureInterval(t *testing.T) {
mockConn := mocks.NewMockConnection(t)

// Create the circuit breaker middleware
middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler)
middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler)

// Bring the circuit breaker to the Open state

if err := middleware.Handle(mockConn, mockRequest); err != testError {
t.Errorf("Expected error %v, but got %v", testError, err)
for i := uint(0); i < threshold; i++ {
if err := middleware.Handle(mockConn, mockRequest); err != testError {
t.Errorf("Expected error %v, but got %v", testError, err)
}
}

// Confirm that the circuit breaker is now in the Semi-open state
time.Sleep(period)

if err := middleware.Handle(mockConn, mockRequest); err != testError {
t.Errorf("Expected error %v, but got %v", testError, err)
}

// Confirm that the circuit breaker is now in the Closed state

errorToReturn = nil
results := make(chan error)

for i := 0; i < 2; i++ {
go func() {
results <- middleware.Handle(mockConn, mockRequest)
}()
}
go func() {
results <- middleware.Handle(mockConn, mockRequest)
}()

OpenErrorCount := 0
SuccessCount := 0

for i := 0; i < 2; i++ {
select {
case err := <-results:
if err != ErrCircuitBreakerOpen && err != nil {
t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err)
continue
}

if err == ErrCircuitBreakerOpen {
OpenErrorCount++
} else if err == nil {
SuccessCount++
}
select {
case err := <-results:
fmt.Println(err)
if err != gobreaker.ErrOpenState && err != nil {

Check failure on line 251 in middleware/request/circuit_breaker_test.go

View workflow job for this annotation

GitHub Actions / tests (1.22.x)

if statements should only be cuddled with assignments (wsl)
t.Errorf("Expected error %v, but got %v", gobreaker.ErrOpenState, err)
}

case <-time.After(100 * time.Millisecond):
t.Fatal("Expected error, but got none")
if err == gobreaker.ErrOpenState {
OpenErrorCount++
} else if err == nil {
SuccessCount++
}

case <-time.After(100 * time.Millisecond):
t.Fatal("Expected error, but got none")
}

if OpenErrorCount != 0 {
t.Errorf("Expected 0 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount)
t.Errorf("Expected 0 gobreaker.ErrOpenState error, but got %d", OpenErrorCount)
}

if SuccessCount != 2 {
t.Errorf("Expected 2 test error, but got %d", SuccessCount)
if SuccessCount != 1 {
t.Errorf("Expected 1 test error, but got %d", SuccessCount)
}
}

0 comments on commit c60d235

Please sign in to comment.