27
27
from kfp .dsl import placeholders
28
28
from kfp .dsl import structures
29
29
from kfp .dsl import utils
30
+ from kfp .dsl import pipeline_channel
30
31
from kfp .dsl .types import type_utils
31
32
from kfp .local import pipeline_orchestrator
32
33
from kfp .pipeline_spec import pipeline_spec_pb2
@@ -321,9 +322,9 @@ def _ensure_container_spec_exists(self) -> None:
321
322
f'{ caller_method_name } can only be used on single-step components, not pipelines used as components, or special components like importers.'
322
323
)
323
324
324
- def _validate_cpu_request_limit (self , cpu : str ) -> float :
325
+ def _validate_cpu_request_limit (self , cpu : str ) -> str :
325
326
"""Validates cpu request/limit string and converts to its numeric
326
- value.
327
+ string value.
327
328
328
329
Args:
329
330
cpu: CPU requests or limits. This string should be a number or a
@@ -335,17 +336,19 @@ def _validate_cpu_request_limit(self, cpu: str) -> float:
335
336
ValueError if the cpu request/limit string value is invalid.
336
337
337
338
Returns:
338
- The numeric value (float) of the cpu request/limit.
339
+ The numeric string of the cpu request/limit.
339
340
"""
340
- if re .match (r'([0-9]*[.])?[0-9]+m?$' , cpu ) is None :
341
- raise ValueError (
342
- 'Invalid cpu string. Should be float or integer, or integer'
343
- ' followed by "m".' )
344
-
345
- return float (cpu [:- 1 ]) / 1000 if cpu .endswith ('m' ) else float (cpu )
341
+ if isinstance (cpu , pipeline_channel .PipelineChannel ):
342
+ cpu = str (cpu )
343
+ else :
344
+ if re .match (r'([0-9]*[.])?[0-9]+m?$' , cpu ) is None :
345
+ raise ValueError (
346
+ 'Invalid cpu string. Should be float or integer, or integer'
347
+ ' followed by "m".' )
348
+ return cpu
346
349
347
350
@block_if_final ()
348
- def set_cpu_request (self , cpu : str ) -> 'PipelineTask' :
351
+ def set_cpu_request (self , cpu : Union [ str , pipeline_channel . PipelineChannel ] ) -> 'PipelineTask' :
349
352
"""Sets CPU request (minimum) for the task.
350
353
351
354
Args:
@@ -370,7 +373,7 @@ def set_cpu_request(self, cpu: str) -> 'PipelineTask':
370
373
return self
371
374
372
375
@block_if_final ()
373
- def set_cpu_limit (self , cpu : str ) -> 'PipelineTask' :
376
+ def set_cpu_limit (self , cpu : Union [ str , pipeline_channel . PipelineChannel ] ) -> 'PipelineTask' :
374
377
"""Sets CPU limit (maximum) for the task.
375
378
376
379
Args:
@@ -395,7 +398,7 @@ def set_cpu_limit(self, cpu: str) -> 'PipelineTask':
395
398
return self
396
399
397
400
@block_if_final ()
398
- def set_accelerator_limit (self , limit : int ) -> 'PipelineTask' :
401
+ def set_accelerator_limit (self , limit : Union [ int , str , pipeline_channel . PipelineChannel ] ) -> 'PipelineTask' :
399
402
"""Sets accelerator limit (maximum) for the task. Only applies if
400
403
accelerator type is also set via .set_accelerator_type().
401
404
@@ -406,11 +409,13 @@ def set_accelerator_limit(self, limit: int) -> 'PipelineTask':
406
409
Self return to allow chained setting calls.
407
410
"""
408
411
self ._ensure_container_spec_exists ()
409
-
410
- if isinstance (limit , str ):
411
- if re .match (r'[1-9]\d*$' , limit ) is None :
412
- raise ValueError (f'{ "limit" !r} must be positive integer.' )
413
- limit = int (limit )
412
+ if isinstance (limit , pipeline_channel .PipelineChannel ):
413
+ limit = str (limit )
414
+ else :
415
+ if isinstance (limit , int ):
416
+ limit = str (limit )
417
+ if isinstance (limit , str ) and re .match (r'^0$|^1$|^2$|^4$|^8$|^16$' , limit ) is None :
418
+ raise ValueError (f'{ "limit" !r} must be one of 0, 1, 2, 4, 8, 16.' )
414
419
415
420
if self .container_spec .resources is not None :
416
421
self .container_spec .resources .accelerator_count = limit
@@ -438,9 +443,9 @@ def set_gpu_limit(self, gpu: str) -> 'PipelineTask':
438
443
category = DeprecationWarning )
439
444
return self .set_accelerator_limit (gpu )
440
445
441
- def _validate_memory_request_limit (self , memory : str ) -> float :
446
+ def _validate_memory_request_limit (self , memory : str ) -> str :
442
447
"""Validates memory request/limit string and converts to its numeric
443
- value.
448
+ string value.
444
449
445
450
Args:
446
451
memory: Memory requests or limits. This string should be a number or
@@ -451,47 +456,21 @@ def _validate_memory_request_limit(self, memory: str) -> float:
451
456
ValueError if the memory request/limit string value is invalid.
452
457
453
458
Returns:
454
- The numeric value (float) of the memory request/limit.
459
+ The numeric string value of the memory request/limit.
455
460
"""
456
- if re .match (r'^[0-9]+(E|Ei|P|Pi|T|Ti|G|Gi|M|Mi|K|Ki){0,1}$' ,
457
- memory ) is None :
458
- raise ValueError (
459
- 'Invalid memory string. Should be a number or a number '
460
- 'followed by one of "E", "Ei", "P", "Pi", "T", "Ti", "G", '
461
- '"Gi", "M", "Mi", "K", "Ki".' )
462
-
463
- if memory .endswith ('E' ):
464
- memory = float (memory [:- 1 ]) * constants ._E / constants ._G
465
- elif memory .endswith ('Ei' ):
466
- memory = float (memory [:- 2 ]) * constants ._EI / constants ._G
467
- elif memory .endswith ('P' ):
468
- memory = float (memory [:- 1 ]) * constants ._P / constants ._G
469
- elif memory .endswith ('Pi' ):
470
- memory = float (memory [:- 2 ]) * constants ._PI / constants ._G
471
- elif memory .endswith ('T' ):
472
- memory = float (memory [:- 1 ]) * constants ._T / constants ._G
473
- elif memory .endswith ('Ti' ):
474
- memory = float (memory [:- 2 ]) * constants ._TI / constants ._G
475
- elif memory .endswith ('G' ):
476
- memory = float (memory [:- 1 ])
477
- elif memory .endswith ('Gi' ):
478
- memory = float (memory [:- 2 ]) * constants ._GI / constants ._G
479
- elif memory .endswith ('M' ):
480
- memory = float (memory [:- 1 ]) * constants ._M / constants ._G
481
- elif memory .endswith ('Mi' ):
482
- memory = float (memory [:- 2 ]) * constants ._MI / constants ._G
483
- elif memory .endswith ('K' ):
484
- memory = float (memory [:- 1 ]) * constants ._K / constants ._G
485
- elif memory .endswith ('Ki' ):
486
- memory = float (memory [:- 2 ]) * constants ._KI / constants ._G
461
+ if isinstance (memory , pipeline_channel .PipelineChannel ):
462
+ memory = str (memory )
487
463
else :
488
- # By default interpret as a plain integer, in the unit of Bytes.
489
- memory = float (memory ) / constants ._G
490
-
464
+ if re .match (r'^[0-9]+(E|Ei|P|Pi|T|Ti|G|Gi|M|Mi|K|Ki){0,1}$' ,
465
+ memory ) is None :
466
+ raise ValueError (
467
+ 'Invalid memory string. Should be a number or a number '
468
+ 'followed by one of "E", "Ei", "P", "Pi", "T", "Ti", "G", '
469
+ '"Gi", "M", "Mi", "K", "Ki".' )
491
470
return memory
492
471
493
472
@block_if_final ()
494
- def set_memory_request (self , memory : str ) -> 'PipelineTask' :
473
+ def set_memory_request (self , memory : Union [ str , pipeline_channel . PipelineChannel ] ) -> 'PipelineTask' :
495
474
"""Sets memory request (minimum) for the task.
496
475
497
476
Args:
@@ -515,7 +494,7 @@ def set_memory_request(self, memory: str) -> 'PipelineTask':
515
494
return self
516
495
517
496
@block_if_final ()
518
- def set_memory_limit (self , memory : str ) -> 'PipelineTask' :
497
+ def set_memory_limit (self , memory : Union [ str , pipeline_channel . PipelineChannel ] ) -> 'PipelineTask' :
519
498
"""Sets memory limit (maximum) for the task.
520
499
521
500
Args:
@@ -579,7 +558,7 @@ def add_node_selector_constraint(self, accelerator: str) -> 'PipelineTask':
579
558
return self .set_accelerator_type (accelerator )
580
559
581
560
@block_if_final ()
582
- def set_accelerator_type (self , accelerator : str ) -> 'PipelineTask' :
561
+ def set_accelerator_type (self , accelerator : Union [ str , pipeline_channel . PipelineChannel ] ) -> 'PipelineTask' :
583
562
"""Sets accelerator type to use when executing this task.
584
563
585
564
Args:
@@ -589,14 +568,16 @@ def set_accelerator_type(self, accelerator: str) -> 'PipelineTask':
589
568
Self return to allow chained setting calls.
590
569
"""
591
570
self ._ensure_container_spec_exists ()
571
+ if isinstance (accelerator , pipeline_channel .PipelineChannel ):
572
+ accelerator = str (accelerator )
592
573
593
574
if self .container_spec .resources is not None :
594
575
self .container_spec .resources .accelerator_type = accelerator
595
576
if self .container_spec .resources .accelerator_count is None :
596
- self .container_spec .resources .accelerator_count = 1
577
+ self .container_spec .resources .accelerator_count = '1'
597
578
else :
598
579
self .container_spec .resources = structures .ResourceSpec (
599
- accelerator_count = 1 , accelerator_type = accelerator )
580
+ accelerator_count = '1' , accelerator_type = accelerator )
600
581
601
582
return self
602
583
0 commit comments