Skip to content

Commit

Permalink
sdpaex : fix backward when attn_mask is not provided
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Dec 20, 2024
1 parent 35ca2e9 commit b3b1d66
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
4 changes: 3 additions & 1 deletion thunder/executors/sdpaex.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,12 @@ def _scaled_dot_product_efficient_attention_backward_impl(
is_causal,
scale=scale,
)
if not utils.same_shape(grad_attn_mask.shape, attn_mask.shape):

if attn_mask is not None and not utils.same_shape(grad_attn_mask.shape, attn_mask.shape):
# Needs to sum over the number of heads dimension in grad_attn_mask
# if the number of heads in attention mask is expanded in _attention_mask_memory_efficient_helper.
grad_attn_mask = torch.sum(grad_attn_mask, dim=1, keepdim=True)

return grad_q, grad_k, grad_v, grad_attn_mask


Expand Down
15 changes: 11 additions & 4 deletions thunder/tests/test_sdpaex_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,17 @@ def device_version_support(
"dtype", [torch.float32, torch.bfloat16, torch.float16], ids=("float32", "bfloat16", "float16")
)
@pytest.mark.parametrize("device,", ["cuda"])
@pytest.mark.parametrize("requires_grad", [True, False])
@requiresCUDA
def test_sdpa(device: str, dtype: torch.dtype):
def test_sdpa(device: str, dtype: torch.dtype, requires_grad: bool):
batch = 10
seq_len = 128
num_heads = 4
dim_per_head = 32

query = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda")
key = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda")
value = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda")
query = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", requires_grad=requires_grad)
key = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", requires_grad=requires_grad)
value = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", requires_grad=requires_grad)

def fn(query, key, value):
return torch.nn.functional.scaled_dot_product_attention(query, key, value)
Expand All @@ -63,6 +64,12 @@ def fn(query, key, value):
bsym.sym.name == "sdpaex_grad_forward_scaled_dot_product_efficient_attention" for bsym in extrace.bound_symbols
)

if requires_grad:
grad_output = torch.rand_like(thunder_result)
actual_grads = torch.autograd.grad(thunder_result, (query, key, value), grad_outputs=grad_output)
expected_grads = torch.autograd.grad(torch_result, (query, key, value), grad_outputs=grad_output)
torch.testing.assert_close(actual_grads, expected_grads)


@requiresCUDA
def test_sdpa_autocast_flash():
Expand Down

0 comments on commit b3b1d66

Please sign in to comment.