Skip to content

Commit

Permalink
explanation
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 5, 2024
1 parent 1d85020 commit 21c2af8
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1937,8 +1937,19 @@ def __init__(self, *args, **kwargs):
self._non_tensors,
output=self,
)
# NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)`
# inside of it either directly or indirectly: indirect way is to call it through
# a custom `torch.autograd.Function` as in
# https://github.com/pytorch/ao/blob/000a490/torchao/float8/float8_tensor.py#L139-L209.
# If it's a direct call, `trace.bound_symbols` and `trace.scopes[-1]` are identical,
# but not, otherwise. As [the lookasdie of `torch.autograd.Function`](
# https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655)
# puts the temporary scope to the current trace.
current_trace = get_tracectx()
current_trace.scopes[-1].append(bsym)
if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]):
current_trace.add_bound_symbol(bsym)
else:
cur_tail_scope.append(bsym)

def replace(self, **changes):
r"""Return a copy of the SubclassTensorProxy object with new values for the specified fields as given to the constructor as arguments.
Expand Down

0 comments on commit 21c2af8

Please sign in to comment.