Skip to content

Commit af3c3e1

Browse files
droctothorpezazulamCarterFendley
committed
fix(backend): implement subdag output resolution
Signed-off-by: droctothorpe <mythicalsunlight@gmail.com> Co-authored-by: zazulam <m.zazula@gmail.com> Co-authored-by: CarterFendley <carter.fendley@gmail.com>
1 parent 1cded35 commit af3c3e1

File tree

2 files changed

+117
-10
lines changed

2 files changed

+117
-10
lines changed

backend/src/v2/driver/driver.go

Lines changed: 104 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ import (
1717
"context"
1818
"encoding/json"
1919
"fmt"
20-
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
2120
"strconv"
2221
"time"
2322

23+
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
24+
2425
"github.com/golang/glog"
2526
"github.com/golang/protobuf/ptypes/timestamp"
2627
"github.com/google/uuid"
@@ -125,6 +126,8 @@ func RootDAG(ctx context.Context, opts Options, mlmd *metadata.Client) (executio
125126
err = fmt.Errorf("driver.RootDAG(%s) failed: %w", opts.info(), err)
126127
}
127128
}()
129+
b, _ := json.Marshal(opts)
130+
glog.V(4).Info("RootDAG opts: ", string(b))
128131
err = validateRootDAG(opts)
129132
if err != nil {
130133
return nil, err
@@ -230,6 +233,8 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl
230233
err = fmt.Errorf("driver.Container(%s) failed: %w", opts.info(), err)
231234
}
232235
}()
236+
b, _ := json.Marshal(opts)
237+
glog.V(4).Info("Container opts: ", string(b))
233238
err = validateContainer(opts)
234239
if err != nil {
235240
return nil, err
@@ -699,6 +704,8 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
699704
err = fmt.Errorf("driver.DAG(%s) failed: %w", opts.info(), err)
700705
}
701706
}()
707+
b, _ := json.Marshal(opts)
708+
glog.V(4).Info("DAG opts: ", string(b))
702709
err = validateDAG(opts)
703710
if err != nil {
704711
return nil, err
@@ -749,6 +756,27 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
749756
ecfg.ParentDagID = dag.Execution.GetID()
750757
ecfg.IterationIndex = iterationIndex
751758
ecfg.NotTriggered = !execution.WillTrigger()
759+
760+
outputParameters := opts.Component.GetDag().GetOutputs().GetParameters()
761+
glog.V(4).Info("outputParameters: ", outputParameters)
762+
for _, value := range outputParameters {
763+
outputParameterKey := value.GetValueFromParameter().OutputParameterKey
764+
producerSubTask := value.GetValueFromParameter().ProducerSubtask
765+
glog.V(4).Info("outputParameterKey: ", outputParameterKey)
766+
glog.V(4).Info("producerSubtask: ", producerSubTask)
767+
768+
outputParameterMap := map[string]interface{}{
769+
"output_parameter_key": outputParameterKey,
770+
"producer_subtask": producerSubTask,
771+
}
772+
773+
outputParameterStruct, _ := structpb.NewValue(outputParameterMap)
774+
775+
ecfg.OutputParameters = map[string]*structpb.Value{
776+
value.GetValueFromParameter().OutputParameterKey: outputParameterStruct,
777+
}
778+
}
779+
752780
if opts.Task.GetArtifactIterator() != nil {
753781
return execution, fmt.Errorf("ArtifactIterator is not implemented")
754782
}
@@ -793,6 +821,12 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
793821
ecfg.IterationCount = &count
794822
execution.IterationCount = &count
795823
}
824+
825+
glog.V(4).Info("pipeline: ", pipeline)
826+
b, _ = json.Marshal(*ecfg)
827+
glog.V(4).Info("ecfg: ", string(b))
828+
glog.V(4).Infof("dag: %v", dag)
829+
796830
// TODO(Bobgy): change execution state to pending, because this is driver, execution hasn't started.
797831
createdExecution, err := mlmd.CreateExecution(ctx, pipeline, ecfg)
798832
if err != nil {
@@ -939,6 +973,8 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
939973
err = fmt.Errorf("failed to resolve inputs: %w", err)
940974
}
941975
}()
976+
glog.V(4).Infof("dag: %v", dag)
977+
glog.V(4).Infof("task: %v", task)
942978
inputParams, _, err := dag.Execution.GetParameters()
943979
if err != nil {
944980
return nil, err
@@ -1112,10 +1148,31 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
11121148
if err != nil {
11131149
return nil, err
11141150
}
1151+
// TODO: Make this recursive.
1152+
for _, v := range tasks {
1153+
if v.GetExecution().GetType() == "system.DAGExecution" {
1154+
glog.V(4).Infof("Found a task, %v, with an execution type of system.DAGExecution. Adding its tasks to the task list.", v.TaskName())
1155+
dag, err := mlmd.GetDAG(ctx, v.GetExecution().GetId())
1156+
if err != nil {
1157+
return nil, err
1158+
}
1159+
subdagTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline)
1160+
if err != nil {
1161+
return nil, err
1162+
}
1163+
for k, v := range subdagTasks {
1164+
tasks[k] = v
1165+
}
1166+
}
1167+
}
11151168
tasksCache = tasks
1169+
11161170
return tasks, nil
11171171
}
1172+
11181173
for name, paramSpec := range task.GetInputs().GetParameters() {
1174+
glog.V(4).Infof("name: %v", name)
1175+
glog.V(4).Infof("paramSpec: %v", paramSpec)
11191176
paramError := func(err error) error {
11201177
return fmt.Errorf("resolving input parameter %s with spec %s: %w", name, paramSpec, err)
11211178
}
@@ -1131,8 +1188,11 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
11311188
}
11321189
inputs.ParameterValues[name] = v
11331190

