Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Handle artifact outputs as well
Browse files Browse the repository at this point in the history
Signed-off-by: droctothorpe <mythicalsunlight@gmail.com>
Co-authored-by: zazulam <m.zazula@gmail.com>
Co-authored-by: CarterFendley <carter.fendley@gmail.com>
Co-authored-by: edmondop <edmondo.porcu@gmail.com>
4 people committed Sep 17, 2024

Verified

This commit was signed with the committer’s verified signature.
jalling97 John Alling
1 parent 1cb4db8 commit d920cf1
Showing 3 changed files with 90 additions and 29 deletions.
1 change: 1 addition & 0 deletions backend/src/v2/cmd/driver/main.go
Original file line number Diff line number Diff line change
@@ -85,6 +85,7 @@ func init() {
flag.Set("logtostderr", "true")
// Change the WARNING to INFO level for debugging.
flag.Set("stderrthreshold", "WARNING")
flag.Set("v", "4")
}

func validate() error {
100 changes: 72 additions & 28 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
@@ -757,6 +757,7 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
ecfg.IterationIndex = iterationIndex
ecfg.NotTriggered = !execution.WillTrigger()

// Handle writing output parameters to MLMD.
outputParameters := opts.Component.GetDag().GetOutputs().GetParameters()
glog.V(4).Info("outputParameters: ", outputParameters)
for _, value := range outputParameters {
@@ -777,6 +778,11 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
}
}

// Handle writing output artifacts to MLMD.
outputArtifacts := opts.Component.GetDag().GetOutputs().GetArtifacts()
glog.V(4).Info("outputArtifacts: ", outputArtifacts)
ecfg.OutputArtifacts = outputArtifacts

if opts.Task.GetArtifactIterator() != nil {
return execution, fmt.Errorf("ArtifactIterator is not implemented")
}
@@ -1177,6 +1183,7 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
return inputs, nil
}

// Handle parameters.
for name, paramSpec := range task.GetInputs().GetParameters() {
glog.V(4).Infof("name: %v", name)
glog.V(4).Infof("paramSpec: %v", paramSpec)
@@ -1224,50 +1231,50 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
glog.V(4).Info("producer: ", producer)

// Get the producer's outputs.
_, producerOutputs, err := producer.GetParameters()
_, producerOutputParameters, err := producer.GetParameters()
if err != nil {
return nil, paramError(fmt.Errorf("get producer output parameters: %w", err))
}
glog.V(4).Info("producer output parameters: ", producerOutputs)
glog.V(4).Info("producer output parameters: ", producerOutputParameters)
// Deserialize them.
var producerOutputsMap map[string]string
b, err := producerOutputs["Output"].GetStructValue().MarshalJSON()
var producerOutputParametersMap map[string]string
b, err := producerOutputParameters["Output"].GetStructValue().MarshalJSON()
if err != nil {
return nil, err
}
json.Unmarshal(b, &producerOutputsMap)
glog.V(4).Info("producerOutputsMap: ", producerOutputsMap)
json.Unmarshal(b, &producerOutputParametersMap)
glog.V(4).Info("producerOutputParametersMap: ", producerOutputParametersMap)

// If the producer's output includes a producer subtask, which means
// that the producer is a DAG that is getting its output from one of
// the tasks in the DAG, then we want to roll up the output from the
// producer subtask to the producer, so that the downstream logic
// can retrieve it appropriately.
if producerSubTask, ok := producerOutputsMap["producer_subtask"]; ok {
if producerSubTask, ok := producerOutputParametersMap["producer_subtask"]; ok {
glog.V(4).Infof(
"Overriding producer task, %v, output with producer_subtask, %v, output.",
producer.TaskName(),
producerSubTask,
)
_, producerOutputs, err = tasks[producerSubTask].GetParameters()
_, producerOutputParameters, err = tasks[producerSubTask].GetParameters()
if err != nil {
return nil, err
}
glog.V(4).Info("producerSubTask output parameters: ", producerOutputs)
glog.V(4).Info("producerSubTask output parameters: ", producerOutputParameters)
// The only reason we're updating this is to make the downstream
// logging more accurate.
taskOutput.ProducerTask = producerOutputsMap["producer_subtask"]
taskOutput.ProducerTask = producerOutputParametersMap["producer_subtask"]
// Grab the value of the producer output.
producerOutputValue, ok := producerOutputs[taskOutput.GetOutputParameterKey()]
producerOutputParameterValue, ok := producerOutputParameters[taskOutput.GetOutputParameterKey()]
if !ok {
return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask()))
}
// Update the input to be the producer output value.
inputs.ParameterValues[name] = producerOutputValue
inputs.ParameterValues[name] = producerOutputParameterValue
} else {
// The producer subtask is not a DAG, so we exit the loop.
producerSubTaskMaybeDAG = false
inputs.ParameterValues[name] = producerOutputs[taskOutput.GetOutputParameterKey()]
inputs.ParameterValues[name] = producerOutputParameters[taskOutput.GetOutputParameterKey()]
}
}

@@ -1286,6 +1293,8 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
return nil, paramError(fmt.Errorf("parameter spec of type %T not implemented yet", t))
}
}

