Skip to content

Commit 1c2da91

Browse files
committed
enable parameterization of container images
This change allows component base images to be parameterized using runtime pipeline parameters. The container images can be specified within an @pipeline decorated function, and takes precedence over the @component(base_image=..) argument. This change also adds logic to resolve these runtime parameters in the argo driver logic. It also includes resolution steps for resolving the accelerator type which functions the same way but was missing the resolution logic. The resolution logic is a generic workaround solution for any run time pod spec input parameters that cannot be resolved because they cannot be added dynamically in the argo pod spec container template. Signed-off-by: Humair Khan <HumairAK@users.noreply.github.com>
1 parent 634aadf commit 1c2da91

File tree

6 files changed

+342
-6
lines changed

6 files changed

+342
-6
lines changed

backend/src/v2/driver/driver.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
1415
package driver
1516

1617
import (
@@ -448,19 +449,28 @@ func initPodSpecPatch(
448449
accelerator := container.GetResources().GetAccelerator()
449450
if accelerator != nil {
450451
if accelerator.GetType() != "" && accelerator.GetCount() > 0 {
451-
q, err := k8sres.ParseQuantity(fmt.Sprintf("%v", accelerator.GetCount()))
452-
if err != nil {
453-
return nil, fmt.Errorf("failed to init podSpecPatch: %w", err)
452+
acceleratorType, err1 := resolvePodSpecInputRuntimeParameter(accelerator.GetType(), executorInput)
453+
if err1 != nil {
454+
return nil, fmt.Errorf("failed to init podSpecPatch: %w", err1)
454455
}
455-
res.Limits[k8score.ResourceName(accelerator.GetType())] = q
456+
q, err1 := k8sres.ParseQuantity(fmt.Sprintf("%v", accelerator.GetCount()))
457+
if err1 != nil {
458+
return nil, fmt.Errorf("failed to init podSpecPatch: %w", err1)
459+
}
460+
res.Limits[k8score.ResourceName(acceleratorType)] = q
456461
}
457462
}
463+
464+
containerImage, err := resolvePodSpecInputRuntimeParameter(container.Image, executorInput)
465+
if err != nil {
466+
return nil, fmt.Errorf("failed to init podSpecPatch: %w", err)
467+
}
458468
podSpec := &k8score.PodSpec{
459469
Containers: []k8score.Container{{
460470
Name: "main", // argo task user container is always called "main"
461471
Command: launcherCmd,
462472
Args: userCmdArgs,
463-
Image: container.Image,
473+
Image: containerImage,
464474
Resources: res,
465475
Env: userEnvVar,
466476
}},

backend/src/v2/driver/util.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright 2021-2024 The Kubeflow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package driver
16+
17+
import (
18+
"fmt"
19+
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
20+
"regexp"
21+
)
22+
23+
// InputPipelineChannelPattern define a regex pattern to match the content within single quotes
24+
// example input channel looks like "{{$.inputs.parameters['pipelinechannel--val']}}"
25+
const InputPipelineChannelPattern = `\$.inputs.parameters\['(.+?)'\]`
26+
27+
func isInputParameterChannel(inputChannel string) bool {
28+
re := regexp.MustCompile(InputPipelineChannelPattern)
29+
match := re.FindStringSubmatch(inputChannel)
30+
if len(match) == 2 {
31+
return true
32+
} else {
33+
// if len(match) > 2, then this is still incorrect because
34+
// inputChannel should contain only one parameter channel input
35+
return false
36+
}
37+
}
38+
39+
// extractInputParameterFromChannel takes an inputChannel that adheres to
40+
// InputPipelineChannelPattern and extracts the channel parameter name.
41+
// For example given an input channel of the form "{{$.inputs.parameters['pipelinechannel--val']}}"
42+
// the channel parameter name "pipelinechannel--val" is returned.
43+
func extractInputParameterFromChannel(inputChannel string) (string, error) {
44+
re := regexp.MustCompile(InputPipelineChannelPattern)
45+
match := re.FindStringSubmatch(inputChannel)
46+
if len(match) > 1 {
47+
extractedValue := match[1]
48+
return extractedValue, nil
49+
} else {
50+
return "", fmt.Errorf("failed to extract input parameter from channel: %s", inputChannel)
51+
}
52+
}
53+
54+
// resolvePodSpecInputRuntimeParameter resolves runtime value that is intended to be
55+
// utilized within the Pod Spec. parameterValue takes the form of:
56+
// "{{$.inputs.parameters['pipelinechannel--someParameterName']}}"
57+
//
58+
// parameterValue is a runtime parameter value that has been resolved and included within
59+
// the executor input. Since the pod spec patch cannot dynamically update the underlying
60+
// container template's inputs in an Argo Workflow, this is a workaround for resolving
61+
// such parameters.
62+
//
63+
// If parameter value is not a parameter channel, then a constant value is assumed and
64+
// returned as is.
65+
func resolvePodSpecInputRuntimeParameter(parameterValue string, executorInput *pipelinespec.ExecutorInput) (string, error) {
66+
if isInputParameterChannel(parameterValue) {
67+
inputImage, err1 := extractInputParameterFromChannel(parameterValue)
68+
if err1 != nil {
69+
return "", err1
70+
}
71+
if val, ok := executorInput.Inputs.ParameterValues[inputImage]; ok {
72+
return val.GetStringValue(), nil
73+
} else {
74+
return "", fmt.Errorf("executorInput did not contain container Image input parameter")
75+
}
76+
}
77+
return parameterValue, nil
78+
}

backend/src/v2/driver/util_test.go

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// Copyright 2021-2024 The Kubeflow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package driver
16+
17+
import (
18+
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
19+
"github.com/stretchr/testify/assert"
20+
structpb "google.golang.org/protobuf/types/known/structpb"
21+
"testing"
22+
)
23+
24+
func Test_isInputParameterChannel(t *testing.T) {
25+
tests := []struct {
26+
name string
27+
input string
28+
isValid bool
29+
}{
30+
{
31+
name: "wellformed pipeline channel should produce no errors",
32+
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}",
33+
isValid: true,
34+
},
35+
{
36+
name: "pipeline channel index should have quotes",
37+
input: "{{$.inputs.parameters[pipelinechannel--someParameterName]}}",
38+
isValid: false,
39+
},
40+
{
41+
name: "plain text as pipelinechannel of parameter type is invalid",
42+
input: "randomtext",
43+
isValid: false,
44+
},
45+
{
46+
name: "inputs should be prefixed with $.",
47+
input: "{{inputs.parameters['pipelinechannel--someParameterName']}}",
48+
isValid: false,
49+
},
50+
}
51+
52+
for _, test := range tests {
53+
t.Run(test.name, func(t *testing.T) {
54+
assert.Equal(t, isInputParameterChannel(test.input), test.isValid)
55+
})
56+
}
57+
}
58+
59+
func Test_extractInputParameterFromChannel(t *testing.T) {
60+
tests := []struct {
61+
name string
62+
input string
63+
expected string
64+
wantErr bool
65+
}{
66+
{
67+
name: "standard parameter pipeline channel input",
68+
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}",
69+
expected: "pipelinechannel--someParameterName",
70+
wantErr: false,
71+
},
72+
{
73+
name: "a more complex parameter pipeline channel input",
74+
input: "{{$.inputs.parameters['pipelinechannel--somePara-me_terName']}}",
75+
expected: "pipelinechannel--somePara-me_terName",
76+
wantErr: false,
77+
},
78+
{
79+
name: "invalid input should return err",
80+
input: "invalidvalue",
81+
wantErr: true,
82+
},
83+
{
84+
name: "invalid input should return err 2",
85+
input: "pipelinechannel--somePara-me_terName",
86+
wantErr: true,
87+
},
88+
}
89+
90+
for _, test := range tests {
91+
t.Run(test.name, func(t *testing.T) {
92+
actual, err := extractInputParameterFromChannel(test.input)
93+
if test.wantErr {
94+
assert.NotNil(t, err)
95+
} else {
96+
assert.NoError(t, err)
97+
assert.Equal(t, actual, test.expected)
98+
}
99+
})
100+
}
101+
}
102+
103+
func Test_resolvePodSpecRuntimeParameter(t *testing.T) {
104+
tests := []struct {
105+
name string
106+
input string
107+
expected string
108+
executorInput *pipelinespec.ExecutorInput
109+
wantErr bool
110+
}{
111+
{
112+
name: "should retrieve correct parameter value",
113+
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}",
114+
expected: "test2",
115+
executorInput: &pipelinespec.ExecutorInput{
116+
Inputs: &pipelinespec.ExecutorInput_Inputs{
117+
ParameterValues: map[string]*structpb.Value{
118+
"pipelinechannel--": structpb.NewStringValue("test1"),
119+
"pipelinechannel--someParameterName": structpb.NewStringValue("test2"),
120+
"someParameterName": structpb.NewStringValue("test3"),
121+
},
122+
},
123+
},
124+
wantErr: false,
125+
},
126+
{
127+
name: "return err when no match is found",
128+
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}",
129+
expected: "test1",
130+
executorInput: &pipelinespec.ExecutorInput{
131+
Inputs: &pipelinespec.ExecutorInput_Inputs{
132+
ParameterValues: map[string]*structpb.Value{
133+
"doesNotMatch": structpb.NewStringValue("test2"),
134+
},
135+
},
136+
},
137+
wantErr: true,
138+
},
139+
{
140+
name: "return const val when input is not a pipeline channel",
141+
input: "not-pipeline-channel",
142+
expected: "not-pipeline-channel",
143+
executorInput: &pipelinespec.ExecutorInput{},
144+
wantErr: false,
145+
},
146+
}
147+
148+
for _, test := range tests {
149+
t.Run(test.name, func(t *testing.T) {
150+
actual, err := resolvePodSpecInputRuntimeParameter(test.input, test.executorInput)
151+
if test.wantErr {
152+
assert.NotNil(t, err)
153+
} else {
154+
assert.NoError(t, err)
155+
assert.Equal(t, actual, test.expected)
156+
}
157+
})
158+
}
159+
}

