|
6 | 6 | import thunder
|
7 | 7 | from thunder.core.trace import TraceCtx
|
8 | 8 | from thunder.core.transforms import bsym_list_to_dag, Node
|
9 |
| -from thunder.core.proxies import TensorProxy |
| 9 | +from thunder.core.proxies import TensorProxy, CollectionProxy |
10 | 10 | from thunder.core.symbol import BoundSymbol
|
11 | 11 | from thunder.torch import _torch_to_thunder_function_map
|
12 | 12 | from thunder.torch.default_torch_ops import torch_auto_registered_ops
|
@@ -291,10 +291,14 @@ def resize_graph(dot, size_per_element=0.15, min_size=12):
|
291 | 291 | dot.graph_attr.update(size=size_str)
|
292 | 292 |
|
293 | 293 |
|
294 |
| -def _repr_tensor_proxy(t_proxy, show_metadata=False): |
295 |
| - assert isinstance(t_proxy, TensorProxy) |
296 |
| - extra_meta = f"\n shape:{t_proxy.shape} \n dtype:{t_proxy.dtype}" if show_metadata else "" |
297 |
| - return f"name:{t_proxy.name}" + extra_meta |
| 294 | +def _repr_proxy(t_proxy, show_metadata=False): |
| 295 | + if isinstance(t_proxy, TensorProxy): |
| 296 | + # Should we just delegate to TensorProxy.__repr__ ? |
| 297 | + extra_meta = f"\n shape:{t_proxy.shape} \n dtype:{t_proxy.dtype}" if show_metadata else "" |
| 298 | + return f"name:{t_proxy.name}" + extra_meta |
| 299 | + |
| 300 | + # For any other proxy, we just print the name. |
| 301 | + return f"name:{t_proxy.name}" |
298 | 302 |
|
299 | 303 |
|
300 | 304 | def make_trace_dot(trace: TraceCtx, show_metadata=False):
|
@@ -356,18 +360,24 @@ def _get_color(node_id):
|
356 | 360 | node_id = id(node)
|
357 | 361 | visited.add(node_id)
|
358 | 362 | color = _get_color(node_id)
|
359 |
| - dot.node(str(node_id), node.bsym.python(indent=0, print_depth=1)[0], fillcolor=color) |
| 363 | + |
| 364 | + # Unpacking collection might be a multi-line. |
| 365 | + node_repr = "\n".join(node.bsym.python(indent=0, print_depth=1)) |
| 366 | + node_repr = node_repr.replace("\\", "") |
| 367 | + dot.node(str(node_id), node_repr, fillcolor=color) |
360 | 368 |
|
361 | 369 | # Add node for args and connect args
|
362 | 370 | for arg in node.bsym.flat_args:
|
363 |
| - if isinstance(arg, TensorProxy): |
| 371 | + # We have collection proxies in backward |
| 372 | + if isinstance(arg, (TensorProxy, CollectionProxy)): |
364 | 373 | arg_id = arg.name
|
365 |
| - dot.node(arg_id, _repr_tensor_proxy(arg, show_metadata)) |
| 374 | + dot.node(arg_id, _repr_proxy(arg, show_metadata)) |
366 | 375 | dot.edge(arg_id, str(node_id))
|
367 | 376 |
|
368 | 377 | # Connect outputs
|
369 | 378 | for out in node.bsym.flat_outs:
|
370 |
| - if isinstance(out, TensorProxy): |
| 379 | + # We have collection proxies in backward |
| 380 | + if isinstance(out, (TensorProxy, CollectionProxy)): |
371 | 381 | out_id = out.name
|
372 | 382 | dot.edge(str(node_id), out_id)
|
373 | 383 |
|
|
0 commit comments