Skip to content

Commit 3f49522

Browse files
feat(sdk): Allow disabling default caching via a CLI flag and env var (#11222)
* feat(sdk): Allow setting a default of execution caching disabled via a compiler CLI flag and env var Co-authored-by: Greg Sheremeta <gshereme@redhat.com> Signed-off-by: ddalvi <ddalvi@redhat.com> * Add tests for disabling default caching var and flag Signed-off-by: ddalvi <ddalvi@redhat.com> --------- Signed-off-by: ddalvi <ddalvi@redhat.com> Co-authored-by: Greg Sheremeta <gshereme@redhat.com>
1 parent 880e46d commit 3f49522

File tree

6 files changed

+165
-1
lines changed

6 files changed

+165
-1
lines changed

sdk/python/kfp/cli/cli_test.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from click import testing
2828
from kfp.cli import cli
2929
from kfp.cli import compile_
30+
import yaml
3031

3132

3233
class TestCliNounAliases(unittest.TestCase):
@@ -196,5 +197,88 @@ def test(self, noun: str, verb: str):
196197
self.assertEqual(result.exit_code, 0)
197198

198199

200+
class TestKfpDslCompile(unittest.TestCase):
201+
202+
def invoke(self, args):
203+
starting_args = ['dsl', 'compile']
204+
args = starting_args + args
205+
runner = testing.CliRunner()
206+
return runner.invoke(
207+
cli=cli.cli, args=args, catch_exceptions=False, obj={})
208+
209+
def create_pipeline_file(self):
210+
pipeline_code = b"""
211+
from kfp import dsl
212+
213+
@dsl.component
214+
def my_component():
215+
pass
216+
217+
@dsl.pipeline(name="tiny-pipeline")
218+
def my_pipeline():
219+
my_component_task = my_component()
220+
"""
221+
temp_pipeline = tempfile.NamedTemporaryFile(suffix='.py', delete=False)
222+
temp_pipeline.write(pipeline_code)
223+
temp_pipeline.flush()
224+
return temp_pipeline
225+
226+
def load_output_yaml(self, output_file):
227+
with open(output_file, 'r') as f:
228+
return yaml.safe_load(f)
229+
230+
def test_compile_with_caching_flag_enabled(self):
231+
temp_pipeline = self.create_pipeline_file()
232+
output_file = 'test_output.yaml'
233+
234+
result = self.invoke(
235+
['--py', temp_pipeline.name, '--output', output_file])
236+
self.assertEqual(result.exit_code, 0)
237+
238+
output_data = self.load_output_yaml(output_file)
239+
self.assertIn('root', output_data)
240+
self.assertIn('tasks', output_data['root']['dag'])
241+
for task in output_data['root']['dag']['tasks'].values():
242+
self.assertIn('cachingOptions', task)
243+
caching_options = task['cachingOptions']
244+
self.assertEqual(caching_options.get('enableCache'), True)
245+
246+
def test_compile_with_caching_flag_disabled(self):
247+
temp_pipeline = self.create_pipeline_file()
248+
output_file = 'test_output.yaml'
249+
250+
result = self.invoke([
251+
'--py', temp_pipeline.name, '--output', output_file,
252+
'--disable-execution-caching-by-default'
253+
])
254+
self.assertEqual(result.exit_code, 0)
255+
256+
output_data = self.load_output_yaml(output_file)
257+
self.assertIn('root', output_data)
258+
self.assertIn('tasks', output_data['root']['dag'])
259+
for task in output_data['root']['dag']['tasks'].values():
260+
self.assertIn('cachingOptions', task)
261+
caching_options = task['cachingOptions']
262+
self.assertEqual(caching_options, {})
263+
264+
def test_compile_with_caching_disabled_env_var(self):
265+
temp_pipeline = self.create_pipeline_file()
266+
output_file = 'test_output.yaml'
267+
268+
os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] = 'true'
269+
result = self.invoke(
270+
['--py', temp_pipeline.name, '--output', output_file])
271+
self.assertEqual(result.exit_code, 0)
272+
del os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT']
273+
274+
output_data = self.load_output_yaml(output_file)
275+
self.assertIn('root', output_data)
276+
self.assertIn('tasks', output_data['root']['dag'])
277+
for task in output_data['root']['dag']['tasks'].values():
278+
self.assertIn('cachingOptions', task)
279+
caching_options = task['cachingOptions']
280+
self.assertEqual(caching_options, {})
281+
282+
199283
if __name__ == '__main__':
200284
unittest.main()

sdk/python/kfp/cli/compile_.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from kfp import compiler
2525
from kfp.dsl import base_component
2626
from kfp.dsl import graph_component
27+
from kfp.dsl.pipeline_context import Pipeline
2728

2829

