diff --git a/retry.go b/retry.go index daf538f..8870a0a 100644 --- a/retry.go +++ b/retry.go @@ -136,38 +136,6 @@ func DoWithData[T any](retryableFunc RetryableFuncWithData[T], opts ...Option) ( return emptyT, err } - // Setting attempts to 0 means we'll retry until we succeed - var lastErr error - if config.attempts == 0 { - for { - t, err := retryableFunc() - if err == nil { - return t, nil - } - - if !IsRecoverable(err) { - return emptyT, err - } - - if !config.retryIf(err) { - return emptyT, err - } - - lastErr = err - - config.onRetry(n, err) - n++ - select { - case <-config.timer.After(delay(config, n, err)): - case <-config.context.Done(): - if config.wrapContextErrorWithLastError { - return emptyT, Error{context.Cause(config.context), lastErr} - } - return emptyT, context.Cause(config.context) - } - } - } - errorLog := Error{} attemptsForError := make(map[error]uint, len(config.attemptsForError)) @@ -184,6 +152,10 @@ shouldRetry: errorLog = append(errorLog, unpackUnrecoverable(err)) + if !IsRecoverable(err) { + return emptyT, err + } + if !config.retryIf(err) { break } @@ -198,6 +170,7 @@ shouldRetry: } } + // Setting attempts to 0 means we'll retry until we succeed // if this is last attempt - don't wait if n == config.attempts-1 { break shouldRetry diff --git a/retry_test.go b/retry_test.go index f3a2465..1d6e45e 100644 --- a/retry_test.go +++ b/retry_test.go @@ -105,6 +105,7 @@ func TestRetryIf_ZeroAttempts(t *testing.T) { return err.Error() != "special" }), Delay(time.Nanosecond), + LastErrorOnly(true), Attempts(0), ) assert.Error(t, err) @@ -216,7 +217,6 @@ func TestLastErrorOnly(t *testing.T) { func TestUnrecoverableError(t *testing.T) { attempts := 0 testErr := errors.New("error") - expectedErr := Error{testErr} err := Do( func() error { attempts++ @@ -224,8 +224,8 @@ func TestUnrecoverableError(t *testing.T) { }, Attempts(2), ) - assert.Equal(t, expectedErr, err) - assert.Equal(t, testErr, errors.Unwrap(err)) + assert.Error(t, err) + assert.Equal(t, Unrecoverable(testErr), err) assert.Equal(t, 1, attempts, "unrecoverable error broke the loop") } @@ -465,6 +465,7 @@ func TestContext(t *testing.T) { cancel() } }), + LastErrorOnly(true), Context(ctx), Attempts(0), )