diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index eb605dc2e8..8cf179e263 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1937,8 +1937,19 @@ def __init__(self, *args, **kwargs): self._non_tensors, output=self, ) + # NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)` + # inside of it either directly or indirectly: indirect way is to call it through + # a custom `torch.autograd.Function` as in + # https://github.com/pytorch/ao/blob/000a490/torchao/float8/float8_tensor.py#L139-L209. + # If it's a direct call, `trace.bound_symbols` and `trace.scopes[-1]` are identical, + # but not, otherwise. As [the lookasdie of `torch.autograd.Function`]( + # https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655) + # puts the temporary scope to the current trace. current_trace = get_tracectx() - current_trace.scopes[-1].append(bsym) + if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]): + current_trace.add_bound_symbol(bsym) + else: + cur_tail_scope.append(bsym) def replace(self, **changes): r"""Return a copy of the SubclassTensorProxy object with new values for the specified fields as given to the constructor as arguments.