diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 94cba1426c..a7b35bc34c 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -4,6 +4,7 @@ import logging import math import os # TODO: use flytekit logger +import typing from contextlib import contextmanager from typing import Any, Dict, List, Optional, Set, Union, cast @@ -13,14 +14,18 @@ from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import transform_interface_to_list_interface from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask +from flytekit.core.type_engine import TypeEngine, is_annotated from flytekit.core.utils import timeit from flytekit.exceptions import scopes as exception_scopes from flytekit.loggers import logger +from flytekit.models import literals as _literal_models from flytekit.models.array_job import ArrayJob from flytekit.models.core.workflow import NodeMetadata from flytekit.models.interface import Variable from flytekit.models.task import Container, K8sPod, Sql, Task from flytekit.tools.module_loader import load_object_from_module +from flytekit.types.pickle import pickle +from flytekit.types.pickle.pickle import FlytePickleTransformer class ArrayNodeMapTask(PythonTask): @@ -57,6 +62,16 @@ def __init__( if not (isinstance(actual_task, PythonFunctionTask) or isinstance(actual_task, PythonInstanceTask)): raise ValueError("Only PythonFunctionTask and PythonInstanceTask are supported in map tasks.") + for k, v in actual_task.python_interface.inputs.items(): + if bound_inputs and k in bound_inputs: + continue + transformer = TypeEngine.get_transformer(v) + if isinstance(transformer, FlytePickleTransformer): + if is_annotated(v): + for annotation in typing.get_args(v)[1:]: + if isinstance(annotation, pickle.BatchSize): + raise ValueError("Choosing a BatchSize for map tasks inputs is not supported.") + n_outputs = len(actual_task.python_interface.outputs) if n_outputs > 1: raise ValueError("Only tasks with a single output are supported in map tasks.") @@ -208,24 +223,38 @@ def __call__(self, *args, **kwargs): kwargs = {**self._partial.keywords, **kwargs} return super().__call__(*args, **kwargs) + def _literal_map_to_python_input( + self, literal_map: _literal_models.LiteralMap, ctx: FlyteContext + ) -> Dict[str, Any]: + ctx = FlyteContextManager.current_context() + inputs_interface = self.python_interface.inputs + inputs_map = literal_map + # If we run locally, we will need to process all of the inputs. If we are running in a remote task execution + # environment, then we should process/download/extract only the inputs that are needed for the current task. + if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: + map_task_inputs = {} + task_index = self._compute_array_job_index() + inputs_interface = self._run_task.python_interface.inputs + for k in self.interface.inputs.keys(): + v = literal_map.literals[k] + + if k not in self.bound_inputs: + # assert that v.collection is not None + if not v.collection or not isinstance(v.collection.literals, list): + raise ValueError(f"Expected a list of literals for {k}") + map_task_inputs[k] = v.collection.literals[task_index] + else: + map_task_inputs[k] = v + inputs_map = _literal_models.LiteralMap(literals=map_task_inputs) + return TypeEngine.literal_map_to_kwargs(ctx, inputs_map, inputs_interface) + def execute(self, **kwargs) -> Any: ctx = FlyteContextManager.current_context() if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: - return self._execute_map_task(ctx, **kwargs) + return exception_scopes.user_entry_point(self.python_function_task.execute)(**kwargs) return self._raw_execute(**kwargs) - def _execute_map_task(self, _: FlyteContext, **kwargs) -> Any: - task_index = self._compute_array_job_index() - map_task_inputs = {} - for k in self.interface.inputs.keys(): - v = kwargs[k] - if isinstance(v, list) and k not in self.bound_inputs: - map_task_inputs[k] = v[task_index] - else: - map_task_inputs[k] = v - return exception_scopes.user_entry_point(self.python_function_task.execute)(**map_task_inputs) - @staticmethod def _compute_array_job_index() -> int: """ @@ -276,8 +305,8 @@ def _raw_execute(self, **kwargs) -> Any: outputs = [] mapped_tasks_count = 0 - if self._run_task.interface.inputs.items(): - for k in self._run_task.interface.inputs.keys(): + if self.python_function_task.interface.inputs.items(): + for k in self.python_function_task.interface.inputs.keys(): v = kwargs[k] if isinstance(v, list) and k not in self.bound_inputs: mapped_tasks_count = len(v) diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 5c84c60984..9b0144096e 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -7,9 +7,12 @@ from flytekit import map_task, task, workflow from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings +from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver from flytekit.core.task import TaskMetadata +from flytekit.core.type_engine import TypeEngine from flytekit.tools.translator import get_serializable +from flytekit.types.pickle import BatchSize @pytest.fixture @@ -54,6 +57,39 @@ def wf() -> List[str]: assert wf() == ["hello hello earth!!", "hello hello mars!!"] +def test_remote_execution(serialization_settings): + @task + def say_hello(name: str) -> str: + return f"hello {name}!" + + ctx = context_manager.FlyteContextManager.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + t = map_task(say_hello) + lm = TypeEngine.dict_to_literal_map(ctx, {"name": ["earth", "mars"]}, type_hints={"name": typing.List[str]}) + res = t.dispatch_execute(ctx, lm) + assert len(res.literals) == 1 + assert res.literals["o0"].scalar.primitive.string_value == "hello earth!" + + +def test_map_task_with_pickle(): + @task + def say_hello(name: typing.Annotated[typing.Any, BatchSize(10)]) -> str: + return f"hello {name}!" + + with pytest.raises(ValueError, match="Choosing a BatchSize for map tasks inputs is not supported."): + map_task(say_hello)(name=["abc", "def"]) + + @task + def say_hello(name: typing.Any) -> str: + return f"hello {name}!" + + map_task(say_hello)(name=["abc", "def"]) + + def test_serialization(serialization_settings): @task def t1(a: int) -> int: