diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 7b363606e8..f354e9abf0 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -148,9 +148,40 @@ def _device_put_transform(a: TensorProxy, device: devices.Device) -> TensorProxy def no_autocast(fn): - fn = torch.autocast(device_type="cpu", enabled=False, cache_enabled=False)(fn) - fn = torch.autocast(device_type="cuda", enabled=False, cache_enabled=False)(fn) - return fn + """ + A decorator that disables torch.autocast for the duration of the decorated + function. + + In Thunder this is useful when you want to ensure that the generated + function is not run with PyTorch's autocast enabled to execute exactly as + generated. + + Args: + fn: The function to decorate. + + Returns: + The decorated function. + """ + # This decorator intentionally does not use the torch.autocast decorator + # because it is much slower than the implementation here. This is because + # the torch.autocast decorator has a lot more overhead to support various + # features that are not needed in Thunder. + from torch import set_autocast_enabled + + prev_cpu = torch.is_autocast_cpu_enabled() + prev = torch.is_autocast_enabled() + + @wraps(fn) + def no_autocast_fn(*args, **kwargs): + try: + set_autocast_enabled("cpu", False) + set_autocast_enabled("cuda", False) + return fn(*args, **kwargs) + finally: + set_autocast_enabled("cpu", prev_cpu) + set_autocast_enabled("cuda", prev) + + return no_autocast_fn # diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 2abab28d34..2a01959f4e 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -85,11 +85,10 @@ def func(): trace = thunder.trace()(func) python_callable = trace.python_callable() - # 3 unwraps for: + # 2 unwraps for: # @no_grad() - # @autocast(device_type="cpu", ...) - # @autocast(device_type="cuda", ...) - cfunc = python_callable.__wrapped__.__wrapped__.__wrapped__ + # @no_autocast + cfunc = python_callable.__wrapped__.__wrapped__ b1, b2 = python_callable() assert b1 is False assert b2 is False @@ -107,10 +106,10 @@ def func(): with torch.autocast(device_type=devicetype, dtype=test_dtype): b1, b2 = python_callable() b3, b4 = cfunc() - assert b1 is False - assert b2 is False - assert b3 is (True if torch_device.type == "cuda" else False) - assert b4 is (True if torch_device.type == "cpu" else False) + assert not b1 + assert not b2 + assert not b3 + assert not b4 @instantiate( @@ -141,24 +140,6 @@ def func(a, b): assert output.dtype == (torch.float16 if torch_device.type == "cuda" else torch.bfloat16) -# Disabling on windows temporarily, until our windows runners source the -# appropriate visual studio config. -@pytest.mark.skipif(not is_inductor_supported() or platform.system() == "Windows", reason="inductor unsupported") -def test_torch_compile_autocast(): - """Checks if our autocast decorator plays well with ``torch.compile``""" - - @no_autocast - def fn(x, y): - return x + y - - a = torch.randn(2, 2) - b = torch.randn(2, 2) - cfn = torch.compile(fn, fullgraph=True) - actual = cfn(a, b) - expected = a + b - torch.testing.assert_close(actual, expected) - - def test_autocast_mixed_dtype_inputs(): def foo(x, w): return torch.nn.functional.linear(x, w)