Skip to content

Commit a7c6e63

Browse files
committed
update for backward
1 parent 586d1aa commit a7c6e63

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

thunder/examine/__init__.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import thunder
77
from thunder.core.trace import TraceCtx
88
from thunder.core.transforms import bsym_list_to_dag, Node
9-
from thunder.core.proxies import TensorProxy
9+
from thunder.core.proxies import TensorProxy, CollectionProxy
1010
from thunder.core.symbol import BoundSymbol
1111
from thunder.torch import _torch_to_thunder_function_map
1212
from thunder.torch.default_torch_ops import torch_auto_registered_ops
@@ -291,10 +291,14 @@ def resize_graph(dot, size_per_element=0.15, min_size=12):
291291
dot.graph_attr.update(size=size_str)
292292

293293

294-
def _repr_tensor_proxy(t_proxy, show_metadata=False):
295-
assert isinstance(t_proxy, TensorProxy)
296-
extra_meta = f"\n shape:{t_proxy.shape} \n dtype:{t_proxy.dtype}" if show_metadata else ""
297-
return f"name:{t_proxy.name}" + extra_meta
294+
def _repr_proxy(t_proxy, show_metadata=False):
295+
if isinstance(t_proxy, TensorProxy):
296+
# Should we just delegate to TensorProxy.__repr__ ?
297+
extra_meta = f"\n shape:{t_proxy.shape} \n dtype:{t_proxy.dtype}" if show_metadata else ""
298+
return f"name:{t_proxy.name}" + extra_meta
299+
300+
# For any other proxy, we just print the name.
301+
return f"name:{t_proxy.name}"
298302

299303

300304
def make_trace_dot(trace: TraceCtx, show_metadata=False):
@@ -356,18 +360,24 @@ def _get_color(node_id):
356360
node_id = id(node)
357361
visited.add(node_id)
358362
color = _get_color(node_id)
359-
dot.node(str(node_id), node.bsym.python(indent=0, print_depth=1)[0], fillcolor=color)
363+
364+
# Unpacking collection might be a multi-line.
365+
node_repr = "\n".join(node.bsym.python(indent=0, print_depth=1))
366+
node_repr = node_repr.replace("\\", "")
367+
dot.node(str(node_id), node_repr, fillcolor=color)
360368

361369
# Add node for args and connect args
362370
for arg in node.bsym.flat_args:
363-
if isinstance(arg, TensorProxy):
371+
# We have collection proxies in backward
372+
if isinstance(arg, (TensorProxy, CollectionProxy)):
364373
arg_id = arg.name
365-
dot.node(arg_id, _repr_tensor_proxy(arg, show_metadata))
374+
dot.node(arg_id, _repr_proxy(arg, show_metadata))
366375
dot.edge(arg_id, str(node_id))
367376

368377
# Connect outputs
369378
for out in node.bsym.flat_outs:
370-
if isinstance(out, TensorProxy):
379+
# We have collection proxies in backward
380+
if isinstance(out, (TensorProxy, CollectionProxy)):
371381
out_id = out.name
372382
dot.edge(str(node_id), out_id)
373383

0 commit comments

Comments
 (0)