Skip to content

Commit

Permalink
attach cudagraph cache to the transform
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Aug 20, 2024
1 parent d526304 commit 0907eab
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 100 deletions.
37 changes: 7 additions & 30 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,35 +90,24 @@ def test_nanogpt_complete_cudagraphs(executor, device, dtype):
config = nanogpt_model.GPTConfig(dropout=0, block_size=512, n_layer=4, n_head=6, n_embd=768)
gpt = nanogpt_model.GPT(config).to(device=device, dtype=tdtype).requires_grad_(False).eval()

from thunder.transforms.cudagraph import CUDAGraphTransform, build_cuda_graph
from thunder.transforms.cudagraph import CUDAGraphTransform

cgtransform = CUDAGraphTransform()
tom = executor.make_callable(gpt, transforms=[cgtransform], disable_torch_autograd=True)

# Cache stats before test runs
build_graph_stats_old = build_cuda_graph.cache_info()

for _ in range(2):
idx = make((4, 64), dtype=torch.int64, low=0, high=255)
torch_result = gpt(idx)

thunder_result = tom(idx)
assert_close(torch_result, thunder_result)

# Cache stats after test runs
build_graph_stats_new = build_cuda_graph.cache_info()
# We ran only a single (forward) graph several times.
# Test that at most 1 cache miss happened after the runs.
assert (build_graph_stats_new.misses - build_graph_stats_old.misses) <= 1
# Test that at only 1 cache entry was created during the runs.
assert len(cgtransform.cuda_graph_runner.cuda_graph_cache) == 1

# Check we really use CUDA graphs {
# Check we really use CUDA graphs
assert _there_is_cudagraph_sym(thunder.last_traces(tom)[-1])
# }

# Let's clear cache if run only in tests
# TODO: merge with the cache of the thunder.jit callable
if build_graph_stats_old.misses == 0:
build_cuda_graph.cache_clear()


@instantiate(dtypes=(thunder.float32,), devicetypes=(thunder.devices.DeviceType.CUDA,))
Expand All @@ -131,15 +120,11 @@ def test_nanogpt_complete_cudagraphs_autograd(executor, device, dtype):
config = nanogpt_model.GPTConfig(dropout=0, block_size=512, n_layer=6, n_head=6, n_embd=768)
gpt = nanogpt_model.GPT(config).to(device=device, dtype=tdtype)

from thunder.transforms.cudagraph import CUDAGraphTransform, build_cuda_graph
from thunder.transforms.cudagraph import CUDAGraphTransform

cgtransform = CUDAGraphTransform()
cmodel = executor.make_callable(gpt, transforms=[cgtransform])

# Checking graph cache stats
# Cache stats before test runs
build_graph_stats_old = build_cuda_graph.cache_info()

# Multiple runs to test whether static buffers are properly updated
for i in range(3):
x = make_tensor((4, 64), dtype=torch.int64, low=0, high=255, device=device)
Expand All @@ -154,22 +139,14 @@ def test_nanogpt_complete_cudagraphs_autograd(executor, device, dtype):
assert_close(torch_result, thunder_result)
assert_close(torch_grads, thunder_grads)

# Cache stats after test runs
build_graph_stats_new = build_cuda_graph.cache_info()
# We ran only at most two (forward and backward) graphs several times.
# Test that at most 2 cache misses happened after the runs
# (at most one per each graph)
assert (build_graph_stats_new.misses - build_graph_stats_old.misses) <= 2
assert len(cgtransform.cuda_graph_runner.cuda_graph_cache) == 2

# Check we really run CUDAGraphExecutor {
# Check we have CUDAGraph symbols in forward and backward
assert _there_is_cudagraph_sym(thunder.last_traces(cmodel)[-1])
assert _there_is_cudagraph_sym(thunder.last_backward_traces(cmodel)[-1])
# }

# Let's clear cache if run only in tests
# TODO: merge with the cache of the thunder.jit callable
if build_graph_stats_old.misses == 0:
build_cuda_graph.cache_clear()


@instantiate(dtypes=(thunder.float32,), executors=all_test_executors_and_dynamo)
Expand Down
188 changes: 118 additions & 70 deletions thunder/transforms/cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class ArgsDescriptor:
sizes: tuple
strides: tuple
non_tensor_args: tuple
args: list = field(hash=False, repr=False, compare=False)


def to_arg_descriptor(*args):
Expand All @@ -35,63 +34,75 @@ def extract_descriptor(arg):
return type(arg), None, None, arg

dtypes, sizes, strides, non_tensor_args = zip(*map(extract_descriptor, args))
return ArgsDescriptor(dtypes, sizes, strides, non_tensor_args, args)
return ArgsDescriptor(dtypes, sizes, strides, non_tensor_args)


@lru_cache
def build_cuda_graph(
fn: Callable, args_descriptor: ArgsDescriptor, static_args_mask: tuple[bool, ...]
) -> tuple[torch.cuda.CUDAGraph, Sequence[torch.Tensor | Any], Sequence[torch.Tensor | Any]]:
class CUDAGraphRunner:
def __init__(self):
self.cuda_graph_cache = {}
self.python_callables = {} # fn_name -> (callable. num_static_inputs)
self.trace_symbols = {} # fn_name -> (bsyms, inputs, outputs)
self.name_counter = 1

def get_static_buffer(x):
def get_static_buffer(self, x):
if isinstance(x, torch.Tensor):
return torch.empty_like(x).copy_(x)
return x

args = args_descriptor.args

# Warmup
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())

with torch.cuda.stream(stream):
static_inputs = tuple(
get_static_buffer(arg) if not is_static else arg for arg, is_static in zip(args, static_args_mask)
)
for _ in range(3):
fn(*static_inputs)

stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()

# Record
# NOTE: We are using default private pool here, but it is possibly better to
# use a custom pool for better memory management. See CUDA Graphs Tree in
# PyTorch's Inductor: torch/_inductor/cudagraph_trees.py
# Design doc: https://docs.google.com/document/d/1ZrxLGWz7T45MSX6gPsL6Ln4t0eZCSfWewtJ_qLd_D0E/view
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = fn(*static_inputs)

return graph, static_inputs, static_outputs
def build_cuda_graph(
self, fn: Callable, args: list[any], static_args_mask: tuple[bool, ...]
) -> tuple[torch.cuda.CUDAGraph, Sequence[torch.Tensor | Any], Sequence[torch.Tensor | Any]]:

# Warmup
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())

with torch.cuda.stream(stream):
static_inputs = tuple(
self.get_static_buffer(arg) if not is_static else arg for arg, is_static in zip(args, static_args_mask)
)
for _ in range(3):
fn(*static_inputs)

stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()

# Record
# NOTE: We are using default private pool here, but it is possibly better to
# use a custom pool for better memory management. See CUDA Graphs Tree in
# PyTorch's Inductor: torch/_inductor/cudagraph_trees.py
# Design doc: https://docs.google.com/document/d/1ZrxLGWz7T45MSX6gPsL6Ln4t0eZCSfWewtJ_qLd_D0E/view
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = fn(*static_inputs)

return graph, static_inputs, static_outputs

def make_static_inputs_mask(self, fn_name, *args):
_, static_inputs_mask = self.python_callables[fn_name]
if static_inputs_mask is None:
static_inputs_mask = tuple(isinstance(arg, torch.nn.Parameter) for arg in args)
return static_inputs_mask

def make_cache_key(self, fn_name, *args):
# if args_descriptor included torch.nn.Parameter-ness, we would
# could use the static_inputs_mask or None as the key.
return (fn_name, self.make_static_inputs_mask(fn_name, *args), to_arg_descriptor(*args))

class CUDAGraphCallable:
def __init__(self, fn: Callable, num_static_inputs: None | int = None):
self.fn = fn
self.num_static_inputs = num_static_inputs
def call_cuda_graph(self, fn_name, *args):
fn, num_static_inputs = self.python_callables[fn_name]

def __call__(self, *args):
if self.num_static_inputs is not None:
static_inputs_mask = (True,) * self.num_static_inputs + (False,) * (len(args) - self.num_static_inputs)
else:
static_inputs_mask = tuple(isinstance(arg, torch.nn.Parameter) for arg in args)
cache_key = self.make_cache_key(fn_name, *args)

args_descriptor = to_arg_descriptor(*args)
cache_entry = self.cuda_graph_cache.get(cache_key)
if cache_entry is None:
static_inputs_mask = self.make_static_inputs_mask(fn_name, *args)
cache_entry = self.build_cuda_graph(fn, args, static_inputs_mask)
self.cuda_graph_cache[cache_key] = cache_entry

