Skip to content

Commit

Permalink
add autocast for conv1/2/3d (#797)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jul 18, 2024
1 parent f160111 commit 10a4efb
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 2 deletions.
70 changes: 70 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3856,6 +3856,76 @@ def autocast_linear_rule(a, w, bias, dtype):
return _linear_autocast_impl(a, w, bias, dtype)


def _convolution_autocast_impl(a, w, bias, *other_args, dtype):
if bias is None:
# Don't pass `bias` to maybe_downcast_to.
downcast_args = maybe_downcast_to(dtype, (a, w)) + (bias,)
else:
downcast_args = maybe_downcast_to(dtype, (a, w, bias))

return prims.convolution(*downcast_args, *other_args)


@register_autocast_rule("torch.nn.functional.conv1d")
def autocast_ltorch_conv1d_rule(
a: TensorProxy,
/,
weight: TensorProxy,
bias: TensorProxy | None = None,
stride: int | Sequence[int] = 1,
padding: int | Sequence[int] | str = 0,
dilation: int = 1,
groups: int = 1,
*,
dtype,
) -> TensorProxy:
from thunder.torch import _conv_helper

return _conv_helper(
1, a, weight, bias, stride, padding, dilation, groups, conv_function=_convolution_autocast_impl, dtype=dtype
)


@register_autocast_rule("torch.nn.functional.conv2d")
def autocast_ltorch_conv2d_rule(
a: TensorProxy,
/,
weight: TensorProxy,
bias: TensorProxy | None = None,
stride: int | Sequence[int] = 1,
padding: int | Sequence[int] | str = 0,
dilation: int = 1,
groups: int = 1,
*,
dtype,
) -> TensorProxy:
from thunder.torch import _conv_helper

return _conv_helper(
2, a, weight, bias, stride, padding, dilation, groups, conv_function=_convolution_autocast_impl, dtype=dtype
)


@register_autocast_rule("torch.nn.functional.conv3d")
def autocast_ltorch_conv3d_rule(
a: TensorProxy,
/,
weight: TensorProxy,
bias: TensorProxy | None = None,
stride: int | Sequence[int] = 1,
padding: int | Sequence[int] | str = 0,
dilation: int = 1,
groups: int = 1,
*,
dtype,
) -> TensorProxy:
from thunder.torch import _conv_helper

return _conv_helper(
3, a, weight, bias, stride, padding, dilation, groups, conv_function=_convolution_autocast_impl, dtype=dtype
)


@register_autocast_rule("torch.nn.functional.scaled_dot_product_attention")
def autocast_scaled_dot_product_attention(
query,
Expand Down
43 changes: 43 additions & 0 deletions thunder/tests/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,46 @@ def foo(x, w):
jit_out = jfoo(x, w)

torch.testing.assert_close(eager_out, jit_out)


@pytest.mark.parametrize("dim", [1, 2, 3])
@pytest.mark.parametrize("requires_grad", [False, True])
def test_autocast_convolution(dim, requires_grad):
conv_fn = getattr(torch.nn.functional, f"conv{dim}d")

def foo(x, w, b=None):
return conv_fn(x, w, b)

x = torch.rand(1, 2, *(dim * (8,)), requires_grad=requires_grad)
w = torch.rand(3, 2, *(dim * (4,)), requires_grad=requires_grad)
b = torch.rand(3, requires_grad=requires_grad)
go = torch.rand(1, 3, *(dim * (5,)))

jfoo = thunder.jit(foo)

with torch.autocast("cpu", torch.float16):
eager_out = foo(x, w, b)
jit_out = jfoo(x, w, b)

torch.testing.assert_close(eager_out, jit_out)

if requires_grad:
eager_grads = torch.autograd.grad(eager_out, [x, w, b], go)
jit_grads = torch.autograd.grad(jit_out, [x, w, b], go)

for eg, jg in zip(eager_grads, jit_grads):
torch.testing.assert_close(eg, jg, rtol=1e-2, atol=1e-2)

with torch.autocast("cpu", torch.float16):
eager_out = foo(x, w)
jit_out = jfoo(x, w)

torch.testing.assert_close(eager_out, jit_out)

if requires_grad:
go = torch.randn_like(eager_out)
eager_grads = torch.autograd.grad(eager_out, [x, w], go)
jit_grads = torch.autograd.grad(jit_out, [x, w], go)

for eg, jg in zip(eager_grads, jit_grads):
torch.testing.assert_close(eg, jg, rtol=1e-2, atol=1e-2)
16 changes: 14 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3816,6 +3816,9 @@ def _conv_helper(
padding: int | Sequence[int] | str = 0,
dilation: int | Sequence[int] = 1,
groups: int = 1,
*,
conv_function=clang.convolution,
**extra_kwargs,
) -> TensorProxy:
# a, weight rank check
utils.check(dim + 1 <= a.ndim <= dim + 2, lambda: f"{a.ndim=} should be either {dim + 1} or {dim + 2}")
Expand Down Expand Up @@ -3875,8 +3878,17 @@ def pad_lo_hi_dilation_seq():
padding, a = process_padding_str(int_to_seq(padding), stride, dilation, a)
# }

res = clang.convolution(
a, weight, bias, stride, padding, dilation, False, (0,) * dim, groups # transposed # output_padding
res = conv_function(
a,
weight,
bias,
stride,
padding,
dilation,
False, # transposed
(0,) * dim, # output_padding
groups,
**extra_kwargs,
)
return res

Expand Down

0 comments on commit 10a4efb

Please sign in to comment.