Skip to content

Commit

Permalink
Enable or Disable cache for the Compiler
Browse files Browse the repository at this point in the history
Signed-off-by: Diego Lovison <diegolovison@gmail.com>
  • Loading branch information
diegolovison committed Sep 13, 2024
1 parent 123ed1e commit 87dafcc
Showing 4 changed files with 41 additions and 19 deletions.
11 changes: 10 additions & 1 deletion sdk/python/kfp/cli/compile_.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@

import click
from kfp import compiler
from kfp.cli.utils import parsing
from kfp.dsl import base_component
from kfp.dsl import graph_component

@@ -133,12 +134,19 @@ def parse_parameters(parameters: Optional[str]) -> Dict:
is_flag=True,
default=False,
help='Whether to disable type checking.')
@click.option(
'--enable-caching/--disable-caching',
type=bool,
default=None,
help=parsing.get_param_descr(compiler.Compiler.compile, 'enable_caching'),
)
def compile_(
py: str,
output: str,
function_name: Optional[str] = None,
pipeline_parameters: Optional[str] = None,
disable_type_check: bool = False,
enable_caching: Optional[bool] = None,
) -> None:
"""Compiles a pipeline or component written in a .py file."""
pipeline_func = collect_pipeline_or_component_func(
@@ -149,7 +157,8 @@ def compile_(
pipeline_func=pipeline_func,
pipeline_parameters=parsed_parameters,
package_path=package_path,
type_check=not disable_type_check)
type_check=not disable_type_check,
enable_caching=enable_caching)

click.echo(package_path)

19 changes: 3 additions & 16 deletions sdk/python/kfp/client/client.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@
from kfp.client import auth
from kfp.client import set_volume_credentials
from kfp.client.token_credentials_base import TokenCredentialsBase
from kfp.compiler.compiler import override_caching_options
from kfp.dsl import base_component
from kfp.pipeline_spec import pipeline_spec_pb2
import kfp_server_api
@@ -955,8 +956,8 @@ def _create_job_config(
# Caching option set at submission time overrides the compile time
# settings.
if enable_caching is not None:
_override_caching_options(pipeline_doc.pipeline_spec,
enable_caching)
override_caching_options(pipeline_doc.pipeline_spec,
enable_caching)
pipeline_spec = pipeline_doc.to_dict()

pipeline_version_reference = None
@@ -1676,17 +1677,3 @@ def _safe_load_yaml(stream: TextIO) -> _PipelineDoc:
raise ValueError(
f'The package_file {package_file} should end with one of the '
'following formats: [.tar.gz, .tgz, .zip, .yaml, .yml].')


def _override_caching_options(
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
enable_caching: bool,
) -> None:
"""Overrides caching options.
Args:
pipeline_spec: The PipelineSpec object to update in-place.
enable_caching: Overrides options, one of True, False.
"""
for _, task_spec in pipeline_spec.root.dag.tasks.items():
task_spec.caching_options.enable_cache = enable_caching
5 changes: 3 additions & 2 deletions sdk/python/kfp/client/client_test.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,8 @@
from google.protobuf import json_format
from kfp.client import auth
from kfp.client import client
from kfp.compiler import Compiler
from kfp.compiler.compiler import Compiler
from kfp.compiler.compiler import override_caching_options
from kfp.dsl import component
from kfp.dsl import pipeline
from kfp.pipeline_spec import pipeline_spec_pb2
@@ -88,7 +89,7 @@ def pipeline_with_two_component(text: str = 'hi there'):
pipeline_obj = yaml.safe_load(f)
pipeline_spec = json_format.ParseDict(
pipeline_obj, pipeline_spec_pb2.PipelineSpec())
client._override_caching_options(pipeline_spec, True)
override_caching_options(pipeline_spec, True)
pipeline_obj = json_format.MessageToDict(pipeline_spec)
self.assertTrue(pipeline_obj['root']['dag']['tasks']
['hello-word']['cachingOptions']['enableCache'])
25 changes: 25 additions & 0 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
from kfp.compiler import pipeline_spec_builder as builder
from kfp.dsl import base_component
from kfp.dsl.types import type_utils
from kfp.pipeline_spec import pipeline_spec_pb2


class Compiler:
@@ -53,6 +54,7 @@ def compile(
pipeline_name: Optional[str] = None,
pipeline_parameters: Optional[Dict[str, Any]] = None,
type_check: bool = True,
enable_caching: Optional[bool] = None,
) -> None:
"""Compiles the pipeline or component function into IR YAML.
@@ -62,6 +64,12 @@ def compile(
pipeline_name: Name of the pipeline.
pipeline_parameters: Map of parameter names to argument values.
type_check: Whether to enable type checking of component interfaces during compilation.
enable_caching: Whether or not to enable caching for the
run. If not set, defaults to the compile time settings, which
is ``True`` for all tasks by default, while users may specify
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
"""

with type_utils.TypeCheckManager(enable=type_check):
@@ -78,9 +86,26 @@ def compile(
pipeline_parameters=pipeline_parameters,
)

if enable_caching is not None:
override_caching_options(pipeline_spec, enable_caching)

builder.write_pipeline_spec_to_file(
pipeline_spec=pipeline_spec,
pipeline_description=pipeline_func.description,
platform_spec=pipeline_func.platform_spec,
package_path=package_path,
)


def override_caching_options(
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
enable_caching: bool,
) -> None:
"""Overrides caching options.
Args:
pipeline_spec: The PipelineSpec object to update in-place.
enable_caching: Overrides options, one of True, False.
"""
for _, task_spec in pipeline_spec.root.dag.tasks.items():
task_spec.caching_options.enable_cache = enable_caching

0 comments on commit 87dafcc

Please sign in to comment.