Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdpaex : fix backward when attn_mask is not provided #1578

Merged
merged 1 commit into from
Dec 20, 2024

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Dec 20, 2024

Repro

import torch
import thunder
from thunder.executors.sdpaex import sdpa_ex

batch = 10
seq_len = 128
num_heads = 4
dim_per_head = 32

requires_grad=True
query = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", requires_grad=requires_grad)
key = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", requires_grad=requires_grad)
value = torch.randn([batch, seq_len, num_heads, dim_per_head], device="cuda", requires_grad=requires_grad)

def fn(query, key, value):
    return torch.nn.functional.scaled_dot_product_attention(query, key, value)

cfn = thunder.jit(fn, executors=[sdpa_ex])

thunder_result = cfn(query, key, value)

grad_output = torch.rand_like(thunder_result)
grads = torch.autograd.grad(thunder_result, (query, key, value), grad_outputs=grad_output)

Error

File "thunder.backward_fn_2", line 19, in backward_fn
  File "/home/kkalambarkar/lightning-thunder/thunder/executors/sdpaex.py", line 325, in _scaled_dot_product_efficient_attention_backward_impl
    if not utils.same_shape(grad_attn_mask.shape, attn_mask.shape):
AttributeError: 'NoneType' object has no attribute 'shape'

Problem -

if not utils.same_shape(grad_attn_mask.shape, attn_mask.shape):
# Needs to sum over the number of heads dimension in grad_attn_mask
# if the number of heads in attention mask is expanded in _attention_mask_memory_efficient_helper.
grad_attn_mask = torch.sum(grad_attn_mask, dim=1, keepdim=True)

It is possible that here grad_attn_mask and attn_mask are both None, so calling .shape leads to error.

@kshitij12345 kshitij12345 requested a review from kiya00 December 20, 2024 14:19
@kshitij12345 kshitij12345 marked this pull request as ready for review December 20, 2024 14:56
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry mruberry enabled auto-merge (squash) December 20, 2024 16:57
@mruberry mruberry merged commit c392c35 into Lightning-AI:main Dec 20, 2024
44 checks passed
@kshitij12345 kshitij12345 deleted the fix-sdpaex-no-attn-case branch December 20, 2024 17:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants