diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index bb8da5be61..3cc8ff4d4b 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3856,6 +3856,76 @@ def autocast_linear_rule(a, w, bias, dtype): return _linear_autocast_impl(a, w, bias, dtype) +def _convolution_autocast_impl(a, w, bias, *other_args, dtype): + if bias is None: + # Don't pass `bias` to maybe_downcast_to. + downcast_args = maybe_downcast_to(dtype, (a, w)) + (bias,) + else: + downcast_args = maybe_downcast_to(dtype, (a, w, bias)) + + return prims.convolution(*downcast_args, *other_args) + + +@register_autocast_rule("torch.nn.functional.conv1d") +def autocast_ltorch_conv1d_rule( + a: TensorProxy, + /, + weight: TensorProxy, + bias: TensorProxy | None = None, + stride: int | Sequence[int] = 1, + padding: int | Sequence[int] | str = 0, + dilation: int = 1, + groups: int = 1, + *, + dtype, +) -> TensorProxy: + from thunder.torch import _conv_helper + + return _conv_helper( + 1, a, weight, bias, stride, padding, dilation, groups, conv_function=_convolution_autocast_impl, dtype=dtype + ) + + +@register_autocast_rule("torch.nn.functional.conv2d") +def autocast_ltorch_conv2d_rule( + a: TensorProxy, + /, + weight: TensorProxy, + bias: TensorProxy | None = None, + stride: int | Sequence[int] = 1, + padding: int | Sequence[int] | str = 0, + dilation: int = 1, + groups: int = 1, + *, + dtype, +) -> TensorProxy: + from thunder.torch import _conv_helper + + return _conv_helper( + 2, a, weight, bias, stride, padding, dilation, groups, conv_function=_convolution_autocast_impl, dtype=dtype + ) + + +@register_autocast_rule("torch.nn.functional.conv3d") +def autocast_ltorch_conv3d_rule( + a: TensorProxy, + /, + weight: TensorProxy, + bias: TensorProxy | None = None, + stride: int | Sequence[int] = 1, + padding: int | Sequence[int] | str = 0, + dilation: int = 1, + groups: int = 1, + *, + dtype, +) -> TensorProxy: + from thunder.torch import _conv_helper + + return _conv_helper( + 3, a, weight, bias, stride, padding, dilation, groups, conv_function=_convolution_autocast_impl, dtype=dtype + ) + + @register_autocast_rule("torch.nn.functional.scaled_dot_product_attention") def autocast_scaled_dot_product_attention( query, diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 9900f3975a..7cfedbe5b6 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -169,3 +169,46 @@ def foo(x, w): jit_out = jfoo(x, w) torch.testing.assert_close(eager_out, jit_out) + + +@pytest.mark.parametrize("dim", [1, 2, 3]) +@pytest.mark.parametrize("requires_grad", [False, True]) +def test_autocast_convolution(dim, requires_grad): + conv_fn = getattr(torch.nn.functional, f"conv{dim}d") + + def foo(x, w, b=None): + return conv_fn(x, w, b) + + x = torch.rand(1, 2, *(dim * (8,)), requires_grad=requires_grad) + w = torch.rand(3, 2, *(dim * (4,)), requires_grad=requires_grad) + b = torch.rand(3, requires_grad=requires_grad) + go = torch.rand(1, 3, *(dim * (5,))) + + jfoo = thunder.jit(foo) + + with torch.autocast("cpu", torch.float16): + eager_out = foo(x, w, b) + jit_out = jfoo(x, w, b) + + torch.testing.assert_close(eager_out, jit_out) + + if requires_grad: + eager_grads = torch.autograd.grad(eager_out, [x, w, b], go) + jit_grads = torch.autograd.grad(jit_out, [x, w, b], go) + + for eg, jg in zip(eager_grads, jit_grads): + torch.testing.assert_close(eg, jg, rtol=1e-2, atol=1e-2) + + with torch.autocast("cpu", torch.float16): + eager_out = foo(x, w) + jit_out = jfoo(x, w) + + torch.testing.assert_close(eager_out, jit_out) + + if requires_grad: + go = torch.randn_like(eager_out) + eager_grads = torch.autograd.grad(eager_out, [x, w], go) + jit_grads = torch.autograd.grad(jit_out, [x, w], go) + + for eg, jg in zip(eager_grads, jit_grads): + torch.testing.assert_close(eg, jg, rtol=1e-2, atol=1e-2) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 4d7214dbe5..ac195ea456 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3816,6 +3816,9 @@ def _conv_helper( padding: int | Sequence[int] | str = 0, dilation: int | Sequence[int] = 1, groups: int = 1, + *, + conv_function=clang.convolution, + **extra_kwargs, ) -> TensorProxy: # a, weight rank check utils.check(dim + 1 <= a.ndim <= dim + 2, lambda: f"{a.ndim=} should be either {dim + 1} or {dim + 2}") @@ -3875,8 +3878,17 @@ def pad_lo_hi_dilation_seq(): padding, a = process_padding_str(int_to_seq(padding), stride, dilation, a) # } - res = clang.convolution( - a, weight, bias, stride, padding, dilation, False, (0,) * dim, groups # transposed # output_padding + res = conv_function( + a, + weight, + bias, + stride, + padding, + dilation, + False, # transposed + (0,) * dim, # output_padding + groups, + **extra_kwargs, ) return res