Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdpaex : fix backward when attn_mask is not provided #1578

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion thunder/executors/sdpaex.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,12 @@ def _scaled_dot_product_efficient_attention_backward_impl(
is_causal,
scale=scale,
)
if not utils.same_shape(grad_attn_mask.shape, attn_mask.shape):

if attn_mask is not None and not utils.same_shape(grad_attn_mask.shape, attn_mask.shape):
mruberry marked this conversation as resolved.
Show resolved Hide resolved
# Needs to sum over the number of heads dimension in grad_attn_mask
# if the number of heads in attention mask is expanded in _attention_mask_memory_efficient_helper.
grad_attn_mask = torch.sum(grad_attn_mask, dim=1, keepdim=True)

return grad_q, grad_k, grad_v, grad_attn_mask


Expand Down
15 changes: 11 additions & 4 deletions thunder/tests/test_sdpaex_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,17 @@ def device_version_support(
"dtype", [torch.float32, torch.bfloat16, torch.float16], ids=("float32", "bfloat16", "float16")
)
@pytest.mark.parametrize("device,", ["cuda"])
@pytest.mark.parametrize("requires_grad", [True, False])
@requiresCUDA
def test_sdpa(device: str, dtype: torch.dtype):
def test_sdpa(device: str, dtype: torch.dtype, requires_grad: bool):
batch = 10
seq_len = 128
num_heads = 4
dim_per_head = 32

query = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda")
key = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda")
value = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda")
query = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", requires_grad=requires_grad)
key = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", requires_grad=requires_grad)
value = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", requires_grad=requires_grad)

def fn(query, key, value):
return torch.nn.functional.scaled_dot_product_attention(query, key, value)
Expand All @@ -63,6 +64,12 @@ def fn(query, key, value):
bsym.sym.name == "sdpaex_grad_forward_scaled_dot_product_efficient_attention" for bsym in extrace.bound_symbols
)

if requires_grad:
grad_output = torch.rand_like(thunder_result)
actual_grads = torch.autograd.grad(thunder_result, (query, key, value), grad_outputs=grad_output)
expected_grads = torch.autograd.grad(torch_result, (query, key, value), grad_outputs=grad_output)
torch.testing.assert_close(actual_grads, expected_grads)


@requiresCUDA
def test_sdpa_autocast_flash():
Expand Down
Loading