Skip to content

Commit 97258f8

Browse files
committed
temp title: change title
1 parent 99bd234 commit 97258f8

12 files changed

+680
-128
lines changed

sdk/python/kfp/compiler/compiler_test.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3382,31 +3382,31 @@ def simple_pipeline():
33823382
['exec-return-1']['container'])
33833383

33843384
self.assertEqual(
3385-
5, dict_format['deploymentSpec']['executors']['exec-return-1-2']
3386-
['container']['resources']['cpuLimit'])
3385+
'5', dict_format['deploymentSpec']['executors']['exec-return-1-2']
3386+
['container']['resources']['resourceCpuLimit'])
33873387
self.assertNotIn(
33883388
'memoryLimit', dict_format['deploymentSpec']['executors']
33893389
['exec-return-1-2']['container']['resources'])
33903390

33913391
self.assertEqual(
3392-
50, dict_format['deploymentSpec']['executors']['exec-return-1-3']
3393-
['container']['resources']['memoryLimit'])
3392+
'50G', dict_format['deploymentSpec']['executors']['exec-return-1-3']
3393+
['container']['resources']['resourceMemoryLimit'])
33943394
self.assertNotIn(
33953395
'cpuLimit', dict_format['deploymentSpec']['executors']
33963396
['exec-return-1-3']['container']['resources'])
33973397

33983398
self.assertEqual(
3399-
2, dict_format['deploymentSpec']['executors']['exec-return-1-4']
3400-
['container']['resources']['cpuRequest'])
3399+
'2', dict_format['deploymentSpec']['executors']['exec-return-1-4']
3400+
['container']['resources']['resourceCpuRequest'])
34013401
self.assertEqual(
3402-
5, dict_format['deploymentSpec']['executors']['exec-return-1-4']
3403-
['container']['resources']['cpuLimit'])
3402+
'5', dict_format['deploymentSpec']['executors']['exec-return-1-4']
3403+
['container']['resources']['resourceCpuLimit'])
34043404
self.assertEqual(
3405-
4, dict_format['deploymentSpec']['executors']['exec-return-1-4']
3406-
['container']['resources']['memoryRequest'])
3405+
'4G', dict_format['deploymentSpec']['executors']['exec-return-1-4']
3406+
['container']['resources']['resourceMemoryRequest'])
34073407
self.assertEqual(
3408-
50, dict_format['deploymentSpec']['executors']['exec-return-1-4']
3409-
['container']['resources']['memoryLimit'])
3408+
'50G', dict_format['deploymentSpec']['executors']['exec-return-1-4']
3409+
['container']['resources']['resourceMemoryLimit'])
34103410

34113411

34123412
class TestPlatformConfig(unittest.TestCase):

