From 1e775e863e09e7388f46585bc4d2f1f94395b3b3 Mon Sep 17 00:00:00 2001 From: redartera Date: Thu, 31 Oct 2024 13:56:01 +0000 Subject: [PATCH 01/12] enable interruptible override in FlyteRemote Signed-off-by: redartera --- flytekit/models/execution.py | 11 ++++ flytekit/remote/remote.py | 65 ++++++++++++++----- .../integration/remote/test_remote.py | 21 ++++++ 3 files changed, 81 insertions(+), 16 deletions(-) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 7e4ff02645..dc4801c5c9 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -10,6 +10,7 @@ import flyteidl.admin.execution_pb2 as _execution_pb2 import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 +from google.protobuf import wrappers_pb2 as _google_wrappers_pb2 import flytekit from flytekit.models import common as _common_models @@ -179,6 +180,7 @@ def __init__( max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, overwrite_cache: Optional[bool] = None, + interruptible: Optional[bool] = None, envs: Optional[_common_models.Envs] = None, tags: Optional[typing.List[str]] = None, cluster_assignment: Optional[ClusterAssignment] = None, @@ -198,6 +200,7 @@ def __init__( parallelism/concurrency of MapTasks is independent from this. :param security_context: Optional security context to use for this execution. :param overwrite_cache: Optional flag to overwrite the cache for this execution. + :param interruptible: Optional flag to interrupt the execution. :param envs: flytekit.models.common.Envs environment variables to set for this execution. :param tags: Optional list of tags to apply to the execution. :param execution_cluster_label: Optional execution cluster label to use for this execution. @@ -213,6 +216,7 @@ def __init__( self._max_parallelism = max_parallelism self._security_context = security_context self._overwrite_cache = overwrite_cache + self._interruptible = interruptible self._envs = envs self._tags = tags self._cluster_assignment = cluster_assignment @@ -287,6 +291,10 @@ def security_context(self) -> typing.Optional[security.SecurityContext]: def overwrite_cache(self) -> Optional[bool]: return self._overwrite_cache + @property + def interruptible(self) -> Optional[bool]: + return self._interruptible + @property def envs(self) -> Optional[_common_models.Envs]: return self._envs @@ -321,6 +329,8 @@ def to_flyte_idl(self): max_parallelism=self.max_parallelism, security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache, + # NOTE: 'interruptible' has to be a BoolValue, and BoolValue(value=None) works the same as passing None directly + interruptible=_google_wrappers_pb2.BoolValue(value=self.interruptible), envs=self.envs.to_flyte_idl() if self.envs else None, tags=self.tags, cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None, @@ -351,6 +361,7 @@ def from_flyte_idl(cls, p): if p.security_context else None, overwrite_cache=p.overwrite_cache, + interruptible=p.interruptible.value, # NOTE: p.interruptible is always a BoolValue no matter what. envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None, tags=p.tags, cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 7eda76ddfa..7c3f94fd6b 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -1290,6 +1290,7 @@ def _execute( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1308,6 +1309,7 @@ def _execute( :param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten for a single execution. If enabled, all calculations are performed even if cached results would be available, overwriting the stored data once execution finishes successfully. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to set for the execution. :param tags: Tags to set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -1379,6 +1381,7 @@ def _execute( 0, ), overwrite_cache=overwrite_cache, + interruptible=interruptible, notifications=notifications, disable_all=options.disable_notifications, labels=options.labels, @@ -1455,6 +1458,7 @@ def execute( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1495,6 +1499,7 @@ def execute( :param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten for a single execution. If enabled, all calculations are performed even if cached results would be available, overwriting the stored data once execution finishes successfully. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to be set for the execution. :param tags: Tags to be set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -1519,6 +1524,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1536,6 +1542,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1551,6 +1558,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1566,6 +1574,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1581,6 +1590,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1599,6 +1609,7 @@ def execute( image_config=image_config, wait=wait, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1618,6 +1629,7 @@ def execute( options=options, wait=wait, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1636,6 +1648,7 @@ def execute( options=options, wait=wait, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1658,6 +1671,7 @@ def execute_remote_task_lp( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1678,6 +1692,7 @@ def execute_remote_task_lp( options=options, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1696,6 +1711,7 @@ def execute_remote_wf( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1717,6 +1733,7 @@ def execute_remote_wf( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1735,6 +1752,7 @@ def execute_reference_task( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1766,6 +1784,7 @@ def execute_reference_task( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1782,6 +1801,7 @@ def execute_reference_workflow( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1827,6 +1847,7 @@ def execute_reference_workflow( options=options, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1843,6 +1864,7 @@ def execute_reference_launch_plan( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1874,6 +1896,7 @@ def execute_reference_launch_plan( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1896,6 +1919,7 @@ def execute_local_task( image_config: typing.Optional[ImageConfig] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1914,6 +1938,7 @@ def execute_local_task( :param image_config: If provided, will use this image config in the pod. :param wait: If True, will wait for the execution to complete before returning. :param overwrite_cache: If True, will overwrite the cache. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to set for the execution. :param tags: Tags to set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -1954,6 +1979,7 @@ def execute_local_task( wait=wait, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1974,6 +2000,7 @@ def execute_local_workflow( options: typing.Optional[Options] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1981,22 +2008,23 @@ def execute_local_workflow( ) -> FlyteWorkflowExecution: """ Execute an @workflow decorated function. - :param entity: - :param inputs: - :param project: - :param domain: - :param name: - :param version: - :param execution_name: - :param image_config: - :param options: - :param wait: - :param overwrite_cache: - :param envs: - :param tags: - :param cluster_pool: - :param execution_cluster_label: - :return: + :param entity: The workflow to execute + :param inputs: Input dictionary + :param project: Project to execute in + :param domain: Domain to execute in + :param name: Optional name override for the workflow + :param version: Optional version for the workflow + :param execution_name: Optional name for the execution + :param image_config: Optional image config override + :param options: Optional Options object + :param wait: Whether to wait for execution completion + :param overwrite_cache: If True, will overwrite the cache + :param interruptible: Optional flag to override the default interruptible flag of the executed entity + :param envs: Environment variables to set for the execution + :param tags: Tags to set for the execution + :param cluster_pool: Specify cluster pool on which newly created execution should be placed + :param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed + :return: FlyteWorkflowExecution object """ if not image_config: image_config = ImageConfig.auto_default_image() @@ -2052,6 +2080,7 @@ def execute_local_workflow( options=options, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -2071,12 +2100,14 @@ def execute_local_launch_plan( options: typing.Optional[Options] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ + Execute a locally defined `LaunchPlan`. :param entity: The locally defined launch plan object :param inputs: Inputs to be passed into the execution as a dict with Python native values. @@ -2088,6 +2119,7 @@ def execute_local_launch_plan( :param options: Options to be passed into the execution. :param wait: If True, will wait for the execution to complete before returning. :param overwrite_cache: If True, will overwrite the cache. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to be passed into the execution. :param tags: Tags to be passed into the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -2119,6 +2151,7 @@ def execute_local_launch_plan( wait=wait, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 4d77e1b610..3590170ef0 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -528,6 +528,27 @@ def test_execute_workflow_with_maptask(register): ) assert execution.outputs["o0"] == [4, 5, 6] +def test_executes_nested_workflow_dictating_interruptible(register): + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + flyte_launch_plan = remote.fetch_launch_plan(name="basic.child_workflow.parent_wf", version=VERSION) + # Set the interruptible value at launch plan creation time + creation_interruptible_values = [True, False, None] + # The interruptible values we expect to see in the materialized executions + expected_interruptible_values = [True, False, False] + executions = [] + for creation_interruptible in creation_interruptible_values: + execution = remote.execute(flyte_launch_plan, inputs={"a": 10}, wait=False, interruptible=creation_interruptible) + executions.append(execution) + # Wait for all executions to complete + for execution, expected_interruptible in zip(executions, expected_interruptible_values): + execution = remote.wait(execution, timeout=300) + # Check that the parent workflow is interruptible as expected + assert execution.spec.interruptible == expected_interruptible + # Check that the child workflow is interruptible as expected + subwf_execution_id = execution.node_executions["n1"].closure.workflow_node_metadata.execution_id.name + subwf_execution = remote.fetch_execution(project=PROJECT, domain=DOMAIN, name=subwf_execution_id) + assert subwf_execution.spec.interruptible == expected_interruptible + @pytest.mark.lftransfers class TestLargeFileTransfers: From 5f17a303675b43f7ab7bb93aae26da9023dafe17 Mon Sep 17 00:00:00 2001 From: redartera Date: Thu, 31 Oct 2024 14:00:07 +0000 Subject: [PATCH 02/12] pass None explicitly if applicable for interruptible Signed-off-by: redartera --- flytekit/models/execution.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index dc4801c5c9..5bd0c71083 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -329,8 +329,8 @@ def to_flyte_idl(self): max_parallelism=self.max_parallelism, security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache, - # NOTE: 'interruptible' has to be a BoolValue, and BoolValue(value=None) works the same as passing None directly - interruptible=_google_wrappers_pb2.BoolValue(value=self.interruptible), + # NOTE: Pass BoolValue ONLY if interruptible is not None - otherwise this overrides the default value + interruptible=None if self.interruptible is None else _google_wrappers_pb2.BoolValue(value=self.interruptible), envs=self.envs.to_flyte_idl() if self.envs else None, tags=self.tags, cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None, @@ -361,7 +361,7 @@ def from_flyte_idl(cls, p): if p.security_context else None, overwrite_cache=p.overwrite_cache, - interruptible=p.interruptible.value, # NOTE: p.interruptible is always a BoolValue no matter what. + interruptible=p.interruptible.value, # NOTE: p.interruptible is always a BoolValue here no matter what - cannot make distinction between None and False envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None, tags=p.tags, cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment) From 4e84778cad4fd26a643aee9c80a6662a2257b8de Mon Sep 17 00:00:00 2001 From: redartera Date: Thu, 31 Oct 2024 14:00:29 +0000 Subject: [PATCH 03/12] format Signed-off-by: redartera --- flytekit/models/execution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 5bd0c71083..72b0da99db 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -330,7 +330,9 @@ def to_flyte_idl(self): security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache, # NOTE: Pass BoolValue ONLY if interruptible is not None - otherwise this overrides the default value - interruptible=None if self.interruptible is None else _google_wrappers_pb2.BoolValue(value=self.interruptible), + interruptible=None + if self.interruptible is None + else _google_wrappers_pb2.BoolValue(value=self.interruptible), envs=self.envs.to_flyte_idl() if self.envs else None, tags=self.tags, cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None, From 378552349b3714bcfa8773ba1e2a99262e7b5906 Mon Sep 17 00:00:00 2001 From: redartera Date: Thu, 31 Oct 2024 14:06:59 +0000 Subject: [PATCH 04/12] minor docstring nit Signed-off-by: redartera --- flytekit/models/execution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 72b0da99db..bbc5988932 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -200,7 +200,7 @@ def __init__( parallelism/concurrency of MapTasks is independent from this. :param security_context: Optional security context to use for this execution. :param overwrite_cache: Optional flag to overwrite the cache for this execution. - :param interruptible: Optional flag to interrupt the execution. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: flytekit.models.common.Envs environment variables to set for this execution. :param tags: Optional list of tags to apply to the execution. :param execution_cluster_label: Optional execution cluster label to use for this execution. From e161429d00101ccd92fc9e9aa40b10fa67244005 Mon Sep 17 00:00:00 2001 From: redartera Date: Thu, 31 Oct 2024 15:07:28 +0000 Subject: [PATCH 05/12] fix interruptible flag deserialization Signed-off-by: redartera --- flytekit/models/execution.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index bbc5988932..f645df8f9d 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -329,10 +329,9 @@ def to_flyte_idl(self): max_parallelism=self.max_parallelism, security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache, - # NOTE: Pass BoolValue ONLY if interruptible is not None - otherwise this overrides the default value - interruptible=None - if self.interruptible is None - else _google_wrappers_pb2.BoolValue(value=self.interruptible), + interruptible=_google_wrappers_pb2.BoolValue(value=self.interruptible) + if self.interruptible is not None + else None, envs=self.envs.to_flyte_idl() if self.envs else None, tags=self.tags, cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None, @@ -363,7 +362,7 @@ def from_flyte_idl(cls, p): if p.security_context else None, overwrite_cache=p.overwrite_cache, - interruptible=p.interruptible.value, # NOTE: p.interruptible is always a BoolValue here no matter what - cannot make distinction between None and False + interruptible=p.interruptible.value if p.HasField("interruptible") else None, envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None, tags=p.tags, cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment) From b27ca6a1f713751d0927d878eb2f171b100ca79b Mon Sep 17 00:00:00 2001 From: redartera Date: Thu, 31 Oct 2024 15:11:37 +0000 Subject: [PATCH 06/12] adjust test to handle None Signed-off-by: redartera --- tests/flytekit/integration/remote/test_remote.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 3590170ef0..d24c1ffbb3 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -531,16 +531,14 @@ def test_execute_workflow_with_maptask(register): def test_executes_nested_workflow_dictating_interruptible(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) flyte_launch_plan = remote.fetch_launch_plan(name="basic.child_workflow.parent_wf", version=VERSION) - # Set the interruptible value at launch plan creation time - creation_interruptible_values = [True, False, None] - # The interruptible values we expect to see in the materialized executions - expected_interruptible_values = [True, False, False] + # The values we want to test for + interruptible_values = [True, False, None] executions = [] - for creation_interruptible in creation_interruptible_values: + for creation_interruptible in interruptible_values: execution = remote.execute(flyte_launch_plan, inputs={"a": 10}, wait=False, interruptible=creation_interruptible) executions.append(execution) # Wait for all executions to complete - for execution, expected_interruptible in zip(executions, expected_interruptible_values): + for execution, expected_interruptible in zip(executions, interruptible_values): execution = remote.wait(execution, timeout=300) # Check that the parent workflow is interruptible as expected assert execution.spec.interruptible == expected_interruptible From e82b2ce50dd668f57d0d67fd5edea26ee09c9c96 Mon Sep 17 00:00:00 2001 From: redartera <120470035+redartera@users.noreply.github.com> Date: Thu, 21 Nov 2024 08:53:43 -0500 Subject: [PATCH 07/12] Update flytekit/remote/remote.py Signed-off-by: redartera --- flytekit/remote/remote.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 7c3f94fd6b..2d53e4dd4c 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -2008,6 +2008,7 @@ def execute_local_workflow( ) -> FlyteWorkflowExecution: """ Execute an @workflow decorated function. + :param entity: The workflow to execute :param inputs: Input dictionary :param project: Project to execute in From a0d01665ef03f3543936c977dc240943fe90f04a Mon Sep 17 00:00:00 2001 From: redartera Date: Thu, 21 Nov 2024 13:59:58 +0000 Subject: [PATCH 08/12] remove trailing whitespaces Signed-off-by: redartera --- flytekit/remote/remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 2d53e4dd4c..8f5f8861ec 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -2008,7 +2008,7 @@ def execute_local_workflow( ) -> FlyteWorkflowExecution: """ Execute an @workflow decorated function. - + :param entity: The workflow to execute :param inputs: Input dictionary :param project: Project to execute in From 4a73d8c26358be09e578a8f0dfbfb40097a3cf75 Mon Sep 17 00:00:00 2001 From: redartera <120470035+redartera@users.noreply.github.com> Date: Thu, 21 Nov 2024 20:17:32 +0000 Subject: [PATCH 09/12] retrigger checks Signed-off-by: redartera <120470035+redartera@users.noreply.github.com> From bcbb9915b6f995c1210de1291dd36c99694e91a4 Mon Sep 17 00:00:00 2001 From: redartera <120470035+redartera@users.noreply.github.com> Date: Thu, 23 Jan 2025 19:51:36 +0000 Subject: [PATCH 10/12] pass interruptible as a bool instead of grpc wrapper Signed-off-by: redartera <120470035+redartera@users.noreply.github.com> --- flytekit/models/execution.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index f645df8f9d..4f2e684227 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -10,7 +10,6 @@ import flyteidl.admin.execution_pb2 as _execution_pb2 import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 -from google.protobuf import wrappers_pb2 as _google_wrappers_pb2 import flytekit from flytekit.models import common as _common_models @@ -329,9 +328,7 @@ def to_flyte_idl(self): max_parallelism=self.max_parallelism, security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache, - interruptible=_google_wrappers_pb2.BoolValue(value=self.interruptible) - if self.interruptible is not None - else None, + interruptible=self.interruptible if self.interruptible is not None else None, envs=self.envs.to_flyte_idl() if self.envs else None, tags=self.tags, cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None, From 73cae9e6bc8eb8f77535b2a2bc2abfe02e87c215 Mon Sep 17 00:00:00 2001 From: redartera <120470035+redartera@users.noreply.github.com> Date: Thu, 23 Jan 2025 19:51:48 +0000 Subject: [PATCH 11/12] add interruptible checks to models unit test Signed-off-by: redartera <120470035+redartera@users.noreply.github.com> --- tests/flytekit/unit/models/test_execution.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/flytekit/unit/models/test_execution.py b/tests/flytekit/unit/models/test_execution.py index fec2b5cfbb..8e1dfa749a 100644 --- a/tests/flytekit/unit/models/test_execution.py +++ b/tests/flytekit/unit/models/test_execution.py @@ -166,6 +166,7 @@ def test_execution_spec(literal_value_pair): ), raw_output_data_config=_common_models.RawOutputDataConfig(output_location_prefix="raw_output"), max_parallelism=100, + interruptible=True ) assert obj.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN assert obj.launch_plan.domain == "domain" @@ -183,6 +184,7 @@ def test_execution_spec(literal_value_pair): ] assert obj.disable_all is None assert obj.max_parallelism == 100 + assert obj.interruptible == True assert obj.raw_output_data_config.output_location_prefix == "raw_output" obj2 = _execution.ExecutionSpec.from_flyte_idl(obj.to_flyte_idl()) @@ -203,6 +205,7 @@ def test_execution_spec(literal_value_pair): ] assert obj2.disable_all is None assert obj2.max_parallelism == 100 + assert obj2.interruptible == True assert obj2.raw_output_data_config.output_location_prefix == "raw_output" obj = _execution.ExecutionSpec( @@ -220,6 +223,7 @@ def test_execution_spec(literal_value_pair): assert obj.metadata.principal == "tester" assert obj.notifications is None assert obj.disable_all is True + assert obj.interruptible is None obj2 = _execution.ExecutionSpec.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 @@ -233,6 +237,7 @@ def test_execution_spec(literal_value_pair): assert obj2.metadata.principal == "tester" assert obj2.notifications is None assert obj2.disable_all is True + assert obj2.interruptible is None def test_workflow_execution_data_response(): From e2083daaeda16d066f88c9afc27267a9cb5d8809 Mon Sep 17 00:00:00 2001 From: redartera <120470035+redartera@users.noreply.github.com> Date: Tue, 28 Jan 2025 14:15:26 +0000 Subject: [PATCH 12/12] Use a boolvalue wrapper for interruptible Signed-off-by: redartera <120470035+redartera@users.noreply.github.com> --- flytekit/models/execution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 4f2e684227..0019e4d79b 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -10,6 +10,7 @@ import flyteidl.admin.execution_pb2 as _execution_pb2 import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 +from google.protobuf.wrappers_pb2 import BoolValue import flytekit from flytekit.models import common as _common_models @@ -328,7 +329,7 @@ def to_flyte_idl(self): max_parallelism=self.max_parallelism, security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache, - interruptible=self.interruptible if self.interruptible is not None else None, + interruptible=BoolValue(value=self.interruptible) if self.interruptible is not None else None, envs=self.envs.to_flyte_idl() if self.envs else None, tags=self.tags, cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None,