diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index 6f88f1f8eb..36a05190f7 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -22,6 +22,7 @@ ) from thunder.tests.opinfos import opinfos, OpInfo, make_number, SampleInput from thunder.tests.make_tensor import make_tensor, make_tensor_like +from thunder.tests.litgpt_model import GPT, Config from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place if TYPE_CHECKING: @@ -812,3 +813,36 @@ def f_copy(a): expected = f(x_ref) torch.testing.assert_close(actual, expected) torch.testing.assert_close(x, x_ref) + + +@instantiate( + dtypes=(dtypes.float32,), + executors=(TorchCompileExecutor,), +) +def test_adamw_with_pythia14m(executor, device, _): + config = Config.from_name("pythia-14m") + model = GPT(config).to(device=device) + + params = list(model.parameters()) + adamw = torch.optim.AdamW(params) + ref_params = [p.clone().detach() for p in params] + ref_adamw = torch.optim.AdamW(ref_params) + + jitted_step = torch.compile(adamw.step, backend=thunder.dynamo.ThunderCompiler()) + + with torch.no_grad(): + for p, ref_p in zip(params, ref_params): + grad = make_tensor_like(p) + p.grad = grad + ref_p.grad = grad.clone().detach() + + for i in range(3): + if i > 0: + with torch.no_grad(): + for p, ref_p in zip(params, ref_params): + p.grad.copy_(make_tensor_like(p)) + ref_p.grad.copy_(p.grad) + jitted_step() + ref_adamw.step() + + torch.testing.assert_close(ref_params, params)