From 80df2b04553f66119fe13c84a14da1d84b3bc46f Mon Sep 17 00:00:00 2001 From: ntny Date: Mon, 23 Sep 2024 23:32:37 +0300 Subject: [PATCH] check that requested cpu/memory less than or equals according limits Signed-off-by: ntny --- sdk/python/kfp/dsl/pipeline_task.py | 44 +++++ sdk/python/kfp/dsl/pipeline_task_test.py | 212 ++++++++++++++++++++++- 2 files changed, 249 insertions(+), 7 deletions(-) diff --git a/sdk/python/kfp/dsl/pipeline_task.py b/sdk/python/kfp/dsl/pipeline_task.py index 773fb1e06765..15372fac7a20 100644 --- a/sdk/python/kfp/dsl/pipeline_task.py +++ b/sdk/python/kfp/dsl/pipeline_task.py @@ -312,6 +312,15 @@ def set_caching_options(self, enable_caching: bool) -> 'PipelineTask': self._task_spec.enable_caching = enable_caching return self + def _ensure_resource_requests_meet_limits(self) -> None: + resources = self.container_spec.resources + if (resources.memory_request is not None + and resources.memory_limit is not None + and self._parse_memory_str_request(resources.memory_request) > self._parse_memory_str_request(resources.memory_limit)): + raise ValueError(f'Requested memory: {resources.memory_request} cannot be greater than memory limit: {resources.memory_limit}. ' + 'Check the set_memory_request and set_memory_limit parameters.') + + def _ensure_container_spec_exists(self) -> None: """Ensures that the task has a container spec.""" caller_method_name = inspect.stack()[1][3] @@ -452,6 +461,37 @@ def set_gpu_limit(self, gpu: str) -> 'PipelineTask': category=DeprecationWarning) return self.set_accelerator_limit(gpu) + def _parse_memory_str_request(self, memory_str: str) -> float: + memory_request = float(0) + if memory_str.endswith('E'): + memory_request = float(memory_str[:-1]) * constants._E / constants._G + elif memory_str.endswith('Ei'): + memory_request = float(memory_str[:-2]) * constants._EI / constants._G + elif memory_str.endswith('P'): + memory_request = float(memory_str[:-1]) * constants._P / constants._G + elif memory_str.endswith('Pi'): + memory_request = float(memory_str[:-2]) * constants._PI / constants._G + elif memory_str.endswith('T'): + memory_request = float(memory_str[:-1]) * constants._T / constants._G + elif memory_str.endswith('Ti'): + memory_request = float(memory_str[:-2]) * constants._TI / constants._G + elif memory_str.endswith('G'): + memory_request = float(memory_str[:-1]) + elif memory_str.endswith('Gi'): + memory_request = float(memory_str[:-2]) * constants._GI / constants._G + elif memory_str.endswith('M'): + memory_request = float(memory_str[:-1]) * constants._M / constants._G + elif memory_str.endswith('Mi'): + memory_request = float(memory_str[:-2]) * constants._MI / constants._G + elif memory_str.endswith('K'): + memory_request = float(memory_str[:-1]) * constants._K / constants._G + elif memory_str.endswith('Ki'): + memory_request = float(memory_str[:-2]) * constants._KI / constants._G + else: + # By default interpret as a plain integer, in the unit of Bytes. + memory_request = float(memory_str) / constants._G + return memory_request + def _validate_memory_request_limit(self, memory: str) -> str: """Validates memory request/limit string and converts to its numeric string value. @@ -503,6 +543,8 @@ def set_memory_request( self.container_spec.resources = structures.ResourceSpec( memory_request=memory) + self._ensure_resource_requests_meet_limits() + return self @block_if_final() @@ -530,6 +572,8 @@ def set_memory_limit( self.container_spec.resources = structures.ResourceSpec( memory_limit=memory) + self._ensure_resource_requests_meet_limits() + return self @block_if_final() diff --git a/sdk/python/kfp/dsl/pipeline_task_test.py b/sdk/python/kfp/dsl/pipeline_task_test.py index 8543058b8268..152630a22b34 100644 --- a/sdk/python/kfp/dsl/pipeline_task_test.py +++ b/sdk/python/kfp/dsl/pipeline_task_test.py @@ -176,7 +176,7 @@ def test_set_valid_cpu_request_limit(self, cpu: str, expected_cpu: str): { 'gpu_limit': '1', 'expected_gpu_number': '1', - },) + }, ) def test_set_valid_gpu_limit(self, gpu_limit: str, expected_gpu_number: str): task = pipeline_task.PipelineTask( @@ -231,6 +231,207 @@ def test_set_accelerator_limit(self, limit, expected_limit): self.assertEqual(expected_limit, task.container_spec.resources.accelerator_count) + @parameterized.parameters( + { + 'memory': '2E', + 'limit': '1E', + }, + { + 'memory': '3Ei', + 'limit': '2Ei', + }, + { + 'memory': '20P', + 'limit': '2P', + }, + { + 'memory': '2P', + 'limit': '1999T', + }, + { + 'memory': '3P', + 'limit': '2000T', + }, + { + 'memory': '25Pi', + 'limit': '24Pi', + }, + { + 'memory': '1Pi', + 'limit': '1023Ti', + }, + { + 'memory': '14T', + 'limit': '4T', + }, + { + 'memory': '4T', + 'limit': '3999G', + }, + { + 'memory': '1P', + 'limit': '999T', + }, + { + 'memory': '1Ti', + 'limit': '999Gi', + }, + { + 'memory': '14G', + 'limit': '9G', + }, + { + 'memory': '1G', + 'limit': '999999K', + }, + { + 'memory': '1Gi', + 'limit': '1000M', + }, + { + 'memory': '10Gi', + 'limit': '9Gi', + }, + { + 'memory': '15M', + 'limit': '5M', + }, + { + 'memory': '5Mi', + 'limit': '5M', + }, + { + 'memory': '95Mi', + 'limit': '94Mi', + }, + { + 'memory': '6K', + 'limit': '5K', + }, + { + 'memory': '100Ki', + 'limit': '65Ki', + }, + { + 'memory': '1Mi', + 'limit': '10Ki', + }, + { + 'memory': '7001', + 'limit': '7000', + }, + ) + def test_set_memory_request_greater_than_limit_should_raise(self, memory: str, limit: str): + task = pipeline_task.PipelineTask( + component_spec=structures.ComponentSpec.from_yaml_documents( + V2_YAML), + args={'input1': 'value'}, + ) + with self.assertRaisesRegex( + ValueError, + f'Requested memory: {memory} cannot be greater than memory limit: {limit}. ' + 'Check the set_memory_request and set_memory_limit parameters.'): + task.set_memory_request(memory).set_memory_limit(limit) + + @parameterized.parameters( + { + 'memory': '1E', + 'limit': '2E', + }, + { + 'memory': '55Ei', + 'limit': '150Ei', + }, + { + 'memory': '2P', + 'limit': '20P', + }, + { + 'memory': '3P', + 'limit': '3000T', + }, + { + 'memory': '25Pi', + 'limit': '25Pi', + }, + { + 'memory': '1Pi', + 'limit': '1024Ti', + }, + { + 'memory': '4T', + 'limit': '14T', + }, + { + 'memory': '4T', + 'limit': '4000G', + }, + { + 'memory': '4T', + 'limit': '1P', + }, + { + 'memory': '1Ti', + 'limit': '1024Gi', + }, + { + 'memory': '4G', + 'limit': '14G', + }, + { + 'memory': '1G', + 'limit': '1000M', + }, + { + 'memory': '1Gi', + 'limit': '1024Mi', + }, + { + 'memory': '45Gi', + 'limit': '100Gi', + }, + { + 'memory': '5M', + 'limit': '15M', + }, + { + 'memory': '5M', + 'limit': '5Mi', + }, + { + 'memory': '95Mi', + 'limit': '95Mi', + }, + { + 'memory': '6K', + 'limit': '7K', + }, + { + 'memory': '65Ki', + 'limit': '100Ki', + }, + { + 'memory': '10Ki', + 'limit': '1Mi', + }, + { + 'memory': '7000', + 'limit': '7001', + }, + ) + def test_set_memory_request_and_limit(self, memory: str, limit: str): + task = pipeline_task.PipelineTask( + component_spec=structures.ComponentSpec.from_yaml_documents( + V2_YAML), + args={'input1': 'value'}, + ) + task.set_memory_request(memory) + self.assertEqual(memory, + task.container_spec.resources.memory_request) + task.set_memory_limit(limit) + self.assertEqual(limit, + task.container_spec.resources.memory_limit) + @parameterized.parameters( { 'memory': '1E', @@ -341,7 +542,6 @@ def test_set_display_name(self): self.assertEqual('test_name', task._task_spec.display_name) def test_set_cpu_limit_on_pipeline_should_raise(self): - @dsl.component def comp(): print('hello') @@ -354,7 +554,6 @@ def graph(): with self.assertRaisesRegex( ValueError, r'set_cpu_limit can only be used on single-step components'): - @dsl.pipeline def my_pipeline(): graph().set_cpu_limit('1') @@ -363,7 +562,6 @@ def my_pipeline(): class TestPlatformSpecificFunctionality(unittest.TestCase): def test_platform_config_to_platform_spec(self): - @dsl.component def comp(): pass @@ -496,9 +694,9 @@ def test_sampling_of_task_configuration_methods(self): def assert_artifacts_equal( - test_class: unittest.TestCase, - a1: dsl.Artifact, - a2: dsl.Artifact, + test_class: unittest.TestCase, + a1: dsl.Artifact, + a2: dsl.Artifact, ) -> None: test_class.assertEqual(a1.name, a2.name) test_class.assertEqual(a1.uri, a2.uri)