Skip to content

Commit

Permalink
Merge branch 'master' into fix-pydantic-default-input
Browse files Browse the repository at this point in the history
  • Loading branch information
Future-Outlier committed Jan 2, 2025
2 parents 0b5ac78 + c95cc63 commit e35e334
Show file tree
Hide file tree
Showing 26 changed files with 242 additions and 75 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pythonpublish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ jobs:
cache-to: type=gha,mode=max
- name: Confirm Agent can start
run: |
docker run --rm ghcr.io/${{ github.repository_owner }}/flyteagent:${{ github.sha }} pyflyte serve agent --port 8000 --timeout 1
docker run --rm ghcr.io/${{ github.repository_owner }}/flyteagent-slim:${{ github.sha }} pyflyte serve agent --port 8000 --timeout 1
- name: Push flyteagent-all Image to GitHub Registry
uses: docker/build-push-action@v2
with:
Expand Down
20 changes: 12 additions & 8 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ adlfs==2024.4.1
# via flytekit
aiobotocore==2.13.0
# via s3fs
aiohttp==3.9.5
aiohappyeyeballs==2.4.4
# via aiohttp
aiohttp==3.10.11
# via
# adlfs
# aiobotocore
Expand Down Expand Up @@ -113,10 +115,8 @@ filelock==3.14.0
# via
# snowflake-connector-python
# virtualenv
flyteidl @ git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl
# via
# -r dev-requirements.in
# flytekit
flyteidl==1.14.1
# via flytekit
frozenlist==1.4.1
# via
# aiohttp
Expand Down Expand Up @@ -244,7 +244,9 @@ keyring==25.2.1
keyrings-alt==5.0.1
# via -r dev-requirements.in
kubernetes==29.0.0
# via -r dev-requirements.in
# via
# -r dev-requirements.in
# flytekit
markdown-it-py==3.0.0
# via
# flytekit
Expand All @@ -260,7 +262,7 @@ marshmallow-enum==1.5.1
# flytekit
marshmallow-jsonschema==0.13.0
# via flytekit
mashumaro==3.13
mashumaro==3.15
# via flytekit
matplotlib-inline==0.1.7
# via
Expand Down Expand Up @@ -345,6 +347,8 @@ prometheus-client==0.20.0
# via -r dev-requirements.in
prompt-toolkit==3.0.45
# via ipython
propcache==0.2.1
# via yarl
proto-plus==1.23.0
# via
# google-api-core
Expand Down Expand Up @@ -557,7 +561,7 @@ websocket-client==1.8.0
# kubernetes
wrapt==1.16.0
# via aiobotocore
yarl==1.9.4
yarl==1.18.3
# via aiohttp
zipp==3.19.1
# via importlib-metadata
Expand Down
2 changes: 1 addition & 1 deletion flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ def _create_command(
h = h + click.style(f" (LP Name: {loaded_entity.name})", fg="yellow")
else:
if loaded_entity.__doc__:
h = h + click.style(f"{loaded_entity.__doc__}", dim=True)
h = h + click.style(f" {loaded_entity.__doc__}", dim=True)
cmd = YamlFileReadingCommand(
name=entity_name,
params=params,
Expand Down
27 changes: 17 additions & 10 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
flyte_entity_call_handler,
translate_inputs_to_literals,
)
from flytekit.core.task import ReferenceTask
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
Expand All @@ -34,8 +35,7 @@
class ArrayNode:
def __init__(
self,
target: Union[LaunchPlan, "FlyteLaunchPlan"],
execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE,
target: Union[LaunchPlan, ReferenceTask, "FlyteLaunchPlan"],
bindings: Optional[List[_literal_models.Binding]] = None,
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
Expand All @@ -51,17 +51,17 @@ def __init__(
:param min_successes: The minimum number of successful executions. If set, this takes precedence over
min_success_ratio
:param min_success_ratio: The minimum ratio of successful executions.
:param execution_mode: The execution mode for propeller to use when handling ArrayNode
:param metadata: The metadata for the underlying node
"""
from flytekit.remote import FlyteLaunchPlan

self.target = target
self._concurrency = concurrency
self._execution_mode = execution_mode
self.id = target.name
self._bindings = bindings or []
self.metadata = metadata
self._data_mode = None
self._execution_mode = None

if min_successes is not None:
self._min_successes = min_successes
Expand Down Expand Up @@ -92,9 +92,12 @@ def __init__(
else:
raise ValueError("No interface found for the target entity.")

if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE:
raise ValueError("Only execution version 1 is supported for LaunchPlans.")
if isinstance(target, (LaunchPlan, FlyteLaunchPlan)):
self._data_mode = _core_workflow.ArrayNode.SINGLE_INPUT_FILE
self._execution_mode = _core_workflow.ArrayNode.FULL_STATE
elif isinstance(target, ReferenceTask):
self._data_mode = _core_workflow.ArrayNode.INDIVIDUAL_INPUT_FILES
self._execution_mode = _core_workflow.ArrayNode.MINIMAL_STATE
else:
raise ValueError(f"Only LaunchPlans are supported for now, but got {type(target)}")

Expand Down Expand Up @@ -133,6 +136,10 @@ def upstream_nodes(self) -> List[Node]:
def flyte_entity(self) -> Any:
return self.target

@property
def data_mode(self) -> _core_workflow.ArrayNode.DataMode:
return self._data_mode

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
if self._remote_interface:
raise ValueError("Mapping over remote entities is not supported in local execution.")
Expand Down Expand Up @@ -254,7 +261,7 @@ def __call__(self, *args, **kwargs):


def array_node(
target: Union[LaunchPlan, "FlyteLaunchPlan"],
target: Union[LaunchPlan, ReferenceTask, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
min_successes: Optional[int] = None,
Expand All @@ -275,8 +282,8 @@ def array_node(
"""
from flytekit.remote import FlyteLaunchPlan

if not isinstance(target, LaunchPlan) and not isinstance(target, FlyteLaunchPlan):
raise ValueError("Only LaunchPlans are supported for now.")
if not isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)):
raise ValueError("Only LaunchPlans and ReferenceTasks are supported for now.")

node = ArrayNode(
target=target,
Expand Down
3 changes: 2 additions & 1 deletion flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.task import ReferenceTask
from flytekit.core.type_engine import TypeEngine
from flytekit.core.utils import timeit
from flytekit.loggers import logger
Expand Down Expand Up @@ -390,7 +391,7 @@ def map_task(
"""
from flytekit.remote import FlyteLaunchPlan

if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
if isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)):
return array_node(
target=target,
concurrency=concurrency,
Expand Down
4 changes: 3 additions & 1 deletion flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,9 @@ def new_compilation_state(self, prefix: str = "") -> CompilationState:
Creates and returns a default compilation state. For most of the code this should be the entrypoint
of compilation, otherwise the code should always uses - with_compilation_state
"""
return CompilationState(prefix=prefix)
from flytekit.core.python_auto_container import default_task_resolver

return CompilationState(prefix=prefix, task_resolver=default_task_resolver)

def new_execution_state(self, working_dir: Optional[os.PathLike] = None) -> ExecutionState:
"""
Expand Down
48 changes: 38 additions & 10 deletions flytekit/core/local_cache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional, Tuple

from diskcache import Cache
from flyteidl.core.literals_pb2 import LiteralMap

from flytekit import lazy_module
from flytekit.models.literals import Literal, LiteralCollection, LiteralMap
from flytekit.models.literals import Literal, LiteralCollection
from flytekit.models.literals import LiteralMap as ModelLiteralMap

joblib = lazy_module("joblib")

Expand All @@ -23,13 +25,16 @@ def _recursive_hash_placement(literal: Literal) -> Literal:
literal_map = {}
for key, literal_value in literal.map.literals.items():
literal_map[key] = _recursive_hash_placement(literal_value)
return Literal(map=LiteralMap(literal_map))
return Literal(map=ModelLiteralMap(literal_map))
else:
return literal


def _calculate_cache_key(
task_name: str, cache_version: str, input_literal_map: LiteralMap, cache_ignore_input_vars: Tuple[str, ...] = ()
task_name: str,
cache_version: str,
input_literal_map: ModelLiteralMap,
cache_ignore_input_vars: Tuple[str, ...] = (),
) -> str:
# Traverse the literals and replace the literal with a new literal that only contains the hash
literal_map_overridden = {}
Expand All @@ -40,7 +45,7 @@ def _calculate_cache_key(

# Generate a stable representation of the underlying protobuf by passing `deterministic=True` to the
# protobuf library.
hashed_inputs = LiteralMap(literal_map_overridden).to_flyte_idl().SerializeToString(deterministic=True)
hashed_inputs = ModelLiteralMap(literal_map_overridden).to_flyte_idl().SerializeToString(deterministic=True)
# Use joblib to hash the string representation of the literal into a fixed length string
return f"{task_name}-{cache_version}-{joblib.hash(hashed_inputs)}"

Expand All @@ -66,24 +71,47 @@ def clear():

@staticmethod
def get(
task_name: str, cache_version: str, input_literal_map: LiteralMap, cache_ignore_input_vars: Tuple[str, ...]
) -> Optional[LiteralMap]:
task_name: str, cache_version: str, input_literal_map: ModelLiteralMap, cache_ignore_input_vars: Tuple[str, ...]
) -> Optional[ModelLiteralMap]:
if not LocalTaskCache._initialized:
LocalTaskCache.initialize()
return LocalTaskCache._cache.get(
serialized_obj = LocalTaskCache._cache.get(
_calculate_cache_key(task_name, cache_version, input_literal_map, cache_ignore_input_vars)
)

if serialized_obj is None:
return None

# If the serialized object is a model file, first convert it back to a proto object (which will force it to
# use the installed flyteidl proto messages) and then convert it to a model object. This will guarantee
# that the object is in the correct format.
if isinstance(serialized_obj, ModelLiteralMap):
return ModelLiteralMap.from_flyte_idl(ModelLiteralMap.to_flyte_idl(serialized_obj))
elif isinstance(serialized_obj, bytes):
# If it is a bytes object, then it is a serialized proto object.
# We need to convert it to a model object first.o
pb_literal_map = LiteralMap()
pb_literal_map.ParseFromString(serialized_obj)
return ModelLiteralMap.from_flyte_idl(pb_literal_map)
else:
raise ValueError(f"Unexpected object type {type(serialized_obj)}")

@staticmethod
def set(
task_name: str,
cache_version: str,
input_literal_map: LiteralMap,
input_literal_map: ModelLiteralMap,
cache_ignore_input_vars: Tuple[str, ...],
value: LiteralMap,
value: ModelLiteralMap,
) -> None:
if not LocalTaskCache._initialized:
LocalTaskCache.initialize()
LocalTaskCache._cache.set(
_calculate_cache_key(task_name, cache_version, input_literal_map, cache_ignore_input_vars), value
_calculate_cache_key(
task_name,
cache_version,
input_literal_map,
cache_ignore_input_vars,
),
value.to_flyte_idl().SerializeToString(),
)
1 change: 0 additions & 1 deletion flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str,
:param f: A task or any other callable
:return: [name to use: str, module_name: str, function_name: str, full_path: str]
"""

if isinstance(f, TrackedInstance):
if hasattr(f, "task_function"):
mod, mod_name, name = _task_module_from_callable(f.task_function)
Expand Down
15 changes: 12 additions & 3 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,13 @@ def task_name(self, t: PythonAutoContainerTask) -> str: # type: ignore
return f"{self.name}.{t.__module__}.{t.name}"

def _validate_add_on_failure_handler(self, ctx: FlyteContext, prefix: str, wf_args: Dict[str, Promise]):
# Compare
resolver = (
ctx.compilation_state.task_resolver
if ctx.compilation_state and ctx.compilation_state.task_resolver
else self
)
with FlyteContextManager.with_context(
ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=self))
ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=resolver))
) as inner_comp_ctx:
# Now lets compile the failure-node if it exists
if self.on_failure:
Expand Down Expand Up @@ -736,9 +740,14 @@ def compile(self, **kwargs):
ctx = FlyteContextManager.current_context()
all_nodes = []
prefix = ctx.compilation_state.prefix if ctx.compilation_state is not None else ""
resolver = (
ctx.compilation_state.task_resolver
if ctx.compilation_state and ctx.compilation_state.task_resolver
else self
)

with FlyteContextManager.with_context(
ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=self))
ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=resolver))
) as comp_ctx:
# Construct the default input promise bindings, but then override with the provided inputs, if any
input_kwargs = construct_input_promises([k for k in self.interface.inputs.keys()])
Expand Down
7 changes: 4 additions & 3 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from prometheus_client import Counter, Summary

from flytekit import logger
from flytekit.bin.entrypoint import get_traceback_str
from flytekit.exceptions.system import FlyteAgentNotFound
from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods
from flytekit.models.literals import LiteralMap
Expand Down Expand Up @@ -63,7 +64,7 @@ def _handle_exception(e: Exception, context: grpc.ServicerContext, task_type: st
context.set_details(error_message)
request_failure_count.labels(task_type=task_type, operation=operation, error_code=HTTPStatus.NOT_FOUND).inc()
else:
error_message = f"failed to {operation} {task_type} task with error: {e}."
error_message = f"failed to {operation} {task_type} task with error:\n {get_traceback_str(e)}."
logger.error(error_message)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(error_message)
Expand Down Expand Up @@ -116,14 +117,14 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon
task_execution_metadata = TaskExecutionMetadata.from_flyte_idl(request.task_execution_metadata)

logger.info(f"{agent.name} start creating the job")
resource_mata = await mirror_async_methods(
resource_meta = await mirror_async_methods(
agent.create,
task_template=template,
inputs=inputs,
output_prefix=request.output_prefix,
task_execution_metadata=task_execution_metadata,
)
return CreateTaskResponse(resource_meta=resource_mata.encode())
return CreateTaskResponse(resource_meta=resource_meta.encode())

@record_agent_metrics
async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
Expand Down
4 changes: 2 additions & 2 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,10 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap:
task_template = get_serializable(OrderedDict(), ss, self).template
self._agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version)

resource_mata = asyncio.run(
resource_meta = asyncio.run(
self._create(task_template=task_template, output_prefix=output_prefix, inputs=kwargs)
)
resource = asyncio.run(self._get(resource_meta=resource_mata))
resource = asyncio.run(self._get(resource_meta=resource_meta))

if resource.phase != TaskExecution.SUCCEEDED:
raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}")
Expand Down
6 changes: 5 additions & 1 deletion flytekit/interactive/vscode_lib/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def prepare_resume_task_python(pid: int):

def prepare_launch_json():
"""
Generate the launch.json for users to easily launch interactive debugging and task resumption.
Generate the launch.json and settings.json for users to easily launch interactive debugging and task resumption.
"""

task_function_source_dir = os.path.dirname(
Expand Down Expand Up @@ -337,6 +337,10 @@ def prepare_launch_json():
with open(os.path.join(vscode_directory, "launch.json"), "w") as file:
json.dump(launch_json, file, indent=4)

settings_json = {"python.defaultInterpreterPath": sys.executable}
with open(os.path.join(vscode_directory, "settings.json"), "w") as file:
json.dump(settings_json, file, indent=4)


VSCODE_TYPE_VALUE = "vscode"

Expand Down
Loading

0 comments on commit e35e334

Please sign in to comment.