Skip to content

Commit

Permalink
Fix issue where tasks preceeding parallel for loops that recieve pipe…
Browse files Browse the repository at this point in the history
…line parameters, are wrongly expected to have task attributes

Signed-off-by: David Farrington <david@shipit.ltd>
  • Loading branch information
farridav committed Jan 2, 2025
1 parent 2ebb853 commit e0fd54b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
47 changes: 30 additions & 17 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def test_set_description_through_pipeline_decorator(self):

@dsl.pipeline(description='Prefer me.')
def my_pipeline():
"""Don't prefer me"""
"""Don't prefer me."""
VALID_PRODUCER_COMPONENT_SAMPLE(input_param='input')

self.assertEqual(my_pipeline.pipeline_spec.pipeline_info.description,
Expand All @@ -441,7 +441,8 @@ def test_set_description_through_pipeline_docstring_long(self):
def my_pipeline():
"""Docstring-specified description.
More information about this pipeline."""
More information about this pipeline.
"""
VALID_PRODUCER_COMPONENT_SAMPLE(input_param='input')

self.assertEqual(
Expand Down Expand Up @@ -2429,6 +2430,7 @@ def pipeline_with_multiline_definition(
sample_input1: bool = True,
sample_input2: str = 'string') -> str:
"""docstring short description.
docstring long description. docstring long description.
"""
op1 = my_comp(string=sample_input2, model=sample_input1)
Expand All @@ -2455,10 +2457,9 @@ def pipeline_with_multiline_definition(
def pipeline_with_multiline_definition(
sample_input1: bool = True,
sample_input2: str = 'string') -> str:
"""
docstring long description.
docstring long description.
docstring long description.
"""docstring long description.
docstring long description. docstring long description.
"""
op1 = my_comp(string=sample_input2, model=sample_input1)
result = op1.output
Expand Down Expand Up @@ -2487,8 +2488,8 @@ def test_idempotency_on_comment_with_multiline_docstring(self):
def my_pipeline(sample_input1: bool = True,
sample_input2: str = 'string') -> str:
"""docstring short description.
docstring long description.
docstring long description.
docstring long description. docstring long description.
"""
op1 = my_comp(string=sample_input2, model=sample_input1)
result = op1.output
Expand Down Expand Up @@ -4144,7 +4145,7 @@ def my_pipeline(
string: str,
in_artifact: Input[Artifact],
) -> Outputs:
"""Pipeline description. Returns
"""Pipeline description. Returns.
Args:
string: Return Pipeline input string. Returns
Expand Down Expand Up @@ -4607,7 +4608,9 @@ class TestDslOneOf(unittest.TestCase):
# To help narrow the tests further (we already test lots of aspects in the following cases), we choose focus on the dsl.OneOf behavior, not the conditional logic if If/Elif/Else. This is more verbose, but more maintainable and the behavior under test is clearer.

def test_if_else_returned(self):
"""Uses If and Else branches, parameters passed to dsl.OneOf, dsl.OneOf returned from a pipeline, and different output keys on dsl.OneOf channels."""
"""Uses If and Else branches, parameters passed to dsl.OneOf, dsl.OneOf
returned from a pipeline, and different output keys on dsl.OneOf
channels."""

@dsl.pipeline
def roll_die_pipeline() -> str:
Expand Down Expand Up @@ -4668,7 +4671,9 @@ def roll_die_pipeline() -> str:
)

def test_if_elif_else_returned(self):
"""Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, dsl.OneOf returned from a pipeline, and different output keys on dsl.OneOf channels."""
"""Uses If, Elif, and Else branches, parameters passed to dsl.OneOf,
dsl.OneOf returned from a pipeline, and different output keys on
dsl.OneOf channels."""

@dsl.pipeline
def roll_die_pipeline() -> str:
Expand Down Expand Up @@ -4743,7 +4748,9 @@ def roll_die_pipeline() -> str:
)

def test_if_elif_else_consumed(self):
"""Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, dsl.OneOf passed to a consumer task, and different output keys on dsl.OneOf channels."""
"""Uses If, Elif, and Else branches, parameters passed to dsl.OneOf,
dsl.OneOf passed to a consumer task, and different output keys on
dsl.OneOf channels."""

@dsl.pipeline
def roll_die_pipeline():
Expand Down Expand Up @@ -4820,7 +4827,9 @@ def roll_die_pipeline():
)

def test_if_else_consumed_and_returned(self):
"""Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, and dsl.OneOf passed to a consumer task and returned from the pipeline."""
"""Uses If, Elif, and Else branches, parameters passed to dsl.OneOf,
and dsl.OneOf passed to a consumer task and returned from the
pipeline."""

@dsl.pipeline
def flip_coin_pipeline() -> str:
Expand Down Expand Up @@ -4893,7 +4902,8 @@ def flip_coin_pipeline() -> str:
)

def test_if_else_consumed_and_returned_artifacts(self):
"""Uses If, Elif, and Else branches, artifacts passed to dsl.OneOf, and dsl.OneOf passed to a consumer task and returned from the pipeline."""
"""Uses If, Elif, and Else branches, artifacts passed to dsl.OneOf, and
dsl.OneOf passed to a consumer task and returned from the pipeline."""

@dsl.pipeline
def flip_coin_pipeline() -> Artifact:
Expand Down Expand Up @@ -5060,7 +5070,8 @@ def flip_coin_pipeline(execute_pipeline: bool):
print_task_2.outputs['a'])

def test_deeply_nested_consumed(self):
"""Uses If, Elif, Else, and OneOf deeply nested within multiple dub-DAGs."""
"""Uses If, Elif, Else, and OneOf deeply nested within multiple dub-
DAGs."""

@dsl.pipeline
def flip_coin_pipeline(execute_pipeline: bool):
Expand Down Expand Up @@ -5159,7 +5170,8 @@ def flip_coin_pipeline(execute_pipeline: bool):
print_task_2.outputs['a'])

def test_oneof_in_condition(self):
"""Tests that dsl.OneOf's channel can be consumed in a downstream group nested one level"""
"""Tests that dsl.OneOf's channel can be consumed in a downstream group
nested one level."""

@dsl.pipeline
def roll_die_pipeline(repeat_on: str = 'Got heads!'):
Expand Down Expand Up @@ -5212,7 +5224,8 @@ def roll_die_pipeline(repeat_on: str = 'Got heads!'):
)

def test_consumed_in_nested_groups(self):
"""Tests that dsl.OneOf's channel can be consumed in a downstream group nested multiple levels"""
"""Tests that dsl.OneOf's channel can be consumed in a downstream group
nested multiple levels."""

@dsl.pipeline
def roll_die_pipeline(
Expand Down
4 changes: 3 additions & 1 deletion sdk/python/kfp/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,9 @@ def get_dependencies(
# then make this validation dsl.Collected-aware
elif isinstance(upstream_parent_group, tasks_group.ParallelFor):
upstream_tasks_that_downstream_consumers_from = [
channel.task.name for channel in task._channel_inputs
channel.task.name
for channel in task._channel_inputs
if channel.task
]
has_data_exchange = upstream_task.name in upstream_tasks_that_downstream_consumers_from
# don't raise for .after
Expand Down

0 comments on commit e0fd54b

Please sign in to comment.