Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdpa: support attn_mask.requires_grad, support expanded number of heads in attn_mask #1563

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions thunder/executors/sdpaex.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,12 @@ def _scaled_dot_product_efficient_attention_backward_impl(
if attn_mask is None:
grad_input_mask.append(False)
else:
grad_input_mask.append(attn_mask.requires_grad)
# Cannot rely on the requires_grad in the meta function,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requires_grad of intermediate TensorProxies is ignored in our automatic differentiation code because we haven't done the work of properly threading this property through all computations.
We should remove the ability to query .requires_grad from intermediate TensorProxies completely to avoid similar bugs in the future. This can be achieved by introducing a separate "InputTensorProxy" which has this attribute and removing it from the regular TensorProxy.

# so here the gradient of attn_mask is always calculated
grad_input_mask.append(True)

# Reference: https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/native/transformers/cuda/attention_backward.cu#L394-L415
return torch.ops.aten._scaled_dot_product_efficient_attention_backward(
grad_q, grad_k, grad_v, grad_attn_mask = torch.ops.aten._scaled_dot_product_efficient_attention_backward(
grad_out,
_sdpa_enforce_input_tensor_contiguity(query),
_sdpa_enforce_input_tensor_contiguity(key),
Expand All @@ -319,6 +321,11 @@ def _scaled_dot_product_efficient_attention_backward_impl(
is_causal,
scale=scale,
)
if not utils.same_shape(grad_attn_mask.shape, attn_mask.shape):
# Needs to sum over the number of heads dimension in grad_attn_mask
# if the number of heads in attention mask is expanded in _attention_mask_memory_efficient_helper.
grad_attn_mask = torch.sum(grad_attn_mask, dim=1, keepdim=True)
return grad_q, grad_k, grad_v, grad_attn_mask


sdpea_bwd = sdpa_ex.register_operator(
Expand Down
41 changes: 41 additions & 0 deletions thunder/tests/test_sdpaex_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,44 @@ def fn(*args, **kwargs):

if result is not None:
return result


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=("f32", "f16", "bf16"))
@pytest.mark.parametrize("device,", ["cuda"])
@pytest.mark.parametrize("attn_mask_requires_grad", [True, False])
@requiresCUDA
def test_sdpa_attn_mask(attn_mask_requires_grad, device: str, dtype: torch.dtype):
# Enable math and memory-efficient sdpa options for Volta and prior devices
torch_device = torch.device(device)
if not device_version_support(torch_device, CudaVersion(8, 0), CudaVersion(9, 0)):
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)

def func(q, k, v, atten_mask):
tmp = atten_mask * atten_mask
return torch.nn.functional.scaled_dot_product_attention(q, k, v, tmp)

query = torch.randn(1, 28, 128, 128, dtype=dtype, device=device, requires_grad=True)
key = torch.randn(1, 28, 128, 128, dtype=dtype, device=device, requires_grad=True)
value = torch.randn(1, 28, 128, 128, dtype=dtype, device=device, requires_grad=True)
attn_mask = torch.randn(1, 1, 128, 128, dtype=dtype, device=device, requires_grad=attn_mask_requires_grad)

query1 = query.detach().clone().requires_grad_()
key1 = key.detach().clone().requires_grad_()
value1 = value.detach().clone().requires_grad_()
attn_mask1 = attn_mask.detach().clone().requires_grad_(attn_mask_requires_grad)

expected = func(query, key, value, attn_mask)
output = expected.mean()
output.backward()

jfun = thunder.jit(func)
actual = jfun(query1, key1, value1, attn_mask1)
output = actual.mean()
output.backward()

torch.testing.assert_close(actual, expected)
torch.testing.assert_close(attn_mask1.grad, attn_mask.grad)
torch.testing.assert_close(query.grad, query1.grad)
torch.testing.assert_close(key.grad, key1.grad)
torch.testing.assert_close(value.grad, value1.grad)
Loading