From 6f2376dc924003420269c49e376be31932fbec40 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Tue, 3 Sep 2024 17:55:20 -0700 Subject: [PATCH] Handles partials in source code hash (#1116) * Handles partials in source code hash This is a stop gap measure to handle partials for the CacheAdapter. I put the change here rather than in the source hash function, since for now it appears that this behavior is specific to the cache adapter.. * Adds unit tests * Check for partial explicitly * Update test doc strings --- hamilton/lifecycle/default.py | 8 ++- tests/lifecycle/test_cache_adapter.py | 87 ++++++++++++++++++++++++--- 2 files changed, 86 insertions(+), 9 deletions(-) diff --git a/hamilton/lifecycle/default.py b/hamilton/lifecycle/default.py index 333bc4e41..3966ceca4 100644 --- a/hamilton/lifecycle/default.py +++ b/hamilton/lifecycle/default.py @@ -8,6 +8,7 @@ import random import shelve import time +from functools import partial from typing import Any, Callable, Dict, List, Optional, Type, Union from hamilton import graph_types, htypes @@ -359,7 +360,7 @@ def __init__( def run_before_graph_execution(self, *, graph: HamiltonGraph, **kwargs): """Set `cache_vars` to all nodes if received None during `__init__`""" self.cache = shelve.open(self.cache_path) - if self.cache_vars == []: + if len(self.cache_vars) == 0: self.cache_vars = [n.name for n in graph.nodes] def run_to_execute_node( @@ -376,7 +377,10 @@ def run_to_execute_node( if node_name not in self.cache_vars: return node_callable(**node_kwargs) - node_hash = graph_types.hash_source_code(node_callable, strip=True) + source_of_node_callable = node_callable + while isinstance(source_of_node_callable, partial): # handle partials + source_of_node_callable = source_of_node_callable.func + node_hash = graph_types.hash_source_code(source_of_node_callable, strip=True) cache_key = CacheAdapter.create_key(node_hash, node_kwargs) from_cache = self.cache.get(cache_key, None) diff --git a/tests/lifecycle/test_cache_adapter.py b/tests/lifecycle/test_cache_adapter.py index a55cb0c45..d96ad7f65 100644 --- a/tests/lifecycle/test_cache_adapter.py +++ b/tests/lifecycle/test_cache_adapter.py @@ -1,4 +1,4 @@ -import inspect +import functools import pathlib import shelve @@ -8,12 +8,8 @@ from hamilton.lifecycle.default import CacheAdapter -def _callable_to_node(callable) -> node.Node: - return node.Node( - name=callable.__name__, - typ=inspect.signature(callable).return_annotation, - callabl=callable, - ) +def _callable_to_node(callable, name=None) -> node.Node: + return node.Node.from_fn(callable, name) @pytest.fixture() @@ -52,6 +48,37 @@ def A(external_input: int) -> int: return _callable_to_node(A) +@pytest.fixture() +def node_a_partial(): + """The function A() is a partial""" + + def A(external_input: int, remainder: int) -> int: + return external_input % remainder + + base_node: node.Node = _callable_to_node(A) + + A = functools.partial(A, remainder=7) + base_node._callable = A + del base_node.input_types["remainder"] + return base_node + + +@pytest.fixture() +def node_a_nested_partial(): + """The function A() is a partial""" + + def A(external_input: int, remainder: int, extra: int) -> int: + return external_input % remainder + + base_node: node.Node = _callable_to_node(A) + A = functools.partial(A, remainder=7) + A = functools.partial(A, extra=7) + base_node._callable = A + del base_node.input_types["remainder"] + del base_node.input_types["extra"] + return base_node + + def test_set_result(hook: CacheAdapter, node_a: node.Node): """Hook sets value and assert value in cache""" node_hash = graph_types.hash_source_code(node_a.callable, strip=True) @@ -138,3 +165,49 @@ def test_commit_nodes_history(hook: CacheAdapter): # need to reopen the hook cache with shelve.open(hook.cache_path) as cache: assert cache.get(CacheAdapter.nodes_history_key) == hook.nodes_history + + +def test_partial_handling(hook: CacheAdapter, node_a_partial: node.Node): + """Tests partial functions are handled properly""" + hook.cache_vars = [node_a_partial.name] + hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache + node_kwargs = dict(external_input=7) + result = hook.run_to_execute_node( + node_name=node_a_partial.name, + node_kwargs=node_kwargs, + node_callable=node_a_partial.callable, + ) + hook.run_after_node_execution( + node_name=node_a_partial.name, + node_kwargs=node_kwargs, + result=result, + ) + result2 = hook.run_to_execute_node( + node_name=node_a_partial.name, + node_kwargs=node_kwargs, + node_callable=node_a_partial.callable, + ) + assert result2 == result + + +def test_nested_partial_handling(hook: CacheAdapter, node_a_nested_partial: node.Node): + """Tests nested partial functions are handled properly""" + hook.cache_vars = [node_a_nested_partial.name] + hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache + node_kwargs = dict(external_input=7) + result = hook.run_to_execute_node( + node_name=node_a_nested_partial.name, + node_kwargs=node_kwargs, + node_callable=node_a_nested_partial.callable, + ) + hook.run_after_node_execution( + node_name=node_a_nested_partial.name, + node_kwargs=node_kwargs, + result=result, + ) + result2 = hook.run_to_execute_node( + node_name=node_a_nested_partial.name, + node_kwargs=node_kwargs, + node_callable=node_a_nested_partial.callable, + ) + assert result2 == result