1191+
// This is the case where we are consuming an output parameter from an
1192+
// upstream task. That task can be a container or a DAG.
11341193
case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter:
11351194
taskOutput := paramSpec.GetTaskOutputParameter()
1195+
glog.V(4).Info("taskOutput: ", taskOutput)
11361196
if taskOutput.GetProducerTask() == "" {
11371197
return nil, paramError(fmt.Errorf("producer task is empty"))
11381198
}
@@ -1143,19 +1203,56 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
11431203
if err != nil {
11441204
return nil, paramError(err)
11451205
}
1206+
1207+
// The producer is the task that produces the output that we need to
1208+
// consume.
11461209
producer, ok := tasks[taskOutput.GetProducerTask()]
1147-
if !ok {
1148-
return nil, paramError(fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask()))
1149-
}
1150-
_, outputs, err := producer.GetParameters()
1210+
1211+
glog.V(4).Info("producer: ", producer)
1212+
1213+
// Get the producer's outputs.
1214+
_, producerOutputs, err := producer.GetParameters()
11511215
if err != nil {
11521216
return nil, paramError(fmt.Errorf("get producer output parameters: %w", err))
11531217
}
1154-
param, ok := outputs[taskOutput.GetOutputParameterKey()]
1218+
glog.V(4).Info("producer output parameters: ", producerOutputs)
1219+
// Deserialize them.
1220+
var producerOutputsMap map[string]string
1221+
b, err := producerOutputs["Output"].GetStructValue().MarshalJSON()
1222+
if err != nil {
1223+
return nil, err
1224+
}
1225+
json.Unmarshal(b, &producerOutputsMap)
1226+
glog.V(4).Info("producerOutputsMap: ", producerOutputsMap)
1227+
1228+
// If the producer's output includes a producer subtask, which means
1229+
// that the producer is a DAG that is getting its output from one of
1230+
// the tasks in the DAG, then, we want to roll up the output
1231+
// from the producer subtask to the producer, so that the downstream
1232+
// logic can retrieve it appropriately.
1233+
if producerSubTask, ok := producerOutputsMap["producer_subtask"]; ok {
1234+
glog.V(4).Infof(
1235+
"Overriding producer task, %v, output with producer_subtask, %v, output.",
1236+
producer.TaskName(),
1237+
producerSubTask,
1238+
)
1239+
_, producerOutputs, err = tasks[producerSubTask].GetParameters()
1240+
if err != nil {
1241+
return nil, err
1242+
}
1243+
glog.V(4).Info("producerSubTask output parameters: ", producerOutputs)
1244+
// The only reason we're updating this is to make the downstream
1245+
// logging more accurate.
1246+
taskOutput.ProducerTask = producerOutputsMap["producer_subtask"]
1247+
}
1248+
1249+
// Grab the value of the producer output.
1250+
producerOutputValue, ok := producerOutputs[taskOutput.GetOutputParameterKey()]
11551251
if !ok {
11561252
return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask()))
11571253
}
1158-
inputs.ParameterValues[name] = param
1254+
// Update the input to be the producer output value.
1255+
inputs.ParameterValues[name] = producerOutputValue
11591256
case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue:
11601257
runtimeValue := paramSpec.GetRuntimeValue()
11611258
switch t := runtimeValue.Value.(type) {

backend/src/v2/metadata/client.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ import (
2121
"encoding/json"
2222
"errors"
2323
"fmt"
24-
"github.com/kubeflow/pipelines/backend/src/common/util"
25-
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
2624
"path"
2725
"strconv"
2826
"strings"
2927
"sync"
3028
"time"
3129

30+
"github.com/kubeflow/pipelines/backend/src/common/util"
31+
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
32+
3233
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
3334

3435
"github.com/golang/glog"
@@ -134,6 +135,7 @@ type ExecutionConfig struct {
134135
NotTriggered bool // optional, not triggered executions will have CANCELED state.
135136
ParentDagID int64 // parent DAG execution ID. Only the root DAG does not have a parent DAG.
136137
InputParameters map[string]*structpb.Value
138+
OutputParameters map[string]*structpb.Value
137139
InputArtifactIDs map[string][]int64
138140
IterationIndex *int // Index of the iteration.
139141

@@ -448,6 +450,8 @@ func getArtifactName(eventPath *pb.Event_Path) (string, error) {
448450
func (c *Client) PublishExecution(ctx context.Context, execution *Execution, outputParameters map[string]*structpb.Value, outputArtifacts []*OutputArtifact, state pb.Execution_State) error {
449451
e := execution.execution
450452
e.LastKnownState = state.Enum()
453+
glog.V(4).Infof("outputParameters: %v", outputParameters)
454+
glog.V(4).Infof("outputArtifacts: %v", outputArtifacts)
451455

452456
if outputParameters != nil {
453457
// Record output parameters.
@@ -576,7 +580,13 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config
576580
},
577581
}}
578582
}
579-
583+
if config.OutputParameters != nil {
584+
e.CustomProperties[keyOutputs] = &pb.Value{Value: &pb.Value_StructValue{
585+
StructValue: &structpb.Struct{
586+
Fields: config.OutputParameters,
587+
},
588+
}}
589+
}
580590
req := &pb.PutExecutionRequest{
581591
Execution: e,
582592
Contexts: []*pb.Context{pipeline.pipelineCtx, pipeline.pipelineRunCtx},

0 commit comments

Comments
 (0)