Skip to content

Commit

Permalink
CUDAGraphExecutor - add tests for the graph caching mechanism (#670)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitaved authored Jun 27, 2024
1 parent 395dfe1 commit 8c5905f
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,35 @@ def test_nanogpt_complete_cudagraphs(executor, device, dtype):

tom = executor.make_callable(gpt, use_cudagraphs=True, disable_torch_autograd=True)

# Checking graph cache stats
from thunder.executors.cudagraphex import build_cuda_graph

# Cache stats before test runs
build_graph_stats_old = build_cuda_graph.cache_info()

for _ in range(2):
idx = make((4, 64), dtype=torch.int64, low=0, high=255)
torch_result = gpt(idx)

thunder_result = tom(idx)
assert_close(torch_result, thunder_result)

# Cache stats after test runs
build_graph_stats_new = build_cuda_graph.cache_info()
# We ran only a single (forward) graph several times.
# Test that at most 1 cache miss happened after the runs.
assert (build_graph_stats_new.misses - build_graph_stats_old.misses) <= 1

# Check we really run CUDAGraphExecutor {
assert tom._lc_cd.use_cudagraphs == True
assert _there_is_cudagraph_sym(thunder.last_traces(tom)[-1])
# }

# Let's clear cache if run only in tests
# TODO: merge with the cache of the thunder.jit callable
if build_graph_stats_old.misses == 0:
build_cuda_graph.cache_clear()


@instantiate(dtypes=(thunder.float32,), devicetypes=(thunder.devices.DeviceType.CUDA,))
@requiresCUDA
Expand All @@ -107,6 +124,12 @@ def test_nanogpt_complete_cuda_graphs_autograd(executor, device, dtype):
gpt = nanogpt_model.GPT(config).to(device=device, dtype=tdtype)
cmodel = executor.make_callable(gpt, use_cudagraphs=True)

# Checking graph cache stats
from thunder.executors.cudagraphex import build_cuda_graph

# Cache stats before test runs
build_graph_stats_old = build_cuda_graph.cache_info()

# Multiple runs to test whether static buffers are properly updated
for i in range(3):
x = make_tensor((4, 64), dtype=torch.int64, low=0, high=255, device=device)
Expand All @@ -121,12 +144,24 @@ def test_nanogpt_complete_cuda_graphs_autograd(executor, device, dtype):
assert_close(torch_result, thunder_result)
assert_close(torch_grads, thunder_grads)

# Cache stats after test runs
build_graph_stats_new = build_cuda_graph.cache_info()
# We ran only at most two (forward and backward) graphs several times.
# Test that at most 2 cache misses happened after the runs
# (at most one per each graph)
assert (build_graph_stats_new.misses - build_graph_stats_old.misses) <= 2

# Check we really run CUDAGraphExecutor {
assert cmodel._lc_cd.use_cudagraphs == True
assert _there_is_cudagraph_sym(thunder.last_traces(cmodel)[-1])
assert _there_is_cudagraph_sym(thunder.last_backward_traces(cmodel)[-1])
# }

# Let's clear cache if run only in tests
# TODO: merge with the cache of the thunder.jit callable
if build_graph_stats_old.misses == 0:
build_cuda_graph.cache_clear()


@instantiate(dtypes=(thunder.float32,))
def test_nanogpt_csa(executor, device, dtype):
Expand Down

0 comments on commit 8c5905f

Please sign in to comment.