From e9d525ed1fc9d2b0cc4d542ef951635c0df6edad Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Tue, 17 Dec 2024 13:13:13 +0100 Subject: [PATCH] sdpa: support attn_mask.requires_grad, support expanded number of heads in attn_mask (#1482) --- thunder/executors/sdpaex.py | 11 +++++-- thunder/tests/test_sdpaex_executor.py | 41 +++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py index b912acaed9..e69afcb05c 100644 --- a/thunder/executors/sdpaex.py +++ b/thunder/executors/sdpaex.py @@ -301,10 +301,12 @@ def _scaled_dot_product_efficient_attention_backward_impl( if attn_mask is None: grad_input_mask.append(False) else: - grad_input_mask.append(attn_mask.requires_grad) + # Cannot rely on the requires_grad in the meta function, + # so here the gradient of attn_mask is always calculated + grad_input_mask.append(True) # Reference: https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/native/transformers/cuda/attention_backward.cu#L394-L415 - return torch.ops.aten._scaled_dot_product_efficient_attention_backward( + grad_q, grad_k, grad_v, grad_attn_mask = torch.ops.aten._scaled_dot_product_efficient_attention_backward( grad_out, _sdpa_enforce_input_tensor_contiguity(query), _sdpa_enforce_input_tensor_contiguity(key), @@ -319,6 +321,11 @@ def _scaled_dot_product_efficient_attention_backward_impl( is_causal, scale=scale, ) + if not utils.same_shape(grad_attn_mask.shape, attn_mask.shape): + # Needs to sum over the number of heads dimension in grad_attn_mask + # if the number of heads in attention mask is expanded in _attention_mask_memory_efficient_helper. + grad_attn_mask = torch.sum(grad_attn_mask, dim=1, keepdim=True) + return grad_q, grad_k, grad_v, grad_attn_mask sdpea_bwd = sdpa_ex.register_operator( diff --git a/thunder/tests/test_sdpaex_executor.py b/thunder/tests/test_sdpaex_executor.py index 1d940dd7be..6052835ba6 100644 --- a/thunder/tests/test_sdpaex_executor.py +++ b/thunder/tests/test_sdpaex_executor.py @@ -163,3 +163,44 @@ def fn(*args, **kwargs): if result is not None: return result + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=("f32", "f16", "bf16")) +@pytest.mark.parametrize("device,", ["cuda"]) +@pytest.mark.parametrize("attn_mask_requires_grad", [True, False]) +@requiresCUDA +def test_sdpa_attn_mask(attn_mask_requires_grad, device: str, dtype: torch.dtype): + # Enable math and memory-efficient sdpa options for Volta and prior devices + torch_device = torch.device(device) + if not device_version_support(torch_device, CudaVersion(8, 0), CudaVersion(9, 0)): + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.backends.cuda.enable_math_sdp(True) + + def func(q, k, v, atten_mask): + tmp = atten_mask * atten_mask + return torch.nn.functional.scaled_dot_product_attention(q, k, v, tmp) + + query = torch.randn(1, 28, 128, 128, dtype=dtype, device=device, requires_grad=True) + key = torch.randn(1, 28, 128, 128, dtype=dtype, device=device, requires_grad=True) + value = torch.randn(1, 28, 128, 128, dtype=dtype, device=device, requires_grad=True) + attn_mask = torch.randn(1, 1, 128, 128, dtype=dtype, device=device, requires_grad=attn_mask_requires_grad) + + query1 = query.detach().clone().requires_grad_() + key1 = key.detach().clone().requires_grad_() + value1 = value.detach().clone().requires_grad_() + attn_mask1 = attn_mask.detach().clone().requires_grad_(attn_mask_requires_grad) + + expected = func(query, key, value, attn_mask) + output = expected.mean() + output.backward() + + jfun = thunder.jit(func) + actual = jfun(query1, key1, value1, attn_mask1) + output = actual.mean() + output.backward() + + torch.testing.assert_close(actual, expected) + torch.testing.assert_close(attn_mask1.grad, attn_mask.grad) + torch.testing.assert_close(query.grad, query1.grad) + torch.testing.assert_close(key.grad, key1.grad) + torch.testing.assert_close(value.grad, value1.grad)