Skip to content

Commit

Permalink
Handles partials in source code hash (#1116)
Browse files Browse the repository at this point in the history
* Handles partials in source code hash

This is a stop gap measure to handle
partials for the CacheAdapter.

I put the change here rather than in the source hash function,
since for now it appears that this behavior is specific to the cache
adapter..

* Adds unit tests

* Check for partial explicitly

* Update test doc strings
  • Loading branch information
skrawcz authored Sep 4, 2024
1 parent d90212f commit 6f2376d
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 9 deletions.
8 changes: 6 additions & 2 deletions hamilton/lifecycle/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import random
import shelve
import time
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Type, Union

from hamilton import graph_types, htypes
Expand Down Expand Up @@ -359,7 +360,7 @@ def __init__(
def run_before_graph_execution(self, *, graph: HamiltonGraph, **kwargs):
"""Set `cache_vars` to all nodes if received None during `__init__`"""
self.cache = shelve.open(self.cache_path)
if self.cache_vars == []:
if len(self.cache_vars) == 0:
self.cache_vars = [n.name for n in graph.nodes]

def run_to_execute_node(
Expand All @@ -376,7 +377,10 @@ def run_to_execute_node(
if node_name not in self.cache_vars:
return node_callable(**node_kwargs)

node_hash = graph_types.hash_source_code(node_callable, strip=True)
source_of_node_callable = node_callable
while isinstance(source_of_node_callable, partial): # handle partials
source_of_node_callable = source_of_node_callable.func
node_hash = graph_types.hash_source_code(source_of_node_callable, strip=True)
cache_key = CacheAdapter.create_key(node_hash, node_kwargs)

from_cache = self.cache.get(cache_key, None)
Expand Down
87 changes: 80 additions & 7 deletions tests/lifecycle/test_cache_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import inspect
import functools
import pathlib
import shelve

Expand All @@ -8,12 +8,8 @@
from hamilton.lifecycle.default import CacheAdapter


def _callable_to_node(callable) -> node.Node:
return node.Node(
name=callable.__name__,
typ=inspect.signature(callable).return_annotation,
callabl=callable,
)
def _callable_to_node(callable, name=None) -> node.Node:
return node.Node.from_fn(callable, name)


@pytest.fixture()
Expand Down Expand Up @@ -52,6 +48,37 @@ def A(external_input: int) -> int:
return _callable_to_node(A)


@pytest.fixture()
def node_a_partial():
"""The function A() is a partial"""

def A(external_input: int, remainder: int) -> int:
return external_input % remainder

base_node: node.Node = _callable_to_node(A)

A = functools.partial(A, remainder=7)
base_node._callable = A
del base_node.input_types["remainder"]
return base_node


@pytest.fixture()
def node_a_nested_partial():
"""The function A() is a partial"""

def A(external_input: int, remainder: int, extra: int) -> int:
return external_input % remainder

base_node: node.Node = _callable_to_node(A)
A = functools.partial(A, remainder=7)
A = functools.partial(A, extra=7)
base_node._callable = A
del base_node.input_types["remainder"]
del base_node.input_types["extra"]
return base_node


def test_set_result(hook: CacheAdapter, node_a: node.Node):
"""Hook sets value and assert value in cache"""
node_hash = graph_types.hash_source_code(node_a.callable, strip=True)
Expand Down Expand Up @@ -138,3 +165,49 @@ def test_commit_nodes_history(hook: CacheAdapter):
# need to reopen the hook cache
with shelve.open(hook.cache_path) as cache:
assert cache.get(CacheAdapter.nodes_history_key) == hook.nodes_history


def test_partial_handling(hook: CacheAdapter, node_a_partial: node.Node):
"""Tests partial functions are handled properly"""
hook.cache_vars = [node_a_partial.name]
hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache
node_kwargs = dict(external_input=7)
result = hook.run_to_execute_node(
node_name=node_a_partial.name,
node_kwargs=node_kwargs,
node_callable=node_a_partial.callable,
)
hook.run_after_node_execution(
node_name=node_a_partial.name,
node_kwargs=node_kwargs,
result=result,
)
result2 = hook.run_to_execute_node(
node_name=node_a_partial.name,
node_kwargs=node_kwargs,
node_callable=node_a_partial.callable,
)
assert result2 == result


def test_nested_partial_handling(hook: CacheAdapter, node_a_nested_partial: node.Node):
"""Tests nested partial functions are handled properly"""
hook.cache_vars = [node_a_nested_partial.name]
hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache
node_kwargs = dict(external_input=7)
result = hook.run_to_execute_node(
node_name=node_a_nested_partial.name,
node_kwargs=node_kwargs,
node_callable=node_a_nested_partial.callable,
)
hook.run_after_node_execution(
node_name=node_a_nested_partial.name,
node_kwargs=node_kwargs,
result=result,
)
result2 = hook.run_to_execute_node(
node_name=node_a_nested_partial.name,
node_kwargs=node_kwargs,
node_callable=node_a_nested_partial.callable,
)
assert result2 == result

0 comments on commit 6f2376d

Please sign in to comment.