diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index c701aadc2ee..68899f14a5b 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package driver import ( @@ -448,19 +449,28 @@ func initPodSpecPatch( accelerator := container.GetResources().GetAccelerator() if accelerator != nil { if accelerator.GetType() != "" && accelerator.GetCount() > 0 { + acceleratorType, err := resolvePodSpecInputRuntimeParameter(accelerator.GetType(), executorInput) + if err != nil { + return nil, fmt.Errorf("failed to init podSpecPatch: %w", err) + } q, err := k8sres.ParseQuantity(fmt.Sprintf("%v", accelerator.GetCount())) if err != nil { return nil, fmt.Errorf("failed to init podSpecPatch: %w", err) } - res.Limits[k8score.ResourceName(accelerator.GetType())] = q + res.Limits[k8score.ResourceName(acceleratorType)] = q } } + + containerImage, err := resolvePodSpecInputRuntimeParameter(container.Image, executorInput) + if err != nil { + return nil, fmt.Errorf("failed to init podSpecPatch: %w", err) + } podSpec := &k8score.PodSpec{ Containers: []k8score.Container{{ Name: "main", // argo task user container is always called "main" Command: launcherCmd, Args: userCmdArgs, - Image: container.Image, + Image: containerImage, Resources: res, Env: userEnvVar, }}, diff --git a/backend/src/v2/driver/util.go b/backend/src/v2/driver/util.go new file mode 100644 index 00000000000..b85e08ffe10 --- /dev/null +++ b/backend/src/v2/driver/util.go @@ -0,0 +1,78 @@ +// Copyright 2021-2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "fmt" + "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" + "regexp" +) + +// inputPipelineChannelPattern define a regex pattern to match the content within single quotes +// example input channel looks like "{{$.inputs.parameters['pipelinechannel--val']}}" +const inputPipelineChannelPattern = `\$.inputs.parameters\['(.+?)'\]` + +func isInputParameterChannel(inputChannel string) bool { + re := regexp.MustCompile(inputPipelineChannelPattern) + match := re.FindStringSubmatch(inputChannel) + if len(match) == 2 { + return true + } else { + // if len(match) > 2, then this is still incorrect because + // inputChannel should contain only one parameter channel input + return false + } +} + +// extractInputParameterFromChannel takes an inputChannel that adheres to +// inputPipelineChannelPattern and extracts the channel parameter name. +// For example given an input channel of the form "{{$.inputs.parameters['pipelinechannel--val']}}" +// the channel parameter name "pipelinechannel--val" is returned. +func extractInputParameterFromChannel(inputChannel string) (string, error) { + re := regexp.MustCompile(inputPipelineChannelPattern) + match := re.FindStringSubmatch(inputChannel) + if len(match) > 1 { + extractedValue := match[1] + return extractedValue, nil + } else { + return "", fmt.Errorf("failed to extract input parameter from channel: %s", inputChannel) + } +} + +// resolvePodSpecInputRuntimeParameter resolves runtime value that is intended to be +// utilized within the Pod Spec. parameterValue takes the form of: +// "{{$.inputs.parameters['pipelinechannel--someParameterName']}}" +// +// parameterValue is a runtime parameter value that has been resolved and included within +// the executor input. Since the pod spec patch cannot dynamically update the underlying +// container template's inputs in an Argo Workflow, this is a workaround for resolving +// such parameters. +// +// If parameter value is not a parameter channel, then a constant value is assumed and +// returned as is. +func resolvePodSpecInputRuntimeParameter(parameterValue string, executorInput *pipelinespec.ExecutorInput) (string, error) { + if isInputParameterChannel(parameterValue) { + inputImage, err := extractInputParameterFromChannel(parameterValue) + if err != nil { + return "", err + } + if val, ok := executorInput.Inputs.ParameterValues[inputImage]; ok { + return val.GetStringValue(), nil + } else { + return "", fmt.Errorf("executorInput did not contain container Image input parameter") + } + } + return parameterValue, nil +} diff --git a/backend/src/v2/driver/util_test.go b/backend/src/v2/driver/util_test.go new file mode 100644 index 00000000000..15d0ffc7e82 --- /dev/null +++ b/backend/src/v2/driver/util_test.go @@ -0,0 +1,159 @@ +// Copyright 2021-2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" + "github.com/stretchr/testify/assert" + structpb "google.golang.org/protobuf/types/known/structpb" + "testing" +) + +func Test_isInputParameterChannel(t *testing.T) { + tests := []struct { + name string + input string + isValid bool + }{ + { + name: "wellformed pipeline channel should produce no errors", + input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}", + isValid: true, + }, + { + name: "pipeline channel index should have quotes", + input: "{{$.inputs.parameters[pipelinechannel--someParameterName]}}", + isValid: false, + }, + { + name: "plain text as pipelinechannel of parameter type is invalid", + input: "randomtext", + isValid: false, + }, + { + name: "inputs should be prefixed with $.", + input: "{{inputs.parameters['pipelinechannel--someParameterName']}}", + isValid: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, isInputParameterChannel(test.input), test.isValid) + }) + } +} + +func Test_extractInputParameterFromChannel(t *testing.T) { + tests := []struct { + name string + input string + expected string + wantErr bool + }{ + { + name: "standard parameter pipeline channel input", + input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}", + expected: "pipelinechannel--someParameterName", + wantErr: false, + }, + { + name: "a more complex parameter pipeline channel input", + input: "{{$.inputs.parameters['pipelinechannel--somePara-me_terName']}}", + expected: "pipelinechannel--somePara-me_terName", + wantErr: false, + }, + { + name: "invalid input should return err", + input: "invalidvalue", + wantErr: true, + }, + { + name: "invalid input should return err 2", + input: "pipelinechannel--somePara-me_terName", + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual, err := extractInputParameterFromChannel(test.input) + if test.wantErr { + assert.NotNil(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, actual, test.expected) + } + }) + } +} + +func Test_resolvePodSpecRuntimeParameter(t *testing.T) { + tests := []struct { + name string + input string + expected string + executorInput *pipelinespec.ExecutorInput + wantErr bool + }{ + { + name: "should retrieve correct parameter value", + input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}", + expected: "test2", + executorInput: &pipelinespec.ExecutorInput{ + Inputs: &pipelinespec.ExecutorInput_Inputs{ + ParameterValues: map[string]*structpb.Value{ + "pipelinechannel--": structpb.NewStringValue("test1"), + "pipelinechannel--someParameterName": structpb.NewStringValue("test2"), + "someParameterName": structpb.NewStringValue("test3"), + }, + }, + }, + wantErr: false, + }, + { + name: "return err when no match is found", + input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}", + expected: "test1", + executorInput: &pipelinespec.ExecutorInput{ + Inputs: &pipelinespec.ExecutorInput_Inputs{ + ParameterValues: map[string]*structpb.Value{ + "doesNotMatch": structpb.NewStringValue("test2"), + }, + }, + }, + wantErr: true, + }, + { + name: "return const val when input is not a pipeline channel", + input: "not-pipeline-channel", + expected: "not-pipeline-channel", + executorInput: &pipelinespec.ExecutorInput{}, + wantErr: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual, err := resolvePodSpecInputRuntimeParameter(test.input, test.executorInput) + if test.wantErr { + assert.NotNil(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, actual, test.expected) + } + }) + } +} diff --git a/sdk/RELEASE.md b/sdk/RELEASE.md index c38c89a9271..b66ca742793 100644 --- a/sdk/RELEASE.md +++ b/sdk/RELEASE.md @@ -2,6 +2,7 @@ ## Features * Expose `--existing-token` flag in `kfp` CLI to allow users to provide an existing token for authentication. [\#11400](https://github.com/kubeflow/pipelines/pull/11400) +* Add the ability to parameterize container images for tasks within pipelines ## Breaking changes diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 1097c06d521..898187e36c3 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -909,6 +909,70 @@ def my_pipeline() -> NamedTuple('Outputs', [ ]): task = print_and_return(text='Hello') + def test_pipeline_with_parameterized_container_image(self): + with tempfile.TemporaryDirectory() as tmpdir: + + @dsl.component(base_image='docker.io/python:3.9.17') + def empty_component(): + pass + + @dsl.pipeline() + def simple_pipeline(img: str): + task = empty_component() + # overwrite base_image="docker.io/python:3.9.17" + task.set_container_image(img) + + output_yaml = os.path.join(tmpdir, 'result.yaml') + compiler.Compiler().compile( + pipeline_func=simple_pipeline, + package_path=output_yaml, + pipeline_parameters={'img': 'someimage'}) + self.assertTrue(os.path.exists(output_yaml)) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + container = pipeline_spec['deploymentSpec']['executors'][ + 'exec-empty-component']['container'] + self.assertEqual( + container['image'], + "{{$.inputs.parameters['pipelinechannel--img']}}") + # A parameter value should result in 2 input parameters + # One for storing pipeline channel template to be resolved during runtime. + # Two for holding the key to the resolved input. + input_parameters = pipeline_spec['root']['dag']['tasks'][ + 'empty-component']['inputs']['parameters'] + self.assertTrue('base_image' in input_parameters) + self.assertTrue('pipelinechannel--img' in input_parameters) + + def test_pipeline_with_constant_container_image(self): + with tempfile.TemporaryDirectory() as tmpdir: + + @dsl.component(base_image='docker.io/python:3.9.17') + def empty_component(): + pass + + @dsl.pipeline() + def simple_pipeline(): + task = empty_component() + # overwrite base_image="docker.io/python:3.9.17" + task.set_container_image('constant-value') + + output_yaml = os.path.join(tmpdir, 'result.yaml') + compiler.Compiler().compile( + pipeline_func=simple_pipeline, package_path=output_yaml) + + self.assertTrue(os.path.exists(output_yaml)) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + container = pipeline_spec['deploymentSpec']['executors'][ + 'exec-empty-component']['container'] + self.assertEqual(container['image'], 'constant-value') + # A constant value should yield no parameters + dag_task = pipeline_spec['root']['dag']['tasks'][ + 'empty-component'] + self.assertTrue('inputs' not in dag_task) + class TestCompilePipelineCaching(unittest.TestCase): diff --git a/sdk/python/kfp/compiler/pipeline_spec_builder.py b/sdk/python/kfp/compiler/pipeline_spec_builder.py index cef5f1e8794..3061faab5e4 100644 --- a/sdk/python/kfp/compiler/pipeline_spec_builder.py +++ b/sdk/python/kfp/compiler/pipeline_spec_builder.py @@ -135,6 +135,11 @@ def build_task_spec_for_task( if val and pipeline_channel.extract_pipeline_channels_from_any(val): task.inputs[key] = val + if task.container_spec and task.container_spec.image: + val = task.container_spec.image + if val and pipeline_channel.extract_pipeline_channels_from_any(val): + task.inputs['base_image'] = val + for input_name, input_value in task.inputs.items(): # Since LoopParameterArgument and LoopArtifactArgument and LoopArgumentVariable are narrower # types than PipelineParameterChannel, start with them. @@ -634,7 +639,7 @@ def convert_to_placeholder(input_value: str) -> str: container_spec = ( pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec( - image=task.container_spec.image, + image=convert_to_placeholder(task.container_spec.image), command=task.container_spec.command, args=task.container_spec.args, env=[ diff --git a/sdk/python/kfp/dsl/pipeline_task.py b/sdk/python/kfp/dsl/pipeline_task.py index 822f5520788..b41a14ef82d 100644 --- a/sdk/python/kfp/dsl/pipeline_task.py +++ b/sdk/python/kfp/dsl/pipeline_task.py @@ -631,6 +631,26 @@ def set_env_variable(self, name: str, value: str) -> 'PipelineTask': self.container_spec.env = {name: value} return self + @block_if_final() + def set_container_image( + self, + name: Union[str, + pipeline_channel.PipelineChannel]) -> 'PipelineTask': + """Sets container type to use when executing this task. Takes + precedence over @component(base_image=...) + + Args: + name: The name of the image, e.g. "python:3.9-alpine". + + Returns: + Self return to allow chained setting calls. + """ + self._ensure_container_spec_exists() + if isinstance(name, pipeline_channel.PipelineChannel): + name = str(name) + self.container_spec.image = name + return self + @block_if_final() def after(self, *tasks) -> 'PipelineTask': """Specifies an explicit dependency on other tasks by requiring this