// Handle artifacts.
for name, artifactSpec := range task.GetInputs().GetArtifacts() {
artifactError := func(err error) error {
return fmt.Errorf("failed to resolve input artifact %s with spec %s: %w", name, artifactSpec, err)
@@ -1314,25 +1323,60 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
if err != nil {
return nil, artifactError(err)
}

producer, ok := tasks[taskOutput.GetProducerTask()]
if !ok {
return nil, artifactError(fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask()))
}
// TODO(Bobgy): cache results
outputs, err := mlmd.GetOutputArtifactsByExecutionId(ctx, producer.GetID())
if err != nil {
return nil, artifactError(err)
}
artifact, ok := outputs[taskOutput.GetOutputArtifactKey()]
if !ok {
return nil, artifactError(fmt.Errorf("cannot find output artifact key %q in producer task %q", taskOutput.GetOutputArtifactKey(), taskOutput.GetProducerTask()))
}
runtimeArtifact, err := artifact.ToRuntimeArtifact()
if err != nil {
return nil, artifactError(err)
}
inputs.Artifacts[name] = &pipelinespec.ArtifactList{
Artifacts: []*pipelinespec.RuntimeArtifact{runtimeArtifact},
glog.V(4).Info("producer: ", producer)
currentTask := producer
var outputArtifactKey string
currentSubTaskMaybeDAG := true
for currentSubTaskMaybeDAG {
// If the current task is a DAG:
glog.V(4).Info("currentTask: ", currentTask.TaskName())
if *currentTask.GetExecution().Type == "system.DAGExecution" {
// Get the sub-task.
outputArtifactsCustomProperty := currentTask.GetExecution().GetCustomProperties()["output_artifacts"]
// Deserialize the output artifacts.
var outputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec
err := json.Unmarshal([]byte(outputArtifactsCustomProperty.GetStringValue()), &outputArtifacts)
if err != nil {
return nil, err
}
glog.V(4).Infof("Deserialized outputArtifacts: %v", outputArtifacts)
artifactSelectors := outputArtifacts["Output"].GetArtifactSelectors()
// TODO: Add support for multiple output artifacts.
subTaskName := artifactSelectors[0].ProducerSubtask
outputArtifactKey = artifactSelectors[0].OutputArtifactKey
glog.V(4).Info("subTaskName: ", subTaskName)
glog.V(4).Info("outputArtifactKey: ", outputArtifactKey)
currentSubTask := tasks[subTaskName]
// If the sub-task is a DAG, reassign currentTask and run
// through the loop again.
currentTask = currentSubTask
// }
} else {
// Base case, subtask is a container, not a DAG.
outputs, err := mlmd.GetOutputArtifactsByExecutionId(ctx, currentTask.GetID())
if err != nil {
return nil, artifactError(err)
}
glog.V(4).Infof("outputs: %#v", outputs)
artifact, ok := outputs[outputArtifactKey]
if !ok {
return nil, artifactError(fmt.Errorf("cannot find output artifact key %q in producer task %q", taskOutput.GetOutputArtifactKey(), taskOutput.GetProducerTask()))
}
runtimeArtifact, err := artifact.ToRuntimeArtifact()
if err != nil {
return nil, artifactError(err)
}
inputs.Artifacts[name] = &pipelinespec.ArtifactList{
Artifacts: []*pipelinespec.RuntimeArtifact{runtimeArtifact},
}
// Since we are in the base case, escape the loop.
currentSubTaskMaybeDAG = false
}
}
default:
return nil, artifactError(fmt.Errorf("artifact spec of type %T not implemented yet", t))
18 changes: 17 additions & 1 deletion backend/src/v2/metadata/client.go
Original file line number Diff line number Diff line change
@@ -136,6 +136,9 @@ type ExecutionConfig struct {
ParentDagID int64 // parent DAG execution ID. Only the root DAG does not have a parent DAG.
InputParameters map[string]*structpb.Value
OutputParameters map[string]*structpb.Value
// OutputArtifacts map[string]*structpb.Value
// OutputArtifacts []*pipelinespec.DagOutputsSpec_ArtifactSelectorSpec
OutputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec
InputArtifactIDs map[string][]int64
IterationIndex *int // Index of the iteration.

@@ -516,7 +519,8 @@ const (
keyCacheFingerPrint = "cache_fingerprint"
keyCachedExecutionID = "cached_execution_id"
keyInputs = "inputs"
keyOutputs = "outputs"
keyOutputs = "outputs" // TODO: Consider renaming this to output_parameters to be consistent.
keyOutputArtifacts = "output_artifacts"
keyParentDagID = "parent_dag_id" // Parent DAG Execution ID.
keyIterationIndex = "iteration_index"
keyIterationCount = "iteration_count"
@@ -580,13 +584,25 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config
},
}}
}
// We save the output parameter and output artifact relationships in MLMD in
// case they're provided by a sub-task so that we can follow the
// relationships and retrieve outputs downstream in components that depend
// on said outputs as inputs.
if config.OutputParameters != nil {
e.CustomProperties[keyOutputs] = &pb.Value{Value: &pb.Value_StructValue{
StructValue: &structpb.Struct{
Fields: config.OutputParameters,
},
}}
}
if config.OutputArtifacts != nil {
b, err := json.Marshal(config.OutputArtifacts)
if err != nil {
return nil, err
}
e.CustomProperties[keyOutputArtifacts] = StringValue(string(b))
}

req := &pb.PutExecutionRequest{
Execution: e,
Contexts: []*pb.Context{pipeline.pipelineCtx, pipeline.pipelineRunCtx},

0 comments on commit d920cf1

Please sign in to comment.