Skip to content

Commit aefb94f

Browse files
committed
add autocast for conv1/2/3d
1 parent 08d8347 commit aefb94f

File tree

3 files changed

+110
-2
lines changed

3 files changed

+110
-2
lines changed

thunder/core/transforms.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3856,6 +3856,76 @@ def autocast_linear_rule(a, w, bias, dtype):
38563856
return _linear_autocast_impl(a, w, bias, dtype)
38573857

38583858

3859+
def _convolution_autocast_impl(a, w, bias, *other_args, dtype):
3860+
if bias is None:
3861+
# Don't pass `bias` to maybe_downcast_to.
3862+
downcast_args = maybe_downcast_to(dtype, (a, w)) + (bias,)
3863+
else:
3864+
downcast_args = maybe_downcast_to(dtype, (a, w, bias))
3865+
3866+
return prims.convolution(*downcast_args, *other_args)
3867+
3868+
3869+
@register_autocast_rule("torch.nn.functional.conv1d")
3870+
def autocast_ltorch_conv1d_rule(
3871+
a: TensorProxy,
3872+
/,
3873+
weight: TensorProxy,
3874+
bias: TensorProxy | None = None,
3875+
stride: int | Sequence[int] = 1,
3876+
padding: int | Sequence[int] | str = 0,
3877+
dilation: int = 1,
3878+
groups: int = 1,
3879+
*,
3880+
dtype,
3881+
) -> TensorProxy:
3882+
from thunder.torch import _conv_helper
3883+
3884+
return _conv_helper(
3885+
1, a, weight, bias, stride, padding, dilation, groups, conv_function=_convolution_autocast_impl, dtype=dtype
3886+
)
3887+
3888+
3889+
@register_autocast_rule("torch.nn.functional.conv2d")
3890+
def autocast_ltorch_conv2d_rule(
3891+
a: TensorProxy,
3892+
/,
3893+
weight: TensorProxy,
3894+
bias: TensorProxy | None = None,
3895+
stride: int | Sequence[int] = 1,
3896+
padding: int | Sequence[int] | str = 0,
3897+
dilation: int = 1,
3898+
groups: int = 1,
3899+
*,
3900+
dtype,
3901+
) -> TensorProxy:
3902+
from thunder.torch import _conv_helper
3903+
3904+
return _conv_helper(
3905+
2, a, weight, bias, stride, padding, dilation, groups, conv_function=_convolution_autocast_impl, dtype=dtype
3906+
)
3907+
3908+
3909+
@register_autocast_rule("torch.nn.functional.conv3d")
3910+
def autocast_ltorch_conv3d_rule(
3911+
a: TensorProxy,
3912+
/,
3913+
weight: TensorProxy,
3914+
bias: TensorProxy | None = None,
3915+
stride: int | Sequence[int] = 1,
3916+
padding: int | Sequence[int] | str = 0,
3917+
dilation: int = 1,
3918+
groups: int = 1,
3919+
*,
3920+
dtype,
3921+
) -> TensorProxy:
3922+
from thunder.torch import _conv_helper
3923+
3924+
return _conv_helper(
3925+
3, a, weight, bias, stride, padding, dilation, groups, conv_function=_convolution_autocast_impl, dtype=dtype
3926+
)
3927+
3928+
38593929
@register_autocast_rule("torch.nn.functional.scaled_dot_product_attention")
38603930
def autocast_scaled_dot_product_attention(
38613931
query,

thunder/tests/test_autocast.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,29 @@ def foo(x, w):
169169
jit_out = jfoo(x, w)
170170

171171
torch.testing.assert_close(eager_out, jit_out)
172+
173+
174+
@pytest.mark.parametrize("dim", [1, 2, 3])
175+
def test_autocast_convolution(dim):
176+
conv_fn = getattr(torch.nn.functional, f"conv{dim}d")
177+
178+
def foo(x, w, b=None):
179+
return conv_fn(x, w, b)
180+
181+
x = torch.randn(1, 2, *(dim * (8,)))
182+
w = torch.randn(3, 2, *(dim * (4,)))
183+
b = torch.randn(3)
184+
185+
jfoo = thunder.jit(foo)
186+
187+
with torch.autocast("cpu", torch.bfloat16):
188+
eager_out = foo(x, w, b)
189+
jit_out = jfoo(x, w, b)
190+
191+
torch.testing.assert_close(eager_out, jit_out)
192+
193+
with torch.autocast("cpu", torch.bfloat16):
194+
eager_out = foo(x, w)
195+
jit_out = jfoo(x, w)
196+
197+
torch.testing.assert_close(eager_out, jit_out)

thunder/torch/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3816,6 +3816,9 @@ def _conv_helper(
38163816
padding: int | Sequence[int] | str = 0,
38173817
dilation: int | Sequence[int] = 1,
38183818
groups: int = 1,
3819+
*,
3820+
conv_function=clang.convolution,
3821+
**extra_kwargs,
38193822
) -> TensorProxy:
38203823
# a, weight rank check
38213824
utils.check(dim + 1 <= a.ndim <= dim + 2, lambda: f"{a.ndim=} should be either {dim + 1} or {dim + 2}")
@@ -3875,8 +3878,17 @@ def pad_lo_hi_dilation_seq():
38753878
padding, a = process_padding_str(int_to_seq(padding), stride, dilation, a)
38763879
# }
38773880

3878-
res = clang.convolution(
3879-
a, weight, bias, stride, padding, dilation, False, (0,) * dim, groups # transposed # output_padding
3881+
res = conv_function(
3882+
a,
3883+
weight,
3884+
bias,
3885+
stride,
3886+
padding,
3887+
dilation,
3888+
False, # transposed
3889+
(0,) * dim, # output_padding
3890+
groups,
3891+
**extra_kwargs,
38803892
)
38813893
return res
38823894

0 commit comments

Comments
 (0)