From cd9fbb9720f0785de3c753b9174db6fbfbf11480 Mon Sep 17 00:00:00 2001 From: droctothorpe Date: Wed, 11 Sep 2024 12:40:28 -0400 Subject: [PATCH 1/8] fix(backend): implement subdag output resolution Signed-off-by: droctothorpe Co-authored-by: zazulam Co-authored-by: CarterFendley --- backend/src/v2/driver/driver.go | 111 ++++++++++++++++++++++++++++-- backend/src/v2/metadata/client.go | 16 ++++- 2 files changed, 117 insertions(+), 10 deletions(-) diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index aeeda9b6a48d..9d579e80d0a9 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -17,10 +17,11 @@ import ( "context" "encoding/json" "fmt" - "github.com/kubeflow/pipelines/backend/src/v2/objectstore" "strconv" "time" + "github.com/kubeflow/pipelines/backend/src/v2/objectstore" + "github.com/golang/glog" "github.com/golang/protobuf/ptypes/timestamp" "github.com/google/uuid" @@ -125,6 +126,8 @@ func RootDAG(ctx context.Context, opts Options, mlmd *metadata.Client) (executio err = fmt.Errorf("driver.RootDAG(%s) failed: %w", opts.info(), err) } }() + b, _ := json.Marshal(opts) + glog.V(4).Info("RootDAG opts: ", string(b)) err = validateRootDAG(opts) if err != nil { return nil, err @@ -230,6 +233,8 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl err = fmt.Errorf("driver.Container(%s) failed: %w", opts.info(), err) } }() + b, _ := json.Marshal(opts) + glog.V(4).Info("Container opts: ", string(b)) err = validateContainer(opts) if err != nil { return nil, err @@ -699,6 +704,8 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E err = fmt.Errorf("driver.DAG(%s) failed: %w", opts.info(), err) } }() + b, _ := json.Marshal(opts) + glog.V(4).Info("DAG opts: ", string(b)) err = validateDAG(opts) if err != nil { return nil, err @@ -749,6 +756,27 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E ecfg.ParentDagID = dag.Execution.GetID() ecfg.IterationIndex = iterationIndex ecfg.NotTriggered = !execution.WillTrigger() + + outputParameters := opts.Component.GetDag().GetOutputs().GetParameters() + glog.V(4).Info("outputParameters: ", outputParameters) + for _, value := range outputParameters { + outputParameterKey := value.GetValueFromParameter().OutputParameterKey + producerSubTask := value.GetValueFromParameter().ProducerSubtask + glog.V(4).Info("outputParameterKey: ", outputParameterKey) + glog.V(4).Info("producerSubtask: ", producerSubTask) + + outputParameterMap := map[string]interface{}{ + "output_parameter_key": outputParameterKey, + "producer_subtask": producerSubTask, + } + + outputParameterStruct, _ := structpb.NewValue(outputParameterMap) + + ecfg.OutputParameters = map[string]*structpb.Value{ + value.GetValueFromParameter().OutputParameterKey: outputParameterStruct, + } + } + if opts.Task.GetArtifactIterator() != nil { return execution, fmt.Errorf("ArtifactIterator is not implemented") } @@ -793,6 +821,12 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E ecfg.IterationCount = &count execution.IterationCount = &count } + + glog.V(4).Info("pipeline: ", pipeline) + b, _ = json.Marshal(*ecfg) + glog.V(4).Info("ecfg: ", string(b)) + glog.V(4).Infof("dag: %v", dag) + // TODO(Bobgy): change execution state to pending, because this is driver, execution hasn't started. createdExecution, err := mlmd.CreateExecution(ctx, pipeline, ecfg) if err != nil { @@ -939,6 +973,8 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, err = fmt.Errorf("failed to resolve inputs: %w", err) } }() + glog.V(4).Infof("dag: %v", dag) + glog.V(4).Infof("task: %v", task) inputParams, _, err := dag.Execution.GetParameters() if err != nil { return nil, err @@ -1112,10 +1148,31 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, if err != nil { return nil, err } + // TODO: Make this recursive. + for _, v := range tasks { + if v.GetExecution().GetType() == "system.DAGExecution" { + glog.V(4).Infof("Found a task, %v, with an execution type of system.DAGExecution. Adding its tasks to the task list.", v.TaskName()) + dag, err := mlmd.GetDAG(ctx, v.GetExecution().GetId()) + if err != nil { + return nil, err + } + subdagTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline) + if err != nil { + return nil, err + } + for k, v := range subdagTasks { + tasks[k] = v + } + } + } tasksCache = tasks + return tasks, nil } + for name, paramSpec := range task.GetInputs().GetParameters() { + glog.V(4).Infof("name: %v", name) + glog.V(4).Infof("paramSpec: %v", paramSpec) paramError := func(err error) error { return fmt.Errorf("resolving input parameter %s with spec %s: %w", name, paramSpec, err) } @@ -1131,8 +1188,11 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, } inputs.ParameterValues[name] = v + // This is the case where we are consuming an output parameter from an + // upstream task. That task can be a container or a DAG. case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter: taskOutput := paramSpec.GetTaskOutputParameter() + glog.V(4).Info("taskOutput: ", taskOutput) if taskOutput.GetProducerTask() == "" { return nil, paramError(fmt.Errorf("producer task is empty")) } @@ -1143,19 +1203,56 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, if err != nil { return nil, paramError(err) } + + // The producer is the task that produces the output that we need to + // consume. producer, ok := tasks[taskOutput.GetProducerTask()] - if !ok { - return nil, paramError(fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask())) - } - _, outputs, err := producer.GetParameters() + + glog.V(4).Info("producer: ", producer) + + // Get the producer's outputs. + _, producerOutputs, err := producer.GetParameters() if err != nil { return nil, paramError(fmt.Errorf("get producer output parameters: %w", err)) } - param, ok := outputs[taskOutput.GetOutputParameterKey()] + glog.V(4).Info("producer output parameters: ", producerOutputs) + // Deserialize them. + var producerOutputsMap map[string]string + b, err := producerOutputs["Output"].GetStructValue().MarshalJSON() + if err != nil { + return nil, err + } + json.Unmarshal(b, &producerOutputsMap) + glog.V(4).Info("producerOutputsMap: ", producerOutputsMap) + + // If the producer's output includes a producer subtask, which means + // that the producer is a DAG that is getting its output from one of + // the tasks in the DAG, then we want to roll up the output from the + // producer subtask to the producer, so that the downstream logic + // can retrieve it appropriately. + if producerSubTask, ok := producerOutputsMap["producer_subtask"]; ok { + glog.V(4).Infof( + "Overriding producer task, %v, output with producer_subtask, %v, output.", + producer.TaskName(), + producerSubTask, + ) + _, producerOutputs, err = tasks[producerSubTask].GetParameters() + if err != nil { + return nil, err + } + glog.V(4).Info("producerSubTask output parameters: ", producerOutputs) + // The only reason we're updating this is to make the downstream + // logging more accurate. + taskOutput.ProducerTask = producerOutputsMap["producer_subtask"] + } + + // Grab the value of the producer output. + producerOutputValue, ok := producerOutputs[taskOutput.GetOutputParameterKey()] if !ok { return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask())) } - inputs.ParameterValues[name] = param + // Update the input to be the producer output value. + inputs.ParameterValues[name] = producerOutputValue case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue: runtimeValue := paramSpec.GetRuntimeValue() switch t := runtimeValue.Value.(type) { diff --git a/backend/src/v2/metadata/client.go b/backend/src/v2/metadata/client.go index a292c1fe6430..1a670c9d12e1 100644 --- a/backend/src/v2/metadata/client.go +++ b/backend/src/v2/metadata/client.go @@ -21,14 +21,15 @@ import ( "encoding/json" "errors" "fmt" - "github.com/kubeflow/pipelines/backend/src/common/util" - "github.com/kubeflow/pipelines/backend/src/v2/objectstore" "path" "strconv" "strings" "sync" "time" + "github.com/kubeflow/pipelines/backend/src/common/util" + "github.com/kubeflow/pipelines/backend/src/v2/objectstore" + "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" "github.com/golang/glog" @@ -134,6 +135,7 @@ type ExecutionConfig struct { NotTriggered bool // optional, not triggered executions will have CANCELED state. ParentDagID int64 // parent DAG execution ID. Only the root DAG does not have a parent DAG. InputParameters map[string]*structpb.Value + OutputParameters map[string]*structpb.Value InputArtifactIDs map[string][]int64 IterationIndex *int // Index of the iteration. @@ -448,6 +450,8 @@ func getArtifactName(eventPath *pb.Event_Path) (string, error) { func (c *Client) PublishExecution(ctx context.Context, execution *Execution, outputParameters map[string]*structpb.Value, outputArtifacts []*OutputArtifact, state pb.Execution_State) error { e := execution.execution e.LastKnownState = state.Enum() + glog.V(4).Infof("outputParameters: %v", outputParameters) + glog.V(4).Infof("outputArtifacts: %v", outputArtifacts) if outputParameters != nil { // Record output parameters. @@ -576,7 +580,13 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config }, }} } - + if config.OutputParameters != nil { + e.CustomProperties[keyOutputs] = &pb.Value{Value: &pb.Value_StructValue{ + StructValue: &structpb.Struct{ + Fields: config.OutputParameters, + }, + }} + } req := &pb.PutExecutionRequest{ Execution: e, Contexts: []*pb.Context{pipeline.pipelineCtx, pipeline.pipelineRunCtx}, From b496a3312f3ee5d7ffbe37326b9de73b150a5410 Mon Sep 17 00:00:00 2001 From: droctothorpe Date: Wed, 11 Sep 2024 16:44:55 -0400 Subject: [PATCH 2/8] Add support for subdags of subdags Signed-off-by: droctothorpe Co-authored-by: zazulam Co-authored-by: CarterFendley --- backend/src/v2/driver/driver.go | 167 ++++++++++++++++++-------------- 1 file changed, 92 insertions(+), 75 deletions(-) diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 9d579e80d0a9..f916b4f058d0 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -967,6 +967,44 @@ func validateNonRoot(opts Options) error { return nil } +// getDAGTasks gets all the tasks associated with the specified DAG and all of +// its subDAGs. +func getDAGTasks( + ctx context.Context, + dag *metadata.DAG, + pipeline *metadata.Pipeline, + mlmd *metadata.Client, + flattenedTasks map[string]*metadata.Execution, +) (map[string]*metadata.Execution, error) { + if flattenedTasks == nil { + flattenedTasks = make(map[string]*metadata.Execution) + } + currentExecutionTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline) + if err != nil { + return nil, err + } + for k, v := range currentExecutionTasks { + flattenedTasks[k] = v + } + for _, v := range currentExecutionTasks { + if v.GetExecution().GetType() == "system.DAGExecution" { + glog.V(4).Infof("Found a task, %v, with an execution type of system.DAGExecution. Adding its tasks to the task list.", v.TaskName()) + subDAG, err := mlmd.GetDAG(ctx, v.GetExecution().GetId()) + if err != nil { + return nil, err + } + // Pass the subDAG into a recursive call to getDAGTasks and update + // tasks to include the subDAG's tasks. + flattenedTasks, err = getDAGTasks(ctx, subDAG, pipeline, mlmd, flattenedTasks) + if err != nil { + return nil, err + } + } + } + + return flattenedTasks, nil +} + func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, pipeline *metadata.Pipeline, task *pipelinespec.PipelineTaskSpec, inputsSpec *pipelinespec.ComponentInputsSpec, mlmd *metadata.Client, expr *expression.Expr) (inputs *pipelinespec.ExecutorInput_Inputs, err error) { defer func() { if err != nil { @@ -1138,37 +1176,6 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, } return inputs, nil } - // get executions in context on demand - var tasksCache map[string]*metadata.Execution - getDAGTasks := func() (map[string]*metadata.Execution, error) { - if tasksCache != nil { - return tasksCache, nil - } - tasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline) - if err != nil { - return nil, err - } - // TODO: Make this recursive. - for _, v := range tasks { - if v.GetExecution().GetType() == "system.DAGExecution" { - glog.V(4).Infof("Found a task, %v, with an execution type of system.DAGExecution. Adding its tasks to the task list.", v.TaskName()) - dag, err := mlmd.GetDAG(ctx, v.GetExecution().GetId()) - if err != nil { - return nil, err - } - subdagTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline) - if err != nil { - return nil, err - } - for k, v := range subdagTasks { - tasks[k] = v - } - } - } - tasksCache = tasks - - return tasks, nil - } for name, paramSpec := range task.GetInputs().GetParameters() { glog.V(4).Infof("name: %v", name) @@ -1199,60 +1206,70 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, if taskOutput.GetOutputParameterKey() == "" { return nil, paramError(fmt.Errorf("output parameter key is empty")) } - tasks, err := getDAGTasks() + tasks, err := getDAGTasks(ctx, dag, pipeline, mlmd, nil) if err != nil { return nil, paramError(err) } - // The producer is the task that produces the output that we need to - // consume. - producer, ok := tasks[taskOutput.GetProducerTask()] + // If the producer is a DAG, AND its output / producer subtask is + // ALSO a DAG, then we need to cycle through this loop until we + // arrive at a non-DAG subtask and essentially bubble up that + // non-DAG subtask so that its value can be consumed. + producerSubTaskMaybeDAG := true + for producerSubTaskMaybeDAG { + // The producer is the task that produces the output that we need to + // consume. + producer := tasks[taskOutput.GetProducerTask()] - glog.V(4).Info("producer: ", producer) + glog.V(4).Info("producer: ", producer) - // Get the producer's outputs. - _, producerOutputs, err := producer.GetParameters() - if err != nil { - return nil, paramError(fmt.Errorf("get producer output parameters: %w", err)) - } - glog.V(4).Info("producer output parameters: ", producerOutputs) - // Deserialize them. - var producerOutputsMap map[string]string - b, err := producerOutputs["Output"].GetStructValue().MarshalJSON() - if err != nil { - return nil, err - } - json.Unmarshal(b, &producerOutputsMap) - glog.V(4).Info("producerOutputsMap: ", producerOutputsMap) - - // If the producer's output includes a producer subtask, which means - // that the producer is a DAG that is getting its output from one of - // the tasks in the DAG, then we want to roll up the output from the - // producer subtask to the producer, so that the downstream logic - // can retrieve it appropriately. - if producerSubTask, ok := producerOutputsMap["producer_subtask"]; ok { - glog.V(4).Infof( - "Overriding producer task, %v, output with producer_subtask, %v, output.", - producer.TaskName(), - producerSubTask, - ) - _, producerOutputs, err = tasks[producerSubTask].GetParameters() + // Get the producer's outputs. + _, producerOutputs, err := producer.GetParameters() + if err != nil { + return nil, paramError(fmt.Errorf("get producer output parameters: %w", err)) + } + glog.V(4).Info("producer output parameters: ", producerOutputs) + // Deserialize them. + var producerOutputsMap map[string]string + b, err := producerOutputs["Output"].GetStructValue().MarshalJSON() if err != nil { return nil, err } - glog.V(4).Info("producerSubTask output parameters: ", producerOutputs) - // The only reason we're updating this is to make the downstream - // logging more accurate. - taskOutput.ProducerTask = producerOutputsMap["producer_subtask"] + json.Unmarshal(b, &producerOutputsMap) + glog.V(4).Info("producerOutputsMap: ", producerOutputsMap) + + // If the producer's output includes a producer subtask, which means + // that the producer is a DAG that is getting its output from one of + // the tasks in the DAG, then we want to roll up the output from the + // producer subtask to the producer, so that the downstream logic + // can retrieve it appropriately. + if producerSubTask, ok := producerOutputsMap["producer_subtask"]; ok { + glog.V(4).Infof( + "Overriding producer task, %v, output with producer_subtask, %v, output.", + producer.TaskName(), + producerSubTask, + ) + _, producerOutputs, err = tasks[producerSubTask].GetParameters() + if err != nil { + return nil, err + } + glog.V(4).Info("producerSubTask output parameters: ", producerOutputs) + // The only reason we're updating this is to make the downstream + // logging more accurate. + taskOutput.ProducerTask = producerOutputsMap["producer_subtask"] + // Grab the value of the producer output. + producerOutputValue, ok := producerOutputs[taskOutput.GetOutputParameterKey()] + if !ok { + return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask())) + } + // Update the input to be the producer output value. + inputs.ParameterValues[name] = producerOutputValue + } else { + // The producer subtask is not a DAG, so we exit the loop. + producerSubTaskMaybeDAG = false + } } - // Grab the value of the producer output. - producerOutputValue, ok := producerOutputs[taskOutput.GetOutputParameterKey()] - if !ok { - return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask())) - } - // Update the input to be the producer output value. - inputs.ParameterValues[name] = producerOutputValue case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue: runtimeValue := paramSpec.GetRuntimeValue() switch t := runtimeValue.Value.(type) { @@ -1292,7 +1309,7 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, if taskOutput.GetOutputArtifactKey() == "" { return nil, artifactError(fmt.Errorf("output artifact key is empty")) } - tasks, err := getDAGTasks() + tasks, err := getDAGTasks(ctx, dag, pipeline, mlmd, nil) if err != nil { return nil, artifactError(err) } From ccafc310e326bed3723b40826016d7a71649c40c Mon Sep 17 00:00:00 2001 From: zazulam Date: Tue, 17 Sep 2024 09:58:39 -0400 Subject: [PATCH 3/8] handle edge case Signed-off-by: zazulam Co-authored-by: droctothorpe --- backend/src/v2/driver/driver.go | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index f916b4f058d0..f38be9a73a93 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -1267,6 +1267,7 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, } else { // The producer subtask is not a DAG, so we exit the loop. producerSubTaskMaybeDAG = false + inputs.ParameterValues[name] = producerOutputs[taskOutput.GetOutputParameterKey()] } } From c1289b9b07ce8a3a1257e7105caa06ada61c0ace Mon Sep 17 00:00:00 2001 From: droctothorpe Date: Fri, 13 Sep 2024 09:14:30 -0400 Subject: [PATCH 4/8] Handle artifact outputs as well Signed-off-by: droctothorpe Co-authored-by: zazulam Co-authored-by: CarterFendley Co-authored-by: edmondop --- backend/src/v2/cmd/driver/main.go | 1 + backend/src/v2/driver/driver.go | 100 +++++++++++++++++++++--------- backend/src/v2/metadata/client.go | 18 +++++- 3 files changed, 90 insertions(+), 29 deletions(-) diff --git a/backend/src/v2/cmd/driver/main.go b/backend/src/v2/cmd/driver/main.go index 793ccfe1b800..9437d889862f 100644 --- a/backend/src/v2/cmd/driver/main.go +++ b/backend/src/v2/cmd/driver/main.go @@ -85,6 +85,7 @@ func init() { flag.Set("logtostderr", "true") // Change the WARNING to INFO level for debugging. flag.Set("stderrthreshold", "WARNING") + flag.Set("v", "4") } func validate() error { diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index f38be9a73a93..444fa177d27a 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -757,6 +757,7 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E ecfg.IterationIndex = iterationIndex ecfg.NotTriggered = !execution.WillTrigger() + // Handle writing output parameters to MLMD. outputParameters := opts.Component.GetDag().GetOutputs().GetParameters() glog.V(4).Info("outputParameters: ", outputParameters) for _, value := range outputParameters { @@ -777,6 +778,11 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E } } + // Handle writing output artifacts to MLMD. + outputArtifacts := opts.Component.GetDag().GetOutputs().GetArtifacts() + glog.V(4).Info("outputArtifacts: ", outputArtifacts) + ecfg.OutputArtifacts = outputArtifacts + if opts.Task.GetArtifactIterator() != nil { return execution, fmt.Errorf("ArtifactIterator is not implemented") } @@ -1177,6 +1183,7 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, return inputs, nil } + // Handle parameters. for name, paramSpec := range task.GetInputs().GetParameters() { glog.V(4).Infof("name: %v", name) glog.V(4).Infof("paramSpec: %v", paramSpec) @@ -1224,50 +1231,50 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, glog.V(4).Info("producer: ", producer) // Get the producer's outputs. - _, producerOutputs, err := producer.GetParameters() + _, producerOutputParameters, err := producer.GetParameters() if err != nil { return nil, paramError(fmt.Errorf("get producer output parameters: %w", err)) } - glog.V(4).Info("producer output parameters: ", producerOutputs) + glog.V(4).Info("producer output parameters: ", producerOutputParameters) // Deserialize them. - var producerOutputsMap map[string]string - b, err := producerOutputs["Output"].GetStructValue().MarshalJSON() + var producerOutputParametersMap map[string]string + b, err := producerOutputParameters["Output"].GetStructValue().MarshalJSON() if err != nil { return nil, err } - json.Unmarshal(b, &producerOutputsMap) - glog.V(4).Info("producerOutputsMap: ", producerOutputsMap) + json.Unmarshal(b, &producerOutputParametersMap) + glog.V(4).Info("producerOutputParametersMap: ", producerOutputParametersMap) // If the producer's output includes a producer subtask, which means // that the producer is a DAG that is getting its output from one of // the tasks in the DAG, then we want to roll up the output from the // producer subtask to the producer, so that the downstream logic // can retrieve it appropriately. - if producerSubTask, ok := producerOutputsMap["producer_subtask"]; ok { + if producerSubTask, ok := producerOutputParametersMap["producer_subtask"]; ok { glog.V(4).Infof( "Overriding producer task, %v, output with producer_subtask, %v, output.", producer.TaskName(), producerSubTask, ) - _, producerOutputs, err = tasks[producerSubTask].GetParameters() + _, producerOutputParameters, err = tasks[producerSubTask].GetParameters() if err != nil { return nil, err } - glog.V(4).Info("producerSubTask output parameters: ", producerOutputs) + glog.V(4).Info("producerSubTask output parameters: ", producerOutputParameters) // The only reason we're updating this is to make the downstream // logging more accurate. - taskOutput.ProducerTask = producerOutputsMap["producer_subtask"] + taskOutput.ProducerTask = producerOutputParametersMap["producer_subtask"] // Grab the value of the producer output. - producerOutputValue, ok := producerOutputs[taskOutput.GetOutputParameterKey()] + producerOutputParameterValue, ok := producerOutputParameters[taskOutput.GetOutputParameterKey()] if !ok { return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask())) } // Update the input to be the producer output value. - inputs.ParameterValues[name] = producerOutputValue + inputs.ParameterValues[name] = producerOutputParameterValue } else { // The producer subtask is not a DAG, so we exit the loop. producerSubTaskMaybeDAG = false - inputs.ParameterValues[name] = producerOutputs[taskOutput.GetOutputParameterKey()] + inputs.ParameterValues[name] = producerOutputParameters[taskOutput.GetOutputParameterKey()] } } @@ -1286,6 +1293,8 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, return nil, paramError(fmt.Errorf("parameter spec of type %T not implemented yet", t)) } } + + // Handle artifacts. for name, artifactSpec := range task.GetInputs().GetArtifacts() { artifactError := func(err error) error { return fmt.Errorf("failed to resolve input artifact %s with spec %s: %w", name, artifactSpec, err) @@ -1314,25 +1323,60 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, if err != nil { return nil, artifactError(err) } + producer, ok := tasks[taskOutput.GetProducerTask()] if !ok { return nil, artifactError(fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask())) } - // TODO(Bobgy): cache results - outputs, err := mlmd.GetOutputArtifactsByExecutionId(ctx, producer.GetID()) - if err != nil { - return nil, artifactError(err) - } - artifact, ok := outputs[taskOutput.GetOutputArtifactKey()] - if !ok { - return nil, artifactError(fmt.Errorf("cannot find output artifact key %q in producer task %q", taskOutput.GetOutputArtifactKey(), taskOutput.GetProducerTask())) - } - runtimeArtifact, err := artifact.ToRuntimeArtifact() - if err != nil { - return nil, artifactError(err) - } - inputs.Artifacts[name] = &pipelinespec.ArtifactList{ - Artifacts: []*pipelinespec.RuntimeArtifact{runtimeArtifact}, + glog.V(4).Info("producer: ", producer) + currentTask := producer + var outputArtifactKey string + currentSubTaskMaybeDAG := true + for currentSubTaskMaybeDAG { + // If the current task is a DAG: + glog.V(4).Info("currentTask: ", currentTask.TaskName()) + if *currentTask.GetExecution().Type == "system.DAGExecution" { + // Get the sub-task. + outputArtifactsCustomProperty := currentTask.GetExecution().GetCustomProperties()["output_artifacts"] + // Deserialize the output artifacts. + var outputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec + err := json.Unmarshal([]byte(outputArtifactsCustomProperty.GetStringValue()), &outputArtifacts) + if err != nil { + return nil, err + } + glog.V(4).Infof("Deserialized outputArtifacts: %v", outputArtifacts) + artifactSelectors := outputArtifacts["Output"].GetArtifactSelectors() + // TODO: Add support for multiple output artifacts. + subTaskName := artifactSelectors[0].ProducerSubtask + outputArtifactKey = artifactSelectors[0].OutputArtifactKey + glog.V(4).Info("subTaskName: ", subTaskName) + glog.V(4).Info("outputArtifactKey: ", outputArtifactKey) + currentSubTask := tasks[subTaskName] + // If the sub-task is a DAG, reassign currentTask and run + // through the loop again. + currentTask = currentSubTask + // } + } else { + // Base case, subtask is a container, not a DAG. + outputs, err := mlmd.GetOutputArtifactsByExecutionId(ctx, currentTask.GetID()) + if err != nil { + return nil, artifactError(err) + } + glog.V(4).Infof("outputs: %#v", outputs) + artifact, ok := outputs[outputArtifactKey] + if !ok { + return nil, artifactError(fmt.Errorf("cannot find output artifact key %q in producer task %q", taskOutput.GetOutputArtifactKey(), taskOutput.GetProducerTask())) + } + runtimeArtifact, err := artifact.ToRuntimeArtifact() + if err != nil { + return nil, artifactError(err) + } + inputs.Artifacts[name] = &pipelinespec.ArtifactList{ + Artifacts: []*pipelinespec.RuntimeArtifact{runtimeArtifact}, + } + // Since we are in the base case, escape the loop. + currentSubTaskMaybeDAG = false + } } default: return nil, artifactError(fmt.Errorf("artifact spec of type %T not implemented yet", t)) diff --git a/backend/src/v2/metadata/client.go b/backend/src/v2/metadata/client.go index 1a670c9d12e1..c6729cb67b53 100644 --- a/backend/src/v2/metadata/client.go +++ b/backend/src/v2/metadata/client.go @@ -136,6 +136,9 @@ type ExecutionConfig struct { ParentDagID int64 // parent DAG execution ID. Only the root DAG does not have a parent DAG. InputParameters map[string]*structpb.Value OutputParameters map[string]*structpb.Value + // OutputArtifacts map[string]*structpb.Value + // OutputArtifacts []*pipelinespec.DagOutputsSpec_ArtifactSelectorSpec + OutputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec InputArtifactIDs map[string][]int64 IterationIndex *int // Index of the iteration. @@ -516,7 +519,8 @@ const ( keyCacheFingerPrint = "cache_fingerprint" keyCachedExecutionID = "cached_execution_id" keyInputs = "inputs" - keyOutputs = "outputs" + keyOutputs = "outputs" // TODO: Consider renaming this to output_parameters to be consistent. + keyOutputArtifacts = "output_artifacts" keyParentDagID = "parent_dag_id" // Parent DAG Execution ID. keyIterationIndex = "iteration_index" keyIterationCount = "iteration_count" @@ -580,6 +584,10 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config }, }} } + // We save the output parameter and output artifact relationships in MLMD in + // case they're provided by a sub-task so that we can follow the + // relationships and retrieve outputs downstream in components that depend + // on said outputs as inputs. if config.OutputParameters != nil { e.CustomProperties[keyOutputs] = &pb.Value{Value: &pb.Value_StructValue{ StructValue: &structpb.Struct{ @@ -587,6 +595,14 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config }, }} } + if config.OutputArtifacts != nil { + b, err := json.Marshal(config.OutputArtifacts) + if err != nil { + return nil, err + } + e.CustomProperties[keyOutputArtifacts] = StringValue(string(b)) + } + req := &pb.PutExecutionRequest{ Execution: e, Contexts: []*pb.Context{pipeline.pipelineCtx, pipeline.pipelineRunCtx}, From 292cc69be38481173f9b1ea295defa9bfd74491e Mon Sep 17 00:00:00 2001 From: droctothorpe Date: Fri, 13 Sep 2024 09:14:30 -0400 Subject: [PATCH 5/8] Simplify parameter handling logic Signed-off-by: droctothorpe Co-authored-by: zazulam --- backend/src/v2/driver/driver.go | 84 +++++++++++++------------------ backend/src/v2/metadata/client.go | 2 - 2 files changed, 34 insertions(+), 52 deletions(-) diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 444fa177d27a..556da204888c 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -1218,63 +1218,48 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, return nil, paramError(err) } + // The producer is the task that produces the output that we need to + // consume. + producer := tasks[taskOutput.GetProducerTask()] + glog.V(4).Info("producer: ", producer) + currentTask := producer // If the producer is a DAG, AND its output / producer subtask is // ALSO a DAG, then we need to cycle through this loop until we // arrive at a non-DAG subtask and essentially bubble up that // non-DAG subtask so that its value can be consumed. - producerSubTaskMaybeDAG := true - for producerSubTaskMaybeDAG { - // The producer is the task that produces the output that we need to - // consume. - producer := tasks[taskOutput.GetProducerTask()] - - glog.V(4).Info("producer: ", producer) - - // Get the producer's outputs. - _, producerOutputParameters, err := producer.GetParameters() - if err != nil { - return nil, paramError(fmt.Errorf("get producer output parameters: %w", err)) - } - glog.V(4).Info("producer output parameters: ", producerOutputParameters) - // Deserialize them. - var producerOutputParametersMap map[string]string - b, err := producerOutputParameters["Output"].GetStructValue().MarshalJSON() + currentSubTaskMaybeDAG := true + for currentSubTaskMaybeDAG { + glog.V(4).Info("currentTask: ", currentTask.TaskName()) + _, outputParametersCustomProperty, err := currentTask.GetParameters() if err != nil { return nil, err } - json.Unmarshal(b, &producerOutputParametersMap) - glog.V(4).Info("producerOutputParametersMap: ", producerOutputParametersMap) - - // If the producer's output includes a producer subtask, which means - // that the producer is a DAG that is getting its output from one of - // the tasks in the DAG, then we want to roll up the output from the - // producer subtask to the producer, so that the downstream logic - // can retrieve it appropriately. - if producerSubTask, ok := producerOutputParametersMap["producer_subtask"]; ok { - glog.V(4).Infof( - "Overriding producer task, %v, output with producer_subtask, %v, output.", - producer.TaskName(), - producerSubTask, - ) - _, producerOutputParameters, err = tasks[producerSubTask].GetParameters() + // If the current task is a DAG: + if *currentTask.GetExecution().Type == "system.DAGExecution" { + // Since currentTask is a DAG, we need to deserialize its + // output parameter map so that we can look its + // corresponding producer sub-task, reassign currentTask, + // and iterate through this loop again. + var outputParametersMap map[string]string + b, err := outputParametersCustomProperty["Output"].GetStructValue().MarshalJSON() if err != nil { return nil, err } - glog.V(4).Info("producerSubTask output parameters: ", producerOutputParameters) - // The only reason we're updating this is to make the downstream - // logging more accurate. - taskOutput.ProducerTask = producerOutputParametersMap["producer_subtask"] - // Grab the value of the producer output. - producerOutputParameterValue, ok := producerOutputParameters[taskOutput.GetOutputParameterKey()] - if !ok { - return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask())) - } - // Update the input to be the producer output value. - inputs.ParameterValues[name] = producerOutputParameterValue + json.Unmarshal(b, &outputParametersMap) + glog.V(4).Info("Deserialized outputParametersMap: ", outputParametersMap) + subTaskName := outputParametersMap["producer_subtask"] + glog.V(4).Infof( + "Overriding currentTask, %v, output with currentTask's producer_subtask, %v, output.", + currentTask.TaskName(), + subTaskName, + ) + + // Reassign sub-task before running through the loop again. + currentTask = tasks[subTaskName] } else { - // The producer subtask is not a DAG, so we exit the loop. - producerSubTaskMaybeDAG = false - inputs.ParameterValues[name] = producerOutputParameters[taskOutput.GetOutputParameterKey()] + inputs.ParameterValues[name] = outputParametersCustomProperty[taskOutput.GetOutputParameterKey()] + // Exit the loop. + currentSubTaskMaybeDAG = false } } @@ -1333,8 +1318,8 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, var outputArtifactKey string currentSubTaskMaybeDAG := true for currentSubTaskMaybeDAG { - // If the current task is a DAG: glog.V(4).Info("currentTask: ", currentTask.TaskName()) + // If the current task is a DAG: if *currentTask.GetExecution().Type == "system.DAGExecution" { // Get the sub-task. outputArtifactsCustomProperty := currentTask.GetExecution().GetCustomProperties()["output_artifacts"] @@ -1351,13 +1336,12 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, outputArtifactKey = artifactSelectors[0].OutputArtifactKey glog.V(4).Info("subTaskName: ", subTaskName) glog.V(4).Info("outputArtifactKey: ", outputArtifactKey) - currentSubTask := tasks[subTaskName] // If the sub-task is a DAG, reassign currentTask and run // through the loop again. - currentTask = currentSubTask + currentTask = tasks[subTaskName] // } } else { - // Base case, subtask is a container, not a DAG. + // Base case, currentTask is a container, not a DAG. outputs, err := mlmd.GetOutputArtifactsByExecutionId(ctx, currentTask.GetID()) if err != nil { return nil, artifactError(err) diff --git a/backend/src/v2/metadata/client.go b/backend/src/v2/metadata/client.go index c6729cb67b53..c11d5f356f03 100644 --- a/backend/src/v2/metadata/client.go +++ b/backend/src/v2/metadata/client.go @@ -136,8 +136,6 @@ type ExecutionConfig struct { ParentDagID int64 // parent DAG execution ID. Only the root DAG does not have a parent DAG. InputParameters map[string]*structpb.Value OutputParameters map[string]*structpb.Value - // OutputArtifacts map[string]*structpb.Value - // OutputArtifacts []*pipelinespec.DagOutputsSpec_ArtifactSelectorSpec OutputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec InputArtifactIDs map[string][]int64 IterationIndex *int // Index of the iteration. From 3f71af77afb432e731531080f0f9e350fd30b743 Mon Sep 17 00:00:00 2001 From: droctothorpe Date: Thu, 19 Sep 2024 11:58:53 -0400 Subject: [PATCH 6/8] Begin decomposition Signed-off-by: droctothorpe --- backend/src/v2/driver/driver.go | 387 +++++++++++++++++++------------- 1 file changed, 231 insertions(+), 156 deletions(-) diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 556da204888c..20fd3e518df8 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -973,44 +973,6 @@ func validateNonRoot(opts Options) error { return nil } -// getDAGTasks gets all the tasks associated with the specified DAG and all of -// its subDAGs. -func getDAGTasks( - ctx context.Context, - dag *metadata.DAG, - pipeline *metadata.Pipeline, - mlmd *metadata.Client, - flattenedTasks map[string]*metadata.Execution, -) (map[string]*metadata.Execution, error) { - if flattenedTasks == nil { - flattenedTasks = make(map[string]*metadata.Execution) - } - currentExecutionTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline) - if err != nil { - return nil, err - } - for k, v := range currentExecutionTasks { - flattenedTasks[k] = v - } - for _, v := range currentExecutionTasks { - if v.GetExecution().GetType() == "system.DAGExecution" { - glog.V(4).Infof("Found a task, %v, with an execution type of system.DAGExecution. Adding its tasks to the task list.", v.TaskName()) - subDAG, err := mlmd.GetDAG(ctx, v.GetExecution().GetId()) - if err != nil { - return nil, err - } - // Pass the subDAG into a recursive call to getDAGTasks and update - // tasks to include the subDAG's tasks. - flattenedTasks, err = getDAGTasks(ctx, subDAG, pipeline, mlmd, flattenedTasks) - if err != nil { - return nil, err - } - } - } - - return flattenedTasks, nil -} - func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, pipeline *metadata.Pipeline, task *pipelinespec.PipelineTaskSpec, inputsSpec *pipelinespec.ComponentInputsSpec, mlmd *metadata.Client, expr *expression.Expr) (inputs *pipelinespec.ExecutorInput_Inputs, err error) { defer func() { if err != nil { @@ -1202,65 +1164,20 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, } inputs.ParameterValues[name] = v - // This is the case where we are consuming an output parameter from an - // upstream task. That task can be a container or a DAG. + // This is the case where the input comes from the output of an upstream task. case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter: - taskOutput := paramSpec.GetTaskOutputParameter() - glog.V(4).Info("taskOutput: ", taskOutput) - if taskOutput.GetProducerTask() == "" { - return nil, paramError(fmt.Errorf("producer task is empty")) + cfg := resolveUpstreamParametersConfig{ + ctx: ctx, + paramSpec: paramSpec, + dag: dag, + pipeline: pipeline, + mlmd: mlmd, + inputs: inputs, + name: name, + paramError: paramError, } - if taskOutput.GetOutputParameterKey() == "" { - return nil, paramError(fmt.Errorf("output parameter key is empty")) - } - tasks, err := getDAGTasks(ctx, dag, pipeline, mlmd, nil) - if err != nil { - return nil, paramError(err) - } - - // The producer is the task that produces the output that we need to - // consume. - producer := tasks[taskOutput.GetProducerTask()] - glog.V(4).Info("producer: ", producer) - currentTask := producer - // If the producer is a DAG, AND its output / producer subtask is - // ALSO a DAG, then we need to cycle through this loop until we - // arrive at a non-DAG subtask and essentially bubble up that - // non-DAG subtask so that its value can be consumed. - currentSubTaskMaybeDAG := true - for currentSubTaskMaybeDAG { - glog.V(4).Info("currentTask: ", currentTask.TaskName()) - _, outputParametersCustomProperty, err := currentTask.GetParameters() - if err != nil { - return nil, err - } - // If the current task is a DAG: - if *currentTask.GetExecution().Type == "system.DAGExecution" { - // Since currentTask is a DAG, we need to deserialize its - // output parameter map so that we can look its - // corresponding producer sub-task, reassign currentTask, - // and iterate through this loop again. - var outputParametersMap map[string]string - b, err := outputParametersCustomProperty["Output"].GetStructValue().MarshalJSON() - if err != nil { - return nil, err - } - json.Unmarshal(b, &outputParametersMap) - glog.V(4).Info("Deserialized outputParametersMap: ", outputParametersMap) - subTaskName := outputParametersMap["producer_subtask"] - glog.V(4).Infof( - "Overriding currentTask, %v, output with currentTask's producer_subtask, %v, output.", - currentTask.TaskName(), - subTaskName, - ) - - // Reassign sub-task before running through the loop again. - currentTask = tasks[subTaskName] - } else { - inputs.ParameterValues[name] = outputParametersCustomProperty[taskOutput.GetOutputParameterKey()] - // Exit the loop. - currentSubTaskMaybeDAG = false - } + if err := resolveUpstreamParameters(cfg); err != nil { + return nil, err } case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue: @@ -1297,77 +1214,235 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, inputs.Artifacts[name] = v case *pipelinespec.TaskInputsSpec_InputArtifactSpec_TaskOutputArtifact: - taskOutput := artifactSpec.GetTaskOutputArtifact() - if taskOutput.GetProducerTask() == "" { - return nil, artifactError(fmt.Errorf("producer task is empty")) + cfg := resolveUpstreamArtifactsConfig{ + ctx: ctx, + artifactSpec: artifactSpec, + dag: dag, + pipeline: pipeline, + mlmd: mlmd, + inputs: inputs, + name: name, + artifactError: artifactError, } - if taskOutput.GetOutputArtifactKey() == "" { - return nil, artifactError(fmt.Errorf("output artifact key is empty")) + if err := resolveUpstreamArtifacts(cfg); err != nil { + return nil, err } - tasks, err := getDAGTasks(ctx, dag, pipeline, mlmd, nil) + default: + return nil, artifactError(fmt.Errorf("artifact spec of type %T not implemented yet", t)) + } + } + // TODO(Bobgy): validate executor inputs match component inputs definition + return inputs, nil +} + +// resolveUpstreamParametersConfig is just a config struct used to store the +// input parameters of the resolveUpstreamParameters function. +type resolveUpstreamParametersConfig struct { + ctx context.Context + paramSpec *pipelinespec.TaskInputsSpec_InputParameterSpec + dag *metadata.DAG + pipeline *metadata.Pipeline + mlmd *metadata.Client + inputs *pipelinespec.ExecutorInput_Inputs + name string + paramError func(error) error +} + +// resolveUpstreamParameters resolves input parameters that come from upstream +// tasks. These tasks can be components/containers, which is relatively +// straightforward, or DAGs, in which case, we need to traverse the graph until +// we arrive at a component/container (since there can be n nested DAGs). +func resolveUpstreamParameters(cfg resolveUpstreamParametersConfig) error { + taskOutput := cfg.paramSpec.GetTaskOutputParameter() + glog.V(4).Info("taskOutput: ", taskOutput) + if taskOutput.GetProducerTask() == "" { + return cfg.paramError(fmt.Errorf("producer task is empty")) + } + if taskOutput.GetOutputParameterKey() == "" { + return cfg.paramError(fmt.Errorf("output parameter key is empty")) + } + tasks, err := getDAGTasks(cfg.ctx, cfg.dag, cfg.pipeline, cfg.mlmd, nil) + if err != nil { + return cfg.paramError(err) + } + + // The producer is the task that produces the output that we need to + // consume. + producer := tasks[taskOutput.GetProducerTask()] + glog.V(4).Info("producer: ", producer) + currentTask := producer + currentSubTaskMaybeDAG := true + // Continue looping until we reach a sub-task that is NOT a DAG. + for currentSubTaskMaybeDAG { + glog.V(4).Info("currentTask: ", currentTask.TaskName()) + _, outputParametersCustomProperty, err := currentTask.GetParameters() + if err != nil { + return err + } + // If the current task is a DAG: + if *currentTask.GetExecution().Type == "system.DAGExecution" { + // Since currentTask is a DAG, we need to deserialize its + // output parameter map so that we can look its + // corresponding producer sub-task, reassign currentTask, + // and iterate through this loop again. + var outputParametersMap map[string]string + b, err := outputParametersCustomProperty["Output"].GetStructValue().MarshalJSON() if err != nil { - return nil, artifactError(err) + return err } + json.Unmarshal(b, &outputParametersMap) + glog.V(4).Info("Deserialized outputParametersMap: ", outputParametersMap) + subTaskName := outputParametersMap["producer_subtask"] + glog.V(4).Infof( + "Overriding currentTask, %v, output with currentTask's producer_subtask, %v, output.", + currentTask.TaskName(), + subTaskName, + ) + + // Reassign sub-task before running through the loop again. + currentTask = tasks[subTaskName] + } else { + cfg.inputs.ParameterValues[cfg.name] = outputParametersCustomProperty[taskOutput.GetOutputParameterKey()] + // Exit the loop. + currentSubTaskMaybeDAG = false + } + } + + return nil +} + +// resolveUpstreamArtifactsConfig is just a config struct used to store the +// input parameters of the resolveUpstreamArtifacts function. +type resolveUpstreamArtifactsConfig struct { + ctx context.Context + artifactSpec *pipelinespec.TaskInputsSpec_InputArtifactSpec + dag *metadata.DAG + pipeline *metadata.Pipeline + mlmd *metadata.Client + inputs *pipelinespec.ExecutorInput_Inputs + name string + artifactError func(error) error +} - producer, ok := tasks[taskOutput.GetProducerTask()] +// resolveUpstreamArtifacts resolves input artifacts that come from upstream +// tasks. These tasks can be components/containers, which is relatively +// straightforward, or DAGs, in which case, we need to traverse the graph until +// we arrive at a component/container (since there can be n nested DAGs). +func resolveUpstreamArtifacts(cfg resolveUpstreamArtifactsConfig) error { + taskOutput := cfg.artifactSpec.GetTaskOutputArtifact() + if taskOutput.GetProducerTask() == "" { + return cfg.artifactError(fmt.Errorf("producer task is empty")) + } + if taskOutput.GetOutputArtifactKey() == "" { + cfg.artifactError(fmt.Errorf("output artifact key is empty")) + } + tasks, err := getDAGTasks(cfg.ctx, cfg.dag, cfg.pipeline, cfg.mlmd, nil) + if err != nil { + cfg.artifactError(err) + } + + producer, ok := tasks[taskOutput.GetProducerTask()] + if !ok { + cfg.artifactError( + fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask()), + ) + } + glog.V(4).Info("producer: ", producer) + currentTask := producer + var outputArtifactKey string + currentSubTaskMaybeDAG := true + // Continue looping until we reach a sub-task that is NOT a DAG. + for currentSubTaskMaybeDAG { + glog.V(4).Info("currentTask: ", currentTask.TaskName()) + // If the current task is a DAG: + if *currentTask.GetExecution().Type == "system.DAGExecution" { + // Get the sub-task. + outputArtifactsCustomProperty := currentTask.GetExecution().GetCustomProperties()["output_artifacts"] + // Deserialize the output artifacts. + var outputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec + err := json.Unmarshal([]byte(outputArtifactsCustomProperty.GetStringValue()), &outputArtifacts) + if err != nil { + return err + } + glog.V(4).Infof("Deserialized outputArtifacts: %v", outputArtifacts) + artifactSelectors := outputArtifacts["Output"].GetArtifactSelectors() + // TODO: Add support for multiple output artifacts. + subTaskName := artifactSelectors[0].ProducerSubtask + outputArtifactKey = artifactSelectors[0].OutputArtifactKey + glog.V(4).Info("subTaskName: ", subTaskName) + glog.V(4).Info("outputArtifactKey: ", outputArtifactKey) + // If the sub-task is a DAG, reassign currentTask and run + // through the loop again. + currentTask = tasks[subTaskName] + // } + } else { + // Base case, currentTask is a container, not a DAG. + outputs, err := cfg.mlmd.GetOutputArtifactsByExecutionId(cfg.ctx, currentTask.GetID()) + if err != nil { + cfg.artifactError(err) + } + glog.V(4).Infof("outputs: %#v", outputs) + artifact, ok := outputs[outputArtifactKey] if !ok { - return nil, artifactError(fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask())) + cfg.artifactError( + fmt.Errorf( + "cannot find output artifact key %q in producer task %q", + taskOutput.GetOutputArtifactKey(), + taskOutput.GetProducerTask(), + ), + ) } - glog.V(4).Info("producer: ", producer) - currentTask := producer - var outputArtifactKey string - currentSubTaskMaybeDAG := true - for currentSubTaskMaybeDAG { - glog.V(4).Info("currentTask: ", currentTask.TaskName()) - // If the current task is a DAG: - if *currentTask.GetExecution().Type == "system.DAGExecution" { - // Get the sub-task. - outputArtifactsCustomProperty := currentTask.GetExecution().GetCustomProperties()["output_artifacts"] - // Deserialize the output artifacts. - var outputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec - err := json.Unmarshal([]byte(outputArtifactsCustomProperty.GetStringValue()), &outputArtifacts) - if err != nil { - return nil, err - } - glog.V(4).Infof("Deserialized outputArtifacts: %v", outputArtifacts) - artifactSelectors := outputArtifacts["Output"].GetArtifactSelectors() - // TODO: Add support for multiple output artifacts. - subTaskName := artifactSelectors[0].ProducerSubtask - outputArtifactKey = artifactSelectors[0].OutputArtifactKey - glog.V(4).Info("subTaskName: ", subTaskName) - glog.V(4).Info("outputArtifactKey: ", outputArtifactKey) - // If the sub-task is a DAG, reassign currentTask and run - // through the loop again. - currentTask = tasks[subTaskName] - // } - } else { - // Base case, currentTask is a container, not a DAG. - outputs, err := mlmd.GetOutputArtifactsByExecutionId(ctx, currentTask.GetID()) - if err != nil { - return nil, artifactError(err) - } - glog.V(4).Infof("outputs: %#v", outputs) - artifact, ok := outputs[outputArtifactKey] - if !ok { - return nil, artifactError(fmt.Errorf("cannot find output artifact key %q in producer task %q", taskOutput.GetOutputArtifactKey(), taskOutput.GetProducerTask())) - } - runtimeArtifact, err := artifact.ToRuntimeArtifact() - if err != nil { - return nil, artifactError(err) - } - inputs.Artifacts[name] = &pipelinespec.ArtifactList{ - Artifacts: []*pipelinespec.RuntimeArtifact{runtimeArtifact}, - } - // Since we are in the base case, escape the loop. - currentSubTaskMaybeDAG = false - } + runtimeArtifact, err := artifact.ToRuntimeArtifact() + if err != nil { + cfg.artifactError(err) } - default: - return nil, artifactError(fmt.Errorf("artifact spec of type %T not implemented yet", t)) + cfg.inputs.Artifacts[cfg.name] = &pipelinespec.ArtifactList{ + Artifacts: []*pipelinespec.RuntimeArtifact{runtimeArtifact}, + } + // Since we are in the base case, escape the loop. + currentSubTaskMaybeDAG = false } } - // TODO(Bobgy): validate executor inputs match component inputs definition - return inputs, nil + + return nil +} + +// getDAGTasks gets all the tasks associated with the specified DAG and all of +// its subDAGs. +func getDAGTasks( + ctx context.Context, + dag *metadata.DAG, + pipeline *metadata.Pipeline, + mlmd *metadata.Client, + flattenedTasks map[string]*metadata.Execution, +) (map[string]*metadata.Execution, error) { + if flattenedTasks == nil { + flattenedTasks = make(map[string]*metadata.Execution) + } + currentExecutionTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline) + if err != nil { + return nil, err + } + for k, v := range currentExecutionTasks { + flattenedTasks[k] = v + } + for _, v := range currentExecutionTasks { + if v.GetExecution().GetType() == "system.DAGExecution" { + glog.V(4).Infof("Found a task, %v, with an execution type of system.DAGExecution. Adding its tasks to the task list.", v.TaskName()) + subDAG, err := mlmd.GetDAG(ctx, v.GetExecution().GetId()) + if err != nil { + return nil, err + } + // Pass the subDAG into a recursive call to getDAGTasks and update + // tasks to include the subDAG's tasks. + flattenedTasks, err = getDAGTasks(ctx, subDAG, pipeline, mlmd, flattenedTasks) + if err != nil { + return nil, err + } + } + } + + return flattenedTasks, nil } func provisionOutputs(pipelineRoot, taskName string, outputsSpec *pipelinespec.ComponentOutputsSpec) *pipelinespec.ExecutorInput_Outputs { From 539c053d7600204231d4d744448262bbdb699493 Mon Sep 17 00:00:00 2001 From: Tyler Kalbach Date: Mon, 23 Sep 2024 22:56:56 -0400 Subject: [PATCH 7/8] Add support for multiple artifacts and params Signed-off-by: Tyler Kalbach --- backend/src/v2/driver/driver.go | 39 ++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 20fd3e518df8..60d04d7f53a9 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -760,9 +760,10 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E // Handle writing output parameters to MLMD. outputParameters := opts.Component.GetDag().GetOutputs().GetParameters() glog.V(4).Info("outputParameters: ", outputParameters) - for _, value := range outputParameters { - outputParameterKey := value.GetValueFromParameter().OutputParameterKey - producerSubTask := value.GetValueFromParameter().ProducerSubtask + ecfg.OutputParameters = make(map[string]*structpb.Value) + for name, value := range outputParameters { + outputParameterKey := value.GetValueFromParameter().GetOutputParameterKey() + producerSubTask := value.GetValueFromParameter().GetProducerSubtask() glog.V(4).Info("outputParameterKey: ", outputParameterKey) glog.V(4).Info("producerSubtask: ", producerSubTask) @@ -773,9 +774,7 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E outputParameterStruct, _ := structpb.NewValue(outputParameterMap) - ecfg.OutputParameters = map[string]*structpb.Value{ - value.GetValueFromParameter().OutputParameterKey: outputParameterStruct, - } + ecfg.OutputParameters[name] = outputParameterStruct } // Handle writing output artifacts to MLMD. @@ -1198,6 +1197,8 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, // Handle artifacts. for name, artifactSpec := range task.GetInputs().GetArtifacts() { + glog.V(4).Infof("inputs: %#v", task.GetInputs()) + glog.V(4).Infof("artifacts: %#v", task.GetInputs().GetArtifacts()) artifactError := func(err error) error { return fmt.Errorf("failed to resolve input artifact %s with spec %s: %w", name, artifactSpec, err) } @@ -1269,6 +1270,7 @@ func resolveUpstreamParameters(cfg resolveUpstreamParametersConfig) error { // The producer is the task that produces the output that we need to // consume. producer := tasks[taskOutput.GetProducerTask()] + outputParameterKey := taskOutput.GetOutputParameterKey() glog.V(4).Info("producer: ", producer) currentTask := producer currentSubTaskMaybeDAG := true @@ -1286,13 +1288,14 @@ func resolveUpstreamParameters(cfg resolveUpstreamParametersConfig) error { // corresponding producer sub-task, reassign currentTask, // and iterate through this loop again. var outputParametersMap map[string]string - b, err := outputParametersCustomProperty["Output"].GetStructValue().MarshalJSON() + b, err := outputParametersCustomProperty[outputParameterKey].GetStructValue().MarshalJSON() if err != nil { return err } json.Unmarshal(b, &outputParametersMap) glog.V(4).Info("Deserialized outputParametersMap: ", outputParametersMap) subTaskName := outputParametersMap["producer_subtask"] + outputParameterKey = outputParametersMap["output_parameter_key"] glog.V(4).Infof( "Overriding currentTask, %v, output with currentTask's producer_subtask, %v, output.", currentTask.TaskName(), @@ -1302,7 +1305,7 @@ func resolveUpstreamParameters(cfg resolveUpstreamParametersConfig) error { // Reassign sub-task before running through the loop again. currentTask = tasks[subTaskName] } else { - cfg.inputs.ParameterValues[cfg.name] = outputParametersCustomProperty[taskOutput.GetOutputParameterKey()] + cfg.inputs.ParameterValues[cfg.name] = outputParametersCustomProperty[outputParameterKey] // Exit the loop. currentSubTaskMaybeDAG = false } @@ -1329,6 +1332,7 @@ type resolveUpstreamArtifactsConfig struct { // straightforward, or DAGs, in which case, we need to traverse the graph until // we arrive at a component/container (since there can be n nested DAGs). func resolveUpstreamArtifacts(cfg resolveUpstreamArtifactsConfig) error { + glog.V(4).Infof("artifactSpec: %#v", cfg.artifactSpec) taskOutput := cfg.artifactSpec.GetTaskOutputArtifact() if taskOutput.GetProducerTask() == "" { return cfg.artifactError(fmt.Errorf("producer task is empty")) @@ -1349,7 +1353,7 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamArtifactsConfig) error { } glog.V(4).Info("producer: ", producer) currentTask := producer - var outputArtifactKey string + var outputArtifactKey string = taskOutput.GetOutputArtifactKey() currentSubTaskMaybeDAG := true // Continue looping until we reach a sub-task that is NOT a DAG. for currentSubTaskMaybeDAG { @@ -1365,12 +1369,17 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamArtifactsConfig) error { return err } glog.V(4).Infof("Deserialized outputArtifacts: %v", outputArtifacts) - artifactSelectors := outputArtifacts["Output"].GetArtifactSelectors() - // TODO: Add support for multiple output artifacts. - subTaskName := artifactSelectors[0].ProducerSubtask - outputArtifactKey = artifactSelectors[0].OutputArtifactKey - glog.V(4).Info("subTaskName: ", subTaskName) - glog.V(4).Info("outputArtifactKey: ", outputArtifactKey) + // Adding support for multiple output artifacts + var subTaskName string + value := outputArtifacts[outputArtifactKey].GetArtifactSelectors() + + for _, v := range value { + glog.V(4).Infof("v: %v", v) + glog.V(4).Infof("v.ProducerSubtask: %v", v.ProducerSubtask) + glog.V(4).Infof("v.OutputArtifactKey: %v", v.OutputArtifactKey) + subTaskName = v.ProducerSubtask + outputArtifactKey = v.OutputArtifactKey + } // If the sub-task is a DAG, reassign currentTask and run // through the loop again. currentTask = tasks[subTaskName] From bf92df0d983767a087c341e9587662a3cd159e30 Mon Sep 17 00:00:00 2001 From: droctothorpe Date: Tue, 1 Oct 2024 11:05:31 -0400 Subject: [PATCH 8/8] Implement large tests for subdagio Signed-off-by: droctothorpe Co-authored-by: zazulam Co-authored-by: CarterFendley --- backend/src/v2/driver/driver.go | 6 +- samples/v2/sample_test.py | 65 +++++++++++++----- samples/v2/subdagio/__init__.py | 7 ++ samples/v2/subdagio/artifact.py | 47 +++++++++++++ samples/v2/subdagio/artifact_cache.py | 42 ++++++++++++ samples/v2/subdagio/mixed_parameters.py | 48 ++++++++++++++ .../subdagio/multiple_artifacts_namedtuple.py | 66 +++++++++++++++++++ .../multiple_parameters_namedtuple.py | 51 ++++++++++++++ samples/v2/subdagio/parameter.py | 45 +++++++++++++ samples/v2/subdagio/parameter_cache.py | 40 +++++++++++ .../kfp/compiler/pipeline_spec_builder.py | 11 +++- 11 files changed, 406 insertions(+), 22 deletions(-) create mode 100644 samples/v2/subdagio/__init__.py create mode 100644 samples/v2/subdagio/artifact.py create mode 100644 samples/v2/subdagio/artifact_cache.py create mode 100644 samples/v2/subdagio/mixed_parameters.py create mode 100644 samples/v2/subdagio/multiple_artifacts_namedtuple.py create mode 100644 samples/v2/subdagio/multiple_parameters_namedtuple.py create mode 100644 samples/v2/subdagio/parameter.py create mode 100644 samples/v2/subdagio/parameter_cache.py diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 60d04d7f53a9..f58a9677ea14 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -1353,7 +1353,7 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamArtifactsConfig) error { } glog.V(4).Info("producer: ", producer) currentTask := producer - var outputArtifactKey string = taskOutput.GetOutputArtifactKey() + outputArtifactKey := taskOutput.GetOutputArtifactKey() currentSubTaskMaybeDAG := true // Continue looping until we reach a sub-task that is NOT a DAG. for currentSubTaskMaybeDAG { @@ -1371,9 +1371,9 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamArtifactsConfig) error { glog.V(4).Infof("Deserialized outputArtifacts: %v", outputArtifacts) // Adding support for multiple output artifacts var subTaskName string - value := outputArtifacts[outputArtifactKey].GetArtifactSelectors() + artifactSelectors := outputArtifacts[outputArtifactKey].GetArtifactSelectors() - for _, v := range value { + for _, v := range artifactSelectors { glog.V(4).Infof("v: %v", v) glog.V(4).Infof("v.ProducerSubtask: %v", v.ProducerSubtask) glog.V(4).Infof("v.OutputArtifactKey: %v", v.OutputArtifactKey) diff --git a/samples/v2/sample_test.py b/samples/v2/sample_test.py index d34599a3c18e..ed5fa0da825c 100644 --- a/samples/v2/sample_test.py +++ b/samples/v2/sample_test.py @@ -11,20 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import unittest -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +import inspect +import os from pprint import pprint from typing import List +import unittest +import component_with_optional_inputs +import hello_world import kfp from kfp.dsl.graph_component import GraphComponent -import component_with_optional_inputs +import pipeline_container_no_input import pipeline_with_env -import hello_world import producer_consumer_param -import pipeline_container_no_input import two_step_pipeline_containerized _MINUTE = 60 # seconds @@ -38,16 +40,21 @@ class TestCase: class SampleTest(unittest.TestCase): - _kfp_host_and_port = os.getenv('KFP_API_HOST_AND_PORT', 'http://localhost:8888') - _kfp_ui_and_port = os.getenv('KFP_UI_HOST_AND_PORT', 'http://localhost:8080') + _kfp_host_and_port = os.getenv('KFP_API_HOST_AND_PORT', + 'http://localhost:8888') + _kfp_ui_and_port = os.getenv('KFP_UI_HOST_AND_PORT', + 'http://localhost:8080') _client = kfp.Client(host=_kfp_host_and_port, ui_host=_kfp_ui_and_port) def test(self): test_cases: List[TestCase] = [ TestCase(pipeline_func=hello_world.pipeline_hello_world), - TestCase(pipeline_func=producer_consumer_param.producer_consumer_param_pipeline), - TestCase(pipeline_func=pipeline_container_no_input.pipeline_container_no_input), - TestCase(pipeline_func=two_step_pipeline_containerized.two_step_pipeline_containerized), + TestCase(pipeline_func=producer_consumer_param + .producer_consumer_param_pipeline), + TestCase(pipeline_func=pipeline_container_no_input + .pipeline_container_no_input), + TestCase(pipeline_func=two_step_pipeline_containerized + .two_step_pipeline_containerized), TestCase(pipeline_func=component_with_optional_inputs.pipeline), TestCase(pipeline_func=pipeline_with_env.pipeline_with_env), @@ -56,27 +63,51 @@ def test(self): # TestCase(pipeline_func=pipeline_with_volume.pipeline_with_volume), # TestCase(pipeline_func=pipeline_with_secret_as_volume.pipeline_secret_volume), # TestCase(pipeline_func=pipeline_with_secret_as_env.pipeline_secret_env), + + # This next set of tests needs to be commented out until issue + # https://github.com/kubeflow/pipelines/issues/11239#issuecomment-2374792592 + # is addressed or the driver image that is used in CI is updated + # because otherwise the tests are run against incompatible version + # of the driver. In the meantime, for local validation, these tests + # can be executed (once you've manually deployed an updated driver + # image). + + # TestCase(pipeline_func=subdagio.parameter.crust), + # TestCase(pipeline_func=subdagio.parameter_cache.crust), + # TestCase(pipeline_func=subdagio.artifact_cache.crust), + # TestCase(pipeline_func=subdagio.artifact.crust), + # TestCase(pipeline_func=subdagio.mixed_parameters.crust), + # TestCase(pipeline_func=subdagio.multiple_parameters_namedtuple.crust) + # TestCase(pipeline_func=subdagio.multiple_artifacts_namedtuple.crust), ] with ThreadPoolExecutor() as executor: futures = [ - executor.submit(self.run_test_case, test_case.pipeline_func, test_case.timeout) - for test_case in test_cases + executor.submit(self.run_test_case, test_case.pipeline_func, + test_case.timeout) for test_case in test_cases ] for future in as_completed(futures): future.result() def run_test_case(self, pipeline_func: GraphComponent, timeout: int): with self.subTest(pipeline=pipeline_func, msg=pipeline_func.name): - run_result = self._client.create_run_from_pipeline_func(pipeline_func=pipeline_func) + print( + f'Running pipeline: {inspect.getmodule(pipeline_func.pipeline_func).__name__}/{pipeline_func.name}.' + ) + run_result = self._client.create_run_from_pipeline_func( + pipeline_func=pipeline_func) run_response = run_result.wait_for_run_completion(timeout) pprint(run_response.run_details) - print("Run details page URL:") - print(f"{self._kfp_ui_and_port}/#/runs/details/{run_response.run_id}") + print('Run details page URL:') + print( + f'{self._kfp_ui_and_port}/#/runs/details/{run_response.run_id}') - self.assertEqual(run_response.state, "SUCCEEDED") + self.assertEqual(run_response.state, 'SUCCEEDED') + print( + f'Pipeline, {inspect.getmodule(pipeline_func.pipeline_func).__name__}/{pipeline_func.name}, succeeded.' + ) if __name__ == '__main__': diff --git a/samples/v2/subdagio/__init__.py b/samples/v2/subdagio/__init__.py new file mode 100644 index 000000000000..dc8f8b3ceaee --- /dev/null +++ b/samples/v2/subdagio/__init__.py @@ -0,0 +1,7 @@ +from subdagio import artifact +from subdagio import artifact_cache +from subdagio import mixed_parameters +from subdagio import multiple_artifacts_namedtuple +from subdagio import multiple_parameters_namedtuple +from subdagio import parameter +from subdagio import parameter_cache diff --git a/samples/v2/subdagio/artifact.py b/samples/v2/subdagio/artifact.py new file mode 100644 index 000000000000..8f425662a1b7 --- /dev/null +++ b/samples/v2/subdagio/artifact.py @@ -0,0 +1,47 @@ +import os + +from kfp import Client +from kfp import dsl + + +@dsl.component +def core_comp(dataset: dsl.Output[dsl.Dataset]): + with open(dataset.path, 'w') as f: + f.write('foo') + + +@dsl.component +def crust_comp(input: dsl.Dataset): + with open(input.path, 'r') as f: + print('input: ', f.read()) + + +@dsl.pipeline +def core() -> dsl.Dataset: + task = core_comp() + task.set_caching_options(False) + + return task.output + + +@dsl.pipeline +def mantle() -> dsl.Dataset: + dag_task = core() + dag_task.set_caching_options(False) + + return dag_task.output + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + dag_task.set_caching_options(False) + + task = crust_comp(input=dag_task.output) + task.set_caching_options(False) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/artifact_cache.py b/samples/v2/subdagio/artifact_cache.py new file mode 100644 index 000000000000..5b52b25fb23d --- /dev/null +++ b/samples/v2/subdagio/artifact_cache.py @@ -0,0 +1,42 @@ +import os + +from kfp import Client +from kfp import dsl + + +@dsl.component +def core_comp(dataset: dsl.Output[dsl.Dataset]): + with open(dataset.path, 'w') as f: + f.write('foo') + + +@dsl.component +def crust_comp(input: dsl.Dataset): + with open(input.path, 'r') as f: + print('input: ', f.read()) + + +@dsl.pipeline +def core() -> dsl.Dataset: + task = core_comp() + + return task.output + + +@dsl.pipeline +def mantle() -> dsl.Dataset: + dag_task = core() + return dag_task.output + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + + task = crust_comp(input=dag_task.output) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/mixed_parameters.py b/samples/v2/subdagio/mixed_parameters.py new file mode 100644 index 000000000000..0a660d335d95 --- /dev/null +++ b/samples/v2/subdagio/mixed_parameters.py @@ -0,0 +1,48 @@ +import os + +from kfp import Client +from kfp import dsl +from kfp.compiler import Compiler + + +@dsl.component +def core_comp() -> int: + return 1 + + +@dsl.component +def crust_comp(x: int, y: int): + print('sum :', x + y) + + +@dsl.pipeline +def core() -> int: + task = core_comp() + task.set_caching_options(False) + + return task.output + + +@dsl.pipeline +def mantle() -> int: + dag_task = core() + dag_task.set_caching_options(False) + + return dag_task.output + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + dag_task.set_caching_options(False) + + task = crust_comp(x=2, y=dag_task.output) + task.set_caching_options(False) + + +if __name__ == '__main__': + Compiler().compile( + pipeline_func=crust, + package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/multiple_artifacts_namedtuple.py b/samples/v2/subdagio/multiple_artifacts_namedtuple.py new file mode 100644 index 000000000000..7d2777d38b06 --- /dev/null +++ b/samples/v2/subdagio/multiple_artifacts_namedtuple.py @@ -0,0 +1,66 @@ +import os +from typing import NamedTuple + +from kfp import Client +from kfp import dsl + + +@dsl.component +def core_comp(ds1: dsl.Output[dsl.Dataset], ds2: dsl.Output[dsl.Dataset]): + with open(ds1.path, 'w') as f: + f.write('foo') + with open(ds2.path, 'w') as f: + f.write('bar') + + +@dsl.component +def crust_comp( + ds1: dsl.Dataset, + ds2: dsl.Dataset, +): + with open(ds1.path, 'r') as f: + print('ds1: ', f.read()) + with open(ds2.path, 'r') as f: + print('ds2: ', f.read()) + + +@dsl.pipeline +def core() -> NamedTuple( + 'outputs', + ds1=dsl.Dataset, + ds2=dsl.Dataset, +): # type: ignore + task = core_comp() + task.set_caching_options(False) + + return task.outputs + + +@dsl.pipeline +def mantle() -> NamedTuple( + 'outputs', + ds1=dsl.Dataset, + ds2=dsl.Dataset, +): # type: ignore + dag_task = core() + dag_task.set_caching_options(False) + + return dag_task.outputs + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + dag_task.set_caching_options(False) + + task = crust_comp( + ds1=dag_task.outputs['ds1'], + ds2=dag_task.outputs['ds2'], + ) + task.set_caching_options(False) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/multiple_parameters_namedtuple.py b/samples/v2/subdagio/multiple_parameters_namedtuple.py new file mode 100644 index 000000000000..29699088554d --- /dev/null +++ b/samples/v2/subdagio/multiple_parameters_namedtuple.py @@ -0,0 +1,51 @@ +import os +from typing import NamedTuple + +from kfp import Client +from kfp import dsl + + +@dsl.component +def core_comp() -> NamedTuple('outputs', val1=str, val2=str): # type: ignore + outputs = NamedTuple('outputs', val1=str, val2=str) + return outputs('foo', 'bar') + + +@dsl.component +def crust_comp(val1: str, val2: str): + print('val1: ', val1) + print('val2: ', val2) + + +@dsl.pipeline +def core() -> NamedTuple('outputs', val1=str, val2=str): # type: ignore + task = core_comp() + task.set_caching_options(False) + + return task.outputs + + +@dsl.pipeline +def mantle() -> NamedTuple('outputs', val1=str, val2=str): # type: ignore + dag_task = core() + dag_task.set_caching_options(False) + + return dag_task.outputs + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + dag_task.set_caching_options(False) + + task = crust_comp( + val1=dag_task.outputs['val1'], + val2=dag_task.outputs['val2'], + ) + task.set_caching_options(False) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/parameter.py b/samples/v2/subdagio/parameter.py new file mode 100644 index 000000000000..c00439dd1c80 --- /dev/null +++ b/samples/v2/subdagio/parameter.py @@ -0,0 +1,45 @@ +import os + +from kfp import Client +from kfp import dsl + + +@dsl.component +def core_comp() -> str: + return 'foo' + + +@dsl.component +def crust_comp(input: str): + print('input :', input) + + +@dsl.pipeline +def core() -> str: + task = core_comp() + task.set_caching_options(False) + + return task.output + + +@dsl.pipeline +def mantle() -> str: + dag_task = core() + dag_task.set_caching_options(False) + + return dag_task.output + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + dag_task.set_caching_options(False) + + task = crust_comp(input=dag_task.output) + task.set_caching_options(False) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/parameter_cache.py b/samples/v2/subdagio/parameter_cache.py new file mode 100644 index 000000000000..9fe2402e2b8c --- /dev/null +++ b/samples/v2/subdagio/parameter_cache.py @@ -0,0 +1,40 @@ +import os + +from kfp import Client +from kfp import dsl + + +@dsl.component +def core_comp() -> str: + return 'foo' + + +@dsl.component +def crust_comp(input: str): + print('input :', input) + + +@dsl.pipeline +def core() -> str: + task = core_comp() + + return task.output + + +@dsl.pipeline +def mantle() -> str: + dag_task = core() + + return dag_task.output + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + task = crust_comp(input=dag_task.output) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/sdk/python/kfp/compiler/pipeline_spec_builder.py b/sdk/python/kfp/compiler/pipeline_spec_builder.py index 6e4bc4e8690c..fbc3bb463df3 100644 --- a/sdk/python/kfp/compiler/pipeline_spec_builder.py +++ b/sdk/python/kfp/compiler/pipeline_spec_builder.py @@ -1872,7 +1872,9 @@ def validate_pipeline_outputs_dict( f'Pipeline outputs may only be returned from the top level of the pipeline function scope. Got pipeline output from within the control flow group dsl.{channel.task.parent_task_group.__class__.__name__}.' ) else: - raise ValueError(f'Got unknown pipeline output: {channel}.') + raise ValueError( + f'Got unknown pipeline output, {channel}, of type {type(channel)}.' + ) def create_pipeline_spec( @@ -2006,13 +2008,18 @@ def convert_pipeline_outputs_to_dict( output name to PipelineChannel.""" if pipeline_outputs is None: return {} + elif isinstance(pipeline_outputs, dict): + # This condition is required to support pipelines that return NamedTuples. + return pipeline_outputs elif isinstance(pipeline_outputs, pipeline_channel.PipelineChannel): return {component_factory.SINGLE_OUTPUT_NAME: pipeline_outputs} elif isinstance(pipeline_outputs, tuple) and hasattr( pipeline_outputs, '_asdict'): return dict(pipeline_outputs._asdict()) else: - raise ValueError(f'Got unknown pipeline output: {pipeline_outputs}') + raise ValueError( + f'Got unknown pipeline output, {pipeline_outputs}, of type {type(pipeline_outputs)}.' + ) def write_pipeline_spec_to_file(