diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index fb81364740..78772e636c 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1395,7 +1395,7 @@ def test_populate_grads_nanogpt(executor, device, dtype): from thunder.benchmarks import NanoGPTBenchmark, NanoGPTConfig # NOTE Currently setting dropout to zero for reproducibility, other settings taken from gpt2 config - config = NanoGPTConfig(dropout=0, n_layer=12, n_head=12, n_embd=768) + config = NanoGPTConfig(dropout=0, n_layer=6, n_head=6, n_embd=768) bench = NanoGPTBenchmark(config=config, requires_grad=True, device=device, dtype=dtype) model = bench.fn()