From 8c5905fd1a93145e690791a7c7a3c3e10b16b32b Mon Sep 17 00:00:00 2001 From: nikitaved Date: Thu, 27 Jun 2024 16:24:32 +0200 Subject: [PATCH] CUDAGraphExecutor - add tests for the graph caching mechanism (#670) --- thunder/tests/test_networks.py | 35 ++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index e9bde4209f..dd6cda053d 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -83,6 +83,12 @@ def test_nanogpt_complete_cudagraphs(executor, device, dtype): tom = executor.make_callable(gpt, use_cudagraphs=True, disable_torch_autograd=True) + # Checking graph cache stats + from thunder.executors.cudagraphex import build_cuda_graph + + # Cache stats before test runs + build_graph_stats_old = build_cuda_graph.cache_info() + for _ in range(2): idx = make((4, 64), dtype=torch.int64, low=0, high=255) torch_result = gpt(idx) @@ -90,11 +96,22 @@ def test_nanogpt_complete_cudagraphs(executor, device, dtype): thunder_result = tom(idx) assert_close(torch_result, thunder_result) + # Cache stats after test runs + build_graph_stats_new = build_cuda_graph.cache_info() + # We ran only a single (forward) graph several times. + # Test that at most 1 cache miss happened after the runs. + assert (build_graph_stats_new.misses - build_graph_stats_old.misses) <= 1 + # Check we really run CUDAGraphExecutor { assert tom._lc_cd.use_cudagraphs == True assert _there_is_cudagraph_sym(thunder.last_traces(tom)[-1]) # } + # Let's clear cache if run only in tests + # TODO: merge with the cache of the thunder.jit callable + if build_graph_stats_old.misses == 0: + build_cuda_graph.cache_clear() + @instantiate(dtypes=(thunder.float32,), devicetypes=(thunder.devices.DeviceType.CUDA,)) @requiresCUDA @@ -107,6 +124,12 @@ def test_nanogpt_complete_cuda_graphs_autograd(executor, device, dtype): gpt = nanogpt_model.GPT(config).to(device=device, dtype=tdtype) cmodel = executor.make_callable(gpt, use_cudagraphs=True) + # Checking graph cache stats + from thunder.executors.cudagraphex import build_cuda_graph + + # Cache stats before test runs + build_graph_stats_old = build_cuda_graph.cache_info() + # Multiple runs to test whether static buffers are properly updated for i in range(3): x = make_tensor((4, 64), dtype=torch.int64, low=0, high=255, device=device) @@ -121,12 +144,24 @@ def test_nanogpt_complete_cuda_graphs_autograd(executor, device, dtype): assert_close(torch_result, thunder_result) assert_close(torch_grads, thunder_grads) + # Cache stats after test runs + build_graph_stats_new = build_cuda_graph.cache_info() + # We ran only at most two (forward and backward) graphs several times. + # Test that at most 2 cache misses happened after the runs + # (at most one per each graph) + assert (build_graph_stats_new.misses - build_graph_stats_old.misses) <= 2 + # Check we really run CUDAGraphExecutor { assert cmodel._lc_cd.use_cudagraphs == True assert _there_is_cudagraph_sym(thunder.last_traces(cmodel)[-1]) assert _there_is_cudagraph_sym(thunder.last_backward_traces(cmodel)[-1]) # } + # Let's clear cache if run only in tests + # TODO: merge with the cache of the thunder.jit callable + if build_graph_stats_old.misses == 0: + build_cuda_graph.cache_clear() + @instantiate(dtypes=(thunder.float32,)) def test_nanogpt_csa(executor, device, dtype):