diff --git a/thunder/benchmarks/__init__.py b/thunder/benchmarks/__init__.py index e1b33a6d0f..338b31429f 100644 --- a/thunder/benchmarks/__init__.py +++ b/thunder/benchmarks/__init__.py @@ -18,6 +18,7 @@ from torch.testing import make_tensor import thunder +import thunder.dynamo import thunder.core.devices as Devices import thunder.core.dtypes as dtypes import thunder.executors as executors @@ -707,6 +708,13 @@ def torch_compile_executor(fn: Callable) -> Callable: return torch.compile(fn) +def thunderfx_executor(fn: Callable) -> Callable: + torch.backends.cuda.matmul.allow_tf32 = True + backend = thunder.dynamo.ThunderCompiler() + torch._dynamo.reset() + return torch.compile(fn, backend=backend) + + def thunder_torch_executor(fn: Callable) -> Callable: torch.backends.cuda.matmul.allow_tf32 = True return thunder.jit(fn, executors=[thunder.pytorch_executor]) diff --git a/thunder/benchmarks/targets.py b/thunder/benchmarks/targets.py index 0828779766..36f09ff060 100644 --- a/thunder/benchmarks/targets.py +++ b/thunder/benchmarks/targets.py @@ -33,6 +33,7 @@ thunder_cudnn_executor, thunder_cudnn_nvfuser_executor, thunder_executor, + thunderfx_executor, thunder_sdpa_torch_compile_nvfuser_executor, torch_compile_executor, torch_executor, @@ -152,17 +153,19 @@ def interpreter_fwd(module: Callable): return fn_ -executors = ( - torch_executor, - torch_compile_executor, - thunder_executor, -) +executors = (torch_executor, torch_compile_executor, thunder_executor) executors_ids = ( "torch", "torch.compile", "thunder", ) +torchbench_executors = (*executors, thunderfx_executor) +torchbench_executors_ids = ( + *executors_ids, + "thunderfx", +) + apex_executors = (thunder_apex_executor, thunder_apex_nvfuser_executor) apex_executors_ids = ("thunder+apex-grad", "thunder+apex+nvfuser-grad") @@ -842,8 +845,8 @@ def test_resnet50(benchmark, executor: Callable, compute_type: ComputeType): ) @pytest.mark.parametrize( "executor,", - executors, - ids=executors_ids, + torchbench_executors, + ids=torchbench_executors_ids, ) @parametrize_compute_type def test_torchbench(benchmark, module_name, executor, compute_type: ComputeType): @@ -868,8 +871,8 @@ def test_torchbench(benchmark, module_name, executor, compute_type: ComputeType) ) @pytest.mark.parametrize( "executor,", - executors, - ids=executors_ids, + torchbench_executors, + ids=torchbench_executors_ids, ) @parametrize_compute_type def test_torchbench_canary(benchmark, module_name, executor, compute_type: ComputeType): diff --git a/thunder/examine/__init__.py b/thunder/examine/__init__.py index ad284e7526..ebc197abae 100644 --- a/thunder/examine/__init__.py +++ b/thunder/examine/__init__.py @@ -5,6 +5,7 @@ import thunder from thunder.core.trace import TraceCtx +from thunder.core.transforms import bsym_list_to_dag, Node from thunder.core.proxies import TensorProxy from thunder.core.symbol import BoundSymbol from thunder.torch import _torch_to_thunder_function_map @@ -13,6 +14,7 @@ import torch from warnings import warn from itertools import chain +import importlib # TODO Maybe make collect_into a set? @@ -272,3 +274,59 @@ def get_nvfuser_repro(trace: TraceCtx, fusion_name: str, /) -> str: raise RuntimeError("The installed version of nvFuser does not support repro generation unless on crash.") return get_repro(fusion.last_inputs) + + +def make_trace_dot(trace: TraceCtx): + """ + Creates a directed graph of the given trace. + + This function is intended to be used to use graphviz to visualize the computation graph of a trace. + Beware, rendering out a graph for large traces might take a while. + + Requires graphviz to be installed, for more information check out -> https://graphviz.readthedocs.io/en/stable/index.html + + Args: + trace (TraceCtx): The Thunder trace to be made into a graph. + + Returns: + graphviz.Digraph: A graphviz directed graph. + """ + if not importlib.util.find_spec("graphviz"): + warn("graphviz is not available. Graph cannot be created.") + return + + import graphviz + + node_attr = dict( + style="filled", shape="box", align="left", fontsize="10", ranksep="0.1", height="0.2", fontname="monospace" + ) + dot = graphviz.Digraph( + node_attr=node_attr, + graph_attr=dict(size="10,10"), + ) + dot.strict = True + + roots, leaves = bsym_list_to_dag(trace.bound_symbols) + leaves_id = {id(leaf) for leaf in leaves} + stack = [*roots] + visited = set() + while stack: + node: Node = stack.pop() + node_id = id(node) + visited.add(node_id) + dot.node(str(node_id), node.bsym.sym.name, fillcolor="orange" if node_id in leaves_id else "white") + + for child in node.children: + child_id = id(child) + out_proxy_name = node.bsym.output.name if isinstance(node.bsym.output, TensorProxy) else None + dot.edge(str(node_id), str(child_id), label=out_proxy_name) + if child_id not in visited and not str(child.bsym).startswith("#"): + stack.append(child) + + for parent in node.parents: + parent_id = id(parent) + dot.edge(str(parent_id), str(node_id)) + if parent_id not in visited and not str(parent.bsym).startswith("#"): + stack.append(parent) + + return dot