sdk/python/kfp/compiler/compiler_test.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,70 @@ def my_pipeline() -> NamedTuple('Outputs', [
909909
]):
910910
task = print_and_return(text='Hello')
911911

912+
def test_pipeline_with_parameterized_container_image(self):
913+
with tempfile.TemporaryDirectory() as tmpdir:
914+
915+
@dsl.component(base_image='docker.io/python:3.9.17')
916+
def empty_component():
917+
pass
918+
919+
@dsl.pipeline()
920+
def simple_pipeline(img: str):
921+
task = empty_component()
922+
# overwrite base_image="docker.io/python:3.9.17"
923+
task.set_container_image(img)
924+
925+
output_yaml = os.path.join(tmpdir, 'result.yaml')
926+
compiler.Compiler().compile(
927+
pipeline_func=simple_pipeline,
928+
package_path=output_yaml,
929+
pipeline_parameters={'img': 'someimage'})
930+
self.assertTrue(os.path.exists(output_yaml))
931+
932+
with open(output_yaml, 'r') as f:
933+
pipeline_spec = yaml.safe_load(f)
934+
container = pipeline_spec['deploymentSpec']['executors'][
935+
'exec-empty-component']['container']
936+
self.assertEqual(
937+
container['image'],
938+
"{{$.inputs.parameters['pipelinechannel--img']}}")
939+
# A parameter value should result in 2 input parameters
940+
# One for storing pipeline channel template to be resolved during runtime.
941+
# Two for holding the key to the resolved input.
942+
input_parameters = pipeline_spec['root']['dag']['tasks'][
943+
'empty-component']['inputs']['parameters']
944+
self.assertTrue('base_image' in input_parameters)
945+
self.assertTrue('pipelinechannel--img' in input_parameters)
946+
947+
def test_pipeline_with_constant_container_image(self):
948+
with tempfile.TemporaryDirectory() as tmpdir:
949+
950+
@dsl.component(base_image='docker.io/python:3.9.17')
951+
def empty_component():
952+
pass
953+
954+
@dsl.pipeline()
955+
def simple_pipeline():
956+
task = empty_component()
957+
# overwrite base_image="docker.io/python:3.9.17"
958+
task.set_container_image('constant-value')
959+
960+
output_yaml = os.path.join(tmpdir, 'result.yaml')
961+
compiler.Compiler().compile(
962+
pipeline_func=simple_pipeline, package_path=output_yaml)
963+
964+
self.assertTrue(os.path.exists(output_yaml))
965+
966+
with open(output_yaml, 'r') as f:
967+
pipeline_spec = yaml.safe_load(f)
968+
container = pipeline_spec['deploymentSpec']['executors'][
969+
'exec-empty-component']['container']
970+
self.assertEqual(container['image'], 'constant-value')
971+
# A constant value should yield no parameters
972+
dag_task = pipeline_spec['root']['dag']['tasks'][
973+
'empty-component']
974+
self.assertTrue('inputs' not in dag_task)
975+
912976

