diff --git a/sdk/python/kfp/dsl/tasks_group.py b/sdk/python/kfp/dsl/tasks_group.py index c19fed788dc..6e711572cb2 100644 --- a/sdk/python/kfp/dsl/tasks_group.py +++ b/sdk/python/kfp/dsl/tasks_group.py @@ -22,6 +22,7 @@ from kfp.dsl import pipeline_channel from kfp.dsl import pipeline_context from kfp.dsl import pipeline_task +from kfp.dsl.pipeline_channel import PipelineParameterChannel class TasksGroupType(str, enum.Enum): @@ -130,7 +131,9 @@ def __init__( is_root=False, ) - if exit_task.dependent_tasks: + self.exit_task = exit_task + + if self.__has_dependent_tasks(): raise ValueError('exit_task cannot depend on any other tasks.') # Removing exit_task form any group @@ -140,7 +143,19 @@ def __init__( # Set is_exit_handler since the compiler might be using this attribute. exit_task.is_exit_handler = True - self.exit_task = exit_task + def __has_dependent_tasks(self) -> bool: + if self.exit_task.dependent_tasks: + return True + + if not self.exit_task.inputs: + return False + + for task_input in self.exit_task.inputs.values(): + if isinstance( + task_input, + PipelineParameterChannel) and task_input.task is not None: + return True + return False class ConditionBranches(TasksGroup):