sdk/python/kfp/compiler/pipeline_spec_builder.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ def build_task_spec_for_task(
127127
if task._task_spec.retry_policy is not None:
128128
pipeline_task_spec.retry_policy.CopyFrom(
129129
task._task_spec.retry_policy.to_proto())
130+
131+
# Inject resource fields into inputs
132+
if task.container_spec and task.container_spec.resources:
133+
for key, val in task.container_spec.resources.__dict__.items():
134+
if val and pipeline_channel.extract_pipeline_channels_from_any(val):
135+
task.inputs[key] = val
130136

131137
for input_name, input_value in task.inputs.items():
132138
# Since LoopParameterArgument and LoopArtifactArgument and LoopArgumentVariable are narrower
@@ -607,6 +613,24 @@ def build_container_spec_for_task(
607613
Returns:
608614
A PipelineContainerSpec object for the task.
609615
"""
616+
def convert_to_placeholder(input_value: str) -> str:
617+
"""Checks if input is a pipeline channel and if so, converts to
618+
compiler injected input name."""
619+
pipeline_channels = (
620+
pipeline_channel.extract_pipeline_channels_from_any(input_value)
621+
)
622+
if pipeline_channels:
623+
assert len(pipeline_channels) == 1
624+
channel = pipeline_channels[0]
625+
additional_input_name = (
626+
compiler_utils.additional_input_name_for_pipeline_channel(
627+
channel))
628+
additional_input_placeholder = placeholders.InputValuePlaceholder(
629+
additional_input_name)._to_string()
630+
input_value = input_value.replace(
631+
channel.pattern, additional_input_placeholder)
632+
return input_value
633+
610634
container_spec = (
611635
pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec(
612636
image=task.container_spec.image,
@@ -620,23 +644,23 @@ def build_container_spec_for_task(
620644

621645
if task.container_spec.resources is not None:
622646
if task.container_spec.resources.cpu_request is not None:
623-
container_spec.resources.cpu_request = (
624-
task.container_spec.resources.cpu_request)
647+
container_spec.resources.resource_cpu_request = (
648+
convert_to_placeholder(task.container_spec.resources.cpu_request))
625649
if task.container_spec.resources.cpu_limit is not None:
626-
container_spec.resources.cpu_limit = (
627-
task.container_spec.resources.cpu_limit)
650+
container_spec.resources.resource_cpu_limit = (
651+
convert_to_placeholder(task.container_spec.resources.cpu_limit))
628652
if task.container_spec.resources.memory_request is not None:
629-
container_spec.resources.memory_request = (
630-
task.container_spec.resources.memory_request)
653+
container_spec.resources.resource_memory_request = (
654+
convert_to_placeholder(task.container_spec.resources.memory_request))
631655
if task.container_spec.resources.memory_limit is not None:
632-
container_spec.resources.memory_limit = (
633-
task.container_spec.resources.memory_limit)
656+
container_spec.resources.resource_memory_limit = (
657+
convert_to_placeholder(task.container_spec.resources.memory_limit))
634658
if task.container_spec.resources.accelerator_count is not None:
635659
container_spec.resources.accelerator.CopyFrom(
636660
pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec
637661
.ResourceSpec.AcceleratorConfig(
638-
type=task.container_spec.resources.accelerator_type,
639-
count=task.container_spec.resources.accelerator_count,
662+
resource_type=convert_to_placeholder(task.container_spec.resources.accelerator_type),
663+
resource_count=convert_to_placeholder(task.container_spec.resources.accelerator_count),
640664
))
641665

642666
return container_spec

sdk/python/kfp/dsl/pipeline_task.py

Lines changed: 40 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from kfp.dsl import placeholders
2828
from kfp.dsl import structures
2929
from kfp.dsl import utils
30+
from kfp.dsl import pipeline_channel
3031
from kfp.dsl.types import type_utils
3132
from kfp.local import pipeline_orchestrator
3233
from kfp.pipeline_spec import pipeline_spec_pb2
@@ -321,9 +322,9 @@ def _ensure_container_spec_exists(self) -> None:
321322
f'{caller_method_name} can only be used on single-step components, not pipelines used as components, or special components like importers.'
322323
)
323324

324-
def _validate_cpu_request_limit(self, cpu: str) -> float:
325+
def _validate_cpu_request_limit(self, cpu: str) -> str:
325326
"""Validates cpu request/limit string and converts to its numeric
326-
value.
327+
string value.
327328
328329
Args:
329330
cpu: CPU requests or limits. This string should be a number or a
@@ -335,17 +336,19 @@ def _validate_cpu_request_limit(self, cpu: str) -> float:
335336
ValueError if the cpu request/limit string value is invalid.
336337
337338
Returns:
338-
The numeric value (float) of the cpu request/limit.
339+
The numeric string of the cpu request/limit.
339340
"""
340-
if re.match(r'([0-9]*[.])?[0-9]+m?$', cpu) is None:
341-
raise ValueError(
342-
'Invalid cpu string. Should be float or integer, or integer'
343-
' followed by "m".')
344-
345-
return float(cpu[:-1]) / 1000 if cpu.endswith('m') else float(cpu)
341+
if isinstance(cpu, pipeline_channel.PipelineChannel):
342+
cpu = str(cpu)
343+
else:
344+
if re.match(r'([0-9]*[.])?[0-9]+m?$', cpu) is None:
345+
raise ValueError(
346+
'Invalid cpu string. Should be float or integer, or integer'
347+
' followed by "m".')
348+
return cpu
346349

347350
@block_if_final()
348-
def set_cpu_request(self, cpu: str) -> 'PipelineTask':
351+
def set_cpu_request(self, cpu: Union[str, pipeline_channel.PipelineChannel]) -> 'PipelineTask':
349352
"""Sets CPU request (minimum) for the task.
350353
351354
Args:
@@ -370,7 +373,7 @@ def set_cpu_request(self, cpu: str) -> 'PipelineTask':
370373
return self
371374

372375
@block_if_final()
373-
def set_cpu_limit(self, cpu: str) -> 'PipelineTask':
376+
def set_cpu_limit(self, cpu: Union[str, pipeline_channel.PipelineChannel]) -> 'PipelineTask':
374377
"""Sets CPU limit (maximum) for the task.
375378
376379
Args:
@@ -395,7 +398,7 @@ def set_cpu_limit(self, cpu: str) -> 'PipelineTask':
395398
return self
396399

397400
@block_if_final()
398-
def set_accelerator_limit(self, limit: int) -> 'PipelineTask':
401+
def set_accelerator_limit(self, limit: Union[int, str, pipeline_channel.PipelineChannel]) -> 'PipelineTask':
399402
"""Sets accelerator limit (maximum) for the task. Only applies if
400403
accelerator type is also set via .set_accelerator_type().
401404
@@ -406,11 +409,13 @@ def set_accelerator_limit(self, limit: int) -> 'PipelineTask':
406409
Self return to allow chained setting calls.
407410
"""
408411
self._ensure_container_spec_exists()
409-
410-
if isinstance(limit, str):
411-
if re.match(r'[1-9]\d*$', limit) is None:
412-
raise ValueError(f'{"limit"!r} must be positive integer.')
413-
limit = int(limit)
412+
if isinstance(limit, pipeline_channel.PipelineChannel):
413+
limit = str(limit)
414+
else:
415+
if isinstance(limit, int):
416+
limit = str(limit)
417+
if isinstance(limit, str) and re.match(r'^0$|^1$|^2$|^4$|^8$|^16$', limit) is None:
418+
raise ValueError(f'{"limit"!r} must be one of 0, 1, 2, 4, 8, 16.')
414419

415420
if self.container_spec.resources is not None:
416421
self.container_spec.resources.accelerator_count = limit
@@ -438,9 +443,9 @@ def set_gpu_limit(self, gpu: str) -> 'PipelineTask':
438443
category=DeprecationWarning)
439444
return self.set_accelerator_limit(gpu)
440445

441-
def _validate_memory_request_limit(self, memory: str) -> float:
446+
def _validate_memory_request_limit(self, memory: str) -> str:
442447
"""Validates memory request/limit string and converts to its numeric
443-
value.
448+
string value.
444449
445450
Args:
446451
memory: Memory requests or limits. This string should be a number or
@@ -451,47 +456,21 @@ def _validate_memory_request_limit(self, memory: str) -> float:
451456
ValueError if the memory request/limit string value is invalid.
452457
453458
Returns:
454-
The numeric value (float) of the memory request/limit.
459+
The numeric string value of the memory request/limit.
455460
"""
456-
if re.match(r'^[0-9]+(E|Ei|P|Pi|T|Ti|G|Gi|M|Mi|K|Ki){0,1}$',
457-
memory) is None:
458-
raise ValueError(
459-
'Invalid memory string. Should be a number or a number '
460-
'followed by one of "E", "Ei", "P", "Pi", "T", "Ti", "G", '
461-
'"Gi", "M", "Mi", "K", "Ki".')
462-
463-
if memory.endswith('E'):
464-
memory = float(memory[:-1]) * constants._E / constants._G
465-
elif memory.endswith('Ei'):
466-
memory = float(memory[:-2]) * constants._EI / constants._G
467-
elif memory.endswith('P'):
468-
memory = float(memory[:-1]) * constants._P / constants._G
469-
elif memory.endswith('Pi'):
470-
memory = float(memory[:-2]) * constants._PI / constants._G
471-
elif memory.endswith('T'):
472-
memory = float(memory[:-1]) * constants._T / constants._G
473-
elif memory.endswith('Ti'):
474-
memory = float(memory[:-2]) * constants._TI / constants._G
475-
elif memory.endswith('G'):
476-
memory = float(memory[:-1])
477-
elif memory.endswith('Gi'):
478-
memory = float(memory[:-2]) * constants._GI / constants._G
479-
elif memory.endswith('M'):
480-
memory = float(memory[:-1]) * constants._M / constants._G
481-
elif memory.endswith('Mi'):
482-
memory = float(memory[:-2]) * constants._MI / constants._G
483-
elif memory.endswith('K'):
484-
memory = float(memory[:-1]) * constants._K / constants._G
485-
elif memory.endswith('Ki'):
486-
memory = float(memory[:-2]) * constants._KI / constants._G
461+
if isinstance(memory, pipeline_channel.PipelineChannel):
462+
memory = str(memory)
487463
else:
488-
# By default interpret as a plain integer, in the unit of Bytes.
489-
memory = float(memory) / constants._G
490-
464+
if re.match(r'^[0-9]+(E|Ei|P|Pi|T|Ti|G|Gi|M|Mi|K|Ki){0,1}$',
465+
memory) is None:
466+
raise ValueError(
467+
'Invalid memory string. Should be a number or a number '
468+
'followed by one of "E", "Ei", "P", "Pi", "T", "Ti", "G", '
469+
'"Gi", "M", "Mi", "K", "Ki".')
491470
return memory
492471

493472
@block_if_final()
494-
def set_memory_request(self, memory: str) -> 'PipelineTask':
473+
def set_memory_request(self, memory: Union[str, pipeline_channel.PipelineChannel]) -> 'PipelineTask':
495474
"""Sets memory request (minimum) for the task.
496475
497476
Args:
@@ -515,7 +494,7 @@ def set_memory_request(self, memory: str) -> 'PipelineTask':
515494
return self
516495

517496
@block_if_final()
518-
def set_memory_limit(self, memory: str) -> 'PipelineTask':
497+
def set_memory_limit(self, memory: Union[str, pipeline_channel.PipelineChannel]) -> 'PipelineTask':
519498
"""Sets memory limit (maximum) for the task.
520499
521500
Args:
@@ -579,7 +558,7 @@ def add_node_selector_constraint(self, accelerator: str) -> 'PipelineTask':
579558
return self.set_accelerator_type(accelerator)
580559

581560
@block_if_final()
582-
def set_accelerator_type(self, accelerator: str) -> 'PipelineTask':
561+
def set_accelerator_type(self, accelerator: Union[str, pipeline_channel.PipelineChannel]) -> 'PipelineTask':
583562
"""Sets accelerator type to use when executing this task.
584563
585564
Args:
@@ -589,14 +568,16 @@ def set_accelerator_type(self, accelerator: str) -> 'PipelineTask':
589568
Self return to allow chained setting calls.
590569
"""
591570
self._ensure_container_spec_exists()
571+
if isinstance(accelerator, pipeline_channel.PipelineChannel):
572+
accelerator = str(accelerator)
592573

593574
if self.container_spec.resources is not None:
594575
self.container_spec.resources.accelerator_type = accelerator
595576
if self.container_spec.resources.accelerator_count is None:
596-
self.container_spec.resources.accelerator_count = 1
577+
self.container_spec.resources.accelerator_count = '1'
597578
else:
598579
self.container_spec.resources = structures.ResourceSpec(
599-
accelerator_count=1, accelerator_type=accelerator)
580+
accelerator_count='1', accelerator_type=accelerator)
600581

601582
return self
602583

0 commit comments

Comments
 (0)