diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 183541c4f..d15c7ece4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: args: [ --fix ] # Run the formatter. - id: ruff-format -# args: [ --diff ] # Use for previewing changes + # args: [ --diff ] # Use for previewing changes - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: diff --git a/examples/ray/ray_Hamilton_UI_tracking/README b/examples/ray/ray_Hamilton_UI_tracking/README new file mode 100644 index 000000000..baa786687 --- /dev/null +++ b/examples/ray/ray_Hamilton_UI_tracking/README @@ -0,0 +1,29 @@ +# Tracking telemetry in Hamilton UI for Ray clusters + +We show the ability to combine the [RayGraphAdapter](https://hamilton.dagworks.io/en/latest/reference/graph-adapters/RayGraphAdapter/) and [HamiltonTracker](https://hamilton.dagworks.io/en/latest/concepts/ui/) to run a dummy DAG. + +# ray_lineage.py +Has three dummy functions: +- waiting 5s +- waiting 1s +- raising an error + +That represent a basic DAG. + +# run_lineage.py +Is where the driver code lives to create the DAG and exercise it. + +To exercise it: +> Have an open instance of Hamilton UI: https://hamilton.dagworks.io/en/latest/concepts/ui/ + +```bash +python -m run_lineage.py +Usage: python -m run_lineage.py [OPTIONS] COMMAND [ARGS]... + +Options: + --help Show this message and exit. + +Commands: + project_id This command will select the created project in Hamilton UI + username This command will input the correct username to access the selected project_id +``` diff --git a/examples/ray/ray_Hamilton_UI_tracking/hamilton_notebook.ipynb b/examples/ray/ray_Hamilton_UI_tracking/hamilton_notebook.ipynb new file mode 100644 index 000000000..165da38df --- /dev/null +++ b/examples/ray/ray_Hamilton_UI_tracking/hamilton_notebook.ipynb @@ -0,0 +1,107 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Hamilton UI Adapter\n", + "\n", + "Needs a running instance of Hamilton UI: https://hamilton.dagworks.io/en/latest/concepts/ui/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton_sdk.adapters import HamiltonTracker\n", + "\n", + "# Inputs required to track into correct project in the UI\n", + "project_id = 2\n", + "username = \"admin\"\n", + "\n", + "tracker_ray = HamiltonTracker(\n", + " project_id=project_id,\n", + " username=username,\n", + " dag_name=\"telemetry_with_ray\",)\n", + "\n", + "tracker_without_ray = HamiltonTracker(\n", + " project_id=project_id,\n", + " username=username,\n", + " dag_name=\"telemetry_without_ray\",\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ray adapter\n", + "\n", + "https://hamilton.dagworks.io/en/latest/reference/graph-adapters/RayGraphAdapter/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton import base\n", + "from hamilton.plugins.h_ray import RayGraphAdapter\n", + "\n", + "rga = RayGraphAdapter(result_builder=base.PandasDataFrameResult())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Importing Hamilton and the DAG modules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton import driver\n", + "import ray_lineage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " dr_ray = driver.Builder().with_modules(ray_lineage).with_adapters(rga, tracker_ray).build()\n", + " result_ray = dr_ray.execute(\n", + " final_vars=[\n", + " \"node_5s\",\n", + " \"node_1s_error\",\n", + " \"add_1_to_previous\",\n", + " ]\n", + " )\n", + " print(result_ray)\n", + "\n", + "except ValueError:\n", + " print(\"UI should display failure.\")\n", + "finally:\n", + " dr_without_ray = driver.Builder().with_modules(ray_lineage).with_adapters(tracker).build()\n", + " result_without_ray = dr_without_ray.execute(final_vars=[\"node_5s\", \"add_1_to_previous\"])\n", + " print(result_without_ray) \n" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/ray/ray_Hamilton_UI_tracking/ray_lineage.py b/examples/ray/ray_Hamilton_UI_tracking/ray_lineage.py new file mode 100644 index 000000000..d9d20c2ca --- /dev/null +++ b/examples/ray/ray_Hamilton_UI_tracking/ray_lineage.py @@ -0,0 +1,18 @@ +import time + + +def node_5s() -> float: + start = time.time() + time.sleep(5) + return time.time() - start + + +def add_1_to_previous(node_5s: float) -> float: + start = time.time() + time.sleep(1) + return node_5s + (time.time() - start) + + +def node_1s_error(node_5s: float) -> float: + time.sleep(1) + raise ValueError("Does not break telemetry if executed through ray") diff --git a/examples/ray/ray_Hamilton_UI_tracking/requirements.txt b/examples/ray/ray_Hamilton_UI_tracking/requirements.txt new file mode 100644 index 000000000..e6916bb92 --- /dev/null +++ b/examples/ray/ray_Hamilton_UI_tracking/requirements.txt @@ -0,0 +1 @@ +sf-hamilton[ray,sdk,ui] diff --git a/examples/ray/ray_Hamilton_UI_tracking/run_lineage.py b/examples/ray/ray_Hamilton_UI_tracking/run_lineage.py new file mode 100644 index 000000000..6d135efd1 --- /dev/null +++ b/examples/ray/ray_Hamilton_UI_tracking/run_lineage.py @@ -0,0 +1,45 @@ +import click +import ray_lineage + +from hamilton import base, driver +from hamilton.plugins.h_ray import RayGraphAdapter +from hamilton_sdk import adapters + + +@click.command() +@click.option("--username", required=True, type=str) +@click.option("--project_id", default=1, type=int) +def run(project_id, username): + try: + tracker_ray = adapters.HamiltonTracker( + project_id=project_id, + username=username, + dag_name="telemetry_with_ray", + ) + rga = RayGraphAdapter(result_builder=base.PandasDataFrameResult()) + dr_ray = driver.Builder().with_modules(ray_lineage).with_adapters(rga, tracker_ray).build() + result_ray = dr_ray.execute( + final_vars=[ + "node_5s", + "node_1s_error", + "add_1_to_previous", + ] + ) + print(result_ray) + + except ValueError: + print("UI should display failure.") + finally: + tracker = adapters.HamiltonTracker( + project_id=project_id, # modify this as needed + username=username, + dag_name="telemetry_without_ray", + ) + dr_without_ray = driver.Builder().with_modules(ray_lineage).with_adapters(tracker).build() + + result_without_ray = dr_without_ray.execute(final_vars=["node_5s", "add_1_to_previous"]) + print(result_without_ray) + + +if __name__ == "__main__": + run() diff --git a/hamilton/dev_utils/deprecation.py b/hamilton/dev_utils/deprecation.py index e1bbe6d83..2ee2d000d 100644 --- a/hamilton/dev_utils/deprecation.py +++ b/hamilton/dev_utils/deprecation.py @@ -48,8 +48,8 @@ class deprecated: @deprecate( warn_starting=(1,10,0) fail_starting=(2,0,0), - use_instead=parameterize_values, - reason='We have redefined the parameterization decorators to consist of `parametrize`, `parametrize_inputs`, and `parametrize_values` + use_this=parameterize_values, + explanation='We have redefined the parameterization decorators to consist of `parametrize`, `parametrize_inputs`, and `parametrize_values` migration_guide="https://github.com/dagworks-inc/hamilton/..." ) class parameterized(...): @@ -66,7 +66,7 @@ class parameterized(...): explanation: str migration_guide: Optional[ str - ] # If this is None, this means that the use_instead is a drop in replacement + ] # If this is None, this means that the use_this is a drop in replacement current_version: Union[Tuple[int, int, int], Version] = dataclasses.field( default_factory=lambda: CURRENT_VERSION ) diff --git a/hamilton/driver.py b/hamilton/driver.py index a2ba558d5..e76e70afa 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -19,6 +19,7 @@ import pandas as pd from hamilton import common, graph_types, htypes +from hamilton.dev_utils import deprecation from hamilton.execution import executors, graph_functions, grouping, state from hamilton.graph_types import HamiltonNode from hamilton.io import materialization @@ -579,26 +580,52 @@ def execute( "Please use visualize_execution()." ) start_time = time.time() + run_id = str(uuid.uuid4()) run_successful = True - error = None + error_execution = None + error_telemetry = None + outputs = None _final_vars = self._create_final_vars(final_vars) + if self.adapter.does_hook("pre_graph_execute", is_async=False): + self.adapter.call_all_lifecycle_hooks_sync( + "pre_graph_execute", + run_id=run_id, + graph=self.graph, + final_vars=_final_vars, + inputs=inputs, + overrides=overrides, + ) try: - outputs = self.raw_execute(_final_vars, overrides, display_graph, inputs=inputs) + outputs = self.__raw_execute( + _final_vars, overrides, display_graph, inputs=inputs, _run_id=run_id + ) if self.adapter.does_method("do_build_result", is_async=False): # Build the result if we have a result builder - return self.adapter.call_lifecycle_method_sync("do_build_result", outputs=outputs) + outputs = self.adapter.call_lifecycle_method_sync( + "do_build_result", outputs=outputs + ) # Otherwise just return a dict - return outputs except Exception as e: run_successful = False logger.error(SLACK_ERROR_MESSAGE) - error = telemetry.sanitize_error(*sys.exc_info()) + error_execution = e + error_telemetry = telemetry.sanitize_error(*sys.exc_info()) raise e finally: + if self.adapter.does_hook("post_graph_execute", is_async=False): + self.adapter.call_all_lifecycle_hooks_sync( + "post_graph_execute", + run_id=run_id, + graph=self.graph, + success=run_successful, + error=error_execution, + results=outputs, + ) duration = time.time() - start_time self.capture_execute_telemetry( - error, _final_vars, inputs, overrides, run_successful, duration + error_telemetry, _final_vars, inputs, overrides, run_successful, duration ) + return outputs def _create_final_vars(self, final_vars: List[Union[str, Callable, Variable]]) -> List[str]: """Creates the final variables list - converting functions names as required. @@ -649,6 +676,13 @@ def capture_execute_telemetry( if logger.isEnabledFor(logging.DEBUG): logger.debug(f"Error caught in processing telemetry: \n{e}") + @deprecation.deprecated( + warn_starting=(1, 0, 0), + fail_starting=(2, 0, 0), + use_this=None, + explanation="This has become a private method and does not guarantee that all the adapters work correctly.", + migration_guide="Don't use this entry point for execution directly. Always go through `.execute()`or `.materialize()`.", + ) def raw_execute( self, final_vars: List[str], @@ -659,7 +693,7 @@ def raw_execute( ) -> Dict[str, Any]: """Raw execute function that does the meat of execute. - Don't use this entry point for execution directly. Always go through `.execute()`. + Don't use this entry point for execution directly. Always go through `.execute()` or `.materialize()`. In case you are using `.raw_execute()` directly, please switch to `.execute()` using a `base.DictResult()`. Note: `base.DictResult()` is the default return of execute if you are using the `driver.Builder()` class to create a `Driver()` object. @@ -725,6 +759,56 @@ def raw_execute( ) return results + def __raw_execute( + self, + final_vars: List[str], + overrides: Dict[str, Any] = None, + display_graph: bool = False, + inputs: Dict[str, Any] = None, + _fn_graph: graph.FunctionGraph = None, + _run_id: str = None, + ) -> Dict[str, Any]: + """Raw execute function that does the meat of execute. + + Private method since the result building and post_graph_execute lifecycle hooks are performed outside and so this returns an incomplete result. + + :param final_vars: Final variables to compute + :param overrides: Overrides to run. + :param display_graph: DEPRECATED. DO NOT USE. Whether or not to display the graph when running it + :param inputs: Runtime inputs to the DAG + :return: + """ + function_graph = _fn_graph if _fn_graph is not None else self.graph + run_id = _run_id + nodes, user_nodes = function_graph.get_upstream_nodes(final_vars, inputs, overrides) + Driver.validate_inputs( + function_graph, self.adapter, user_nodes, inputs, nodes + ) # TODO -- validate within the function graph itself + if display_graph: # deprecated flow. + logger.warning( + "display_graph=True is deprecated. It will be removed in the 2.0.0 release. " + "Please use visualize_execution()." + ) + self.visualize_execution(final_vars, "test-output/execute.gv", {"view": True}) + if self.has_cycles( + final_vars, function_graph + ): # here for backwards compatible driver behavior. + raise ValueError("Error: cycles detected in your graph.") + all_nodes = nodes | user_nodes + self.graph_executor.validate(list(all_nodes)) + results = None + try: + results = self.graph_executor.execute( + function_graph, + final_vars, + overrides if overrides is not None else {}, + inputs if inputs is not None else {}, + run_id, + ) + return results + except Exception as e: + raise e + @capture_function_usage def list_available_variables( self, *, tag_filter: Dict[str, Union[Optional[str], List[str]]] = None @@ -1516,8 +1600,10 @@ def materialize( additional_vars = [] start_time = time.time() run_successful = True - error = None - + error_execution = None + error_telemetry = None + run_id = str(uuid.uuid4()) + outputs = (None, None) final_vars = self._create_final_vars(additional_vars) # This is so the finally logging statement does not accidentally die materializer_vars = [] @@ -1544,32 +1630,58 @@ def materialize( # Note we will not run the loaders if they're not upstream of the # materializers or additional_vars materializer_vars = [m.id for m in materializer_factories] + if self.adapter.does_hook("pre_graph_execute", is_async=False): + self.adapter.call_all_lifecycle_hooks_sync( + "pre_graph_execute", + run_id=run_id, + graph=function_graph, + final_vars=final_vars + materializer_vars, + inputs=inputs, + overrides=overrides, + ) + nodes, user_nodes = function_graph.get_upstream_nodes( final_vars + materializer_vars, inputs, overrides ) Driver.validate_inputs(function_graph, self.adapter, user_nodes, inputs, nodes) all_nodes = nodes | user_nodes self.graph_executor.validate(list(all_nodes)) - raw_results = self.raw_execute( + raw_results = self.__raw_execute( final_vars=final_vars + materializer_vars, inputs=inputs, overrides=overrides, _fn_graph=function_graph, + _run_id=run_id, ) materialization_output = {key: raw_results[key] for key in materializer_vars} raw_results_output = {key: raw_results[key] for key in final_vars} - - return materialization_output, raw_results_output + outputs = materialization_output, raw_results_output except Exception as e: run_successful = False logger.error(SLACK_ERROR_MESSAGE) - error = telemetry.sanitize_error(*sys.exc_info()) + error_telemetry = telemetry.sanitize_error(*sys.exc_info()) + error_execution = e raise e finally: + if self.adapter.does_hook("post_graph_execute", is_async=False): + self.adapter.call_all_lifecycle_hooks_sync( + "post_graph_execute", + run_id=run_id, + graph=function_graph, + success=run_successful, + error=error_execution, + results=outputs[1], + ) duration = time.time() - start_time self.capture_execute_telemetry( - error, final_vars + materializer_vars, inputs, overrides, run_successful, duration + error_telemetry, + final_vars + materializer_vars, + inputs, + overrides, + run_successful, + duration, ) + return outputs @capture_function_usage def visualize_materialization( diff --git a/hamilton/execution/graph_functions.py b/hamilton/execution/graph_functions.py index 3a86d8b60..b3379cdde 100644 --- a/hamilton/execution/graph_functions.py +++ b/hamilton/execution/graph_functions.py @@ -1,5 +1,6 @@ import logging import pprint +from functools import partial from typing import Any, Collection, Dict, List, Optional, Set, Tuple from hamilton import node @@ -201,59 +202,24 @@ def dfs_traverse( for dependency in node_.dependencies: if dependency.name in computed: kwargs[dependency.name] = computed[dependency.name] - error = None - result = None - success = True - pre_node_execute_errored = False - try: - if adapter.does_hook("pre_node_execute", is_async=False): - try: - adapter.call_all_lifecycle_hooks_sync( - "pre_node_execute", - run_id=run_id, - node_=node_, - kwargs=kwargs, - task_id=task_id, - ) - except Exception as e: - pre_node_execute_errored = True - raise e - if adapter.does_method("do_node_execute", is_async=False): - result = adapter.call_lifecycle_method_sync( - "do_node_execute", - run_id=run_id, - node_=node_, - kwargs=kwargs, - task_id=task_id, - ) - else: - result = node_(**kwargs) - except Exception as e: - success = False - error = e - step = "[pre-node-execute]" if pre_node_execute_errored else "" - message = create_error_message(kwargs, node_, step) - logger.exception(message) - raise - finally: - if not pre_node_execute_errored and adapter.does_hook( - "post_node_execute", is_async=False - ): - try: - adapter.call_all_lifecycle_hooks_sync( - "post_node_execute", - run_id=run_id, - node_=node_, - kwargs=kwargs, - success=success, - error=error, - result=result, - task_id=task_id, - ) - except Exception: - message = create_error_message(kwargs, node_, "[post-node-execute]") - logger.exception(message) - raise + + execute_lifecycle_for_node_partial = partial( + execute_lifecycle_for_node, + __node_=node_, + __adapter=adapter, + __run_id=run_id, + __task_id=task_id, + ) + + if adapter.does_method("do_remote_execute", is_async=False): + result = adapter.call_lifecycle_method_sync( + "do_remote_execute", + node=node_, + execute_lifecycle_for_node=execute_lifecycle_for_node_partial, + **kwargs, + ) + else: + result = execute_lifecycle_for_node_partial(**kwargs) computed[node_.name] = result # > pruning the graph @@ -285,6 +251,86 @@ def dfs_traverse( return computed +# TODO: better function name +def execute_lifecycle_for_node( + __node_: node.Node, + __adapter: LifecycleAdapterSet, + __run_id: str, + __task_id: str, + **__kwargs: Dict[str, Any], +): + """Helper function to properly execute node lifecycle. + + Firstly, we execute the pre-node-execute hooks if supplied adapters have any, then we execute the node function, and lastly, we execute the post-node-execute hooks if present in the adapters. + + For local runtime gets execute directy. Otherwise, serves as a sandwich function that guarantees the pre_node and post_node lifecycle hooks are executed in the remote environment. + + :param __node_: Node that is being executed + :param __adapter: Adapter to use to compute + :param __run_id: ID of the run, unique in scope of the driver. + :param __task_id: ID of the task, defaults to None if not in a task setting + :param ___kwargs: Keyword arguments that are being passed into the node + """ + + error = None + result = None + success = True + pre_node_execute_errored = False + + try: + if __adapter.does_hook("pre_node_execute", is_async=False): + try: + __adapter.call_all_lifecycle_hooks_sync( + "pre_node_execute", + run_id=__run_id, + node_=__node_, + kwargs=__kwargs, + task_id=__task_id, + ) + except Exception as e: + pre_node_execute_errored = True + raise e + if __adapter.does_method("do_node_execute", is_async=False): + result = __adapter.call_lifecycle_method_sync( + "do_node_execute", + run_id=__run_id, + node_=__node_, + kwargs=__kwargs, + task_id=__task_id, + ) + else: + result = __node_(**__kwargs) + + return result + + except Exception as e: + success = False + error = e + step = "[pre-node-execute]" if pre_node_execute_errored else "" + message = create_error_message(__kwargs, __node_, step) + logger.exception(message) + raise + finally: + if not pre_node_execute_errored and __adapter.does_hook( + "post_node_execute", is_async=False + ): + try: + __adapter.call_all_lifecycle_hooks_sync( + "post_node_execute", + run_id=__run_id, + node_=__node_, + kwargs=__kwargs, + success=success, + error=error, + result=result, + task_id=__task_id, + ) + except Exception: + message = create_error_message(__kwargs, __node_, "[post-node-execute]") + logger.exception(message) + raise + + def nodes_between( end_node: node.Node, search_condition: lambda node_: bool, diff --git a/hamilton/lifecycle/base.py b/hamilton/lifecycle/base.py index ef9d07387..12ea36b2a 100644 --- a/hamilton/lifecycle/base.py +++ b/hamilton/lifecycle/base.py @@ -519,6 +519,26 @@ def do_node_execute( pass +@lifecycle.base_method("do_remote_execute") +class BaseDoRemoteExecute(abc.ABC): + @abc.abstractmethod + def do_remote_execute( + self, + *, + node: "node.Node", + kwargs: Dict[str, Any], + execute_lifecycle_for_node: Callable, + ) -> Any: + """Method that is called to implement correct remote execution of hooks. This makes sure that all the pre-node and post-node hooks get executed in the remote environment which is necessary for some adapters. Node execution is called the same as before through "do_node_execute". + + + :param node: Node that is being executed + :param kwargs: Keyword arguments that are being passed into the node + :param execute_lifecycle_for_node: Function executing lifecycle_hooks and lifecycle_methods + """ + pass + + @lifecycle.base_method("do_node_execute") class BaseDoNodeExecuteAsync(abc.ABC): @abc.abstractmethod diff --git a/hamilton/plugins/h_ray.py b/hamilton/plugins/h_ray.py index 0f0692803..66d0a9255 100644 --- a/hamilton/plugins/h_ray.py +++ b/hamilton/plugins/h_ray.py @@ -1,12 +1,14 @@ +import abc import functools import json import logging +import time import typing import ray from ray import workflow -from hamilton import base, htypes, node +from hamilton import base, htypes, lifecycle, node from hamilton.execution import executors from hamilton.execution.executors import TaskFuture from hamilton.execution.grouping import TaskImplementation @@ -50,7 +52,14 @@ def parse_ray_remote_options_from_tags(tags: typing.Dict[str, str]) -> typing.Di return ray_options -class RayGraphAdapter(base.HamiltonGraphAdapter, base.ResultMixin): +class RayGraphAdapter( + lifecycle.base.BaseDoRemoteExecute, + lifecycle.base.BaseDoBuildResult, + lifecycle.base.BaseDoValidateInput, + lifecycle.base.BaseDoCheckEdgeTypesMatch, + lifecycle.base.BasePostGraphExecute, + abc.ABC, +): """Class representing what's required to make Hamilton run on Ray. This walks the graph and translates it to run onto `Ray `__. @@ -86,13 +95,21 @@ class RayGraphAdapter(base.HamiltonGraphAdapter, base.ResultMixin): DISCLAIMER -- this class is experimental, so signature changes are a possibility! """ - def __init__(self, result_builder: base.ResultMixin): + def __init__( + self, + result_builder: base.ResultMixin, + ray_init_config: typing.Dict[str, typing.Any] = None, + shutdown_ray_on_completion: bool = False, + ): """Constructor You have the ability to pass in a ResultMixin object to the constructor to control the return type that gets \ produce by running on Ray. :param result_builder: Required. An implementation of base.ResultMixin. + :param ray_init_config: allows to connect to an existing cluster or start a new one with custom configuration (https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html) + :param shutdown_ray_on_completion: by default we leave the cluster open, but we can also shut it down + """ self.result_builder = result_builder if not self.result_builder: @@ -100,28 +117,39 @@ def __init__(self, result_builder: base.ResultMixin): "Error: ResultMixin object required. Please pass one in for `result_builder`." ) + self.shutdown_ray_on_completion = shutdown_ray_on_completion + + if ray_init_config is not None: + ray.init(**ray_init_config) + @staticmethod - def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool: + def do_validate_input(node_type: typing.Type, input_value: typing.Any) -> bool: # NOTE: the type of a raylet is unknown until they are computed if isinstance(input_value, ray._raylet.ObjectRef): return True return htypes.check_input_type(node_type, input_value) @staticmethod - def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool: - return node_type == input_type + def do_check_edge_types_match(type_from: typing.Type, type_to: typing.Type) -> bool: + return type_from == type_to - def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any: + def do_remote_execute( + self, + *, + execute_lifecycle_for_node: typing.Callable, + node: node.Node, + **kwargs: typing.Dict[str, typing.Any], + ) -> typing.Any: """Function that is called as we walk the graph to determine how to execute a hamilton function. - :param node: the node from the graph. + :param execute_lifecycle_for_node: wrapper function that executes lifecycle hooks and methods :param kwargs: the arguments that should be passed to it. :return: returns a ray object reference. """ ray_options = parse_ray_remote_options_from_tags(node.tags) - return ray.remote(raify(node.callable)).options(**ray_options).remote(**kwargs) + return ray.remote(raify(execute_lifecycle_for_node)).options(**ray_options).remote(**kwargs) - def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any: + def do_build_result(self, outputs: typing.Dict[str, typing.Any]) -> typing.Any: """Builds the result and brings it back to this running process. :param outputs: the dictionary of key -> Union[ray object reference | value] @@ -135,6 +163,14 @@ def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any: result = ray.get(remote_combine) # this materializes the object locally return result + def post_graph_execute(self, *args, **kwargs): + """We have the option to close the cluster down after execution.""" + + if self.shutdown_ray_on_completion: + # In case we have Hamilton Tracker to have enough time to properly flush + time.sleep(5) + ray.shutdown() + class RayWorkflowGraphAdapter(base.HamiltonGraphAdapter, base.ResultMixin): """Class representing what's required to make Hamilton run Ray Workflows diff --git a/pyproject.toml b/pyproject.toml index 1a61e8aae..e41829d23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dask-distributed = ["dask[distributed]"] datadog = ["ddtrace"] dev = [ "pre-commit", - "ruff", + "ruff==0.5.7", # this should match `.pre-commit-config.yaml` ] diskcache = ["diskcache"] docs = [ diff --git a/tests/lifecycle/lifecycle_adapters_for_testing.py b/tests/lifecycle/lifecycle_adapters_for_testing.py index a151c4ecc..37f3b3961 100644 --- a/tests/lifecycle/lifecycle_adapters_for_testing.py +++ b/tests/lifecycle/lifecycle_adapters_for_testing.py @@ -9,6 +9,7 @@ from hamilton.lifecycle.base import ( BaseDoBuildResult, BaseDoNodeExecute, + BaseDoRemoteExecute, BaseDoValidateInput, BasePostGraphConstruct, BasePostGraphExecute, @@ -187,6 +188,23 @@ def do_node_execute( return node_(**kwargs) +class TrackingDoRemoteExecuteHook(ExtendToTrackCalls, BaseDoRemoteExecute): + def __init__(self, name: str, additional_value: int): + super().__init__(name) + self._additional_value = additional_value + + def do_remote_execute( + self, + node: "node.Node", + execute_lifecycle_for_node: Callable, + **kwargs: Dict[str, Any], + ) -> Any: + node_ = node + if node_.type == int and node_.name != "n_iters": + return execute_lifecycle_for_node(**kwargs) + self._additional_value + return execute_lifecycle_for_node(**kwargs) + + class TrackingDoBuildResultMethod(ExtendToTrackCalls, BaseDoBuildResult): def __init__(self, name: str, result: Any): super().__init__(name) diff --git a/tests/lifecycle/test_lifecycle_adapters_end_to_end.py b/tests/lifecycle/test_lifecycle_adapters_end_to_end.py index cf57c02c0..a753ea874 100644 --- a/tests/lifecycle/test_lifecycle_adapters_end_to_end.py +++ b/tests/lifecycle/test_lifecycle_adapters_end_to_end.py @@ -1,5 +1,5 @@ from types import ModuleType -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import pytest @@ -7,6 +7,7 @@ from hamilton.io.materialization import to from hamilton.lifecycle.base import ( BaseDoNodeExecute, + BaseDoRemoteExecute, BasePostGraphConstruct, BasePostGraphExecute, BasePostNodeExecute, @@ -22,6 +23,7 @@ SentinelException, TrackingDoBuildResultMethod, TrackingDoNodeExecuteHook, + TrackingDoRemoteExecuteHook, TrackingDoValidateInputMethod, TrackingPostGraphConstructHook, TrackingPostGraphExecuteHook, @@ -343,6 +345,78 @@ def post_graph_execute( assert len(calls) == 16 +def test_multi_hook_remote(): + class MultiHook( + BasePreDoAnythingHook, + BasePostGraphConstruct, + BasePreGraphExecute, + BasePreNodeExecute, + BaseDoRemoteExecute, + BasePostNodeExecute, + BasePostGraphExecute, + ExtendToTrackCalls, + ): + def do_remote_execute( + self, + node: node.Node, + execute_lifecycle_for_node: Callable, + **kwargs: Dict[str, Any], + ): + return execute_lifecycle_for_node(**kwargs) + + def pre_do_anything(self): + pass + + def post_graph_construct( + self, graph: "FunctionGraph", modules: List[ModuleType], config: Dict[str, Any] + ): + pass + + def pre_graph_execute( + self, + run_id: str, + graph: "FunctionGraph", + final_vars: List[str], + inputs: Dict[str, Any], + overrides: Dict[str, Any], + ): + pass + + def pre_node_execute( + self, run_id: str, node_: Node, kwargs: Dict[str, Any], task_id: Optional[str] = None + ): + pass + + def post_node_execute( + self, + run_id: str, + node_: node.Node, + kwargs: Dict[str, Any], + success: bool, + error: Optional[Exception], + result: Optional[Any], + task_id: Optional[str] = None, + ): + pass + + def post_graph_execute( + self, + run_id: str, + graph: "FunctionGraph", + success: bool, + error: Optional[Exception], + results: Optional[Dict[str, Any]], + ): + pass + + multi_hook = MultiHook(name="multi_hook") + + dr = _sample_driver(multi_hook) + dr.execute(["d"], inputs={"input": 1}) + calls = multi_hook.calls + assert len(calls) == 16 + + def test_individual_do_validate_input_method(): method_name = "do_validate_input" method = TrackingDoValidateInputMethod(name=method_name, valid=True) @@ -378,6 +452,16 @@ def test_individual_do_node_execute_method(): assert res == {"d": 17**3 + 1} # adding one to each one +def test_individual_do_remote_execute_method(): + method_name = "do_remote_execute" + method = TrackingDoRemoteExecuteHook(name=method_name, additional_value=1) + dr = _sample_driver(method) + res = dr.execute(["d"], inputs={"input": 1}) + relevant_calls = [item for item in method.calls if item.name == method_name] + assert len(relevant_calls) == 4 + assert res == {"d": 17**3 + 1} # adding one to each one + + def test_individual_do_build_results_method(): method_name = "do_build_result" method = TrackingDoBuildResultMethod(name=method_name, result=-1) diff --git a/ui/sdk/.pre-commit-config.yaml b/ui/sdk/.pre-commit-config.yaml index 75a5287c3..6cddba381 100644 --- a/ui/sdk/.pre-commit-config.yaml +++ b/ui/sdk/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: rev: v0.0.265 hooks: - id: ruff - args: [ --fix, --exit-non-zero-on-fix ] + args: [ --fix , --exit-non-zero-on-fix ] - repo: https://github.com/ambv/black rev: 23.3.0 hooks: diff --git a/ui/sdk/src/hamilton_sdk/adapters.py b/ui/sdk/src/hamilton_sdk/adapters.py index 995143e64..23cbb3a2c 100644 --- a/ui/sdk/src/hamilton_sdk/adapters.py +++ b/ui/sdk/src/hamilton_sdk/adapters.py @@ -103,6 +103,10 @@ def __init__( # if you're using a float value. self.seed = None + def stop(self): + """Initiates stop if run in remote environment""" + self.client.stop() + def post_graph_construct( self, graph: h_graph.FunctionGraph, modules: List[ModuleType], config: Dict[str, Any] ): diff --git a/ui/sdk/src/hamilton_sdk/api/clients.py b/ui/sdk/src/hamilton_sdk/api/clients.py index 87e687027..551216495 100644 --- a/ui/sdk/src/hamilton_sdk/api/clients.py +++ b/ui/sdk/src/hamilton_sdk/api/clients.py @@ -1,5 +1,6 @@ import abc import asyncio +import atexit import datetime import functools import logging @@ -196,6 +197,9 @@ def __init__( ).start() self.worker_thread.start() + # Makes sure the process stops even if in remote environment + atexit.register(self.stop) + def __getstate__(self): # Copy the object's state from self.__dict__ which contains # all our instance attributes. Always use the dict.copy() @@ -217,6 +221,7 @@ def __setstate__(self, state): target=lambda: threading.main_thread().join() or self.data_queue.put(None) ).start() self.worker_thread.start() + atexit.register(self.stop) def worker(self): """Worker thread to process the queue."""