2930
def is_pipeline_func(func: Callable) -> bool:
@@ -133,14 +134,23 @@ def parse_parameters(parameters: Optional[str]) -> Dict:
133134
is_flag=True,
134135
default=False,
135136
help='Whether to disable type checking.')
137+
@click.option(
138+
'--disable-execution-caching-by-default',
139+
is_flag=True,
140+
default=False,
141+
envvar='KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT',
142+
help='Whether to disable execution caching by default.')
136143
def compile_(
137144
py: str,
138145
output: str,
139146
function_name: Optional[str] = None,
140147
pipeline_parameters: Optional[str] = None,
141148
disable_type_check: bool = False,
149+
disable_execution_caching_by_default: bool = False,
142150
) -> None:
143151
"""Compiles a pipeline or component written in a .py file."""
152+
153+
Pipeline._execution_caching_default = not disable_execution_caching_by_default
144154
pipeline_func = collect_pipeline_or_component_func(
145155
python_file=py, function_name=function_name)
146156
parsed_parameters = parse_parameters(parameters=pipeline_parameters)

sdk/python/kfp/compiler/compiler_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,59 @@ def my_pipeline() -> NamedTuple('Outputs', [
910910
task = print_and_return(text='Hello')
911911

912912

913+
class TestCompilePipelineCaching(unittest.TestCase):
914+
915+
def test_compile_pipeline_with_caching_enabled(self):
916+
"""Test pipeline compilation with caching enabled."""
917+
918+
@dsl.component
919+
def my_component():
920+
pass
921+
922+
@dsl.pipeline(name='tiny-pipeline')
923+
def my_pipeline():
924+
my_task = my_component()
925+
my_task.set_caching_options(True)
926+
927+
with tempfile.TemporaryDirectory() as tempdir:
928+
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
929+
compiler.Compiler().compile(
930+
pipeline_func=my_pipeline, package_path=output_yaml)
931+
932+
with open(output_yaml, 'r') as f:
933+
pipeline_spec = yaml.safe_load(f)
934+
935+
task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
936+
caching_options = task_spec['cachingOptions']
937+
938+
self.assertTrue(caching_options['enableCache'])
939+
940+
def test_compile_pipeline_with_caching_disabled(self):
941+
"""Test pipeline compilation with caching disabled."""
942+
943+
@dsl.component
944+
def my_component():
945+
pass
946+
947+
@dsl.pipeline(name='tiny-pipeline')
948+
def my_pipeline():
949+
my_task = my_component()
950+
my_task.set_caching_options(False)
951+
952+
with tempfile.TemporaryDirectory() as tempdir:
953+
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
954+
compiler.Compiler().compile(
955+
pipeline_func=my_pipeline, package_path=output_yaml)
956+
957+
with open(output_yaml, 'r') as f:
958+
pipeline_spec = yaml.safe_load(f)
959+
960+
task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
961+
caching_options = task_spec.get('cachingOptions', {})
962+
963+
self.assertEqual(caching_options, {})
964+
965+
913966
class V2NamespaceAliasTest(unittest.TestCase):
914967
"""Test that imports of both modules and objects are aliased (e.g. all
915968
import path variants work)."""

sdk/python/kfp/dsl/base_component.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def __call__(self, *args, **kwargs) -> pipeline_task.PipelineTask:
103103
args=task_inputs,
104104
execute_locally=pipeline_context.Pipeline.get_default_pipeline() is
105105
None,
106+
execution_caching_default=pipeline_context.Pipeline
107+
.get_execution_caching_default(),
106108
)
107109

108110
@property

sdk/python/kfp/dsl/pipeline_context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Definition for Pipeline."""
1515

1616
import functools
17+
import os
1718
from typing import Callable, Optional
1819

1920
from kfp.dsl import component_factory
@@ -101,6 +102,19 @@ def get_default_pipeline():
101102
"""Gets the default pipeline."""
102103
return Pipeline._default_pipeline
103104

105+
# _execution_caching_default can be disabled via the click option --disable-execution-caching-by-default
106+
# or the env var KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT.
107+
# align with click's treatment of env vars for boolean flags.
108+
# per click doc, "1", "true", "t", "yes", "y", and "on" are all converted to True
109+
_execution_caching_default = not str(
110+
os.getenv('KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT')).strip().lower(
111+
) in {'1', 'true', 't', 'yes', 'y', 'on'}
112+
113+
@staticmethod
114+
def get_execution_caching_default():
115+
"""Gets the default execution caching."""
116+
return Pipeline._execution_caching_default
117+
104118
def __init__(self, name: str):
105119
"""Creates a new instance of Pipeline.
106120

sdk/python/kfp/dsl/pipeline_task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
component_spec: structures.ComponentSpec,
9999
args: Dict[str, Any],
100100
execute_locally: bool = False,
101+
execution_caching_default: bool = True,
101102
) -> None:
102103
"""Initilizes a PipelineTask instance."""
103104
# import within __init__ to avoid circular import
@@ -130,7 +131,7 @@ def __init__(
130131
inputs=dict(args.items()),
131132
dependent_tasks=[],
132133
component_ref=component_spec.name,
133-
enable_caching=True)
134+
enable_caching=execution_caching_default)
134135
self._run_after: List[str] = []
135136

136137
self.importer_spec = None

0 commit comments

Comments
 (0)