diff --git a/middleware/request/circuit_breaker.go b/middleware/request/circuit_breaker.go index 581a86e..b814ac0 100644 --- a/middleware/request/circuit_breaker.go +++ b/middleware/request/circuit_breaker.go @@ -31,26 +31,31 @@ const ( // The `treshold` parameter specifies the maximum number of consecutive errors allowed before opening the circuit. // The `period` parameter specifies the duration of time after which the circuit breaker transitions to the semi-open state. // The returned function is a middleware that can be used with the `wasabi` framework. -func NewCircuitBreakerMiddleware(treshold int, period time.Duration) func(next wasabi.RequestHandler) wasabi.RequestHandler { +func NewCircuitBreakerMiddleware(treshold uint, period time.Duration, recoverAfter uint) func(next wasabi.RequestHandler) wasabi.RequestHandler { + var errorCounter, successCounter uint + intervalEnds := time.Now().Add(period) state := Closed - errorCounter := 0 + lock := &sync.RWMutex{} semiOpenLock := &sync.Mutex{} return func(next wasabi.RequestHandler) wasabi.RequestHandler { return dispatch.RequestHandlerFunc(func(conn wasabi.Connection, req wasabi.Request) error { lock.RLock() + currentState := state + lock.RUnlock() - switch state { + switch currentState { case Closed: err := next.Handle(conn, req) if err == nil { return nil } - lock.RUnlock() lock.Lock() + defer lock.Unlock() + now := time.Now() if intervalEnds.Before(time.Now()) { intervalEnds = now.Add(period) @@ -61,10 +66,9 @@ func NewCircuitBreakerMiddleware(treshold int, period time.Duration) func(next w if errorCounter >= treshold { state = Open } - lock.Unlock() + return err case Open: - lock.RUnlock() if !semiOpenLock.TryLock() { return ErrCircuitBreakerOpen } @@ -72,13 +76,24 @@ func NewCircuitBreakerMiddleware(treshold int, period time.Duration) func(next w defer semiOpenLock.Unlock() err := next.Handle(conn, req) - if err == nil { - lock.Lock() + + lock.Lock() + defer lock.Unlock() + + if err != nil { + successCounter = 0 + return err + } + + successCounter++ + + if successCounter >= recoverAfter { state = Closed errorCounter = 0 + successCounter = 0 } - return err + return nil default: panic("Unknown state of circuit breaker") }