diff --git a/common/backoff/retry.go b/common/backoff/retry.go index c0b41fb07..bc4a513a6 100644 --- a/common/backoff/retry.go +++ b/common/backoff/retry.go @@ -21,6 +21,7 @@ package backoff import ( + "context" "sync" "time" ) @@ -86,11 +87,12 @@ func NewConcurrentRetrier(retryPolicy RetryPolicy) *ConcurrentRetrier { } // Retry function can be used to wrap any call with retry logic using the passed in policy -func Retry(operation Operation, policy RetryPolicy, isRetryable IsRetryable) error { +func Retry(ctx context.Context, operation Operation, policy RetryPolicy, isRetryable IsRetryable) error { var err error var next time.Duration r := NewRetrier(policy, SystemClock) +Retry_Loop: for { // operation completed successfully. No need to retry. if err = operation(); err == nil { @@ -106,6 +108,18 @@ func Retry(operation Operation, policy RetryPolicy, isRetryable IsRetryable) err return err } + // check if ctx is done + if ctxDone := ctx.Done(); ctxDone != nil { + timer := time.NewTimer(next) + select { + case <-ctxDone: + return err + case <-timer.C: + continue Retry_Loop + } + } + + // ctx is not cancellable time.Sleep(next) } } diff --git a/common/backoff/retry_test.go b/common/backoff/retry_test.go index 73eed86e8..b9be6fe70 100644 --- a/common/backoff/retry_test.go +++ b/common/backoff/retry_test.go @@ -21,6 +21,7 @@ package backoff import ( + "context" "fmt" "testing" "time" @@ -62,11 +63,35 @@ func (s *RetrySuite) TestRetrySuccess() { policy.SetMaximumInterval(5 * time.Millisecond) policy.SetMaximumAttempts(10) - err := Retry(op, policy, nil) + err := Retry(context.Background(), op, policy, nil) s.NoError(err) s.Equal(5, i) } +func (s *RetrySuite) TestNoRetryAfterContextDone() { + i := 0 + op := func() error { + i++ + + if i == 5 { + return nil + } + + return &someError{} + } + + policy := NewExponentialRetryPolicy(1 * time.Millisecond) + policy.SetMaximumInterval(5 * time.Millisecond) + policy.SetMaximumAttempts(10) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*5) + defer cancel() + + err := Retry(ctx, op, policy, nil) + s.Error(err) + s.True(i >= 2) // verify that we did retried +} + func (s *RetrySuite) TestRetryFailed() { i := 0 op := func() error { @@ -83,7 +108,7 @@ func (s *RetrySuite) TestRetryFailed() { policy.SetMaximumInterval(5 * time.Millisecond) policy.SetMaximumAttempts(5) - err := Retry(op, policy, nil) + err := Retry(context.Background(), op, policy, nil) s.Error(err) } @@ -111,7 +136,7 @@ func (s *RetrySuite) TestIsRetryableSuccess() { policy.SetMaximumInterval(5 * time.Millisecond) policy.SetMaximumAttempts(10) - err := Retry(op, policy, isRetryable) + err := Retry(context.Background(), op, policy, isRetryable) s.NoError(err, "Retry count: %v", i) s.Equal(5, i) } @@ -132,7 +157,7 @@ func (s *RetrySuite) TestIsRetryableFailure() { policy.SetMaximumInterval(5 * time.Millisecond) policy.SetMaximumAttempts(10) - err := Retry(op, policy, IgnoreErrors([]error{&someError{}})) + err := Retry(context.Background(), op, policy, IgnoreErrors([]error{&someError{}})) s.Error(err) s.Equal(1, i) } diff --git a/internal_task_handlers.go b/internal_task_handlers.go index e3b1c3e5b..259908713 100644 --- a/internal_task_handlers.go +++ b/internal_task_handlers.go @@ -1042,7 +1042,7 @@ func recordActivityHeartbeat( Identity: common.StringPtr(identity)} var heartbeatResponse *s.RecordActivityTaskHeartbeatResponse - heartbeatErr := backoff.Retry( + heartbeatErr := backoff.Retry(ctx, func() error { tchCtx, cancel := newTChannelContext(ctx) defer cancel() diff --git a/internal_task_pollers.go b/internal_task_pollers.go index edb8aece4..d1bf6a761 100644 --- a/internal_task_pollers.go +++ b/internal_task_pollers.go @@ -161,11 +161,12 @@ func (wtp *workflowTaskPoller) ProcessTask(task interface{}) error { } wtp.metricsScope.Timer(metrics.DecisionExecutionLatency).Record(time.Now().Sub(executionStartTime)) + ctx := context.Background() responseStartTime := time.Now() // Respond task completion. - err = backoff.Retry( + err = backoff.Retry(ctx, func() error { - tchCtx, cancel := newTChannelContext(context.Background()) + tchCtx, cancel := newTChannelContext(ctx) defer cancel() var err1 error switch request := completedRequest.(type) { @@ -255,7 +256,7 @@ func newGetHistoryPageFunc( metricsScope.Counter(metrics.WorkflowGetHistoryCounter).Inc(1) startTime := time.Now() var resp *s.GetWorkflowExecutionHistoryResponse - err := backoff.Retry( + err := backoff.Retry(ctx, func() error { tchCtx, cancel := newTChannelContext(ctx) defer cancel() @@ -401,17 +402,17 @@ func reportActivityComplete(ctx context.Context, service m.TChanWorkflowService, var reportErr error switch request := request.(type) { case *s.RespondActivityTaskCanceledRequest: - reportErr = backoff.Retry( + reportErr = backoff.Retry(ctx, func() error { return service.RespondActivityTaskCanceled(tchCtx, request) }, serviceOperationRetryPolicy, isServiceTransientError) case *s.RespondActivityTaskFailedRequest: - reportErr = backoff.Retry( + reportErr = backoff.Retry(ctx, func() error { return service.RespondActivityTaskFailed(tchCtx, request) }, serviceOperationRetryPolicy, isServiceTransientError) case *s.RespondActivityTaskCompletedRequest: - reportErr = backoff.Retry( + reportErr = backoff.Retry(ctx, func() error { return service.RespondActivityTaskCompleted(tchCtx, request) }, serviceOperationRetryPolicy, isServiceTransientError) diff --git a/internal_worker.go b/internal_worker.go index 7c09ca2cd..b11c0d208 100644 --- a/internal_worker.go +++ b/internal_worker.go @@ -162,9 +162,9 @@ func ensureRequiredParams(params *workerExecutionParameters) { // It returns an error, if the server returns an EntityNotExist or BadRequest error // On any other transient error, this method will just return success func verifyDomainExist(client m.TChanWorkflowService, domain string, logger *zap.Logger) error { - + ctx := context.Background() descDomainOp := func() error { - tchCtx, cancel := newTChannelContext(context.Background()) + tchCtx, cancel := newTChannelContext(ctx) defer cancel() _, err := client.DescribeDomain(tchCtx, &shared.DescribeDomainRequest{Name: &domain}) if err != nil { @@ -187,7 +187,7 @@ func verifyDomainExist(client m.TChanWorkflowService, domain string, logger *zap } // exponential backoff retry for upto a minute - return backoff.Retry(descDomainOp, serviceOperationRetryPolicy, isServiceTransientError) + return backoff.Retry(ctx, descDomainOp, serviceOperationRetryPolicy, isServiceTransientError) } func newWorkflowWorkerInternal( diff --git a/internal_workflow_client.go b/internal_workflow_client.go index 1fd325743..231f46176 100644 --- a/internal_workflow_client.go +++ b/internal_workflow_client.go @@ -117,7 +117,7 @@ func (wc *workflowClient) StartWorkflow( var response *s.StartWorkflowExecutionResponse // Start creating workflow request. - err = backoff.Retry( + err = backoff.Retry(ctx, func() error { tchCtx, cancel := newTChannelContext(ctx) defer cancel() @@ -162,7 +162,7 @@ func (wc *workflowClient) SignalWorkflow(ctx context.Context, workflowID string, Identity: common.StringPtr(wc.identity), } - return backoff.Retry( + return backoff.Retry(ctx, func() error { tchCtx, cancel := newTChannelContext(ctx) defer cancel() @@ -181,7 +181,7 @@ func (wc *workflowClient) CancelWorkflow(ctx context.Context, workflowID string, Identity: common.StringPtr(wc.identity), } - return backoff.Retry( + return backoff.Retry(ctx, func() error { tchCtx, cancel := newTChannelContext(ctx) defer cancel() @@ -203,7 +203,7 @@ func (wc *workflowClient) TerminateWorkflow(ctx context.Context, workflowID stri Identity: common.StringPtr(wc.identity), } - err := backoff.Retry( + err := backoff.Retry(ctx, func() error { tchCtx, cancel := newTChannelContext(ctx) defer cancel() @@ -231,7 +231,7 @@ GetHistoryLoop: } var response *s.GetWorkflowExecutionHistoryResponse - err := backoff.Retry( + err := backoff.Retry(ctx, func() error { var err1 error tchCtx, cancel := newTChannelContext(ctx) @@ -352,7 +352,7 @@ func (wc *workflowClient) ListClosedWorkflow(ctx context.Context, request *s.Lis request.Domain = common.StringPtr(wc.domain) } var response *s.ListClosedWorkflowExecutionsResponse - err := backoff.Retry( + err := backoff.Retry(ctx, func() error { var err1 error tchCtx, cancel := newTChannelContext(ctx) @@ -376,7 +376,7 @@ func (wc *workflowClient) ListOpenWorkflow(ctx context.Context, request *s.ListO request.Domain = common.StringPtr(wc.domain) } var response *s.ListOpenWorkflowExecutionsResponse - err := backoff.Retry( + err := backoff.Retry(ctx, func() error { var err1 error tchCtx, cancel := newTChannelContext(ctx) @@ -423,7 +423,7 @@ func (wc *workflowClient) QueryWorkflow(ctx context.Context, workflowID string, } var resp *s.QueryWorkflowResponse - err := backoff.Retry( + err := backoff.Retry(ctx, func() error { tchCtx, cancel := newTChannelContext(ctx) defer cancel() @@ -444,7 +444,7 @@ func (wc *workflowClient) QueryWorkflow(ctx context.Context, workflowID string, // - BadRequestError // - InternalServiceError func (dc *domainClient) Register(ctx context.Context, request *s.RegisterDomainRequest) error { - return backoff.Retry( + return backoff.Retry(ctx, func() error { tchCtx, cancel := newTChannelContext(ctx) defer cancel() @@ -465,7 +465,7 @@ func (dc *domainClient) Describe(ctx context.Context, name string) (*s.DomainInf } var response *s.DescribeDomainResponse - err := backoff.Retry( + err := backoff.Retry(ctx, func() error { tchCtx, cancel := newTChannelContext(ctx) defer cancel() @@ -493,7 +493,7 @@ func (dc *domainClient) Update(ctx context.Context, name string, domainInfo *s.U Configuration: domainConfig, } - return backoff.Retry( + return backoff.Retry(ctx, func() error { tchCtx, cancel := newTChannelContext(ctx) defer cancel()