diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py index e69afcb05c..3fc6eab9e9 100644 --- a/thunder/executors/sdpaex.py +++ b/thunder/executors/sdpaex.py @@ -321,10 +321,12 @@ def _scaled_dot_product_efficient_attention_backward_impl( is_causal, scale=scale, ) - if not utils.same_shape(grad_attn_mask.shape, attn_mask.shape): + + if attn_mask is not None and not utils.same_shape(grad_attn_mask.shape, attn_mask.shape): # Needs to sum over the number of heads dimension in grad_attn_mask # if the number of heads in attention mask is expanded in _attention_mask_memory_efficient_helper. grad_attn_mask = torch.sum(grad_attn_mask, dim=1, keepdim=True) + return grad_q, grad_k, grad_v, grad_attn_mask diff --git a/thunder/tests/test_sdpaex_executor.py b/thunder/tests/test_sdpaex_executor.py index 6052835ba6..b0a4a7f6f8 100644 --- a/thunder/tests/test_sdpaex_executor.py +++ b/thunder/tests/test_sdpaex_executor.py @@ -35,16 +35,17 @@ def device_version_support( "dtype", [torch.float32, torch.bfloat16, torch.float16], ids=("float32", "bfloat16", "float16") ) @pytest.mark.parametrize("device,", ["cuda"]) +@pytest.mark.parametrize("requires_grad", [True, False]) @requiresCUDA -def test_sdpa(device: str, dtype: torch.dtype): +def test_sdpa(device: str, dtype: torch.dtype, requires_grad: bool): batch = 10 seq_len = 128 num_heads = 4 dim_per_head = 32 - query = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda") - key = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda") - value = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda") + query = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", requires_grad=requires_grad) + key = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", requires_grad=requires_grad) + value = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", requires_grad=requires_grad) def fn(query, key, value): return torch.nn.functional.scaled_dot_product_attention(query, key, value) @@ -63,6 +64,12 @@ def fn(query, key, value): bsym.sym.name == "sdpaex_grad_forward_scaled_dot_product_efficient_attention" for bsym in extrace.bound_symbols ) + if requires_grad: + grad_output = torch.rand_like(thunder_result) + actual_grads = torch.autograd.grad(thunder_result, (query, key, value), grad_outputs=grad_output) + expected_grads = torch.autograd.grad(torch_result, (query, key, value), grad_outputs=grad_output) + torch.testing.assert_close(actual_grads, expected_grads) + @requiresCUDA def test_sdpa_autocast_flash():