From e37bec26c9ff40d7a790d4f8b48bc94524f6385b Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 18 Dec 2024 13:17:21 +0100 Subject: [PATCH] move adhoc executor first for priority in autograd (#1569) --- thunder/__init__.py | 2 +- thunder/tests/test_grad.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index e3aa5ac2aa..eceacc4cad 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -335,7 +335,7 @@ def jit( # Resolve names of executors executors = resolve_executors(executors) ad_hoc_executor = extend.AdHocExecutor() - executors = (*executors, ad_hoc_executor) + executors = (ad_hoc_executor, *executors) # TODO: verify that tutorials don't have false positives and enable warning by default # # Make sharp_edges == warn default if not supplied and if in the general jit diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 4520f517b1..9b68b37e7f 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1853,3 +1853,36 @@ def forward(x): actual_grad = torch.autograd.grad(actual, x, grad_o) expected_grad = torch.autograd.grad(expected, x, grad_o) torch.testing.assert_close(actual_grad, expected_grad) + + +@instantiate( + dtypes=NOTHING, +) +def test_adhoc_executor_grad(executor, device, _): + import torch + import thunder + + x = torch.ones(2, device=device, requires_grad=True) + + class Sin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.sin(x) + + @staticmethod + def backward(ctx, g): + (x,) = ctx.saved_tensors + return g * torch.cos(x) * 200 + + def func(x): + return Sin.apply(x) + + cfunc = thunder.jit(func) + actual = cfunc(x) + (actual_gr,) = torch.autograd.grad(actual.sum(), x) + expected = func(x) + (expected_gr,) = torch.autograd.grad(expected.sum(), x) + + torch.testing.assert_close(actual, expected) + torch.testing.assert_close(actual_gr, expected_gr)