Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make cudagraphs a transform #977

Merged
merged 5 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ def jit(
if executors is None:
executors = compile_options.pop("executors_list")

if "early_transforms" in compile_options:
raise RuntimeError("early_transforms= has been absorbed by transforms=")

if compile_options.get("use_cudagraphs") is not None:
raise RuntimeError("use_cudagraphs is replaced by using thunder.transforms.CUDAGraphTransform")

# Resolves interpreter option
interpretation = resolve_interpretation_option(interpretation)
interpreter: Callable
Expand Down Expand Up @@ -341,16 +347,13 @@ def jit(

# TODO RC1 Refine the compile data option to remove unused options
# TODO: refine options
# NOTE(fixme): use_cudagraphs is being absorbed into compile_options
use_cudagraphs = compile_options.get("use_cudagraphs", False)
cd = CompileData(
fn=fn,
langctx=langctx,
executors_list=executors,
cache_option=cache,
sharp_edges=sharp_edges,
using_jit=True,
use_cudagraphs=use_cudagraphs,
disable_torch_autograd_support=disable_torch_autograd,
use_rematerialization=False,
only_execute_prims=False,
Expand Down Expand Up @@ -692,16 +695,6 @@ def get_computation_and_inputs(*args, **kwargs):
computation_traces.extend(extraces)
computation_trc = computation_traces[-1]

if cd.use_cudagraphs:
from thunder.executors.cudagraphex import cudagraphex

computation_trc = cudagraphex.fusion_pass(computation_trc)
computation_traces.append(computation_trc)

if backward_trc is not None:
backward_trc = cudagraphex.fusion_pass(backward_trc, num_static_inputs=len(backward_trc.args[0][0]))
backward_traces.append(backward_trc)

if backward_trc is None:
computation_trc = thunder.executors.passes.del_last_used(computation_trc)

Expand Down Expand Up @@ -835,7 +828,6 @@ def compile(
langctx: None | Any = None,
executors_list: None | Sequence[Executor] = None,
cache_mode: None | str | CACHE_OPTIONS = None,
use_cudagraphs: bool = False,
disable_torch_autograd_support: bool = False,
use_rematerialization: bool = False,
only_execute_prims: bool = False,
Expand All @@ -847,7 +839,6 @@ def compile(
langctx=langctx,
executors_list=executors_list,
cache_option=cache_mode,
use_cudagraphs=use_cudagraphs,
disable_torch_autograd_support=disable_torch_autograd_support,
use_rematerialization=use_rematerialization,
only_execute_prims=only_execute_prims,
Expand Down
4 changes: 3 additions & 1 deletion thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from thunder.executors.transformer_engineex import transformer_engine_ex, TE_AVAILABLE
from thunder.executors.sdpaex import sdpa_ex
from thunder.executors.torch_compile import torch_compile_cat_ex, torch_compile_ex
from thunder.transforms.cudagraph import CUDAGraphTransform
from thunder.tests import nanogpt_model, hf_bart_self_attn, litgpt_model
from thunder.tests.litgpt_model import Config as LitGPTConfig
from thunder.tests.make_tensor import make_tensor, make_tensor_like
Expand Down Expand Up @@ -954,7 +955,8 @@ def default_thunder_cudagraphs_executor(fn: Callable) -> Callable:
executors_list.append("apex_xentropy")

executors_list.extend((executors.NVFUSER, executors.TORCH))
return thunder.jit(fn, executors=executors_list, use_cudagraphs=True, disable_torch_autograd=True)
transforms = [CUDAGraphTransform()]
return thunder.jit(fn, executors=executors_list, transforms=transforms, disable_torch_autograd=True)


#
Expand Down
2 changes: 0 additions & 2 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def __init__(
using_jit: bool = False,
only_execute_prims: bool = False,
disable_preprocessing: bool = False,
use_cudagraphs: bool = False,
disable_torch_autograd_support: bool = False,
use_rematerialization: bool = False,
debug_log: None | StringIO = None,
Expand Down Expand Up @@ -253,7 +252,6 @@ def __init__(
self.only_execute_prims = only_execute_prims
self.disable_preprocessing = disable_preprocessing
self.use_rematerialization = use_rematerialization
self.use_cudagraphs = use_cudagraphs
self.disable_torch_autograd_support = disable_torch_autograd_support
self.debug_log = debug_log

Expand Down
1 change: 0 additions & 1 deletion thunder/extend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ def get_all_executors() -> tuple[Executor, ...]:
# manually import all native executors to let them register themselves
from thunder.executors import (
apexex,
cudagraphex,
cudnn_layernormex,
cudnnex,
nvfuserex,
Expand Down
25 changes: 0 additions & 25 deletions thunder/tests/test_cudagraphs_executor.py

This file was deleted.

1 change: 0 additions & 1 deletion thunder/tests/test_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def test_get_all_executors_includes_all_native_executors():
"torchcompile_cat",
"python",
"transformer_engine",
"cudagraphex",
}
if package_available("triton"):
# `triton` maybe installed on a system without GPU.
Expand Down
24 changes: 12 additions & 12 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ def test_nanogpt_complete_cudagraphs(executor, device, dtype):

# Creates a nanoGPT model with a smaller size than any of the default options for testing
# NOTE Sets dropout to zero for reproducibility
config = nanogpt_model.GPTConfig(dropout=0, block_size=512, n_layer=6, n_head=6, n_embd=768)
gpt = nanogpt_model.GPT(config).to(device=device, dtype=tdtype)
config = nanogpt_model.GPTConfig(dropout=0, block_size=512, n_layer=4, n_head=6, n_embd=768)
gpt = nanogpt_model.GPT(config).to(device=device, dtype=tdtype).requires_grad_(False).eval()

tom = executor.make_callable(gpt, use_cudagraphs=True, disable_torch_autograd=True)
from thunder.transforms.cudagraph import CUDAGraphTransform, build_cuda_graph

# Checking graph cache stats
from thunder.executors.cudagraphex import build_cuda_graph
cgtransform = CUDAGraphTransform()
tom = executor.make_callable(gpt, transforms=[cgtransform], disable_torch_autograd=True)

# Cache stats before test runs
build_graph_stats_old = build_cuda_graph.cache_info()
Expand All @@ -111,8 +111,7 @@ def test_nanogpt_complete_cudagraphs(executor, device, dtype):
# Test that at most 1 cache miss happened after the runs.
assert (build_graph_stats_new.misses - build_graph_stats_old.misses) <= 1

# Check we really run CUDAGraphExecutor {
assert tom._lc_cd.use_cudagraphs == True
# Check we really use CUDA graphs {
assert _there_is_cudagraph_sym(thunder.last_traces(tom)[-1])
# }

Expand All @@ -124,18 +123,20 @@ def test_nanogpt_complete_cudagraphs(executor, device, dtype):

@instantiate(dtypes=(thunder.float32,), devicetypes=(thunder.devices.DeviceType.CUDA,))
@requiresCUDA
def test_nanogpt_complete_cuda_graphs_autograd(executor, device, dtype):
def test_nanogpt_complete_cudagraphs_autograd(executor, device, dtype):
tdtype = ttorch.to_torch_dtype(dtype)

# Creates a nanoGPT model with a smaller size than any of the default options for testing
# NOTE Sets dropout to zero for reproducibility
config = nanogpt_model.GPTConfig(dropout=0, block_size=512, n_layer=6, n_head=6, n_embd=768)
gpt = nanogpt_model.GPT(config).to(device=device, dtype=tdtype)
cmodel = executor.make_callable(gpt, use_cudagraphs=True)

# Checking graph cache stats
from thunder.executors.cudagraphex import build_cuda_graph
from thunder.transforms.cudagraph import CUDAGraphTransform, build_cuda_graph

cgtransform = CUDAGraphTransform()
cmodel = executor.make_callable(gpt, transforms=[cgtransform])

# Checking graph cache stats
# Cache stats before test runs
build_graph_stats_old = build_cuda_graph.cache_info()

Expand All @@ -161,7 +162,6 @@ def test_nanogpt_complete_cuda_graphs_autograd(executor, device, dtype):
assert (build_graph_stats_new.misses - build_graph_stats_old.misses) <= 2

# Check we really run CUDAGraphExecutor {
assert cmodel._lc_cd.use_cudagraphs == True
assert _there_is_cudagraph_sym(thunder.last_traces(cmodel)[-1])
assert _there_is_cudagraph_sym(thunder.last_backward_traces(cmodel)[-1])
# }
Expand Down
19 changes: 19 additions & 0 deletions thunder/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,25 @@ def post_callback(bsym, *args, **kwargs):
assert debug_headers == expected_headers


@requiresCUDA
def test_cudagraph_warmup_runs_with_correct_buffers():
"""
Tests whether newly-created buffers are being properly initialized.
Otherwise we should expect failures because of incorrect values.
"""

from thunder.transforms.cudagraph import CUDAGraphTransform

weights = torch.tensor([0, 10, 3, 0], device="cuda", dtype=torch.float)

def f(x):
return torch.multinomial(x, num_samples=3, replacement=True)

jf = thunder.jit(f, transforms=[CUDAGraphTransform()])
jf(weights)
jf(weights)


@requiresCUDA
def test_materialization_init():
from thunder.transforms import MaterializationTransform
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from thunder import trace
from thunder.core.transform_common import Transform
from thunder.core.transforms import eval_trace
from thunder.extend import FusionExecutor, register_executor
from thunder.core import utils, prims
Expand Down Expand Up @@ -97,7 +98,6 @@ def __call__(self, *args):
static_input.copy_(arg)

graph.replay()

return static_outputs


Expand Down Expand Up @@ -130,9 +130,13 @@ def region_fn():
return region_trace.python_callable()


class CUDAGraphExecutor(FusionExecutor):
def __init__(self, name: Hashable):
super().__init__(name, version=torch.version.cuda)
class CUDAGraphTransform(Transform):
"""
Transform to fuse operations into CUDA graphs post optimization.

This class provides the basic infrastructure, but it is expected that you might subclass this transform
in order to override ``can_fuse```or other methods.
"""

def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | int = None) -> BoundSymbol:
inputs = [unvariableify(inp) for inp in region.inputs]
Expand All @@ -143,8 +147,8 @@ def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | in
region.bound_symbols = _del_last_used(region.bound_symbols, outputs)

fusion_name = f"CUDAGraph{fusion_counter}"
fusion_callable: Callable = make_callable(f"{fusion_name}_fn", region.bound_symbols, inputs, outputs)
fusion_callable = CUDAGraphCallable(fusion_callable, num_static_inputs)
fusible_callable: Callable = make_callable(f"{fusion_name}_fn", region.bound_symbols, inputs, outputs)
fusion_callable = CUDAGraphCallable(fusible_callable, num_static_inputs)

fusion_sym = Symbol(fusion_name, meta=None, is_fusion=True, executor=self)
fusion_bsym = BoundSymbol(
Expand Down Expand Up @@ -191,7 +195,7 @@ def can_fuse(self, bsym: BoundSymbol):

return True

def fusion_pass(self, trace: TraceCtx, num_static_inputs: None | int = None) -> TraceCtx:
def transform_trace_post_optimization(self, trace: TraceCtx, **kwargs) -> TraceCtx:
start_time_ns: int = time.perf_counter_ns()

def _should_fuse(a: Node, b: Node):
Expand Down Expand Up @@ -230,7 +234,7 @@ def _can_fuse_node(n: Node):
fused_bsyms.extend(bsyms)
else:
region = Region(producers, consumers, bsyms)
fusion_bsym: BoundSymbol = self.fuse(region, fusion_counter, num_static_inputs)
fusion_bsym: BoundSymbol = self.fuse(region, fusion_counter)
Comment on lines -233 to +237
Copy link
Contributor

@nikitaved nikitaved Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are missing num_static_inputs here for a reason? If we get several non-isomorphic graphs, we should probably think about how to handle this parameter better, maybe through a callback. But that's not relevant now...

Copy link
Collaborator Author

@t-vi t-vi Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think that this is one of the things where the default transform is not ideal, but it is not something that was used before: the parameter was there with a default, but there was no way of providing it.
(And I am not sure that it is correct to have a nontrivial trace-global parameter for it either.)

Copy link
Contributor

@nikitaved nikitaved Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in the backward pass. We are not loosing it there, are we?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch, right.

Copy link
Contributor

@nikitaved nikitaved Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But yes, the design of this parameter is so-so as it was not expected to have had graph breaks...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but which were the args covered by it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is the removed code:

            if cd.use_cudagraphs:
                from thunder.executors.cudagraphex import cudagraphex

                computation_trc = cudagraphex.fusion_pass(computation_trc)
                computation_traces.append(computation_trc)

                if backward_trc is not None:
                    backward_trc = cudagraphex.fusion_pass(backward_trc, num_static_inputs=len(backward_trc.args[0][0]))
                    backward_traces.append(backward_trc)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems that these were the saved for backwards of (???). Is the assumption here that they were either static in the forward (parameters) or copied to the input area of the forward cuda graph?

But with the old code:

import torch, thunder
with torch.device("cuda"):
    m = torch.nn.Linear(2, 3)
    inp = torch.randn(1, 2, requires_grad=True)
jm = thunder.jit(m, use_cudagraphs=True)

res = jm(inp)
grads = torch.autograd.grad(res.sum(), (inp, *m.parameters()))

the forward has no cudagraph, so in that case, having the input as a fixed is not really correct. (admittedly, a corner case).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the happy part is that we will get the parameters as static anyways and we will have to look into the buffers for our own good...

fusion_counter += 1
fused_bsyms.append(fusion_bsym)

Expand All @@ -244,7 +248,3 @@ def _can_fuse_node(n: Node):
fused_trace.set_provenance(TraceProvenance(f"CUDAGraph fusion (took {elapsed_time_ms} milliseconds)"))

return fused_trace


cudagraphex = CUDAGraphExecutor(name="cudagraphex")
register_executor(cudagraphex)
Loading