Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple accesses for non-cached property fails in Thunder JIT #729

Open
IvanYashchuk opened this issue Jul 8, 2024 · 0 comments
Open
Labels

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Jul 8, 2024

🐛 Bug

It seems that Thunder JIT currently assumes that attribute accesses always return the same object. It's true only for @functools.cached_property, if a class method is decorated with @property it should be treated as a method call.

import torch
import thunder

class Test(torch.nn.Module):
    @property
    def test(self):
        return object()

    def forward(self):
        return self.test, self.test

jtest = thunder.jit(Test())
jtest()

traceback:

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6164, in _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)
   6162 if lookaside_fn:
   6163     runtimectx.record_lookaside(lookaside_fn)
-> 6164     res = lookaside_fn(*args, **kwargs)
   6165     return res
   6167 # TODO: disabled as partial is just like any other class
   6168 # (3) Handles partial objects

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:732, in _general_jit_getattr_lookaside(obj, name, *maybe_default)
    729 getattr_lookaside = default_lookaside(getattr)
    730 assert getattr_lookaside is not None
--> 732 value = getattr_lookaside(obj, name, *maybe_default)
    733 if value is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
    734     return value

File ~/dev/lightning-thunder/thunder/core/interpreter.py:1678, in _getattr_lookaside(obj, name, *maybe_default)
   1676 if result is not INTERPRETER_SIGNALS.EXCEPTION_RAISED or not isinstance(ctx.curexc, AttributeError):
   1677     if result is not INTERPRETER_SIGNALS.EXCEPTION_RAISED and compilectx._with_provenance_tracking:
-> 1678         result = wrap_attribute(result, obj, name)
   1679     return result
   1681 # `__getattr__` is only triggered if `__getattribute__` fails.
   1682 # TODO: this should be `_interpret_call_with_unwrapping(getattr, obj, "__getattr__", null := object())`, but that would require multiple current exceptions.

File ~/dev/lightning-thunder/thunder/core/interpreter.py:1632, in wrap_attribute(plain_result, obj, name)
   1629 # note: there are cases where "is" will always fail (e.g. BuiltinMethods
   1630 #       are recreated every time)
   1631 if known_wrapper is not None:
-> 1632     assert plausibly_wrapper_of(
   1633         known_wrapper, plain_result
   1634     ), f"attribute {name.value} of {type(obj.value).__name__} object out of sync: {known_wrapper.value} vs. {plain_result}"
   1635     return known_wrapper
   1637 pr = ProvenanceRecord(PseudoInst.LOAD_ATTR, inputs=[obj.provenance, name.provenance])

AssertionError: attribute test of Test object out of sync: <object object at 0x7f412657f850> vs. <object object at 0x7f412657f960>

cc @apaz-cli

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant