From baee79219605fbc475f36e36fa115aee8514dcd2 Mon Sep 17 00:00:00 2001 From: ddalvi Date: Sat, 16 Nov 2024 13:50:26 -0500 Subject: [PATCH] feat(sdk) Add SemaphoreKey and MutexName fields to DSL --- .../kfp/compiler/pipeline_spec_builder.py | 12 +++++++----- sdk/python/kfp/dsl/pipeline_config.py | 19 +++++++++++++++++-- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/sdk/python/kfp/compiler/pipeline_spec_builder.py b/sdk/python/kfp/compiler/pipeline_spec_builder.py index 0395146c80b..9ec24a8bab7 100644 --- a/sdk/python/kfp/compiler/pipeline_spec_builder.py +++ b/sdk/python/kfp/compiler/pipeline_spec_builder.py @@ -2011,11 +2011,13 @@ def write_pipeline_spec_to_file( def _merge_pipeline_config(pipelineConfig: pipeline_config.PipelineConfig, platformSpec: pipeline_spec_pb2.PlatformSpec): - # TODO: add pipeline config options (ttl, semaphore, etc.) to the dict - # json_format.ParseDict( - # {'pipelineConfig': { - # '': pipelineConfig., - # }}, platformSpec.platforms['kubernetes']) + json_format.ParseDict( + { + 'pipelineConfig': { + 'semaphoreKey': pipelineConfig.semaphore_key, + 'mutexName': pipelineConfig.mutex_name, + } + }, platformSpec.platforms['kubernetes']) return platformSpec diff --git a/sdk/python/kfp/dsl/pipeline_config.py b/sdk/python/kfp/dsl/pipeline_config.py index a4e90c28a01..8a730548d8b 100644 --- a/sdk/python/kfp/dsl/pipeline_config.py +++ b/sdk/python/kfp/dsl/pipeline_config.py @@ -18,6 +18,21 @@ class PipelineConfig: """PipelineConfig contains pipeline-level config options.""" def __init__(self): - pass + self.semaphore_key = None + self.mutex_name = None - # TODO add pipeline level configs + def set_semaphore_key(self, semaphore_key: str): + """Set the name of the semaphore to control pipeline concurrency. + + Args: + semaphore_key (str): Name of the semaphore. + """ + self.semaphore_key = semaphore_key.strip() + + def set_mutex_name(self, mutex_name: str): + """Set the name of the mutex to ensure mutual exclusion. + + Args: + mutex_name (str): Name of the mutex. + """ + self.mutex_name = mutex_name.strip()