Skip to content

Commit

Permalink
sdpa: support attn_mask.requires_grad, support expanded number of hea…
Browse files Browse the repository at this point in the history
…ds in attn_mask (#1482)
  • Loading branch information
kiya00 committed Dec 17, 2024
1 parent 9ff73be commit e9d525e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
11 changes: 9 additions & 2 deletions thunder/executors/sdpaex.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,12 @@ def _scaled_dot_product_efficient_attention_backward_impl(
if attn_mask is None:
grad_input_mask.append(False)
else:
grad_input_mask.append(attn_mask.requires_grad)
# Cannot rely on the requires_grad in the meta function,
# so here the gradient of attn_mask is always calculated
grad_input_mask.append(True)

# Reference: https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/native/transformers/cuda/attention_backward.cu#L394-L415
return torch.ops.aten._scaled_dot_product_efficient_attention_backward(
grad_q, grad_k, grad_v, grad_attn_mask = torch.ops.aten._scaled_dot_product_efficient_attention_backward(
grad_out,
_sdpa_enforce_input_tensor_contiguity(query),
_sdpa_enforce_input_tensor_contiguity(key),
Expand All @@ -319,6 +321,11 @@ def _scaled_dot_product_efficient_attention_backward_impl(
is_causal,
scale=scale,
)
if not utils.same_shape(grad_attn_mask.shape, attn_mask.shape):
# Needs to sum over the number of heads dimension in grad_attn_mask
# if the number of heads in attention mask is expanded in _attention_mask_memory_efficient_helper.
grad_attn_mask = torch.sum(grad_attn_mask, dim=1, keepdim=True)
return grad_q, grad_k, grad_v, grad_attn_mask


sdpea_bwd = sdpa_ex.register_operator(
Expand Down
41 changes: 41 additions & 0 deletions thunder/tests/test_sdpaex_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,44 @@ def fn(*args, **kwargs):

if result is not None:
return result


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=("f32", "f16", "bf16"))
@pytest.mark.parametrize("device,", ["cuda"])
@pytest.mark.parametrize("attn_mask_requires_grad", [True, False])
@requiresCUDA
def test_sdpa_attn_mask(attn_mask_requires_grad, device: str, dtype: torch.dtype):
# Enable math and memory-efficient sdpa options for Volta and prior devices
torch_device = torch.device(device)
if not device_version_support(torch_device, CudaVersion(8, 0), CudaVersion(9, 0)):
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)

def func(q, k, v, atten_mask):
tmp = atten_mask * atten_mask
return torch.nn.functional.scaled_dot_product_attention(q, k, v, tmp)

query = torch.randn(1, 28, 128, 128, dtype=dtype, device=device, requires_grad=True)
key = torch.randn(1, 28, 128, 128, dtype=dtype, device=device, requires_grad=True)
value = torch.randn(1, 28, 128, 128, dtype=dtype, device=device, requires_grad=True)
attn_mask = torch.randn(1, 1, 128, 128, dtype=dtype, device=device, requires_grad=attn_mask_requires_grad)

query1 = query.detach().clone().requires_grad_()
key1 = key.detach().clone().requires_grad_()
value1 = value.detach().clone().requires_grad_()
attn_mask1 = attn_mask.detach().clone().requires_grad_(attn_mask_requires_grad)

expected = func(query, key, value, attn_mask)
output = expected.mean()
output.backward()

jfun = thunder.jit(func)
actual = jfun(query1, key1, value1, attn_mask1)
output = actual.mean()
output.backward()

torch.testing.assert_close(actual, expected)
torch.testing.assert_close(attn_mask1.grad, attn_mask.grad)
torch.testing.assert_close(query.grad, query1.grad)
torch.testing.assert_close(key.grad, key1.grad)
torch.testing.assert_close(value.grad, value1.grad)

0 comments on commit e9d525e

Please sign in to comment.