diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index c1d8b18855..edede5335e 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -21,6 +21,7 @@ import thunder.core.dtypes as dtypes import thunder.core.devices as devices from thunder.core.proxies import Proxy, NumberProxy, variableify, CollectionProxy +from thunder.core.utils import FrozenDict from thunder.core.trace import ( get_tracectx, @@ -647,36 +648,37 @@ def has_tags(bsym: BoundSymbol, tags: set[OpTags]) -> bool: return not tags.isdisjoint(gather_tags(bsym)) -# NOTE: A wrapper class that hashes and equates only the right hand side of a BoundSymbol. +# Wrapper class that hashes and equates only the right hand side of a BoundSymbol for CSE. # That is to say, its symbol, args, and kwargs, but not its output. -# The intent is that this will be useful in writing a common subexpression elimination pass, beacuse -# it will allow dictionary lookups to find equivalent BoundSymbols. -@dataclass(**baseutils.default_dataclass_params) +@dataclass class BoundSymbolRHS: parent: BoundSymbol - _hash: int | None = None + _frozen_kwargs: FrozenDict - def _do_hash(self) -> int: - if self.parent.kwargs and len(self.parent.kwargs) > 0: + def __init__(self, parent: BoundSymbol) -> None: + self.parent = parent + self._frozen_kwargs = FrozenDict(parent._var_kwargs) + + @functools.cached_property + def _hash(self) -> int: + # TODO: Find a better way to identify inputs by id instead of hash. + if self.parent.sym.name == "unpack_trivial": return id(self) try: - return hash((self.parent.sym, self.parent._var_args)) + return hash((self.parent.sym, self.parent._var_args, self._frozen_kwargs)) except: return id(self) def __hash__(self) -> int: - if not self._hash: - h = self._do_hash() - object.__setattr__(self, "_hash", h) - return h return self._hash - # TODO: Deal with kwargs, in __eq__ and __hash__, just like with BoundSymbol. def __eq__(self, other: BoundSymbolRHS) -> bool: if not isinstance(other, BoundSymbolRHS): return False if self.parent is other.parent: return True - if len(self.parent.kwargs) > 0 or len(other.parent.kwargs) > 0: - return False - return (self.parent.sym, self.parent._var_args) == (other.parent.sym, other.parent._var_args) + return (self.parent.sym, self.parent._var_args, self._frozen_kwargs) == ( + other.parent.sym, + other.parent._var_args, + other._frozen_kwargs, + ) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index ac25df8e2c..06d508ccae 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -1312,7 +1312,6 @@ def mul_rhs(a, b): all_eq([hash(b.rhs()) for b in bsyms]) all_eq([b.rhs() for b in bsyms]) - # TODO Update needed here # The current way BoundSymbols are compared treats args and kwargs the same, # so the same semantic call can be considered 'equal' if the arguments are # passed differently. @@ -1321,8 +1320,6 @@ def mul_rhs_kwargs(a, b): d = ltorch.mul(a, b) return c, d - # Assert the current behavior. - # When the test case is supported, switch this to all_eq. bsyms = extract_bsyms(mul_rhs_kwargs, (a, b), ("mul",)) all_eq([hash(b.rhs()) for b in bsyms]) all_eq([b.rhs() for b in bsyms]) @@ -1331,25 +1328,21 @@ def mul_rhs_kwargs(a, b): all_eq([b.sym for b in bsyms]) all_eq([hash(b.sym) for b in bsyms]) - # TODO: We also currently cannot assert that the right hand side of - # identical operators with kwargs are equal. + # Assert that rhs of identical operators with same kwargs are equal. def same_kwargs(device, dtype): a = ltorch.full((2, 2), 5, device=device, dtype=dtype) b = ltorch.full((2, 2), 5, device=device, dtype=dtype) return a + b - # Assert the current behavior. - # When the test case is supported, switch the all_neq below to all_eq. bsyms = extract_bsyms(same_kwargs, (device, dtype), ("full",)) all_eq([hash(b.rhs()) for b in bsyms]) - all_neq([b.rhs() for b in bsyms]) + all_eq([b.rhs() for b in bsyms]) - # Again, the symbols should be the same. + # The symbols should be the same. all_eq([b.sym for b in bsyms]) all_eq([hash(b.sym) for b in bsyms]) - # We can, however, know when the number of kwargs are different, - # or the args are different. + # Assert that the kwargs are different and hash differently. def diff_kwargs(device, dtype): a = ltorch.full((1, 2), 2, device=device, dtype=dtype) b = ltorch.full((2, 3), 5, device=device, dtype=dtype) @@ -1357,7 +1350,7 @@ def diff_kwargs(device, dtype): return a, b, c bsyms = extract_bsyms(diff_kwargs, (device, dtype), ("full",)) - all_eq([hash(b.rhs()) for b in bsyms]) + all_neq([hash(b.rhs()) for b in bsyms]) all_neq([b.rhs() for b in bsyms]) # Assert that boundsymbols for different ops hash/compare differently.