Skip to content

Commit

Permalink
adding context as first parameter to all methods on our client (#239)
Browse files Browse the repository at this point in the history
  • Loading branch information
yiminc authored Sep 29, 2017
1 parent ab0bf9b commit 815dd18
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 83 deletions.
33 changes: 17 additions & 16 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package cadence

import (
"context"
"time"

"github.com/uber-go/tally"
Expand All @@ -40,14 +41,14 @@ type (
// StartWorkflow starts a workflow execution
// The user can use this to start using a function or workflow type name.
// Either by
// StartWorkflow(options, "workflowTypeName", input)
// StartWorkflow(ctx, options, "workflowTypeName", input)
// or
// StartWorkflow(options, workflowExecuteFn, arg1, arg2, arg3)
// StartWorkflow(ctx, options, workflowExecuteFn, arg1, arg2, arg3)
// The errors it can return:
// - EntityNotExistsError
// - BadRequestError
// - WorkflowExecutionAlreadyStartedError
StartWorkflow(options StartWorkflowOptions, workflow interface{}, args ...interface{}) (*WorkflowExecution, error)
StartWorkflow(ctx context.Context, options StartWorkflowOptions, workflow interface{}, args ...interface{}) (*WorkflowExecution, error)

// SignalWorkflow sends a signals to a workflow in execution
// - workflow ID of the workflow.
Expand All @@ -56,7 +57,7 @@ type (
// The errors it can return:
// - EntityNotExistsError
// - InternalServiceError
SignalWorkflow(workflowID string, runID string, signalName string, arg interface{}) error
SignalWorkflow(ctx context.Context, workflowID string, runID string, signalName string, arg interface{}) error

// CancelWorkflow cancels a workflow in execution
// - workflow ID of the workflow.
Expand All @@ -65,7 +66,7 @@ type (
// - EntityNotExistsError
// - BadRequestError
// - InternalServiceError
CancelWorkflow(workflowID string, runID string) error
CancelWorkflow(ctx context.Context, workflowID string, runID string) error

// TerminateWorkflow terminates a workflow execution.
// workflowID is required, other parameters are optional.
Expand All @@ -75,7 +76,7 @@ type (
// - EntityNotExistsError
// - BadRequestError
// - InternalServiceError
TerminateWorkflow(workflowID string, runID string, reason string, details []byte) error
TerminateWorkflow(ctx context.Context, workflowID string, runID string, reason string, details []byte) error

// GetWorkflowHistory gets history of a particular workflow.
// - workflow ID of the workflow.
Expand All @@ -84,7 +85,7 @@ type (
// - EntityNotExistsError
// - BadRequestError
// - InternalServiceError
GetWorkflowHistory(workflowID string, runID string) (*s.History, error)
GetWorkflowHistory(ctx context.Context, workflowID string, runID string) (*s.History, error)

// GetWorkflowStackTrace gets a stack trace of all goroutines of a particular workflow.
// atDecisionTaskCompletedEventID is the eventID of the CompleteDecisionTask event at which stack trace should be taken.
Expand All @@ -94,7 +95,7 @@ type (
// - EntityNotExistsError
// - BadRequestError
// - InternalServiceError
GetWorkflowStackTrace(workflowID string, runID string, atDecisionTaskCompletedEventID int64) (string, error)
GetWorkflowStackTrace(ctx context.Context, workflowID string, runID string, atDecisionTaskCompletedEventID int64) (string, error)

// CompleteActivity reports activity completed.
// activity Execute method can return cadence.ErrActivityResultPending to
Expand All @@ -109,28 +110,28 @@ type (
// To fail the activity with an error.
// CompleteActivity(token, nil, NewErrorWithDetails("reason", details)
// The activity can fail with below errors ErrorWithDetails, TimeoutError, CanceledError.
CompleteActivity(taskToken []byte, result interface{}, err error) error
CompleteActivity(ctx context.Context, taskToken []byte, result interface{}, err error) error

// RecordActivityHeartbeat records heartbeat for an activity.
// details - is the progress you want to record along with heart beat for this activity.
// The errors it can return:
// - EntityNotExistsError
// - InternalServiceError
RecordActivityHeartbeat(taskToken []byte, details ...interface{}) error
RecordActivityHeartbeat(ctx context.Context, taskToken []byte, details ...interface{}) error

// ListClosedWorkflow gets closed workflow executions based on request filters
// The errors it can return:
// - BadRequestError
// - InternalServiceError
// - EntityNotExistError
ListClosedWorkflow(request *s.ListClosedWorkflowExecutionsRequest) (*s.ListClosedWorkflowExecutionsResponse, error)
ListClosedWorkflow(ctx context.Context, request *s.ListClosedWorkflowExecutionsRequest) (*s.ListClosedWorkflowExecutionsResponse, error)

// ListClosedWorkflow gets open workflow executions based on request filters
// The errors it can return:
// - BadRequestError
// - InternalServiceError
// - EntityNotExistError
ListOpenWorkflow(request *s.ListOpenWorkflowExecutionsRequest) (*s.ListOpenWorkflowExecutionsResponse, error)
ListOpenWorkflow(ctx context.Context, request *s.ListOpenWorkflowExecutionsRequest) (*s.ListOpenWorkflowExecutionsResponse, error)

// QueryWorkflow queries a given workflow execution and returns the query result synchronously. Parameter workflowID
// and queryType are required, other parameters are optional. The workflowID and runID (optional) identify the
Expand All @@ -150,7 +151,7 @@ type (
// - InternalServiceError
// - EntityNotExistError
// - QueryFailError
QueryWorkflow(workflowID string, runID string, queryType string, args ...interface{}) (EncodedValue, error)
QueryWorkflow(ctx context.Context, workflowID string, runID string, queryType string, args ...interface{}) (EncodedValue, error)
}

// ClientOptions are optional parameters for Client creation.
Expand Down Expand Up @@ -191,7 +192,7 @@ type (
// - DomainAlreadyExistsError
// - BadRequestError
// - InternalServiceError
Register(request *s.RegisterDomainRequest) error
Register(ctx context.Context, request *s.RegisterDomainRequest) error

// Describe a domain. The domain has two part of information.
// DomainInfo - Which has Name, Status, Description, Owner Email.
Expand All @@ -200,7 +201,7 @@ type (
// - EntityNotExistsError
// - BadRequestError
// - InternalServiceError
Describe(name string) (*s.DomainInfo, *s.DomainConfiguration, error)
Describe(ctx context.Context, name string) (*s.DomainInfo, *s.DomainConfiguration, error)

// Update a domain. The domain has two part of information.
// UpdateDomainInfo - To update domain Description and Owner Email.
Expand All @@ -209,7 +210,7 @@ type (
// - EntityNotExistsError
// - BadRequestError
// - InternalServiceError
Update(name string, domainInfo *s.UpdateDomainInfo, domainConfig *s.DomainConfiguration) error
Update(ctx context.Context, name string, domainInfo *s.UpdateDomainInfo, domainConfig *s.DomainConfiguration) error
}
)

Expand Down
7 changes: 4 additions & 3 deletions internal_task_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ func (i *cadenceInvoker) Heartbeat(details []byte) error {

func (i *cadenceInvoker) internalHeartBeat(details []byte) (bool, error) {
isActivityCancelled := false
err := recordActivityHeartbeat(i.service, i.identity, i.taskToken, details, i.retryPolicy)
err := recordActivityHeartbeat(context.Background(), i.service, i.identity, i.taskToken, details, i.retryPolicy)

switch err.(type) {
case *CanceledError:
Expand Down Expand Up @@ -1030,6 +1030,7 @@ func createNewDecision(decisionType s.DecisionType) *s.Decision {
}

func recordActivityHeartbeat(
ctx context.Context,
service m.TChanWorkflowService,
identity string,
taskToken, details []byte,
Expand All @@ -1043,11 +1044,11 @@ func recordActivityHeartbeat(
var heartbeatResponse *s.RecordActivityTaskHeartbeatResponse
heartbeatErr := backoff.Retry(
func() error {
ctx, cancel := newTChannelContext()
tchCtx, cancel := newTChannelContext(ctx)
defer cancel()

var err error
heartbeatResponse, err = service.RecordActivityTaskHeartbeat(ctx, request)
heartbeatResponse, err = service.RecordActivityTaskHeartbeat(tchCtx, request)
return err
}, retryPolicy, isServiceTransientError)

Expand Down
4 changes: 2 additions & 2 deletions internal_task_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ func (t *TaskHandlersTestSuite) TestGetWorkflowStackTraceByID() {
domain := "testDomain"
workflowClient := NewClient(service, domain, nil)

dump, err := workflowClient.GetWorkflowStackTrace("id1", "runId1", 0)
dump, err := workflowClient.GetWorkflowStackTrace(context.Background(), "id1", "runId1", 0)
t.NoError(err)
t.NotNil(dump)
t.True(strings.Contains(dump, ".Receive]"))
Expand Down Expand Up @@ -643,7 +643,7 @@ func (t *TaskHandlersTestSuite) TestGetWorkflowStackTraceByIDAndDecisionTaskComp
domain := "testDomain"
workflowClient := NewClient(service, domain, nil)

dump, err := workflowClient.GetWorkflowStackTrace("id1", "runId1", 5)
dump, err := workflowClient.GetWorkflowStackTrace(context.Background(), "id1", "runId1", 5)
t.NoError(err)
t.NotNil(dump)
t.True(strings.Contains(dump, ".Receive]"))
Expand Down
33 changes: 17 additions & 16 deletions internal_task_pollers.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,19 @@ func (wtp *workflowTaskPoller) ProcessTask(task interface{}) error {
// Respond task completion.
err = backoff.Retry(
func() error {
ctx, cancel := newTChannelContext()
tchCtx, cancel := newTChannelContext(context.Background())
defer cancel()
var err1 error
switch request := completedRequest.(type) {
case *s.RespondDecisionTaskCompletedRequest:
err1 = wtp.service.RespondDecisionTaskCompleted(ctx, request)
err1 = wtp.service.RespondDecisionTaskCompleted(tchCtx, request)
if err1 != nil {
traceLog(func() {
wtp.logger.Debug("RespondDecisionTaskCompleted failed.", zap.Error(err1))
})
}
case *s.RespondQueryTaskCompletedRequest:
err1 = wtp.service.RespondQueryTaskCompleted(ctx, request)
err1 = wtp.service.RespondQueryTaskCompleted(tchCtx, request)
if err1 != nil {
traceLog(func() {
wtp.logger.Debug("RespondQueryTaskCompleted failed.", zap.Error(err1))
Expand Down Expand Up @@ -217,10 +217,10 @@ func (wtp *workflowTaskPoller) poll() (*workflowTask, error) {
Identity: common.StringPtr(wtp.identity),
}

ctx, cancel := newTChannelContext(tchanTimeout(pollTaskServiceTimeOut), tchanRetryOption(retryNeverOptions))
tchCtx, cancel := newTChannelContext(context.Background(), tchanTimeout(pollTaskServiceTimeOut), tchanRetryOption(retryNeverOptions))
defer cancel()

response, err := wtp.service.PollForDecisionTask(ctx, request)
response, err := wtp.service.PollForDecisionTask(tchCtx, request)
if err != nil {
if isServiceTransientError(err) {
wtp.metricsScope.Counter(metrics.DecisionPollTransientFailedCounter).Inc(1)
Expand All @@ -236,14 +236,15 @@ func (wtp *workflowTaskPoller) poll() (*workflowTask, error) {
}

execution := response.GetWorkflowExecution()
iterator := newGetHistoryPageFunc(wtp.service, wtp.domain, execution, math.MaxInt64, wtp.metricsScope)
iterator := newGetHistoryPageFunc(context.Background(), wtp.service, wtp.domain, execution, math.MaxInt64, wtp.metricsScope)
task := &workflowTask{task: response, getHistoryPageFunc: iterator, pollStartTime: startTime}
wtp.metricsScope.Counter(metrics.DecisionPollSucceedCounter).Inc(1)
wtp.metricsScope.Timer(metrics.DecisionPollLatency).Record(time.Now().Sub(startTime))
return task, nil
}

func newGetHistoryPageFunc(
ctx context.Context,
service m.TChanWorkflowService,
domain string,
execution *s.WorkflowExecution,
Expand All @@ -256,11 +257,11 @@ func newGetHistoryPageFunc(
var resp *s.GetWorkflowExecutionHistoryResponse
err := backoff.Retry(
func() error {
ctx, cancel := newTChannelContext()
tchCtx, cancel := newTChannelContext(ctx)
defer cancel()

var err1 error
resp, err1 = service.GetWorkflowExecutionHistory(ctx, &s.GetWorkflowExecutionHistoryRequest{
resp, err1 = service.GetWorkflowExecutionHistory(tchCtx, &s.GetWorkflowExecutionHistoryRequest{
Domain: common.StringPtr(domain),
Execution: execution,
NextPageToken: nextPageToken,
Expand Down Expand Up @@ -317,10 +318,10 @@ func (atp *activityTaskPoller) poll() (*activityTask, error) {
Identity: common.StringPtr(atp.identity),
}

ctx, cancel := newTChannelContext(tchanTimeout(pollTaskServiceTimeOut), tchanRetryOption(retryNeverOptions))
tchCtx, cancel := newTChannelContext(context.Background(), tchanTimeout(pollTaskServiceTimeOut), tchanRetryOption(retryNeverOptions))
defer cancel()

response, err := atp.service.PollForActivityTask(ctx, request)
response, err := atp.service.PollForActivityTask(tchCtx, request)
if err != nil {
if isServiceTransientError(err) {
atp.metricsScope.Counter(metrics.ActivityPollTransientFailedCounter).Inc(1)
Expand Down Expand Up @@ -375,7 +376,7 @@ func (atp *activityTaskPoller) ProcessTask(task interface{}) error {
}

responseStartTime := time.Now()
reportErr := reportActivityComplete(atp.service, request, atp.metricsScope)
reportErr := reportActivityComplete(context.Background(), atp.service, request, atp.metricsScope)
if reportErr != nil {
atp.metricsScope.Counter(metrics.ActivityResponseFailedCounter).Inc(1)
traceLog(func() {
Expand All @@ -389,30 +390,30 @@ func (atp *activityTaskPoller) ProcessTask(task interface{}) error {
return nil
}

func reportActivityComplete(service m.TChanWorkflowService, request interface{}, metricsScope tally.Scope) error {
func reportActivityComplete(ctx context.Context, service m.TChanWorkflowService, request interface{}, metricsScope tally.Scope) error {
if request == nil {
// nothing to report
return nil
}

ctx, cancel := newTChannelContext()
tchCtx, cancel := newTChannelContext(ctx)
defer cancel()
var reportErr error
switch request := request.(type) {
case *s.RespondActivityTaskCanceledRequest:
reportErr = backoff.Retry(
func() error {
return service.RespondActivityTaskCanceled(ctx, request)
return service.RespondActivityTaskCanceled(tchCtx, request)
}, serviceOperationRetryPolicy, isServiceTransientError)
case *s.RespondActivityTaskFailedRequest:
reportErr = backoff.Retry(
func() error {
return service.RespondActivityTaskFailed(ctx, request)
return service.RespondActivityTaskFailed(tchCtx, request)
}, serviceOperationRetryPolicy, isServiceTransientError)
case *s.RespondActivityTaskCompletedRequest:
reportErr = backoff.Retry(
func() error {
return service.RespondActivityTaskCompleted(ctx, request)
return service.RespondActivityTaskCompleted(tchCtx, request)
}, serviceOperationRetryPolicy, isServiceTransientError)
}
if reportErr == nil {
Expand Down
5 changes: 4 additions & 1 deletion internal_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,11 @@ func tchanRetryOption(retryOpt *tchannel.RetryOptions) func(builder *tchannel.Co
}

// newTChannelContext - Get a tchannel context
func newTChannelContext(options ...func(builder *tchannel.ContextBuilder)) (tchannel.ContextWithHeaders, context.CancelFunc) {
func newTChannelContext(ctx context.Context, options ...func(builder *tchannel.ContextBuilder)) (tchannel.ContextWithHeaders, context.CancelFunc) {
builder := tchannel.NewContextBuilder(defaultRPCTimeout)
if ctx != nil {
builder.SetParentContext(ctx)
}
builder.SetRetryOptions(retryDefaultOptions)
builder.AddHeader(versionHeaderName, LibraryVersion)
for _, opt := range options {
Expand Down
4 changes: 2 additions & 2 deletions internal_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ func ensureRequiredParams(params *workerExecutionParameters) {
func verifyDomainExist(client m.TChanWorkflowService, domain string, logger *zap.Logger) error {

descDomainOp := func() error {
ctx, cancel := newTChannelContext()
tchCtx, cancel := newTChannelContext(context.Background())
defer cancel()
_, err := client.DescribeDomain(ctx, &shared.DescribeDomainRequest{Name: &domain})
_, err := client.DescribeDomain(tchCtx, &shared.DescribeDomainRequest{Name: &domain})
if err != nil {
if _, ok := err.(*shared.EntityNotExistsError); ok {
logger.Error("domain does not exist", zap.String("domain", domain), zap.Error(err))
Expand Down
2 changes: 1 addition & 1 deletion internal_worker_interfaces_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (s *InterfacesTestSuite) TestInterface() {
DecisionTaskStartToCloseTimeout: 10 * time.Second,
}
workflowClient := NewClient(service, domain, nil)
wfExecution, err := workflowClient.StartWorkflow(workflowOptions, "workflowType")
wfExecution, err := workflowClient.StartWorkflow(context.Background(), workflowOptions, "workflowType")
s.NoError(err)
fmt.Printf("Started workflow: %v \n", wfExecution)
}
10 changes: 5 additions & 5 deletions internal_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,13 @@ func TestCompleteActivity(t *testing.T) {
failedRequest = args.Get(1).(*s.RespondActivityTaskFailedRequest)
})

wfClient.CompleteActivity([]byte("task-token"), nil, nil)
wfClient.CompleteActivity(context.Background(), []byte("task-token"), nil, nil)
require.NotNil(t, completedRequest)

wfClient.CompleteActivity([]byte("task-token"), nil, NewCanceledError())
wfClient.CompleteActivity(context.Background(), []byte("task-token"), nil, NewCanceledError())
require.NotNil(t, canceledRequest)

wfClient.CompleteActivity([]byte("task-token"), nil, errors.New(""))
wfClient.CompleteActivity(context.Background(), []byte("task-token"), nil, errors.New(""))
require.NotNil(t, failedRequest)
}

Expand All @@ -391,8 +391,8 @@ func TestRecordActivityHeartbeat(t *testing.T) {
heartbeatRequest = args.Get(1).(*s.RecordActivityTaskHeartbeatRequest)
})

wfClient.RecordActivityHeartbeat(nil)
wfClient.RecordActivityHeartbeat(nil, "testStack", "customerObjects", 4)
wfClient.RecordActivityHeartbeat(context.Background(), nil)
wfClient.RecordActivityHeartbeat(context.Background(), nil, "testStack", "customerObjects", 4)
require.NotNil(t, heartbeatRequest)
}

Expand Down
Loading

0 comments on commit 815dd18

Please sign in to comment.