diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index e5a8118923..4aceee472e 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -224,7 +224,7 @@ jobs: cache-to: type=gha,mode=max - name: Confirm Agent can start run: | - docker run --rm ghcr.io/${{ github.repository_owner }}/flyteagent:${{ github.sha }} pyflyte serve agent --port 8000 --timeout 1 + docker run --rm ghcr.io/${{ github.repository_owner }}/flyteagent-slim:${{ github.sha }} pyflyte serve agent --port 8000 --timeout 1 - name: Push flyteagent-all Image to GitHub Registry uses: docker/build-push-action@v2 with: diff --git a/flytekit/core/array_node.py b/flytekit/core/array_node.py index 0cb2c8d25c..466058a791 100644 --- a/flytekit/core/array_node.py +++ b/flytekit/core/array_node.py @@ -19,6 +19,7 @@ flyte_entity_call_handler, translate_inputs_to_literals, ) +from flytekit.core.task import ReferenceTask from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models @@ -34,8 +35,7 @@ class ArrayNode: def __init__( self, - target: Union[LaunchPlan, "FlyteLaunchPlan"], - execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE, + target: Union[LaunchPlan, ReferenceTask, "FlyteLaunchPlan"], bindings: Optional[List[_literal_models.Binding]] = None, concurrency: Optional[int] = None, min_successes: Optional[int] = None, @@ -51,17 +51,17 @@ def __init__( :param min_successes: The minimum number of successful executions. If set, this takes precedence over min_success_ratio :param min_success_ratio: The minimum ratio of successful executions. - :param execution_mode: The execution mode for propeller to use when handling ArrayNode :param metadata: The metadata for the underlying node """ from flytekit.remote import FlyteLaunchPlan self.target = target self._concurrency = concurrency - self._execution_mode = execution_mode self.id = target.name self._bindings = bindings or [] self.metadata = metadata + self._data_mode = None + self._execution_mode = None if min_successes is not None: self._min_successes = min_successes @@ -92,9 +92,12 @@ def __init__( else: raise ValueError("No interface found for the target entity.") - if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan): - if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE: - raise ValueError("Only execution version 1 is supported for LaunchPlans.") + if isinstance(target, (LaunchPlan, FlyteLaunchPlan)): + self._data_mode = _core_workflow.ArrayNode.SINGLE_INPUT_FILE + self._execution_mode = _core_workflow.ArrayNode.FULL_STATE + elif isinstance(target, ReferenceTask): + self._data_mode = _core_workflow.ArrayNode.INDIVIDUAL_INPUT_FILES + self._execution_mode = _core_workflow.ArrayNode.MINIMAL_STATE else: raise ValueError(f"Only LaunchPlans are supported for now, but got {type(target)}") @@ -133,6 +136,10 @@ def upstream_nodes(self) -> List[Node]: def flyte_entity(self) -> Any: return self.target + @property + def data_mode(self) -> _core_workflow.ArrayNode.DataMode: + return self._data_mode + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: if self._remote_interface: raise ValueError("Mapping over remote entities is not supported in local execution.") @@ -254,7 +261,7 @@ def __call__(self, *args, **kwargs): def array_node( - target: Union[LaunchPlan, "FlyteLaunchPlan"], + target: Union[LaunchPlan, ReferenceTask, "FlyteLaunchPlan"], concurrency: Optional[int] = None, min_success_ratio: Optional[float] = None, min_successes: Optional[int] = None, @@ -275,8 +282,8 @@ def array_node( """ from flytekit.remote import FlyteLaunchPlan - if not isinstance(target, LaunchPlan) and not isinstance(target, FlyteLaunchPlan): - raise ValueError("Only LaunchPlans are supported for now.") + if not isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)): + raise ValueError("Only LaunchPlans and ReferenceTasks are supported for now.") node = ArrayNode( target=target, diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 44458a53d2..78b9611651 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -18,6 +18,7 @@ from flytekit.core.interface import transform_interface_to_list_interface from flytekit.core.launch_plan import LaunchPlan from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask +from flytekit.core.task import ReferenceTask from flytekit.core.type_engine import TypeEngine from flytekit.core.utils import timeit from flytekit.loggers import logger @@ -390,7 +391,7 @@ def map_task( """ from flytekit.remote import FlyteLaunchPlan - if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan): + if isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)): return array_node( target=target, concurrency=concurrency, diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 59dfc91a94..1077858c27 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -718,7 +718,9 @@ def new_compilation_state(self, prefix: str = "") -> CompilationState: Creates and returns a default compilation state. For most of the code this should be the entrypoint of compilation, otherwise the code should always uses - with_compilation_state """ - return CompilationState(prefix=prefix) + from flytekit.core.python_auto_container import default_task_resolver + + return CompilationState(prefix=prefix, task_resolver=default_task_resolver) def new_execution_state(self, working_dir: Optional[os.PathLike] = None) -> ExecutionState: """ diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 7cd87e2a49..d6c7f93f99 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -1,9 +1,11 @@ from typing import Optional, Tuple from diskcache import Cache +from flyteidl.core.literals_pb2 import LiteralMap from flytekit import lazy_module -from flytekit.models.literals import Literal, LiteralCollection, LiteralMap +from flytekit.models.literals import Literal, LiteralCollection +from flytekit.models.literals import LiteralMap as ModelLiteralMap joblib = lazy_module("joblib") @@ -23,13 +25,16 @@ def _recursive_hash_placement(literal: Literal) -> Literal: literal_map = {} for key, literal_value in literal.map.literals.items(): literal_map[key] = _recursive_hash_placement(literal_value) - return Literal(map=LiteralMap(literal_map)) + return Literal(map=ModelLiteralMap(literal_map)) else: return literal def _calculate_cache_key( - task_name: str, cache_version: str, input_literal_map: LiteralMap, cache_ignore_input_vars: Tuple[str, ...] = () + task_name: str, + cache_version: str, + input_literal_map: ModelLiteralMap, + cache_ignore_input_vars: Tuple[str, ...] = (), ) -> str: # Traverse the literals and replace the literal with a new literal that only contains the hash literal_map_overridden = {} @@ -40,7 +45,7 @@ def _calculate_cache_key( # Generate a stable representation of the underlying protobuf by passing `deterministic=True` to the # protobuf library. - hashed_inputs = LiteralMap(literal_map_overridden).to_flyte_idl().SerializeToString(deterministic=True) + hashed_inputs = ModelLiteralMap(literal_map_overridden).to_flyte_idl().SerializeToString(deterministic=True) # Use joblib to hash the string representation of the literal into a fixed length string return f"{task_name}-{cache_version}-{joblib.hash(hashed_inputs)}" @@ -66,24 +71,47 @@ def clear(): @staticmethod def get( - task_name: str, cache_version: str, input_literal_map: LiteralMap, cache_ignore_input_vars: Tuple[str, ...] - ) -> Optional[LiteralMap]: + task_name: str, cache_version: str, input_literal_map: ModelLiteralMap, cache_ignore_input_vars: Tuple[str, ...] + ) -> Optional[ModelLiteralMap]: if not LocalTaskCache._initialized: LocalTaskCache.initialize() - return LocalTaskCache._cache.get( + serialized_obj = LocalTaskCache._cache.get( _calculate_cache_key(task_name, cache_version, input_literal_map, cache_ignore_input_vars) ) + if serialized_obj is None: + return None + + # If the serialized object is a model file, first convert it back to a proto object (which will force it to + # use the installed flyteidl proto messages) and then convert it to a model object. This will guarantee + # that the object is in the correct format. + if isinstance(serialized_obj, ModelLiteralMap): + return ModelLiteralMap.from_flyte_idl(ModelLiteralMap.to_flyte_idl(serialized_obj)) + elif isinstance(serialized_obj, bytes): + # If it is a bytes object, then it is a serialized proto object. + # We need to convert it to a model object first.o + pb_literal_map = LiteralMap() + pb_literal_map.ParseFromString(serialized_obj) + return ModelLiteralMap.from_flyte_idl(pb_literal_map) + else: + raise ValueError(f"Unexpected object type {type(serialized_obj)}") + @staticmethod def set( task_name: str, cache_version: str, - input_literal_map: LiteralMap, + input_literal_map: ModelLiteralMap, cache_ignore_input_vars: Tuple[str, ...], - value: LiteralMap, + value: ModelLiteralMap, ) -> None: if not LocalTaskCache._initialized: LocalTaskCache.initialize() LocalTaskCache._cache.set( - _calculate_cache_key(task_name, cache_version, input_literal_map, cache_ignore_input_vars), value + _calculate_cache_key( + task_name, + cache_version, + input_literal_map, + cache_ignore_input_vars, + ), + value.to_flyte_idl().SerializeToString(), ) diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index b0a6525ecd..4a14f8d3c5 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -344,7 +344,6 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, :param f: A task or any other callable :return: [name to use: str, module_name: str, function_name: str, full_path: str] """ - if isinstance(f, TrackedInstance): if hasattr(f, "task_function"): mod, mod_name, name = _task_module_from_callable(f.task_function) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index bb48cde73b..4e6535b492 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -696,9 +696,13 @@ def task_name(self, t: PythonAutoContainerTask) -> str: # type: ignore return f"{self.name}.{t.__module__}.{t.name}" def _validate_add_on_failure_handler(self, ctx: FlyteContext, prefix: str, wf_args: Dict[str, Promise]): - # Compare + resolver = ( + ctx.compilation_state.task_resolver + if ctx.compilation_state and ctx.compilation_state.task_resolver + else self + ) with FlyteContextManager.with_context( - ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=self)) + ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=resolver)) ) as inner_comp_ctx: # Now lets compile the failure-node if it exists if self.on_failure: @@ -736,9 +740,14 @@ def compile(self, **kwargs): ctx = FlyteContextManager.current_context() all_nodes = [] prefix = ctx.compilation_state.prefix if ctx.compilation_state is not None else "" + resolver = ( + ctx.compilation_state.task_resolver + if ctx.compilation_state and ctx.compilation_state.task_resolver + else self + ) with FlyteContextManager.with_context( - ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=self)) + ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=resolver)) ) as comp_ctx: # Construct the default input promise bindings, but then override with the provided inputs, if any input_kwargs = construct_input_promises([k for k in self.interface.inputs.keys()]) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 5919ab0611..ee97b9ddfb 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -25,6 +25,7 @@ from prometheus_client import Counter, Summary from flytekit import logger +from flytekit.bin.entrypoint import get_traceback_str from flytekit.exceptions.system import FlyteAgentNotFound from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods from flytekit.models.literals import LiteralMap @@ -63,7 +64,7 @@ def _handle_exception(e: Exception, context: grpc.ServicerContext, task_type: st context.set_details(error_message) request_failure_count.labels(task_type=task_type, operation=operation, error_code=HTTPStatus.NOT_FOUND).inc() else: - error_message = f"failed to {operation} {task_type} task with error: {e}." + error_message = f"failed to {operation} {task_type} task with error:\n {get_traceback_str(e)}." logger.error(error_message) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(error_message) diff --git a/flytekit/interactive/vscode_lib/decorator.py b/flytekit/interactive/vscode_lib/decorator.py index 2ed3406cb4..055bda6639 100644 --- a/flytekit/interactive/vscode_lib/decorator.py +++ b/flytekit/interactive/vscode_lib/decorator.py @@ -302,7 +302,7 @@ def prepare_resume_task_python(pid: int): def prepare_launch_json(): """ - Generate the launch.json for users to easily launch interactive debugging and task resumption. + Generate the launch.json and settings.json for users to easily launch interactive debugging and task resumption. """ task_function_source_dir = os.path.dirname( @@ -337,6 +337,10 @@ def prepare_launch_json(): with open(os.path.join(vscode_directory, "launch.json"), "w") as file: json.dump(launch_json, file, indent=4) + settings_json = {"python.defaultInterpreterPath": sys.executable} + with open(os.path.join(vscode_directory, "settings.json"), "w") as file: + json.dump(settings_json, file, indent=4) + VSCODE_TYPE_VALUE = "vscode" diff --git a/flytekit/loggers.py b/flytekit/loggers.py index 8c6e0de196..0224d177b4 100644 --- a/flytekit/loggers.py +++ b/flytekit/loggers.py @@ -177,7 +177,7 @@ def get_level_from_cli_verbosity(verbosity: int) -> int: :return: logging level """ if verbosity == 0: - return logging.CRITICAL + return _get_env_logging_level(default_level=logging.CRITICAL) elif verbosity == 1: return logging.WARNING elif verbosity == 2: diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 8d8bf9c9ef..f3fed3d4f3 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -390,6 +390,7 @@ def __init__( min_success_ratio=None, execution_mode=None, is_original_sub_node_interface=False, + data_mode=None, ) -> None: """ TODO: docstring @@ -401,6 +402,7 @@ def __init__( self._min_success_ratio = min_success_ratio self._execution_mode = execution_mode self._is_original_sub_node_interface = is_original_sub_node_interface + self._data_mode = data_mode @property def node(self) -> "Node": @@ -414,6 +416,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode: min_success_ratio=self._min_success_ratio, execution_mode=self._execution_mode, is_original_sub_node_interface=BoolValue(value=self._is_original_sub_node_interface), + data_mode=self._data_mode, ) @classmethod diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index d65ebfafae..a4b5a1d359 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -979,7 +979,15 @@ def offloaded_metadata(self) -> Optional[LiteralOffloadedMetadata]: """ This value holds metadata about the offloaded literal. """ - return self._offloaded_metadata + # The following check might seem non-sensical, since `_offloaded_metadata` is set in the constructor. + # This is here to support backwards compatibility caused by the local cache implementation. Let me explain. + # The local cache pickles values and unpickles them. When unpickling, the constructor is not called, so there + # are cases where the `_offloaded_metadata` is not set (for example if you cache a value using flytekit<=1.13.6 + # and you load that value later using flytekit>1.13.6). + # In other words, this is a workaround to support backwards compatibility with the local cache. + if hasattr(self, "_offloaded_metadata"): + return self._offloaded_metadata + return None def to_flyte_idl(self): """ diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 868f657610..85e4847bd8 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -597,6 +597,7 @@ def get_serializable_array_node( min_success_ratio=array_node.min_success_ratio, execution_mode=array_node.execution_mode, is_original_sub_node_interface=array_node.is_original_sub_node_interface, + data_mode=array_node.data_mode, ) diff --git a/plugins/flytekit-flyteinteractive/tests/test_flyteinteractive_vscode.py b/plugins/flytekit-flyteinteractive/tests/test_flyteinteractive_vscode.py index 96fa9261c2..7996a0da41 100644 --- a/plugins/flytekit-flyteinteractive/tests/test_flyteinteractive_vscode.py +++ b/plugins/flytekit-flyteinteractive/tests/test_flyteinteractive_vscode.py @@ -2,6 +2,9 @@ import mock import pytest + +from flytekit.core import context_manager +from flytekit.core.python_auto_container import default_task_resolver from flytekitplugins.flyteinteractive import ( CODE_TOGETHER_CONFIG, CODE_TOGETHER_EXTENSION, @@ -24,9 +27,9 @@ is_extension_installed, ) -from flytekit import task, workflow +from flytekit import task, workflow, dynamic from flytekit.configuration import Image, ImageConfig, SerializationSettings -from flytekit.core.context_manager import ExecutionState +from flytekit.core.context_manager import ExecutionState, FlyteContextManager from flytekit.tools.translator import get_serializable_task @@ -402,3 +405,35 @@ def test_get_installed_extensions_failed(mock_run): expected_extensions = [] assert installed_extensions == expected_extensions + + +def test_vscode_with_dynamic(vscode_patches): + ( + mock_process, + mock_prepare_interactive_python, + mock_exit_handler, + mock_download_vscode, + mock_signal, + mock_prepare_resume_task_python, + mock_prepare_launch_json, + ) = vscode_patches + + mock_exit_handler.return_value = None + + @task() + def train(): + print("forward") + print("backward") + + @dynamic() + @vscode + def d1(): + print("dynamic", flush=True) + train() + + ctx = FlyteContextManager.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) + ): + d1() + assert d1.task_resolver == default_task_resolver diff --git a/pyproject.toml b/pyproject.toml index 6b50981faf..e8320d30bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.13.9", + "flyteidl>=1.14.1", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 70be22527e..848dbbf6e1 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -477,10 +477,7 @@ def test_nested_workflow(working_dir, wf_path, monkeypatch: pytest.MonkeyPatch): ], catch_exceptions=False, ) - assert ( - result.stdout.strip() - == "Running Execution on local.\nRunning Execution on local." - ) + assert ("Running Execution on local." in result.stdout.strip()) assert result.exit_code == 0 @@ -853,7 +850,8 @@ def test_list_default_arguments(task_path): catch_exceptions=False, ) assert result.exit_code == 0 - assert result.stdout == "Running Execution on local.\n0 Hello Color.RED\n\n" + assert "Running Execution on local." in result.stdout + assert "Hello Color.RED" in result.stdout def test_entity_non_found_in_file(): diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py index cf3e90e338..0990541a84 100644 --- a/tests/flytekit/unit/core/test_local_cache.py +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -1,4 +1,6 @@ import datetime +import pathlib +import pickle import re import sys import typing @@ -627,3 +629,23 @@ def test_set_cache_ignore_input_vars_without_set_cache(): @task(cache_ignore_input_vars=["a"]) def add(a: int, b: int) -> int: return a + b + + +@pytest.mark.serial +def test_cache_old_version_of_literal_map(): + cache_key = "t.produce_dc-1-ea65cfadb0079394a8be1f4aa1e96e2b" + + # Load a literal map from a previous version of the cache from a local file + with open(pathlib.Path(__file__).parent / "testdata/pickled_value.bin", "rb") as f: + literal_map = pickle.loads(f.read()) + LocalTaskCache._cache.set(cache_key, literal_map) + + assert _calculate_cache_key("t.produce_dc", "1", LiteralMap(literals={})) == cache_key + + # Hit the cache directly and confirm that the loaded object does not have the `_offloaded_metadata` attribute + literal_map = LocalTaskCache._cache.get(cache_key) + assert hasattr(literal_map.literals['o0'], "_offloaded_metadata") is False + + # Now load the same object from the cache and confirm that the `_offloaded_metadata` attribute is now present + loaded_literal_map = LocalTaskCache.get("t.produce_dc", "1", LiteralMap(literals={}), ()) + assert hasattr(loaded_literal_map.literals['o0'], "_offloaded_metadata") is True diff --git a/tests/flytekit/unit/core/testdata/pickled_value.bin b/tests/flytekit/unit/core/testdata/pickled_value.bin new file mode 100644 index 0000000000..71a45e1909 Binary files /dev/null and b/tests/flytekit/unit/core/testdata/pickled_value.bin differ