@@ -17,10 +17,11 @@ import (
17
17
"context"
18
18
"encoding/json"
19
19
"fmt"
20
- "github.com/kubeflow/pipelines/backend/src/v2/objectstore"
21
20
"strconv"
22
21
"time"
23
22
23
+ "github.com/kubeflow/pipelines/backend/src/v2/objectstore"
24
+
24
25
"github.com/golang/glog"
25
26
"github.com/golang/protobuf/ptypes/timestamp"
26
27
"github.com/google/uuid"
@@ -125,6 +126,8 @@ func RootDAG(ctx context.Context, opts Options, mlmd *metadata.Client) (executio
125
126
err = fmt .Errorf ("driver.RootDAG(%s) failed: %w" , opts .info (), err )
126
127
}
127
128
}()
129
+ b , _ := json .Marshal (opts )
130
+ glog .V (4 ).Info ("RootDAG opts: " , string (b ))
128
131
err = validateRootDAG (opts )
129
132
if err != nil {
130
133
return nil , err
@@ -230,6 +233,8 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl
230
233
err = fmt .Errorf ("driver.Container(%s) failed: %w" , opts .info (), err )
231
234
}
232
235
}()
236
+ b , _ := json .Marshal (opts )
237
+ glog .V (4 ).Info ("Container opts: " , string (b ))
233
238
err = validateContainer (opts )
234
239
if err != nil {
235
240
return nil , err
@@ -699,6 +704,8 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
699
704
err = fmt .Errorf ("driver.DAG(%s) failed: %w" , opts .info (), err )
700
705
}
701
706
}()
707
+ b , _ := json .Marshal (opts )
708
+ glog .V (4 ).Info ("DAG opts: " , string (b ))
702
709
err = validateDAG (opts )
703
710
if err != nil {
704
711
return nil , err
@@ -749,6 +756,27 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
749
756
ecfg .ParentDagID = dag .Execution .GetID ()
750
757
ecfg .IterationIndex = iterationIndex
751
758
ecfg .NotTriggered = ! execution .WillTrigger ()
759
+
760
+ outputParameters := opts .Component .GetDag ().GetOutputs ().GetParameters ()
761
+ glog .V (4 ).Info ("outputParameters: " , outputParameters )
762
+ for _ , value := range outputParameters {
763
+ outputParameterKey := value .GetValueFromParameter ().OutputParameterKey
764
+ producerSubTask := value .GetValueFromParameter ().ProducerSubtask
765
+ glog .V (4 ).Info ("outputParameterKey: " , outputParameterKey )
766
+ glog .V (4 ).Info ("producerSubtask: " , producerSubTask )
767
+
768
+ outputParameterMap := map [string ]interface {}{
769
+ "output_parameter_key" : outputParameterKey ,
770
+ "producer_subtask" : producerSubTask ,
771
+ }
772
+
773
+ outputParameterStruct , _ := structpb .NewValue (outputParameterMap )
774
+
775
+ ecfg .OutputParameters = map [string ]* structpb.Value {
776
+ value .GetValueFromParameter ().OutputParameterKey : outputParameterStruct ,
777
+ }
778
+ }
779
+
752
780
if opts .Task .GetArtifactIterator () != nil {
753
781
return execution , fmt .Errorf ("ArtifactIterator is not implemented" )
754
782
}
@@ -793,6 +821,12 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
793
821
ecfg .IterationCount = & count
794
822
execution .IterationCount = & count
795
823
}
824
+
825
+ glog .V (4 ).Info ("pipeline: " , pipeline )
826
+ b , _ = json .Marshal (* ecfg )
827
+ glog .V (4 ).Info ("ecfg: " , string (b ))
828
+ glog .V (4 ).Infof ("dag: %v" , dag )
829
+
796
830
// TODO(Bobgy): change execution state to pending, because this is driver, execution hasn't started.
797
831
createdExecution , err := mlmd .CreateExecution (ctx , pipeline , ecfg )
798
832
if err != nil {
@@ -939,6 +973,8 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
939
973
err = fmt .Errorf ("failed to resolve inputs: %w" , err )
940
974
}
941
975
}()
976
+ glog .V (4 ).Infof ("dag: %v" , dag )
977
+ glog .V (4 ).Infof ("task: %v" , task )
942
978
inputParams , _ , err := dag .Execution .GetParameters ()
943
979
if err != nil {
944
980
return nil , err
@@ -1112,10 +1148,31 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
1112
1148
if err != nil {
1113
1149
return nil , err
1114
1150
}
1151
+ // TODO: Make this recursive.
1152
+ for _ , v := range tasks {
1153
+ if v .GetExecution ().GetType () == "system.DAGExecution" {
1154
+ glog .V (4 ).Infof ("Found a task, %v, with an execution type of system.DAGExecution. Adding its tasks to the task list." , v .TaskName ())
1155
+ dag , err := mlmd .GetDAG (ctx , v .GetExecution ().GetId ())
1156
+ if err != nil {
1157
+ return nil , err
1158
+ }
1159
+ subdagTasks , err := mlmd .GetExecutionsInDAG (ctx , dag , pipeline )
1160
+ if err != nil {
1161
+ return nil , err
1162
+ }
1163
+ for k , v := range subdagTasks {
1164
+ tasks [k ] = v
1165
+ }
1166
+ }
1167
+ }
1115
1168
tasksCache = tasks
1169
+
1116
1170
return tasks , nil
1117
1171
}
1172
+
1118
1173
for name , paramSpec := range task .GetInputs ().GetParameters () {
1174
+ glog .V (4 ).Infof ("name: %v" , name )
1175
+ glog .V (4 ).Infof ("paramSpec: %v" , paramSpec )
1119
1176
paramError := func (err error ) error {
1120
1177
return fmt .Errorf ("resolving input parameter %s with spec %s: %w" , name , paramSpec , err )
1121
1178
}
@@ -1131,8 +1188,11 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
1131
1188
}
1132
1189
inputs .ParameterValues [name ] = v
1133
1190
1191
+ // This is the case where we are consuming an output parameter from an
1192
+ // upstream task. That task can be a container or a DAG.
1134
1193
case * pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter :
1135
1194
taskOutput := paramSpec .GetTaskOutputParameter ()
1195
+ glog .V (4 ).Info ("taskOutput: " , taskOutput )
1136
1196
if taskOutput .GetProducerTask () == "" {
1137
1197
return nil , paramError (fmt .Errorf ("producer task is empty" ))
1138
1198
}
@@ -1143,19 +1203,56 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
1143
1203
if err != nil {
1144
1204
return nil , paramError (err )
1145
1205
}
1206
+
1207
+ // The producer is the task that produces the output that we need to
1208
+ // consume.
1146
1209
producer , ok := tasks [taskOutput .GetProducerTask ()]
1147
- if ! ok {
1148
- return nil , paramError (fmt .Errorf ("cannot find producer task %q" , taskOutput .GetProducerTask ()))
1149
- }
1150
- _ , outputs , err := producer .GetParameters ()
1210
+
1211
+ glog .V (4 ).Info ("producer: " , producer )
1212
+
1213
+ // Get the producer's outputs.
1214
+ _ , producerOutputs , err := producer .GetParameters ()
1151
1215
if err != nil {
1152
1216
return nil , paramError (fmt .Errorf ("get producer output parameters: %w" , err ))
1153
1217
}
1154
- param , ok := outputs [taskOutput .GetOutputParameterKey ()]
1218
+ glog .V (4 ).Info ("producer output parameters: " , producerOutputs )
1219
+ // Deserialize them.
1220
+ var producerOutputsMap map [string ]string
1221
+ b , err := producerOutputs ["Output" ].GetStructValue ().MarshalJSON ()
1222
+ if err != nil {
1223
+ return nil , err
1224
+ }
1225
+ json .Unmarshal (b , & producerOutputsMap )
1226
+ glog .V (4 ).Info ("producerOutputsMap: " , producerOutputsMap )
1227
+
1228
+ // If the producer's output includes a producer subtask, which means
1229
+ // that the producer is a DAG that is getting its output from one of
1230
+ // the tasks in the DAG, then, we want to roll up the output
1231
+ // from the producer subtask to the producer, so that the downstream
1232
+ // logic can retrieve it appropriately.
1233
+ if producerSubTask , ok := producerOutputsMap ["producer_subtask" ]; ok {
1234
+ glog .V (4 ).Infof (
1235
+ "Overriding producer task, %v, output with producer_subtask, %v, output." ,
1236
+ producer .TaskName (),
1237
+ producerSubTask ,
1238
+ )
1239
+ _ , producerOutputs , err = tasks [producerSubTask ].GetParameters ()
1240
+ if err != nil {
1241
+ return nil , err
1242
+ }
1243
+ glog .V (4 ).Info ("producerSubTask output parameters: " , producerOutputs )
1244
+ // The only reason we're updating this is to make the downstream
1245
+ // logging more accurate.
1246
+ taskOutput .ProducerTask = producerOutputsMap ["producer_subtask" ]
1247
+ }
1248
+
1249
+ // Grab the value of the producer output.
1250
+ producerOutputValue , ok := producerOutputs [taskOutput .GetOutputParameterKey ()]
1155
1251
if ! ok {
1156
1252
return nil , paramError (fmt .Errorf ("cannot find output parameter key %q in producer task %q" , taskOutput .GetOutputParameterKey (), taskOutput .GetProducerTask ()))
1157
1253
}
1158
- inputs .ParameterValues [name ] = param
1254
+ // Update the input to be the producer output value.
1255
+ inputs .ParameterValues [name ] = producerOutputValue
1159
1256
case * pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue :
1160
1257
runtimeValue := paramSpec .GetRuntimeValue ()
1161
1258
switch t := runtimeValue .Value .(type ) {
0 commit comments