Skip to content

Utilize torch.nn.functional.scaled_dot_product_attention for more performance #580

@MarcelLieb

Description

@MarcelLieb

nn.functional.scaled_dot_product_attention is a very efficient implementation of attention.
It is way faster and a lot more memory efficient than using the naive implementation and shouldn't require any new dependencies or any changes outside the module.

From the documentation:

There are currently three supported implementations of scaled dot product attention:

The function may call optimized kernels for improved performance when using the CUDA backend. For all other backends, the PyTorch implementation will be used.

The following snippets:

q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)

attn = (
F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
# torch.log(torch.tensor(1. / 0.01)) = 4.6052
logit_scale = torch.clamp(self.scale, max=4.6052).exp()
attn = attn * logit_scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)

could be simplified to:

x = F.scaled_dot_product_attention(
    q,
    k,
    v, 
    scale=self.scale, 
    dropout_p=(self.attn_drop.p if self.training else 0.0)
).transpose(1, 2).reshape(B, N, -1) 

and

x = F.scaled_dot_product_attention(
    F.normalize(q, dim=-1),
    F.normalize(k, dim=-1),
    v, 
    scale=torch.clamp(self.scale, max=4.6052).exp(), 
    dropout_p=(self.attn_drop.p if self.training else 0.0)
).transpose(1, 2).reshape(B, N, -1) 

(Be aware, I didn't test the proposed snippets)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions