diff --git a/backend/src/v2/component/launcher_v2.go b/backend/src/v2/component/launcher_v2.go index 2b5297c7b09c..8b40ff346e08 100644 --- a/backend/src/v2/component/launcher_v2.go +++ b/backend/src/v2/component/launcher_v2.go @@ -148,6 +148,17 @@ func (l *LauncherV2) Execute(ctx context.Context) (err error) { } } glog.Infof("publish success.") + // At the end of the current task, we check the statuses of all tasks in the current DAG and update the DAG's status accordingly. + // TODO: If there's a pipeline whose only components are DAGs, this launcher logic will never run and as a result the dag status will never be updated. We need to implement a mechanism to handle this edge case. + dag, err := l.metadataClient.GetDAG(ctx, execution.GetExecution().CustomProperties["parent_dag_id"].GetIntValue()) + if err != nil { + glog.Errorf("DAG Status Update: failed to get DAG: %s", err.Error()) + } + pipeline, _ := l.metadataClient.GetPipelineFromExecution(ctx, execution.GetID()) + err = l.metadataClient.UpdateDAGExecutionsState(ctx, dag, pipeline) + if err != nil { + glog.Errorf("failed to update DAG state: %s", err.Error()) + } }() executedStartedTime := time.Now().Unix() execution, err = l.prePublish(ctx) diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 8e1801864c2f..45ac48a0daa6 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -344,7 +344,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl return execution, err } if opts.KubernetesExecutorConfig != nil { - dagTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline) + dagTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline, true) if err != nil { return execution, err } @@ -758,29 +758,12 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E ecfg.NotTriggered = !execution.WillTrigger() // Handle writing output parameters to MLMD. - outputParameters := opts.Component.GetDag().GetOutputs().GetParameters() - glog.V(4).Info("outputParameters: ", outputParameters) - ecfg.OutputParameters = make(map[string]*structpb.Value) - for name, value := range outputParameters { - outputParameterKey := value.GetValueFromParameter().GetOutputParameterKey() - producerSubTask := value.GetValueFromParameter().GetProducerSubtask() - glog.V(4).Info("outputParameterKey: ", outputParameterKey) - glog.V(4).Info("producerSubtask: ", producerSubTask) - - outputParameterMap := map[string]interface{}{ - "output_parameter_key": outputParameterKey, - "producer_subtask": producerSubTask, - } - - outputParameterStruct, _ := structpb.NewValue(outputParameterMap) - - ecfg.OutputParameters[name] = outputParameterStruct - } + ecfg.OutputParameters = opts.Component.GetDag().GetOutputs().GetParameters() + glog.V(4).Info("outputParameters: ", ecfg.OutputParameters) // Handle writing output artifacts to MLMD. - outputArtifacts := opts.Component.GetDag().GetOutputs().GetArtifacts() - glog.V(4).Info("outputArtifacts: ", outputArtifacts) - ecfg.OutputArtifacts = outputArtifacts + ecfg.OutputArtifacts = opts.Component.GetDag().GetOutputs().GetArtifacts() + glog.V(4).Info("outputArtifacts: ", ecfg.OutputArtifacts) if opts.Task.GetArtifactIterator() != nil { return execution, fmt.Errorf("ArtifactIterator is not implemented") @@ -1256,55 +1239,136 @@ type resolveUpstreamParametersConfig struct { func resolveUpstreamParameters(cfg resolveUpstreamParametersConfig) error { taskOutput := cfg.paramSpec.GetTaskOutputParameter() glog.V(4).Info("taskOutput: ", taskOutput) - if taskOutput.GetProducerTask() == "" { - return cfg.paramError(fmt.Errorf("producer task is empty")) + producerTaskName := taskOutput.GetProducerTask() + if producerTaskName == "" { + return cfg.paramError(fmt.Errorf("producerTaskName is empty")) } - if taskOutput.GetOutputParameterKey() == "" { + outputParameterKey := taskOutput.GetOutputParameterKey() + if outputParameterKey == "" { return cfg.paramError(fmt.Errorf("output parameter key is empty")) } - tasks, err := getDAGTasks(cfg.ctx, cfg.dag, cfg.pipeline, cfg.mlmd, nil) + tasks, err := cfg.mlmd.GetExecutionsInDAG(cfg.ctx, cfg.dag, cfg.pipeline, false) if err != nil { return cfg.paramError(err) } - + glog.V(4).Infof("tasks: %#v", tasks) // The producer is the task that produces the output that we need to // consume. - producer := tasks[taskOutput.GetProducerTask()] - outputParameterKey := taskOutput.GetOutputParameterKey() + producer, ok := tasks[producerTaskName] + if !ok { + return cfg.paramError(fmt.Errorf("producer task, %v, not in tasks", producerTaskName)) + } glog.V(4).Info("producer: ", producer) currentTask := producer currentSubTaskMaybeDAG := true // Continue looping until we reach a sub-task that is NOT a DAG. for currentSubTaskMaybeDAG { glog.V(4).Info("currentTask: ", currentTask.TaskName()) - _, outputParametersCustomProperty, err := currentTask.GetParameters() - if err != nil { - return err - } // If the current task is a DAG: if *currentTask.GetExecution().Type == "system.DAGExecution" { // Since currentTask is a DAG, we need to deserialize its - // output parameter map so that we can look its + // output parameter map so that we can look up its // corresponding producer sub-task, reassign currentTask, // and iterate through this loop again. - var outputParametersMap map[string]string - b, err := outputParametersCustomProperty[outputParameterKey].GetStructValue().MarshalJSON() - if err != nil { - return err + outputParametersCustomProperty, ok := currentTask.GetExecution().GetCustomProperties()["parameter_producer_task"] + if !ok { + return cfg.paramError(fmt.Errorf("Task, %v, does not have a parameter_producer_task custom property", currentTask.TaskName())) + } + glog.V(4).Infof("outputParametersCustomProperty: %#v", outputParametersCustomProperty) + + dagOutputParametersMap := make(map[string]*pipelinespec.DagOutputsSpec_DagOutputParameterSpec) + glog.V(4).Infof("outputParametersCustomProperty: %v", outputParametersCustomProperty.GetStructValue()) + + for name, value := range outputParametersCustomProperty.GetStructValue().GetFields() { + outputSpec := &pipelinespec.DagOutputsSpec_DagOutputParameterSpec{} + err := protojson.Unmarshal([]byte(value.GetStringValue()), outputSpec) + if err != nil { + return err + } + dagOutputParametersMap[name] = outputSpec + } + + glog.V(4).Infof("Deserialized dagOutputParametersMap: %v", dagOutputParametersMap) + + // Support for the 2 DagOutputParameterSpec types: + // ValueFromParameter & ValueFromOneof + var subTaskName string + switch dagOutputParametersMap[outputParameterKey].Kind.(type) { + case *pipelinespec.DagOutputsSpec_DagOutputParameterSpec_ValueFromParameter: + subTaskName = dagOutputParametersMap[outputParameterKey].GetValueFromParameter().GetProducerSubtask() + outputParameterKey = dagOutputParametersMap[outputParameterKey].GetValueFromParameter().GetOutputParameterKey() + case *pipelinespec.DagOutputsSpec_DagOutputParameterSpec_ValueFromOneof: + // When OneOf is specified in a pipeline, the output of only 1 task is consumed even though there may be more than 1 task output set. In this case we will attempt to grab the first successful task output. + paramSelectors := dagOutputParametersMap[outputParameterKey].GetValueFromOneof().GetParameterSelectors() + glog.V(4).Infof("paramSelectors: %v", paramSelectors) + // Since we have the tasks map, we can iterate through the parameterSelectors if the ProducerSubTask is not present in the task map and then assign the new OutputParameterKey only if it exists. + successfulOneOfTask := false + for !successfulOneOfTask { + for _, paramSelector := range paramSelectors { + subTaskName = paramSelector.GetProducerSubtask() + glog.V(4).Infof("subTaskName from paramSelector: %v", subTaskName) + glog.V(4).Infof("outputParameterKey from paramSelector: %v", paramSelector.GetOutputParameterKey()) + if subTask, ok := tasks[subTaskName]; ok { + subTaskState := subTask.GetExecution().LastKnownState.String() + glog.V(4).Infof("subTask: %w , subTaskState: %v", subTaskName, subTaskState) + if subTaskState == "CACHED" || subTaskState == "COMPLETE" { + + outputParameterKey = paramSelector.GetOutputParameterKey() + successfulOneOfTask = true + break + } + } + } + return cfg.paramError(fmt.Errorf("Processing OneOf: No successful task found")) + } + } + // if reflect.TypeOf(dagOutputParametersMap[outputParameterKey].Kind).String() == "*pipelinespec.DagOutputsSpec_DagOutputParameterSpec_ValueFromParameter" { + + // } else { + // // Type of dagOutputParametersMap[outputParameterKey].Kind is *pipelinespec.DagOutputsSpec_DagOutputParameterSpec_ValueFromOneof + // paramSelectors := dagOutputParametersMap[outputParameterKey].GetValueFromOneof().GetParameterSelectors() + // glog.V(4).Infof("paramSelectors: %v", paramSelectors) + // // Since we have the tasks map, we can iterate through the parameterSelectors if the ProducerSubTask is not present in the task map and then assign the new OutputParameterKey only if it exists. + // successfulOneOfTask := false + // for !successfulOneOfTask { + // for _, paramSelector := range paramSelectors { + // subTaskName = paramSelector.GetProducerSubtask() + // glog.V(4).Infof("subTaskName from paramSelector: %v", subTaskName) + // glog.V(4).Infof("outputParameterKey from paramSelector: %v", paramSelector.GetOutputParameterKey()) + // if subTask, ok := tasks[subTaskName]; ok { + // subTaskState := subTask.GetExecution().LastKnownState.String() + // glog.V(4).Infof("subTask: %w , subTaskState: %v", subTaskName, subTaskState) + // if subTaskState == "CACHED" || subTaskState == "COMPLETE" { + + // outputParameterKey = paramSelector.GetOutputParameterKey() + // successfulOneOfTask = true + // break + // } + // } + // } + // return cfg.paramError(fmt.Errorf("Processing OneOf: No successful task found")) + // } + // } + glog.V(4).Infof("SubTaskName from outputParams: %v", subTaskName) + glog.V(4).Infof("OutputParameterKey from outputParams: %v", outputParameterKey) + if subTaskName == "" { + return cfg.paramError(fmt.Errorf("producer_subtask not in outputParams")) } - json.Unmarshal(b, &outputParametersMap) - glog.V(4).Info("Deserialized outputParametersMap: ", outputParametersMap) - subTaskName := outputParametersMap["producer_subtask"] - outputParameterKey = outputParametersMap["output_parameter_key"] glog.V(4).Infof( "Overriding currentTask, %v, output with currentTask's producer_subtask, %v, output.", currentTask.TaskName(), subTaskName, ) + currentTask, ok = tasks[subTaskName] + if !ok { + return cfg.paramError(fmt.Errorf("subTaskName, %v, not in tasks", subTaskName)) + } - // Reassign sub-task before running through the loop again. - currentTask = tasks[subTaskName] } else { + _, outputParametersCustomProperty, err := currentTask.GetParameters() + if err != nil { + return err + } cfg.inputs.ParameterValues[cfg.name] = outputParametersCustomProperty[outputParameterKey] // Exit the loop. currentSubTaskMaybeDAG = false @@ -1340,7 +1404,7 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamArtifactsConfig) error { if taskOutput.GetOutputArtifactKey() == "" { cfg.artifactError(fmt.Errorf("output artifact key is empty")) } - tasks, err := getDAGTasks(cfg.ctx, cfg.dag, cfg.pipeline, cfg.mlmd, nil) + tasks, err := cfg.mlmd.GetExecutionsInDAG(cfg.ctx, cfg.dag, cfg.pipeline, false) if err != nil { cfg.artifactError(err) } @@ -1361,7 +1425,7 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamArtifactsConfig) error { // If the current task is a DAG: if *currentTask.GetExecution().Type == "system.DAGExecution" { // Get the sub-task. - outputArtifactsCustomProperty := currentTask.GetExecution().GetCustomProperties()["output_artifacts"] + outputArtifactsCustomProperty := currentTask.GetExecution().GetCustomProperties()["artifact_producer_task"] // Deserialize the output artifacts. var outputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec err := json.Unmarshal([]byte(outputArtifactsCustomProperty.GetStringValue()), &outputArtifacts) @@ -1416,44 +1480,6 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamArtifactsConfig) error { return nil } -// getDAGTasks gets all the tasks associated with the specified DAG and all of -// its subDAGs. -func getDAGTasks( - ctx context.Context, - dag *metadata.DAG, - pipeline *metadata.Pipeline, - mlmd *metadata.Client, - flattenedTasks map[string]*metadata.Execution, -) (map[string]*metadata.Execution, error) { - if flattenedTasks == nil { - flattenedTasks = make(map[string]*metadata.Execution) - } - currentExecutionTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline) - if err != nil { - return nil, err - } - for k, v := range currentExecutionTasks { - flattenedTasks[k] = v - } - for _, v := range currentExecutionTasks { - if v.GetExecution().GetType() == "system.DAGExecution" { - glog.V(4).Infof("Found a task, %v, with an execution type of system.DAGExecution. Adding its tasks to the task list.", v.TaskName()) - subDAG, err := mlmd.GetDAG(ctx, v.GetExecution().GetId()) - if err != nil { - return nil, err - } - // Pass the subDAG into a recursive call to getDAGTasks and update - // tasks to include the subDAG's tasks. - flattenedTasks, err = getDAGTasks(ctx, subDAG, pipeline, mlmd, flattenedTasks) - if err != nil { - return nil, err - } - } - } - - return flattenedTasks, nil -} - func provisionOutputs(pipelineRoot, taskName string, outputsSpec *pipelinespec.ComponentOutputsSpec, outputUriSalt string) *pipelinespec.ExecutorInput_Outputs { outputs := &pipelinespec.ExecutorInput_Outputs{ Artifacts: make(map[string]*pipelinespec.ArtifactList), diff --git a/backend/src/v2/metadata/client.go b/backend/src/v2/metadata/client.go index c11d5f356f03..a3702b20f3ba 100644 --- a/backend/src/v2/metadata/client.go +++ b/backend/src/v2/metadata/client.go @@ -89,7 +89,9 @@ type ClientInterface interface { GetExecutions(ctx context.Context, ids []int64) ([]*pb.Execution, error) GetExecution(ctx context.Context, id int64) (*Execution, error) GetPipelineFromExecution(ctx context.Context, id int64) (*Pipeline, error) - GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pipeline) (executionsMap map[string]*Execution, err error) + GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pipeline, filter bool) (executionsMap map[string]*Execution, err error) + UpdateDAGExecutionsState(ctx context.Context, dag *DAG, pipeline *Pipeline) (err error) + PutDAGExecutionState(ctx context.Context, executionID int64, state pb.Execution_State) (err error) GetEventsByArtifactIDs(ctx context.Context, artifactIds []int64) ([]*pb.Event, error) GetArtifactName(ctx context.Context, artifactId int64) (string, error) GetArtifacts(ctx context.Context, ids []int64) ([]*pb.Artifact, error) @@ -135,7 +137,7 @@ type ExecutionConfig struct { NotTriggered bool // optional, not triggered executions will have CANCELED state. ParentDagID int64 // parent DAG execution ID. Only the root DAG does not have a parent DAG. InputParameters map[string]*structpb.Value - OutputParameters map[string]*structpb.Value + OutputParameters map[string]*pipelinespec.DagOutputsSpec_DagOutputParameterSpec OutputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec InputArtifactIDs map[string][]int64 IterationIndex *int // Index of the iteration. @@ -505,23 +507,25 @@ func (c *Client) PublishExecution(ctx context.Context, execution *Execution, out // metadata keys const ( - keyDisplayName = "display_name" - keyTaskName = "task_name" - keyImage = "image" - keyPodName = "pod_name" - keyPodUID = "pod_uid" - keyNamespace = "namespace" - keyResourceName = "resource_name" - keyPipelineRoot = "pipeline_root" - keyStoreSessionInfo = "store_session_info" - keyCacheFingerPrint = "cache_fingerprint" - keyCachedExecutionID = "cached_execution_id" - keyInputs = "inputs" - keyOutputs = "outputs" // TODO: Consider renaming this to output_parameters to be consistent. - keyOutputArtifacts = "output_artifacts" - keyParentDagID = "parent_dag_id" // Parent DAG Execution ID. - keyIterationIndex = "iteration_index" - keyIterationCount = "iteration_count" + keyDisplayName = "display_name" + keyTaskName = "task_name" + keyImage = "image" + keyPodName = "pod_name" + keyPodUID = "pod_uid" + keyNamespace = "namespace" + keyResourceName = "resource_name" + keyPipelineRoot = "pipeline_root" + keyStoreSessionInfo = "store_session_info" + keyCacheFingerPrint = "cache_fingerprint" + keyCachedExecutionID = "cached_execution_id" + keyInputs = "inputs" + keyOutputs = "outputs" + keyParameterProducerTask = "parameter_producer_task" + keyOutputArtifacts = "output_artifacts" + keyArtifactProducerTask = "artifact_producer_task" + keyParentDagID = "parent_dag_id" // Parent DAG Execution ID. + keyIterationIndex = "iteration_index" + keyIterationCount = "iteration_count" ) // CreateExecution creates a new MLMD execution under the specified Pipeline. @@ -587,9 +591,25 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config // relationships and retrieve outputs downstream in components that depend // on said outputs as inputs. if config.OutputParameters != nil { - e.CustomProperties[keyOutputs] = &pb.Value{Value: &pb.Value_StructValue{ + // Convert OutputParameters to a format that can be saved in MLMD. + glog.V(4).Info("outputParameters: ", config.OutputParameters) + outputParametersCustomPropertyProtoMap := make(map[string]*structpb.Value) + + for name, value := range config.OutputParameters { + if outputParameterProtoMsg, ok := interface{}(value).(proto.Message); ok { + glog.V(4).Infof("name: %v, value: %w", name, value) + glog.V(4).Info("protoMessage: ", outputParameterProtoMsg) + b, err := protojson.Marshal(outputParameterProtoMsg) + if err != nil { + return nil, err + } + outputValue, _ := structpb.NewValue(string(b)) + outputParametersCustomPropertyProtoMap[name] = outputValue + } + } + e.CustomProperties[keyParameterProducerTask] = &pb.Value{Value: &pb.Value_StructValue{ StructValue: &structpb.Struct{ - Fields: config.OutputParameters, + Fields: outputParametersCustomPropertyProtoMap, }, }} } @@ -598,7 +618,7 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config if err != nil { return nil, err } - e.CustomProperties[keyOutputArtifacts] = StringValue(string(b)) + e.CustomProperties[keyArtifactProducerTask] = StringValue(string(b)) } req := &pb.PutExecutionRequest{ @@ -664,6 +684,61 @@ func (c *Client) PrePublishExecution(ctx context.Context, execution *Execution, return execution, nil } +// UpdateDAGExecutionState checks all the statuses of the tasks in the given DAG, based on that it will update the DAG to the corresponding status if necessary. +func (c *Client) UpdateDAGExecutionsState(ctx context.Context, dag *DAG, pipeline *Pipeline) error { + tasks, err := c.GetExecutionsInDAG(ctx, dag, pipeline, true) + if err != nil { + return err + } + glog.V(4).Infof("tasks: %v", tasks) + glog.V(4).Infof("Checking Tasks' State") + completedTasks := 0 + failedTasks := 0 + totalTasks := len(tasks) + for _, task := range tasks { + taskState := task.GetExecution().LastKnownState.String() + glog.V(4).Infof("task: %s", task.TaskName()) + glog.V(4).Infof("task state: %s", taskState) + switch taskState { + case "FAILED": + failedTasks++ + case "COMPLETE": + completedTasks++ + case "CACHED": + completedTasks++ + case "CANCELED": + completedTasks++ + } + } + glog.V(4).Infof("completedTasks: %d", completedTasks) + glog.V(4).Infof("failedTasks: %d", failedTasks) + glog.V(4).Infof("totalTasks: %d", totalTasks) + + glog.Infof("Attempting to update DAG state") + if completedTasks == totalTasks { + c.PutDAGExecutionState(ctx, dag.Execution.GetID(), pb.Execution_COMPLETE) + } else if failedTasks > 0 { + c.PutDAGExecutionState(ctx, dag.Execution.GetID(), pb.Execution_FAILED) + } else { + glog.V(4).Infof("DAG is still running") + } + return nil +} + +// PutDAGExecutionState updates the given DAG Id to the state provided. +func (c *Client) PutDAGExecutionState(ctx context.Context, executionID int64, state pb.Execution_State) error { + + e, err := c.GetExecution(ctx, executionID) + if err != nil { + return err + } + e.execution.LastKnownState = state.Enum() + _, err = c.svc.PutExecution(ctx, &pb.PutExecutionRequest{ + Execution: e.execution, + }) + return err +} + // GetExecutions ... func (c *Client) GetExecutions(ctx context.Context, ids []int64) ([]*pb.Execution, error) { req := &pb.GetExecutionsByIDRequest{ExecutionIds: ids} @@ -728,7 +803,7 @@ func (c *Client) GetPipelineFromExecution(ctx context.Context, id int64) (*Pipel // GetExecutionsInDAG gets all executions in the DAG, and organize them // into a map, keyed by task name. -func (c *Client) GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pipeline) (executionsMap map[string]*Execution, err error) { +func (c *Client) GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pipeline, filter bool) (executionsMap map[string]*Execution, err error) { defer func() { if err != nil { err = fmt.Errorf("failed to get executions in %s: %w", dag.Info(), err) @@ -737,7 +812,12 @@ func (c *Client) GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pip executionsMap = make(map[string]*Execution) // Documentation on query syntax: // https://github.com/google/ml-metadata/blob/839c3501a195d340d2855b6ffdb2c4b0b49862c9/ml_metadata/proto/metadata_store.proto#L831 - parentDAGFilter := fmt.Sprintf("custom_properties.parent_dag_id.int_value = %v", dag.Execution.GetID()) + // If filter is set to true, the MLMD call will only grab executions for the current DAG, else it would grab all the execution for the context which includes sub-DAGs. + parentDAGFilter := "" + if filter { + parentDAGFilter = fmt.Sprintf("custom_properties.parent_dag_id.int_value = %v", dag.Execution.GetID()) + } + // Note, because MLMD does not have index on custom properties right now, we // take a pipeline run context to limit the number of executions the DB needs to // iterate through to find sub-executions. @@ -756,11 +836,16 @@ func (c *Client) GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pip } execs := res.GetExecutions() + glog.V(4).Infof("execs: %v", execs) for _, e := range execs { execution := &Execution{execution: e} taskName := execution.TaskName() if taskName == "" { - return nil, fmt.Errorf("empty task name for execution ID: %v", execution.GetID()) + if e.GetCustomProperties()[keyParentDagID] != nil { + return nil, fmt.Errorf("empty task name for execution ID: %v", execution.GetID()) + } + // When retrieving executions without the parentDAGFilter, the rootDAG execution is supplied but does not have an associated TaskName nor is the parentDagID set, therefore we won't include it in the executionsMap. + continue } existing, ok := executionsMap[taskName] if ok { diff --git a/backend/src/v2/metadata/client_test.go b/backend/src/v2/metadata/client_test.go index 94f081b32b0b..3cb5e1cc64c0 100644 --- a/backend/src/v2/metadata/client_test.go +++ b/backend/src/v2/metadata/client_test.go @@ -311,7 +311,7 @@ func Test_DAG(t *testing.T) { t.Fatal(err) } rootDAG := &metadata.DAG{Execution: root} - rootChildren, err := client.GetExecutionsInDAG(ctx, rootDAG, pipeline) + rootChildren, err := client.GetExecutionsInDAG(ctx, rootDAG, pipeline, true) if err != nil { t.Fatal(err) } @@ -324,7 +324,7 @@ func Test_DAG(t *testing.T) { if rootChildren["task2"].GetID() != task2.GetID() { t.Errorf("executions[\"task2\"].GetID()=%v, task2.GetID()=%v. Not equal", rootChildren["task2"].GetID(), task2.GetID()) } - task1Children, err := client.GetExecutionsInDAG(ctx, &metadata.DAG{Execution: task1DAG}, pipeline) + task1Children, err := client.GetExecutionsInDAG(ctx, &metadata.DAG{Execution: task1DAG}, pipeline, true) if len(task1Children) != 1 { t.Errorf("len(task1Children)=%v, expect 1", len(task1Children)) } diff --git a/backend/src/v2/objectstore/config.go b/backend/src/v2/objectstore/config.go index cc8d6372eb65..c3c982fd4f07 100644 --- a/backend/src/v2/objectstore/config.go +++ b/backend/src/v2/objectstore/config.go @@ -18,12 +18,13 @@ package objectstore import ( "encoding/json" "fmt" - "github.com/golang/glog" "os" "path" "regexp" "strconv" "strings" + + "github.com/golang/glog" ) // The endpoint uses Kubernetes service DNS name with namespace: @@ -228,6 +229,9 @@ func StructuredS3Params(p map[string]string) (*S3Params, error) { return nil, err } sparams.ForcePathStyle = boolVal + } else { + // Default to true if not specified, added for backwards compatibilty. + sparams.ForcePathStyle = true } return sparams, nil } diff --git a/backend/src/v2/objectstore/object_store.go b/backend/src/v2/objectstore/object_store.go index 41b5118c49f0..42ec6418c430 100644 --- a/backend/src/v2/objectstore/object_store.go +++ b/backend/src/v2/objectstore/object_store.go @@ -17,6 +17,13 @@ package objectstore import ( "context" "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "regexp" + "strings" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" @@ -27,14 +34,8 @@ import ( "gocloud.dev/blob/s3blob" "gocloud.dev/gcp" "golang.org/x/oauth2/google" - "io" - "io/ioutil" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" - "os" - "path/filepath" - "regexp" - "strings" ) func OpenBucket(ctx context.Context, k8sClient kubernetes.Interface, namespace string, config *Config) (bucket *blob.Bucket, err error) { diff --git a/samples/v2/sample_test.py b/samples/v2/sample_test.py index ed5fa0da825c..b925e29d7623 100644 --- a/samples/v2/sample_test.py +++ b/samples/v2/sample_test.py @@ -28,6 +28,8 @@ import pipeline_with_env import producer_consumer_param import two_step_pipeline_containerized +# import subdagio + _MINUTE = 60 # seconds _DEFAULT_TIMEOUT = 5 * _MINUTE @@ -74,10 +76,11 @@ def test(self): # TestCase(pipeline_func=subdagio.parameter.crust), # TestCase(pipeline_func=subdagio.parameter_cache.crust), + # TestCase(pipeline_func=subdagio.mixed_parameters.crust), + # TestCase(pipeline_func=subdagio.multiple_parameters_namedtuple.crust), + # TestCase(pipeline_func=subdagio.parameter_oneof.crust), # TestCase(pipeline_func=subdagio.artifact_cache.crust), # TestCase(pipeline_func=subdagio.artifact.crust), - # TestCase(pipeline_func=subdagio.mixed_parameters.crust), - # TestCase(pipeline_func=subdagio.multiple_parameters_namedtuple.crust) # TestCase(pipeline_func=subdagio.multiple_artifacts_namedtuple.crust), ] diff --git a/samples/v2/subdagio/__init__.py b/samples/v2/subdagio/__init__.py index dc8f8b3ceaee..d95886b8c559 100644 --- a/samples/v2/subdagio/__init__.py +++ b/samples/v2/subdagio/__init__.py @@ -5,3 +5,4 @@ from subdagio import multiple_parameters_namedtuple from subdagio import parameter from subdagio import parameter_cache +from subdagio import parameter_oneof \ No newline at end of file diff --git a/samples/v2/subdagio/parameter_oneof.py b/samples/v2/subdagio/parameter_oneof.py new file mode 100644 index 000000000000..6459c155ef62 --- /dev/null +++ b/samples/v2/subdagio/parameter_oneof.py @@ -0,0 +1,54 @@ +import os + +from kfp import Client +from kfp import dsl + +@dsl.component +def flip_coin() -> str: + import random + return 'heads' if random.randint(0, 1) == 0 else 'tails' + +@dsl.component +def core_comp(input: str) -> str: + print('input :', input) + return input + +@dsl.component +def core_output_comp(input: str, output_key: dsl.OutputPath(str)): + print('input :', input) + with open(output_key, 'w') as f: + f.write(input) + +@dsl.component +def crust_comp(input: str): + print('input :', input) + +@dsl.pipeline +def core() -> str: + flip_coin_task = flip_coin().set_caching_options(False) + with dsl.If(flip_coin_task.output == 'heads'): + t1 = core_comp(input='Got heads!').set_caching_options(False) + with dsl.Else(): + t2 = core_output_comp(input='Got tails!').set_caching_options(False) + return dsl.OneOf(t1.output, t2.outputs['output_key']) + +@dsl.pipeline +def mantle() -> str: + dag_task = core() + dag_task.set_caching_options(False) + + return dag_task.output + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + dag_task.set_caching_options(False) + + task = crust_comp(input=dag_task.output) + task.set_caching_options(False) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust)