Skip to content

Commit

Permalink
Merge branch 'main' into benchmark_dynamo_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 authored Oct 15, 2024
2 parents d254715 + caec01b commit e6ae93f
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 9 deletions.
8 changes: 8 additions & 0 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.testing import make_tensor

import thunder
import thunder.dynamo
import thunder.core.devices as Devices
import thunder.core.dtypes as dtypes
import thunder.executors as executors
Expand Down Expand Up @@ -707,6 +708,13 @@ def torch_compile_executor(fn: Callable) -> Callable:
return torch.compile(fn)


def thunderfx_executor(fn: Callable) -> Callable:
torch.backends.cuda.matmul.allow_tf32 = True
backend = thunder.dynamo.ThunderCompiler()
torch._dynamo.reset()
return torch.compile(fn, backend=backend)


def thunder_torch_executor(fn: Callable) -> Callable:
torch.backends.cuda.matmul.allow_tf32 = True
return thunder.jit(fn, executors=[thunder.pytorch_executor])
Expand Down
21 changes: 12 additions & 9 deletions thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
thunder_cudnn_executor,
thunder_cudnn_nvfuser_executor,
thunder_executor,
thunderfx_executor,
thunder_sdpa_torch_compile_nvfuser_executor,
torch_compile_executor,
torch_executor,
Expand Down Expand Up @@ -152,17 +153,19 @@ def interpreter_fwd(module: Callable):
return fn_


executors = (
torch_executor,
torch_compile_executor,
thunder_executor,
)
executors = (torch_executor, torch_compile_executor, thunder_executor)
executors_ids = (
"torch",
"torch.compile",
"thunder",
)

torchbench_executors = (*executors, thunderfx_executor)
torchbench_executors_ids = (
*executors_ids,
"thunderfx",
)

apex_executors = (thunder_apex_executor, thunder_apex_nvfuser_executor)
apex_executors_ids = ("thunder+apex-grad", "thunder+apex+nvfuser-grad")

Expand Down Expand Up @@ -842,8 +845,8 @@ def test_resnet50(benchmark, executor: Callable, compute_type: ComputeType):
)
@pytest.mark.parametrize(
"executor,",
executors,
ids=executors_ids,
torchbench_executors,
ids=torchbench_executors_ids,
)
@parametrize_compute_type
def test_torchbench(benchmark, module_name, executor, compute_type: ComputeType):
Expand All @@ -868,8 +871,8 @@ def test_torchbench(benchmark, module_name, executor, compute_type: ComputeType)
)
@pytest.mark.parametrize(
"executor,",
executors,
ids=executors_ids,
torchbench_executors,
ids=torchbench_executors_ids,
)
@parametrize_compute_type
def test_torchbench_canary(benchmark, module_name, executor, compute_type: ComputeType):
Expand Down
58 changes: 58 additions & 0 deletions thunder/examine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import thunder
from thunder.core.trace import TraceCtx
from thunder.core.transforms import bsym_list_to_dag, Node
from thunder.core.proxies import TensorProxy
from thunder.core.symbol import BoundSymbol
from thunder.torch import _torch_to_thunder_function_map
Expand All @@ -13,6 +14,7 @@
import torch
from warnings import warn
from itertools import chain
import importlib


# TODO Maybe make collect_into a set?
Expand Down Expand Up @@ -272,3 +274,59 @@ def get_nvfuser_repro(trace: TraceCtx, fusion_name: str, /) -> str:
raise RuntimeError("The installed version of nvFuser does not support repro generation unless on crash.")

return get_repro(fusion.last_inputs)


def make_trace_dot(trace: TraceCtx):
"""
Creates a directed graph of the given trace.
This function is intended to be used to use graphviz to visualize the computation graph of a trace.
Beware, rendering out a graph for large traces might take a while.
Requires graphviz to be installed, for more information check out -> https://graphviz.readthedocs.io/en/stable/index.html
Args:
trace (TraceCtx): The Thunder trace to be made into a graph.
Returns:
graphviz.Digraph: A graphviz directed graph.
"""
if not importlib.util.find_spec("graphviz"):
warn("graphviz is not available. Graph cannot be created.")
return

import graphviz

node_attr = dict(
style="filled", shape="box", align="left", fontsize="10", ranksep="0.1", height="0.2", fontname="monospace"
)
dot = graphviz.Digraph(
node_attr=node_attr,
graph_attr=dict(size="10,10"),
)
dot.strict = True

roots, leaves = bsym_list_to_dag(trace.bound_symbols)
leaves_id = {id(leaf) for leaf in leaves}
stack = [*roots]
visited = set()
while stack:
node: Node = stack.pop()
node_id = id(node)
visited.add(node_id)
dot.node(str(node_id), node.bsym.sym.name, fillcolor="orange" if node_id in leaves_id else "white")

for child in node.children:
child_id = id(child)
out_proxy_name = node.bsym.output.name if isinstance(node.bsym.output, TensorProxy) else None
dot.edge(str(node_id), str(child_id), label=out_proxy_name)
if child_id not in visited and not str(child.bsym).startswith("#"):
stack.append(child)

for parent in node.parents:
parent_id = id(parent)
dot.edge(str(parent_id), str(node_id))
if parent_id not in visited and not str(parent.bsym).startswith("#"):
stack.append(parent)

return dot

0 comments on commit e6ae93f

Please sign in to comment.