Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: replace circuit breaker with library #98

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ go 1.22.1
require (
github.com/google/uuid v1.5.0
github.com/ksysoev/ratestor v0.1.0
github.com/sony/gobreaker/v2 v2.0.0
github.com/stretchr/testify v1.9.0
golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e
golang.org/x/sync v0.7.0
nhooyr.io/websocket v1.8.11
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
golang.org/x/sync v0.7.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ github.com/ksysoev/ratestor v0.1.0 h1:zAlHYNXHyfwj78TnjUF6FHyYwkMcZxxxGul2DRhF4/
github.com/ksysoev/ratestor v0.1.0/go.mod h1:ZJ3MX2d9JtBetKh9WMLvn0ESotKPJnl9rEU/qetyObk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sony/gobreaker/v2 v2.0.0 h1:23AaR4JQ65y4rz8JWMzgXw2gKOykZ/qfqYunll4OwJ4=
github.com/sony/gobreaker/v2 v2.0.0/go.mod h1:8JnRUz80DJ1/ne8M8v7nmTs2713i58nIt4s7XcGe/DI=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
Expand Down
84 changes: 21 additions & 63 deletions middleware/request/circuit_breaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package request

import (
"fmt"
"sync"
"time"

"github.com/ksysoev/wasabi"
"github.com/ksysoev/wasabi/dispatch"
"github.com/sony/gobreaker/v2"
)

type CircuitBreakerState uint8
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably this type is not needed anymore.

Expand All @@ -24,78 +24,36 @@ const (
// It returns a function that wraps the provided `wasabi.RequestHandler` and implements the circuit breaker logic.
// The circuit breaker monitors the number of errors and successful requests within a given time period.
// If the number of errors exceeds the threshold, the circuit breaker switches to the "Open" state and rejects subsequent requests.
// After a specified number of successful requests, the circuit breaker switches back to the "Closed" state.
// The circuit breaker uses a lock to ensure thread safety.
// The `treshold` parameter specifies the maximum number of errors allowed within the time period.
// After a set amount of period, the circuit breaker switches to the
// "Semi-open" state.
// If request succeeds in "Semi-open" state, the state will be changed to
// "Closed", else back to "Open".
// The `threshold` parameter specifies the maximum number of errors allowed within the time period.
// The `period` parameter specifies the duration of the time period.
// The `recoverAfter` parameter specifies the number of successful requests required to recover from the "Open" state.
// The returned function can be used as middleware in a Wasabi server.
func NewCircuitBreakerMiddleware(treshold uint, period time.Duration, recoverAfter uint) func(next wasabi.RequestHandler) wasabi.RequestHandler {
var errorCounter, successCounter uint

intervalEnds := time.Now().Add(period)
state := Closed

lock := &sync.RWMutex{}
semiOpenLock := &sync.Mutex{}
func NewCircuitBreakerMiddleware(threshold uint, period time.Duration) func(next wasabi.RequestHandler) wasabi.RequestHandler {
Copy link
Owner

@ksysoev ksysoev Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update type of threshold to uint32 to be in consistent.

var st gobreaker.Settings
st.Timeout = period
st.ReadyToTrip = func(counts gobreaker.Counts) bool {
return counts.ConsecutiveFailures >= uint32(threshold)
}
cb := gobreaker.NewCircuitBreaker[any](st)

return func(next wasabi.RequestHandler) wasabi.RequestHandler {
return dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error {
lock.RLock()
currentState := state
lock.RUnlock()

switch currentState {
case Closed:
_, err := cb.Execute(func() (any, error) {
err := next.Handle(conn, req)
if err == nil {
return nil
}

lock.Lock()
defer lock.Unlock()

now := time.Now()
if intervalEnds.Before(time.Now()) {
intervalEnds = now.Add(period)
errorCounter = 0
}

errorCounter++
if errorCounter >= treshold {
state = Open
}

return err
case Open:
if !semiOpenLock.TryLock() {
return ErrCircuitBreakerOpen
}

defer semiOpenLock.Unlock()

err := next.Handle(conn, req)

lock.Lock()
defer lock.Unlock()

if err != nil {
successCounter = 0
return err
return nil, err
}

successCounter++

if successCounter >= recoverAfter {
state = Closed
errorCounter = 0
successCounter = 0
}

return nil
default:
panic("Unknown state of circuit breaker")
return struct{}{}, nil
})
if err != nil {
return err
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check for gobreaker.ErrOpenState error and return instead ErrCircuitBreakerOpen to hide dependency on the gobreaker... to have possibility to replace library in the future.

}

return nil
})
}
}
105 changes: 49 additions & 56 deletions middleware/request/circuit_breaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,23 @@ import (
"github.com/ksysoev/wasabi"
"github.com/ksysoev/wasabi/dispatch"
"github.com/ksysoev/wasabi/mocks"
"github.com/sony/gobreaker/v2"
)

func TestNewCircuitBreakerMiddleware_ClosedState(t *testing.T) {
treshold := uint(3)
threshold := uint(3)
period := time.Second
recoverAfter := uint(1)

// Create a mock request handler
mockHandler := dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error { return nil })
mockRequest := mocks.NewMockRequest(t)
mockConn := mocks.NewMockConnection(t)

// Create the circuit breaker middleware
middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler)
middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler)

// Test the Closed state
for i := uint(0); i < treshold+1; i++ {
for i := uint(0); i < threshold+1; i++ {
err := middleware.Handle(mockConn, mockRequest)
if err != nil {
t.Errorf("Expected no error, but got %v", err)
Expand All @@ -33,9 +33,8 @@ func TestNewCircuitBreakerMiddleware_ClosedState(t *testing.T) {
}

func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
treshold := uint(1)
threshold := uint(1)
period := time.Second
recoverAfter := uint(1)

testError := fmt.Errorf("test error")

Expand All @@ -49,7 +48,7 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
mockConn := mocks.NewMockConnection(t)

// Create the circuit breaker middleware
middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler)
middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler)

// Bring the circuit breaker to the Open state
err := middleware.Handle(mockConn, mockRequest)
Expand All @@ -64,6 +63,8 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
go func() {
results <- middleware.Handle(mockConn, mockRequest)
}()
// Wait out circuit break to change state
time.Sleep(period)
}

OpenErrorCount := 0
Expand All @@ -72,12 +73,12 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
for i := 0; i < 2; i++ {
select {
case err := <-results:
if err != ErrCircuitBreakerOpen && err != testError {
t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err)
if err != gobreaker.ErrOpenState && err != testError {
t.Errorf("Expected error %v, but got %v", gobreaker.ErrOpenState, err)
continue
}

if err == ErrCircuitBreakerOpen {
if err == gobreaker.ErrOpenState {
OpenErrorCount++
} else if err == testError {
TestErrorCount++
Expand All @@ -89,7 +90,7 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
}

if OpenErrorCount != 1 {
t.Errorf("Expected 1 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount)
t.Errorf("Expected 1 gobreaker.ErrOpenState error, but got %d", OpenErrorCount)
}

if TestErrorCount != 1 {
Expand All @@ -98,9 +99,8 @@ func TestNewCircuitBreakerMiddleware_OpenState(t *testing.T) {
}

func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
treshold := uint(1)
threshold := uint(1)
period := time.Second
recoverAfter := uint(1)

testError := fmt.Errorf("test error")

Expand All @@ -116,7 +116,7 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
mockConn := mocks.NewMockConnection(t)

// Create the circuit breaker middleware
middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler)
middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler)

// Bring the circuit breaker to the Open state
err := middleware.Handle(mockConn, mockRequest)
Expand All @@ -134,17 +134,19 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
go func() {
results <- middleware.Handle(mockConn, mockRequest)
}()
// Wait out circuit breaker to change state
time.Sleep(period)
}

for i := 0; i < 2; i++ {
select {
case err := <-results:
if err != ErrCircuitBreakerOpen && err != nil {
t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err)
if err != gobreaker.ErrOpenState && err != nil {
t.Errorf("Expected error %v, but got %v", gobreaker.ErrOpenState, err)
continue
}

if err == ErrCircuitBreakerOpen {
if err == gobreaker.ErrOpenState {
OpenErrorCount++
} else if err == nil {
SuccessCount++
Expand All @@ -156,7 +158,7 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
}

if OpenErrorCount != 1 {
t.Errorf("Expected 1 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount)
t.Errorf("Expected 1 gobreaker.ErrOpenState error, but got %d", OpenErrorCount)
}

if SuccessCount != 1 {
Expand All @@ -177,12 +179,12 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
for i := 0; i < 2; i++ {
select {
case err := <-results:
if err != ErrCircuitBreakerOpen && err != nil {
t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err)
if err != gobreaker.ErrOpenState && err != nil {
t.Errorf("Expected error %v, but got %v", gobreaker.ErrOpenState, err)
continue
}

if err == ErrCircuitBreakerOpen {
if err == gobreaker.ErrOpenState {
OpenErrorCount++
} else if err == nil {
SuccessCount++
Expand All @@ -194,7 +196,7 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
}

if OpenErrorCount != 0 {
t.Errorf("Expected 0 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount)
t.Errorf("Expected 0 gobreaker.ErrOpenState error, but got %d", OpenErrorCount)
}

if SuccessCount != 2 {
Expand All @@ -203,9 +205,8 @@ func TestNewCircuitBreakerMiddleware_SemiOpenState(t *testing.T) {
}

func TestNewCircuitBreakerMiddleware_ResetMeasureInterval(t *testing.T) {
treshold := uint(2)
threshold := uint(2)
period := 20 * time.Millisecond
recoverAfter := uint(1)

testError := fmt.Errorf("test error")

Expand All @@ -221,58 +222,50 @@ func TestNewCircuitBreakerMiddleware_ResetMeasureInterval(t *testing.T) {
mockConn := mocks.NewMockConnection(t)

// Create the circuit breaker middleware
middleware := NewCircuitBreakerMiddleware(treshold, period, recoverAfter)(mockHandler)
middleware := NewCircuitBreakerMiddleware(threshold, period)(mockHandler)

// Bring the circuit breaker to the Open state

if err := middleware.Handle(mockConn, mockRequest); err != testError {
t.Errorf("Expected error %v, but got %v", testError, err)
for i := uint(0); i < threshold; i++ {
if err := middleware.Handle(mockConn, mockRequest); err != testError {
t.Errorf("Expected error %v, but got %v", testError, err)
}
}

// Confirm that the circuit breaker is now in the Semi-open state
time.Sleep(period)

if err := middleware.Handle(mockConn, mockRequest); err != testError {
t.Errorf("Expected error %v, but got %v", testError, err)
}

// Confirm that the circuit breaker is now in the Closed state

errorToReturn = nil
results := make(chan error)

for i := 0; i < 2; i++ {
go func() {
results <- middleware.Handle(mockConn, mockRequest)
}()
}
go func() {
results <- middleware.Handle(mockConn, mockRequest)
}()

OpenErrorCount := 0
SuccessCount := 0

for i := 0; i < 2; i++ {
select {
case err := <-results:
if err != ErrCircuitBreakerOpen && err != nil {
t.Errorf("Expected error %v, but got %v", ErrCircuitBreakerOpen, err)
continue
}

if err == ErrCircuitBreakerOpen {
OpenErrorCount++
} else if err == nil {
SuccessCount++
}
select {
case err := <-results:
if err != gobreaker.ErrOpenState && err != nil {
t.Errorf("Expected error %v, but got %v", gobreaker.ErrOpenState, err)
}

case <-time.After(100 * time.Millisecond):
t.Fatal("Expected error, but got none")
if err == gobreaker.ErrOpenState {
OpenErrorCount++
} else if err == nil {
SuccessCount++
}

case <-time.After(100 * time.Millisecond):
t.Fatal("Expected error, but got none")
}

if OpenErrorCount != 0 {
t.Errorf("Expected 0 ErrCircuitBreakerOpen error, but got %d", OpenErrorCount)
t.Errorf("Expected 0 gobreaker.ErrOpenState error, but got %d", OpenErrorCount)
}

if SuccessCount != 2 {
t.Errorf("Expected 2 test error, but got %d", SuccessCount)
if SuccessCount != 1 {
t.Errorf("Expected 1 test error, but got %d", SuccessCount)
}
}
Loading