913977
class TestCompilePipelineCaching(unittest.TestCase):
914978

sdk/python/kfp/compiler/pipeline_spec_builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ def build_task_spec_for_task(
135135
if val and pipeline_channel.extract_pipeline_channels_from_any(val):
136136
task.inputs[key] = val
137137

138+
if task.container_spec and task.container_spec.image:
139+
val = task.container_spec.image
140+
if val and pipeline_channel.extract_pipeline_channels_from_any(val):
141+
task.inputs['base_image'] = val
142+
138143
for input_name, input_value in task.inputs.items():
139144
# Since LoopParameterArgument and LoopArtifactArgument and LoopArgumentVariable are narrower
140145
# types than PipelineParameterChannel, start with them.
@@ -634,7 +639,7 @@ def convert_to_placeholder(input_value: str) -> str:
634639

635640
container_spec = (
636641
pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec(
637-
image=task.container_spec.image,
642+
image=convert_to_placeholder(task.container_spec.image),
638643
command=task.container_spec.command,
639644
args=task.container_spec.args,
640645
env=[

sdk/python/kfp/dsl/pipeline_task.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,26 @@ def set_env_variable(self, name: str, value: str) -> 'PipelineTask':
631631
self.container_spec.env = {name: value}
632632
return self
633633

634+
@block_if_final()
635+
def set_container_image(
636+
self,
637+
name: Union[str,
638+
pipeline_channel.PipelineChannel]) -> 'PipelineTask':
639+
"""Sets container type to use when executing this task. Takes
640+
precedence over @component(base_image=...)
641+
642+
Args:
643+
name: The name of the image, e.g. "python:3.9-alpine".
644+
645+
Returns:
646+
Self return to allow chained setting calls.
647+
"""
648+
self._ensure_container_spec_exists()
649+
if isinstance(name, pipeline_channel.PipelineChannel):
650+
name = str(name)
651+
self.container_spec.image = name
652+
return self
653+
634654
@block_if_final()
635655
def after(self, *tasks) -> 'PipelineTask':
636656
"""Specifies an explicit dependency on other tasks by requiring this

0 commit comments

Comments
 (0)