Skip to content

Commit

Permalink
Tidy mocking
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom-Newton committed Dec 21, 2023
1 parent 8e9a76e commit 50f64c1
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions flytepropeller/pkg/controller/workflow/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"strconv"

"testing"
"time"

Expand Down Expand Up @@ -517,15 +516,15 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) {

h := &nodemocks.NodeHandler{}
h.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)
handleMockCall := h.OnHandleMatch(mock.Anything, mock.Anything)
handleMockCall.RunFn = func(args mock.Arguments) {
if args[1].(*nodes.NodeExecContext).Node().IsStartNode() {
handleMockCall.ReturnArguments = mock.Arguments{handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil}
} else {
executionError := core.ExecutionError{Code: "code", Message: "message", ErrorUri: "uri"}
handleMockCall.ReturnArguments = mock.Arguments{handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailureErr(&executionError, nil)), nil}
}
}

startNodeMatcher := mock.MatchedBy(func(nodeExecContext *nodes.NodeExecContext) bool {
return nodeExecContext.Node().IsStartNode()
})
h.OnHandleMatch(mock.Anything, startNodeMatcher).Return(handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil)
h.OnHandleMatch(mock.Anything, mock.Anything).Return(handler.DoTransition(
handler.TransitionTypeEphemeral,
handler.PhaseInfoFailureErr(&core.ExecutionError{Code: "code", Message: "message", ErrorUri: "uri"}, nil)), nil,
)

h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(nil)
h.OnFinalizeRequired().Return(false)
Expand All @@ -534,25 +533,24 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) {
handlerFactory.OnSetupMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)
handlerFactory.OnGetHandlerMatch(mock.Anything).Return(h, nil)


tests := []struct {
name string
onFailurePolicy v1alpha1.WorkflowOnFailurePolicy
clearPreviousError bool
expectedRoundsToFail int
name string
onFailurePolicy v1alpha1.WorkflowOnFailurePolicy
clearPreviousError bool
expectedRoundsToFail int
expectedNodesWithErrorsCount int
expectedFailedNodesCount int
}{
{"failImidiately", v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_IMMEDIATELY), false, 6, 1, 1},
{"failImidiately clearPreviousError", v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_IMMEDIATELY), true, 6, 1, 1},
{"failAfterExecutableNodesComplete", v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE), false, 12, 2, 2},
{"failAfterExecutableNodesComplete clearPreviousError", v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE), true, 12, 1, 2},
}
expectedFailedNodesCount int
}{
{"failImidiately", v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_IMMEDIATELY), false, 6, 1, 1},
{"failImidiately clearPreviousError", v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_IMMEDIATELY), true, 6, 1, 1},
{"failAfterExecutableNodesComplete", v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE), false, 12, 2, 2},
{"failAfterExecutableNodesComplete clearPreviousError", v1alpha1.WorkflowOnFailurePolicy(core.WorkflowMetadata_FAIL_AFTER_EXECUTABLE_NODES_COMPLETE), true, 12, 1, 2},
}

wJSON, err := yamlutils.ReadYamlFileAsJSON("testdata/benchmark_wf.yaml")
assert.NoError(t, err)
for _, test := range tests {

t.Run(test.name, func(t *testing.T) {
nodeConfig := config.GetConfig().NodeConfig
nodeConfig.ClearPreviousError = test.clearPreviousError
Expand All @@ -569,7 +567,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) {
}
if assert.NoError(t, json.Unmarshal(wJSON, w)) {
// For benchmark workflow, we will run into the first failure on round 6

for i := 0; i < test.expectedRoundsToFail; i++ {
t.Run(fmt.Sprintf("Round[%d]", i), func(t *testing.T) {
err := executor.HandleFlyteWorkflow(ctx, w)
Expand All @@ -581,7 +579,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) {
v.ResetDirty()
}
fmt.Printf("\n")

if i == test.expectedRoundsToFail-1 {
assert.Equal(t, v1alpha1.WorkflowPhaseFailed, w.Status.Phase)
} else if i == test.expectedRoundsToFail-2 {
Expand Down

0 comments on commit 50f64c1

Please sign in to comment.