Skip to content

Commit

Permalink
Added kwargs to bound symbol rhs hash (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardofelluga authored Jun 5, 2024
1 parent 4a7097d commit fbaa3a8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 28 deletions.
34 changes: 18 additions & 16 deletions thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.proxies import Proxy, NumberProxy, variableify, CollectionProxy
from thunder.core.utils import FrozenDict

from thunder.core.trace import (
get_tracectx,
Expand Down Expand Up @@ -647,36 +648,37 @@ def has_tags(bsym: BoundSymbol, tags: set[OpTags]) -> bool:
return not tags.isdisjoint(gather_tags(bsym))


# NOTE: A wrapper class that hashes and equates only the right hand side of a BoundSymbol.
# Wrapper class that hashes and equates only the right hand side of a BoundSymbol for CSE.
# That is to say, its symbol, args, and kwargs, but not its output.
# The intent is that this will be useful in writing a common subexpression elimination pass, beacuse
# it will allow dictionary lookups to find equivalent BoundSymbols.
@dataclass(**baseutils.default_dataclass_params)
@dataclass
class BoundSymbolRHS:
parent: BoundSymbol
_hash: int | None = None
_frozen_kwargs: FrozenDict

def _do_hash(self) -> int:
if self.parent.kwargs and len(self.parent.kwargs) > 0:
def __init__(self, parent: BoundSymbol) -> None:
self.parent = parent
self._frozen_kwargs = FrozenDict(parent._var_kwargs)

@functools.cached_property
def _hash(self) -> int:
# TODO: Find a better way to identify inputs by id instead of hash.
if self.parent.sym.name == "unpack_trivial":
return id(self)
try:
return hash((self.parent.sym, self.parent._var_args))
return hash((self.parent.sym, self.parent._var_args, self._frozen_kwargs))
except:
return id(self)

def __hash__(self) -> int:
if not self._hash:
h = self._do_hash()
object.__setattr__(self, "_hash", h)
return h
return self._hash

# TODO: Deal with kwargs, in __eq__ and __hash__, just like with BoundSymbol.
def __eq__(self, other: BoundSymbolRHS) -> bool:
if not isinstance(other, BoundSymbolRHS):
return False
if self.parent is other.parent:
return True
if len(self.parent.kwargs) > 0 or len(other.parent.kwargs) > 0:
return False
return (self.parent.sym, self.parent._var_args) == (other.parent.sym, other.parent._var_args)
return (self.parent.sym, self.parent._var_args, self._frozen_kwargs) == (
other.parent.sym,
other.parent._var_args,
other._frozen_kwargs,
)
17 changes: 5 additions & 12 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,6 @@ def mul_rhs(a, b):
all_eq([hash(b.rhs()) for b in bsyms])
all_eq([b.rhs() for b in bsyms])

# TODO Update needed here
# The current way BoundSymbols are compared treats args and kwargs the same,
# so the same semantic call can be considered 'equal' if the arguments are
# passed differently.
Expand All @@ -1321,8 +1320,6 @@ def mul_rhs_kwargs(a, b):
d = ltorch.mul(a, b)
return c, d

# Assert the current behavior.
# When the test case is supported, switch this to all_eq.
bsyms = extract_bsyms(mul_rhs_kwargs, (a, b), ("mul",))
all_eq([hash(b.rhs()) for b in bsyms])
all_eq([b.rhs() for b in bsyms])
Expand All @@ -1331,33 +1328,29 @@ def mul_rhs_kwargs(a, b):
all_eq([b.sym for b in bsyms])
all_eq([hash(b.sym) for b in bsyms])

# TODO: We also currently cannot assert that the right hand side of
# identical operators with kwargs are equal.
# Assert that rhs of identical operators with same kwargs are equal.
def same_kwargs(device, dtype):
a = ltorch.full((2, 2), 5, device=device, dtype=dtype)
b = ltorch.full((2, 2), 5, device=device, dtype=dtype)
return a + b

# Assert the current behavior.
# When the test case is supported, switch the all_neq below to all_eq.
bsyms = extract_bsyms(same_kwargs, (device, dtype), ("full",))
all_eq([hash(b.rhs()) for b in bsyms])
all_neq([b.rhs() for b in bsyms])
all_eq([b.rhs() for b in bsyms])

# Again, the symbols should be the same.
# The symbols should be the same.
all_eq([b.sym for b in bsyms])
all_eq([hash(b.sym) for b in bsyms])

# We can, however, know when the number of kwargs are different,
# or the args are different.
# Assert that the kwargs are different and hash differently.
def diff_kwargs(device, dtype):
a = ltorch.full((1, 2), 2, device=device, dtype=dtype)
b = ltorch.full((2, 3), 5, device=device, dtype=dtype)
c = ltorch.full((2, 3), 5, device=device)
return a, b, c

bsyms = extract_bsyms(diff_kwargs, (device, dtype), ("full",))
all_eq([hash(b.rhs()) for b in bsyms])
all_neq([hash(b.rhs()) for b in bsyms])
all_neq([b.rhs() for b in bsyms])

# Assert that boundsymbols for different ops hash/compare differently.
Expand Down

0 comments on commit fbaa3a8

Please sign in to comment.