From 4c481619942a81d0c9c2517b45929e615a57a7ca Mon Sep 17 00:00:00 2001 From: jernejfrank <50105951+jernejfrank@users.noreply.github.com> Date: Thu, 15 Aug 2024 14:40:42 -0700 Subject: [PATCH] Refactor to enable RayGraphAdapter and HamiltonTracker to work well together This is a squash commit: - issue=#1079 - PR=#1103 Describes what to do in `graph_functions.py` Adds comments to lifecycle base Update h_ray.py with comments for ray tracking compatibility Replicate previous error Inline function, unsure if catching errors and exceptions to be handadled differently BaseDoRemoteExecute has the added Callable function that snadwisched lifecycle hooks method fails, says AssertionError about ray.remote decorator simple script for now to check telemetry, execution yield the ray.remote AssertionError passing pointer through and arguments to lifecycle wrapper into ray.remote post-execute hook for node not called finally executed only when exception occurs, hamilton tracker not executed atexit.register does not work, node keeps running inui added stop() method, but doesn't get called Ray telemtry works for single node, problem with connected nodes Ray telemtry works for single node, problem with connected nodes Ray telemtry works for single node, problem with connected nodes Fixes ray object dereferencing Ray does not resolve nested arguments: https://docs.ray.io/en/latest/ray-core/objects.html#passing-object-arguments So one option is to make them all top level: - one way to do that is to make the other arguments not clash with any possible user parameters -- hence the `__` prefix. This is what I did. - another way would be in the ray adapter, wrap the incoming function, and explicitly do a ray.get() on any ray object references in the kwargs arguments. i.e. keep the nested structure, but when the ray task starts way for all inputs... not sure which is best, but this now works correctly. ray works checkpoint, pre-commit fixed fixed graph level telemtry proposal pinned ruff Correct output, added option to start ray cluster Unit test mimicks the DoNodeExecute unit test Refactored driver so all tests pass Workaround to not break ray by calling init on an open cluster raw_execute does not have post_graph_execute and is private now Correct version for depraction warning all tests work this looks better ruff version comment Refactored pre- and post-graph-execute hooks outside of raw_execute which now has deprecation warning added readme, notebook and made script cli interactive made cluster init optional through inserting config dict User has option to shutdown ray cluster Co-authored-by: Stefan Krawczyk --- .pre-commit-config.yaml | 2 +- examples/ray/ray_Hamilton_UI_tracking/README | 29 ++++ .../hamilton_notebook.ipynb | 107 ++++++++++++ .../ray_Hamilton_UI_tracking/ray_lineage.py | 18 +++ .../ray_Hamilton_UI_tracking/requirements.txt | 1 + .../ray_Hamilton_UI_tracking/run_lineage.py | 45 ++++++ hamilton/dev_utils/deprecation.py | 6 +- hamilton/driver.py | 140 ++++++++++++++-- hamilton/execution/graph_functions.py | 152 ++++++++++++------ hamilton/lifecycle/base.py | 20 +++ hamilton/plugins/h_ray.py | 56 +++++-- pyproject.toml | 2 +- .../lifecycle_adapters_for_testing.py | 18 +++ .../test_lifecycle_adapters_end_to_end.py | 86 +++++++++- ui/sdk/.pre-commit-config.yaml | 2 +- ui/sdk/src/hamilton_sdk/adapters.py | 4 + ui/sdk/src/hamilton_sdk/api/clients.py | 5 + 17 files changed, 609 insertions(+), 84 deletions(-) create mode 100644 examples/ray/ray_Hamilton_UI_tracking/README create mode 100644 examples/ray/ray_Hamilton_UI_tracking/hamilton_notebook.ipynb create mode 100644 examples/ray/ray_Hamilton_UI_tracking/ray_lineage.py create mode 100644 examples/ray/ray_Hamilton_UI_tracking/requirements.txt create mode 100644 examples/ray/ray_Hamilton_UI_tracking/run_lineage.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 183541c4f..d15c7ece4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: args: [ --fix ] # Run the formatter. - id: ruff-format -# args: [ --diff ] # Use for previewing changes + # args: [ --diff ] # Use for previewing changes - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: diff --git a/examples/ray/ray_Hamilton_UI_tracking/README b/examples/ray/ray_Hamilton_UI_tracking/README new file mode 100644 index 000000000..baa786687 --- /dev/null +++ b/examples/ray/ray_Hamilton_UI_tracking/README @@ -0,0 +1,29 @@ +# Tracking telemetry in Hamilton UI for Ray clusters + +We show the ability to combine the [RayGraphAdapter](https://hamilton.dagworks.io/en/latest/reference/graph-adapters/RayGraphAdapter/) and [HamiltonTracker](https://hamilton.dagworks.io/en/latest/concepts/ui/) to run a dummy DAG. + +# ray_lineage.py +Has three dummy functions: +- waiting 5s +- waiting 1s +- raising an error + +That represent a basic DAG. + +# run_lineage.py +Is where the driver code lives to create the DAG and exercise it. + +To exercise it: +> Have an open instance of Hamilton UI: https://hamilton.dagworks.io/en/latest/concepts/ui/ + +```bash +python -m run_lineage.py +Usage: python -m run_lineage.py [OPTIONS] COMMAND [ARGS]... + +Options: + --help Show this message and exit. + +Commands: + project_id This command will select the created project in Hamilton UI + username This command will input the correct username to access the selected project_id +``` diff --git a/examples/ray/ray_Hamilton_UI_tracking/hamilton_notebook.ipynb b/examples/ray/ray_Hamilton_UI_tracking/hamilton_notebook.ipynb new file mode 100644 index 000000000..165da38df --- /dev/null +++ b/examples/ray/ray_Hamilton_UI_tracking/hamilton_notebook.ipynb @@ -0,0 +1,107 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Hamilton UI Adapter\n", + "\n", + "Needs a running instance of Hamilton UI: https://hamilton.dagworks.io/en/latest/concepts/ui/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton_sdk.adapters import HamiltonTracker\n", + "\n", + "# Inputs required to track into correct project in the UI\n", + "project_id = 2\n", + "username = \"admin\"\n", + "\n", + "tracker_ray = HamiltonTracker(\n", + " project_id=project_id,\n", + " username=username,\n", + " dag_name=\"telemetry_with_ray\",)\n", + "\n", + "tracker_without_ray = HamiltonTracker(\n", + " project_id=project_id,\n", + " username=username,\n", + " dag_name=\"telemetry_without_ray\",\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ray adapter\n", + "\n", + "https://hamilton.dagworks.io/en/latest/reference/graph-adapters/RayGraphAdapter/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton import base\n", + "from hamilton.plugins.h_ray import RayGraphAdapter\n", + "\n", + "rga = RayGraphAdapter(result_builder=base.PandasDataFrameResult())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Importing Hamilton and the DAG modules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton import driver\n", + "import ray_lineage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " dr_ray = driver.Builder().with_modules(ray_lineage).with_adapters(rga, tracker_ray).build()\n", + " result_ray = dr_ray.execute(\n", + " final_vars=[\n", + " \"node_5s\",\n", + " \"node_1s_error\",\n", + " \"add_1_to_previous\",\n", + " ]\n", + " )\n", + " print(result_ray)\n", + "\n", + "except ValueError:\n", + " print(\"UI should display failure.\")\n", + "finally:\n", + " dr_without_ray = driver.Builder().with_modules(ray_lineage).with_adapters(tracker).build()\n", + " result_without_ray = dr_without_ray.execute(final_vars=[\"node_5s\", \"add_1_to_previous\"])\n", + " print(result_without_ray) \n" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/ray/ray_Hamilton_UI_tracking/ray_lineage.py b/examples/ray/ray_Hamilton_UI_tracking/ray_lineage.py new file mode 100644 index 000000000..d9d20c2ca --- /dev/null +++ b/examples/ray/ray_Hamilton_UI_tracking/ray_lineage.py @@ -0,0 +1,18 @@ +import time + + +def node_5s() -> float: + start = time.time() + time.sleep(5) + return time.time() - start + + +def add_1_to_previous(node_5s: float) -> float: + start = time.time() + time.sleep(1) + return node_5s + (time.time() - start) + + +def node_1s_error(node_5s: float) -> float: + time.sleep(1) + raise ValueError("Does not break telemetry if executed through ray") diff --git a/examples/ray/ray_Hamilton_UI_tracking/requirements.txt b/examples/ray/ray_Hamilton_UI_tracking/requirements.txt new file mode 100644 index 000000000..e6916bb92 --- /dev/null +++ b/examples/ray/ray_Hamilton_UI_tracking/requirements.txt @@ -0,0 +1 @@ +sf-hamilton[ray,sdk,ui] diff --git a/examples/ray/ray_Hamilton_UI_tracking/run_lineage.py b/examples/ray/ray_Hamilton_UI_tracking/run_lineage.py new file mode 100644 index 000000000..6d135efd1 --- /dev/null +++ b/examples/ray/ray_Hamilton_UI_tracking/run_lineage.py @@ -0,0 +1,45 @@ +import click +import ray_lineage + +from hamilton import base, driver +from hamilton.plugins.h_ray import RayGraphAdapter +from hamilton_sdk import adapters + + +@click.command() +@click.option("--username", required=True, type=str) +@click.option("--project_id", default=1, type=int) +def run(project_id, username): + try: + tracker_ray = adapters.HamiltonTracker( + project_id=project_id, + username=username, + dag_name="telemetry_with_ray", + ) + rga = RayGraphAdapter(result_builder=base.PandasDataFrameResult()) + dr_ray = driver.Builder().with_modules(ray_lineage).with_adapters(rga, tracker_ray).build() + result_ray = dr_ray.execute( + final_vars=[ + "node_5s", + "node_1s_error", + "add_1_to_previous", + ] + ) + print(result_ray) + + except ValueError: + print("UI should display failure.") + finally: + tracker = adapters.HamiltonTracker( + project_id=project_id, # modify this as needed + username=username, + dag_name="telemetry_without_ray", + ) + dr_without_ray = driver.Builder().with_modules(ray_lineage).with_adapters(tracker).build() + + result_without_ray = dr_without_ray.execute(final_vars=["node_5s", "add_1_to_previous"]) + print(result_without_ray) + + +if __name__ == "__main__": + run() diff --git a/hamilton/dev_utils/deprecation.py b/hamilton/dev_utils/deprecation.py index e1bbe6d83..2ee2d000d 100644 --- a/hamilton/dev_utils/deprecation.py +++ b/hamilton/dev_utils/deprecation.py @@ -48,8 +48,8 @@ class deprecated: @deprecate( warn_starting=(1,10,0) fail_starting=(2,0,0), - use_instead=parameterize_values, - reason='We have redefined the parameterization decorators to consist of `parametrize`, `parametrize_inputs`, and `parametrize_values` + use_this=parameterize_values, + explanation='We have redefined the parameterization decorators to consist of `parametrize`, `parametrize_inputs`, and `parametrize_values` migration_guide="https://github.com/dagworks-inc/hamilton/..." ) class parameterized(...): @@ -66,7 +66,7 @@ class parameterized(...): explanation: str migration_guide: Optional[ str - ] # If this is None, this means that the use_instead is a drop in replacement + ] # If this is None, this means that the use_this is a drop in replacement current_version: Union[Tuple[int, int, int], Version] = dataclasses.field( default_factory=lambda: CURRENT_VERSION ) diff --git a/hamilton/driver.py b/hamilton/driver.py index a2ba558d5..e76e70afa 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -19,6 +19,7 @@ import pandas as pd from hamilton import common, graph_types, htypes +from hamilton.dev_utils import deprecation from hamilton.execution import executors, graph_functions, grouping, state from hamilton.graph_types import HamiltonNode from hamilton.io import materialization @@ -579,26 +580,52 @@ def execute( "Please use visualize_execution()." ) start_time = time.time() + run_id = str(uuid.uuid4()) run_successful = True - error = None + error_execution = None + error_telemetry = None + outputs = None _final_vars = self._create_final_vars(final_vars) + if self.adapter.does_hook("pre_graph_execute", is_async=False): + self.adapter.call_all_lifecycle_hooks_sync( + "pre_graph_execute", + run_id=run_id, + graph=self.graph, + final_vars=_final_vars, + inputs=inputs, + overrides=overrides, + ) try: - outputs = self.raw_execute(_final_vars, overrides, display_graph, inputs=inputs) + outputs = self.__raw_execute( + _final_vars, overrides, display_graph, inputs=inputs, _run_id=run_id + ) if self.adapter.does_method("do_build_result", is_async=False): # Build the result if we have a result builder - return self.adapter.call_lifecycle_method_sync("do_build_result", outputs=outputs) + outputs = self.adapter.call_lifecycle_method_sync( + "do_build_result", outputs=outputs + ) # Otherwise just return a dict - return outputs except Exception as e: run_successful = False logger.error(SLACK_ERROR_MESSAGE) - error = telemetry.sanitize_error(*sys.exc_info()) + error_execution = e + error_telemetry = telemetry.sanitize_error(*sys.exc_info()) raise e finally: + if self.adapter.does_hook("post_graph_execute", is_async=False): + self.adapter.call_all_lifecycle_hooks_sync( + "post_graph_execute", + run_id=run_id, + graph=self.graph, + success=run_successful, + error=error_execution, + results=outputs, + ) duration = time.time() - start_time self.capture_execute_telemetry( - error, _final_vars, inputs, overrides, run_successful, duration + error_telemetry, _final_vars, inputs, overrides, run_successful, duration ) + return outputs def _create_final_vars(self, final_vars: List[Union[str, Callable, Variable]]) -> List[str]: """Creates the final variables list - converting functions names as required. @@ -649,6 +676,13 @@ def capture_execute_telemetry( if logger.isEnabledFor(logging.DEBUG): logger.debug(f"Error caught in processing telemetry: \n{e}") + @deprecation.deprecated( + warn_starting=(1, 0, 0), + fail_starting=(2, 0, 0), + use_this=None, + explanation="This has become a private method and does not guarantee that all the adapters work correctly.", + migration_guide="Don't use this entry point for execution directly. Always go through `.execute()`or `.materialize()`.", + ) def raw_execute( self, final_vars: List[str], @@ -659,7 +693,7 @@ def raw_execute( ) -> Dict[str, Any]: """Raw execute function that does the meat of execute. - Don't use this entry point for execution directly. Always go through `.execute()`. + Don't use this entry point for execution directly. Always go through `.execute()` or `.materialize()`. In case you are using `.raw_execute()` directly, please switch to `.execute()` using a `base.DictResult()`. Note: `base.DictResult()` is the default return of execute if you are using the `driver.Builder()` class to create a `Driver()` object. @@ -725,6 +759,56 @@ def raw_execute( ) return results + def __raw_execute( + self, + final_vars: List[str], + overrides: Dict[str, Any] = None, + display_graph: bool = False, + inputs: Dict[str, Any] = None, + _fn_graph: graph.FunctionGraph = None, + _run_id: str = None, + ) -> Dict[str, Any]: + """Raw execute function that does the meat of execute. + + Private method since the result building and post_graph_execute lifecycle hooks are performed outside and so this returns an incomplete result. + + :param final_vars: Final variables to compute + :param overrides: Overrides to run. + :param display_graph: DEPRECATED. DO NOT USE. Whether or not to display the graph when running it + :param inputs: Runtime inputs to the DAG + :return: + """ + function_graph = _fn_graph if _fn_graph is not None else self.graph + run_id = _run_id + nodes, user_nodes = function_graph.get_upstream_nodes(final_vars, inputs, overrides) + Driver.validate_inputs( + function_graph, self.adapter, user_nodes, inputs, nodes + ) # TODO -- validate within the function graph itself + if display_graph: # deprecated flow. + logger.warning( + "display_graph=True is deprecated. It will be removed in the 2.0.0 release. " + "Please use visualize_execution()." + ) + self.visualize_execution(final_vars, "test-output/execute.gv", {"view": True}) + if self.has_cycles( + final_vars, function_graph + ): # here for backwards compatible driver behavior. + raise ValueError("Error: cycles detected in your graph.") + all_nodes = nodes | user_nodes + self.graph_executor.validate(list(all_nodes)) + results = None + try: + results = self.graph_executor.execute( + function_graph, + final_vars, + overrides if overrides is not None else {}, + inputs if inputs is not None else {}, + run_id, + ) + return results + except Exception as e: + raise e + @capture_function_usage def list_available_variables( self, *, tag_filter: Dict[str, Union[Optional[str], List[str]]] = None @@ -1516,8 +1600,10 @@ def materialize( additional_vars = [] start_time = time.time() run_successful = True - error = None - + error_execution = None + error_telemetry = None + run_id = str(uuid.uuid4()) + outputs = (None, None) final_vars = self._create_final_vars(additional_vars) # This is so the finally logging statement does not accidentally die materializer_vars = [] @@ -1544,32 +1630,58 @@ def materialize( # Note we will not run the loaders if they're not upstream of the # materializers or additional_vars materializer_vars = [m.id for m in materializer_factories] + if self.adapter.does_hook("pre_graph_execute", is_async=False): + self.adapter.call_all_lifecycle_hooks_sync( + "pre_graph_execute", + run_id=run_id, + graph=function_graph, + final_vars=final_vars + materializer_vars, + inputs=inputs, + overrides=overrides, + ) + nodes, user_nodes = function_graph.get_upstream_nodes( final_vars + materializer_vars, inputs, overrides ) Driver.validate_inputs(function_graph, self.adapter, user_nodes, inputs, nodes) all_nodes = nodes | user_nodes self.graph_executor.validate(list(all_nodes)) - raw_results = self.raw_execute( + raw_results = self.__raw_execute( final_vars=final_vars + materializer_vars, inputs=inputs, overrides=overrides, _fn_graph=function_graph, + _run_id=run_id, ) materialization_output = {key: raw_results[key] for key in materializer_vars} raw_results_output = {key: raw_results[key] for key in final_vars} - - return materialization_output, raw_results_output + outputs = materialization_output, raw_results_output except Exception as e: run_successful = False logger.error(SLACK_ERROR_MESSAGE) - error = telemetry.sanitize_error(*sys.exc_info()) + error_telemetry = telemetry.sanitize_error(*sys.exc_info()) + error_execution = e raise e finally: + if self.adapter.does_hook("post_graph_execute", is_async=False): + self.adapter.call_all_lifecycle_hooks_sync( + "post_graph_execute", + run_id=run_id, + graph=function_graph, + success=run_successful, + error=error_execution, + results=outputs[1], + ) duration = time.time() - start_time self.capture_execute_telemetry( - error, final_vars + materializer_vars, inputs, overrides, run_successful, duration + error_telemetry, + final_vars + materializer_vars, + inputs, + overrides, + run_successful, + duration, ) + return outputs @capture_function_usage def visualize_materialization( diff --git a/hamilton/execution/graph_functions.py b/hamilton/execution/graph_functions.py index 3a86d8b60..b3379cdde 100644 --- a/hamilton/execution/graph_functions.py +++ b/hamilton/execution/graph_functions.py @@ -1,5 +1,6 @@ import logging import pprint +from functools import partial from typing import Any, Collection, Dict, List, Optional, Set, Tuple from hamilton import node @@ -201,59 +202,24 @@ def dfs_traverse( for dependency in node_.dependencies: if dependency.name in computed: kwargs[dependency.name] = computed[dependency.name] - error = None - result = None - success = True - pre_node_execute_errored = False - try: - if adapter.does_hook("pre_node_execute", is_async=False): - try: - adapter.call_all_lifecycle_hooks_sync( - "pre_node_execute", - run_id=run_id, - node_=node_, - kwargs=kwargs, - task_id=task_id, - ) - except Exception as e: - pre_node_execute_errored = True - raise e - if adapter.does_method("do_node_execute", is_async=False): - result = adapter.call_lifecycle_method_sync( - "do_node_execute", - run_id=run_id, - node_=node_, - kwargs=kwargs, - task_id=task_id, - ) - else: - result = node_(**kwargs) - except Exception as e: - success = False - error = e - step = "[pre-node-execute]" if pre_node_execute_errored else "" - message = create_error_message(kwargs, node_, step) - logger.exception(message) - raise - finally: - if not pre_node_execute_errored and adapter.does_hook( - "post_node_execute", is_async=False - ): - try: - adapter.call_all_lifecycle_hooks_sync( - "post_node_execute", - run_id=run_id, - node_=node_, - kwargs=kwargs, - success=success, - error=error, - result=result, - task_id=task_id, - ) - except Exception: - message = create_error_message(kwargs, node_, "[post-node-execute]") - logger.exception(message) - raise + + execute_lifecycle_for_node_partial = partial( + execute_lifecycle_for_node, + __node_=node_, + __adapter=adapter, + __run_id=run_id, + __task_id=task_id, + ) + + if adapter.does_method("do_remote_execute", is_async=False): + result = adapter.call_lifecycle_method_sync( + "do_remote_execute", + node=node_, + execute_lifecycle_for_node=execute_lifecycle_for_node_partial, + **kwargs, + ) + else: + result = execute_lifecycle_for_node_partial(**kwargs) computed[node_.name] = result # > pruning the graph @@ -285,6 +251,86 @@ def dfs_traverse( return computed +# TODO: better function name +def execute_lifecycle_for_node( + __node_: node.Node, + __adapter: LifecycleAdapterSet, + __run_id: str, + __task_id: str, + **__kwargs: Dict[str, Any], +): + """Helper function to properly execute node lifecycle. + + Firstly, we execute the pre-node-execute hooks if supplied adapters have any, then we execute the node function, and lastly, we execute the post-node-execute hooks if present in the adapters. + + For local runtime gets execute directy. Otherwise, serves as a sandwich function that guarantees the pre_node and post_node lifecycle hooks are executed in the remote environment. + + :param __node_: Node that is being executed + :param __adapter: Adapter to use to compute + :param __run_id: ID of the run, unique in scope of the driver. + :param __task_id: ID of the task, defaults to None if not in a task setting + :param ___kwargs: Keyword arguments that are being passed into the node + """ + + error = None + result = None + success = True + pre_node_execute_errored = False + + try: + if __adapter.does_hook("pre_node_execute", is_async=False): + try: + __adapter.call_all_lifecycle_hooks_sync( + "pre_node_execute", + run_id=__run_id, + node_=__node_, + kwargs=__kwargs, + task_id=__task_id, + ) + except Exception as e: + pre_node_execute_errored = True + raise e + if __adapter.does_method("do_node_execute", is_async=False): + result = __adapter.call_lifecycle_method_sync( + "do_node_execute", + run_id=__run_id, + node_=__node_, + kwargs=__kwargs, + task_id=__task_id, + ) + else: + result = __node_(**__kwargs) + + return result + + except Exception as e: + success = False + error = e + step = "[pre-node-execute]" if pre_node_execute_errored else "" + message = create_error_message(__kwargs, __node_, step) + logger.exception(message) + raise + finally: + if not pre_node_execute_errored and __adapter.does_hook( + "post_node_execute", is_async=False + ): + try: + __adapter.call_all_lifecycle_hooks_sync( + "post_node_execute", + run_id=__run_id, + node_=__node_, + kwargs=__kwargs, + success=success, + error=error, + result=result, + task_id=__task_id, + ) + except Exception: + message = create_error_message(__kwargs, __node_, "[post-node-execute]") + logger.exception(message) + raise + + def nodes_between( end_node: node.Node, search_condition: lambda node_: bool, diff --git a/hamilton/lifecycle/base.py b/hamilton/lifecycle/base.py index ef9d07387..12ea36b2a 100644 --- a/hamilton/lifecycle/base.py +++ b/hamilton/lifecycle/base.py @@ -519,6 +519,26 @@ def do_node_execute( pass +@lifecycle.base_method("do_remote_execute") +class BaseDoRemoteExecute(abc.ABC): + @abc.abstractmethod + def do_remote_execute( + self, + *, + node: "node.Node", + kwargs: Dict[str, Any], + execute_lifecycle_for_node: Callable, + ) -> Any: + """Method that is called to implement correct remote execution of hooks. This makes sure that all the pre-node and post-node hooks get executed in the remote environment which is necessary for some adapters. Node execution is called the same as before through "do_node_execute". + + + :param node: Node that is being executed + :param kwargs: Keyword arguments that are being passed into the node + :param execute_lifecycle_for_node: Function executing lifecycle_hooks and lifecycle_methods + """ + pass + + @lifecycle.base_method("do_node_execute") class BaseDoNodeExecuteAsync(abc.ABC): @abc.abstractmethod diff --git a/hamilton/plugins/h_ray.py b/hamilton/plugins/h_ray.py index 0f0692803..66d0a9255 100644 --- a/hamilton/plugins/h_ray.py +++ b/hamilton/plugins/h_ray.py @@ -1,12 +1,14 @@ +import abc import functools import json import logging +import time import typing import ray from ray import workflow -from hamilton import base, htypes, node +from hamilton import base, htypes, lifecycle, node from hamilton.execution import executors from hamilton.execution.executors import TaskFuture from hamilton.execution.grouping import TaskImplementation @@ -50,7 +52,14 @@ def parse_ray_remote_options_from_tags(tags: typing.Dict[str, str]) -> typing.Di return ray_options -class RayGraphAdapter(base.HamiltonGraphAdapter, base.ResultMixin): +class RayGraphAdapter( + lifecycle.base.BaseDoRemoteExecute, + lifecycle.base.BaseDoBuildResult, + lifecycle.base.BaseDoValidateInput, + lifecycle.base.BaseDoCheckEdgeTypesMatch, + lifecycle.base.BasePostGraphExecute, + abc.ABC, +): """Class representing what's required to make Hamilton run on Ray. This walks the graph and translates it to run onto `Ray `__. @@ -86,13 +95,21 @@ class RayGraphAdapter(base.HamiltonGraphAdapter, base.ResultMixin): DISCLAIMER -- this class is experimental, so signature changes are a possibility! """ - def __init__(self, result_builder: base.ResultMixin): + def __init__( + self, + result_builder: base.ResultMixin, + ray_init_config: typing.Dict[str, typing.Any] = None, + shutdown_ray_on_completion: bool = False, + ): """Constructor You have the ability to pass in a ResultMixin object to the constructor to control the return type that gets \ produce by running on Ray. :param result_builder: Required. An implementation of base.ResultMixin. + :param ray_init_config: allows to connect to an existing cluster or start a new one with custom configuration (https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html) + :param shutdown_ray_on_completion: by default we leave the cluster open, but we can also shut it down + """ self.result_builder = result_builder if not self.result_builder: @@ -100,28 +117,39 @@ def __init__(self, result_builder: base.ResultMixin): "Error: ResultMixin object required. Please pass one in for `result_builder`." ) + self.shutdown_ray_on_completion = shutdown_ray_on_completion + + if ray_init_config is not None: + ray.init(**ray_init_config) + @staticmethod - def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool: + def do_validate_input(node_type: typing.Type, input_value: typing.Any) -> bool: # NOTE: the type of a raylet is unknown until they are computed if isinstance(input_value, ray._raylet.ObjectRef): return True return htypes.check_input_type(node_type, input_value) @staticmethod - def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool: - return node_type == input_type + def do_check_edge_types_match(type_from: typing.Type, type_to: typing.Type) -> bool: + return type_from == type_to - def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any: + def do_remote_execute( + self, + *, + execute_lifecycle_for_node: typing.Callable, + node: node.Node, + **kwargs: typing.Dict[str, typing.Any], + ) -> typing.Any: """Function that is called as we walk the graph to determine how to execute a hamilton function. - :param node: the node from the graph. + :param execute_lifecycle_for_node: wrapper function that executes lifecycle hooks and methods :param kwargs: the arguments that should be passed to it. :return: returns a ray object reference. """ ray_options = parse_ray_remote_options_from_tags(node.tags) - return ray.remote(raify(node.callable)).options(**ray_options).remote(**kwargs) + return ray.remote(raify(execute_lifecycle_for_node)).options(**ray_options).remote(**kwargs) - def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any: + def do_build_result(self, outputs: typing.Dict[str, typing.Any]) -> typing.Any: """Builds the result and brings it back to this running process. :param outputs: the dictionary of key -> Union[ray object reference | value] @@ -135,6 +163,14 @@ def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any: result = ray.get(remote_combine) # this materializes the object locally return result + def post_graph_execute(self, *args, **kwargs): + """We have the option to close the cluster down after execution.""" + + if self.shutdown_ray_on_completion: + # In case we have Hamilton Tracker to have enough time to properly flush + time.sleep(5) + ray.shutdown() + class RayWorkflowGraphAdapter(base.HamiltonGraphAdapter, base.ResultMixin): """Class representing what's required to make Hamilton run Ray Workflows diff --git a/pyproject.toml b/pyproject.toml index 1a61e8aae..e41829d23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dask-distributed = ["dask[distributed]"] datadog = ["ddtrace"] dev = [ "pre-commit", - "ruff", + "ruff==0.5.7", # this should match `.pre-commit-config.yaml` ] diskcache = ["diskcache"] docs = [ diff --git a/tests/lifecycle/lifecycle_adapters_for_testing.py b/tests/lifecycle/lifecycle_adapters_for_testing.py index a151c4ecc..37f3b3961 100644 --- a/tests/lifecycle/lifecycle_adapters_for_testing.py +++ b/tests/lifecycle/lifecycle_adapters_for_testing.py @@ -9,6 +9,7 @@ from hamilton.lifecycle.base import ( BaseDoBuildResult, BaseDoNodeExecute, + BaseDoRemoteExecute, BaseDoValidateInput, BasePostGraphConstruct, BasePostGraphExecute, @@ -187,6 +188,23 @@ def do_node_execute( return node_(**kwargs) +class TrackingDoRemoteExecuteHook(ExtendToTrackCalls, BaseDoRemoteExecute): + def __init__(self, name: str, additional_value: int): + super().__init__(name) + self._additional_value = additional_value + + def do_remote_execute( + self, + node: "node.Node", + execute_lifecycle_for_node: Callable, + **kwargs: Dict[str, Any], + ) -> Any: + node_ = node + if node_.type == int and node_.name != "n_iters": + return execute_lifecycle_for_node(**kwargs) + self._additional_value + return execute_lifecycle_for_node(**kwargs) + + class TrackingDoBuildResultMethod(ExtendToTrackCalls, BaseDoBuildResult): def __init__(self, name: str, result: Any): super().__init__(name) diff --git a/tests/lifecycle/test_lifecycle_adapters_end_to_end.py b/tests/lifecycle/test_lifecycle_adapters_end_to_end.py index cf57c02c0..a753ea874 100644 --- a/tests/lifecycle/test_lifecycle_adapters_end_to_end.py +++ b/tests/lifecycle/test_lifecycle_adapters_end_to_end.py @@ -1,5 +1,5 @@ from types import ModuleType -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import pytest @@ -7,6 +7,7 @@ from hamilton.io.materialization import to from hamilton.lifecycle.base import ( BaseDoNodeExecute, + BaseDoRemoteExecute, BasePostGraphConstruct, BasePostGraphExecute, BasePostNodeExecute, @@ -22,6 +23,7 @@ SentinelException, TrackingDoBuildResultMethod, TrackingDoNodeExecuteHook, + TrackingDoRemoteExecuteHook, TrackingDoValidateInputMethod, TrackingPostGraphConstructHook, TrackingPostGraphExecuteHook, @@ -343,6 +345,78 @@ def post_graph_execute( assert len(calls) == 16 +def test_multi_hook_remote(): + class MultiHook( + BasePreDoAnythingHook, + BasePostGraphConstruct, + BasePreGraphExecute, + BasePreNodeExecute, + BaseDoRemoteExecute, + BasePostNodeExecute, + BasePostGraphExecute, + ExtendToTrackCalls, + ): + def do_remote_execute( + self, + node: node.Node, + execute_lifecycle_for_node: Callable, + **kwargs: Dict[str, Any], + ): + return execute_lifecycle_for_node(**kwargs) + + def pre_do_anything(self): + pass + + def post_graph_construct( + self, graph: "FunctionGraph", modules: List[ModuleType], config: Dict[str, Any] + ): + pass + + def pre_graph_execute( + self, + run_id: str, + graph: "FunctionGraph", + final_vars: List[str], + inputs: Dict[str, Any], + overrides: Dict[str, Any], + ): + pass + + def pre_node_execute( + self, run_id: str, node_: Node, kwargs: Dict[str, Any], task_id: Optional[str] = None + ): + pass + + def post_node_execute( + self, + run_id: str, + node_: node.Node, + kwargs: Dict[str, Any], + success: bool, + error: Optional[Exception], + result: Optional[Any], + task_id: Optional[str] = None, + ): + pass + + def post_graph_execute( + self, + run_id: str, + graph: "FunctionGraph", + success: bool, + error: Optional[Exception], + results: Optional[Dict[str, Any]], + ): + pass + + multi_hook = MultiHook(name="multi_hook") + + dr = _sample_driver(multi_hook) + dr.execute(["d"], inputs={"input": 1}) + calls = multi_hook.calls + assert len(calls) == 16 + + def test_individual_do_validate_input_method(): method_name = "do_validate_input" method = TrackingDoValidateInputMethod(name=method_name, valid=True) @@ -378,6 +452,16 @@ def test_individual_do_node_execute_method(): assert res == {"d": 17**3 + 1} # adding one to each one +def test_individual_do_remote_execute_method(): + method_name = "do_remote_execute" + method = TrackingDoRemoteExecuteHook(name=method_name, additional_value=1) + dr = _sample_driver(method) + res = dr.execute(["d"], inputs={"input": 1}) + relevant_calls = [item for item in method.calls if item.name == method_name] + assert len(relevant_calls) == 4 + assert res == {"d": 17**3 + 1} # adding one to each one + + def test_individual_do_build_results_method(): method_name = "do_build_result" method = TrackingDoBuildResultMethod(name=method_name, result=-1) diff --git a/ui/sdk/.pre-commit-config.yaml b/ui/sdk/.pre-commit-config.yaml index 75a5287c3..6cddba381 100644 --- a/ui/sdk/.pre-commit-config.yaml +++ b/ui/sdk/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: rev: v0.0.265 hooks: - id: ruff - args: [ --fix, --exit-non-zero-on-fix ] + args: [ --fix , --exit-non-zero-on-fix ] - repo: https://github.com/ambv/black rev: 23.3.0 hooks: diff --git a/ui/sdk/src/hamilton_sdk/adapters.py b/ui/sdk/src/hamilton_sdk/adapters.py index 995143e64..23cbb3a2c 100644 --- a/ui/sdk/src/hamilton_sdk/adapters.py +++ b/ui/sdk/src/hamilton_sdk/adapters.py @@ -103,6 +103,10 @@ def __init__( # if you're using a float value. self.seed = None + def stop(self): + """Initiates stop if run in remote environment""" + self.client.stop() + def post_graph_construct( self, graph: h_graph.FunctionGraph, modules: List[ModuleType], config: Dict[str, Any] ): diff --git a/ui/sdk/src/hamilton_sdk/api/clients.py b/ui/sdk/src/hamilton_sdk/api/clients.py index 87e687027..551216495 100644 --- a/ui/sdk/src/hamilton_sdk/api/clients.py +++ b/ui/sdk/src/hamilton_sdk/api/clients.py @@ -1,5 +1,6 @@ import abc import asyncio +import atexit import datetime import functools import logging @@ -196,6 +197,9 @@ def __init__( ).start() self.worker_thread.start() + # Makes sure the process stops even if in remote environment + atexit.register(self.stop) + def __getstate__(self): # Copy the object's state from self.__dict__ which contains # all our instance attributes. Always use the dict.copy() @@ -217,6 +221,7 @@ def __setstate__(self, state): target=lambda: threading.main_thread().join() or self.data_queue.put(None) ).start() self.worker_thread.start() + atexit.register(self.stop) def worker(self): """Worker thread to process the queue."""