graph, static_inputs, static_outputs = build_cuda_graph(self.fn, args_descriptor, static_inputs_mask)
graph, static_inputs, static_outputs = cache_entry

for static_input, arg in utils.safe_zip(static_inputs, args):
if id(static_input) != id(arg) and isinstance(static_input, torch.Tensor) and isinstance(arg, torch.Tensor):
Expand All @@ -100,44 +111,79 @@ def __call__(self, *args):
graph.replay()
return static_outputs

def make_python_callable_from_symbols(
self,
fn_name: str,
bsyms: list[BoundSymbol],
inputs: list[Proxy],
outputs: list[Proxy],
) -> Callable:

def make_callable(
fn_name: str,
bsyms: list[BoundSymbol],
inputs: list[Proxy],
outputs: list[Proxy],
) -> Callable:
from inspect import Parameter, Signature
from inspect import Parameter, Signature

region_fn_params = (
Parameter(getattr(param, "name", f"arg{i}"), Parameter.POSITIONAL_ONLY) for i, param in enumerate(inputs)
)
region_fn_params = (
Parameter(getattr(param, "name", f"arg{i}"), Parameter.POSITIONAL_ONLY) for i, param in enumerate(inputs)
)

region_fn_signature = Signature(region_fn_params)
region_fn_signature = Signature(region_fn_params)

def region_fn():
pass
def region_fn():
pass

region_fn.__signature__ = region_fn_signature
region_fn.__name__ = fn_name
region_fn.__signature__ = region_fn_signature
region_fn.__name__ = fn_name

region_trace = TraceCtx(region_fn)
region_trace.bound_symbols = bsyms
region_trace.args = inputs
region_trace.kwargs = {}
region_trace.bound_symbols.append(prims.python_return.bind(outputs, output=()))
region_trace = TraceCtx(region_fn)
region_trace.bound_symbols = bsyms
region_trace.args = inputs
region_trace.kwargs = {}
region_trace.bound_symbols.append(prims.python_return.bind(outputs, output=()))
return region_trace.python_callable()

return region_trace.python_callable()
def make_cuda_graph_callable_from_symbols(
self,
fn_name: str,
bsyms: list[BoundSymbol],
inputs: list[Proxy],
outputs: list[Proxy],
num_static_inputs: int | None = None,
) -> Callable:

if num_static_inputs is not None:
static_inputs_mask = (True,) * num_static_inputs + (False,) * (len(inputs) - num_static_inputs)
else:
static_inputs_mask = None

x_fn_name = f"{fn_name}_{self.name_counter}"
self.name_counter += 1

self.python_callables[x_fn_name] = (
self.make_python_callable_from_symbols(fn_name, bsyms, inputs, outputs),
static_inputs_mask,
)
self.trace_symbols[x_fn_name] = (bsyms, inputs, outputs)

def callable(*args):
return self.call_cuda_graph(x_fn_name, *args)

callable.__name__ = fn_name
callable.__qualname__ = fn_name

return callable


class CUDAGraphTransform(Transform):
"""
Transform to fuse operations into CUDA graphs post optimization.
This class provides the basic infrastructure, but it is expected that you might subclass this transform
This class provides the basic infrastructure, but it is expected that you might subclass this transformm
in order to override ``can_fuse```or other methods.
"""

def __init__(self):
super().__init__()
self.cuda_graph_runner = CUDAGraphRunner()

def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | int = None) -> BoundSymbol:
inputs = [unvariableify(inp) for inp in region.inputs]
outputs = [unvariableify(out) for out in region.outputs]
Expand All @@ -147,8 +193,10 @@ def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | in
region.bound_symbols = _del_last_used(region.bound_symbols, outputs)

fusion_name = f"CUDAGraph{fusion_counter}"
fusible_callable: Callable = make_callable(f"{fusion_name}_fn", region.bound_symbols, inputs, outputs)
fusion_callable = CUDAGraphCallable(fusible_callable, num_static_inputs)

fusion_callable = self.cuda_graph_runner.make_cuda_graph_callable_from_symbols(
f"{fusion_name}_fn", region.bound_symbols, inputs, outputs, num_static_inputs
)

fusion_sym = Symbol(fusion_name, meta=None, is_fusion=True, executor=self)
fusion_bsym = BoundSymbol(
Expand Down

0 comments on commit 0907eab

Please sign in to comment.