-
Notifications
You must be signed in to change notification settings - Fork 43
Description
nn.functional.scaled_dot_product_attention is a very efficient implementation of attention.
It is way faster and a lot more memory efficient than using the naive implementation and shouldn't require any new dependencies or any changes outside the module.
From the documentation:
There are currently three supported implementations of scaled dot product attention:
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- Memory-Efficient Attention
- A PyTorch implementation defined in C++ matching the above formulation
The function may call optimized kernels for improved performance when using the CUDA backend. For all other backends, the PyTorch implementation will be used.
The following snippets:
waldo/deepcheat/VideoMAEv2/models/modeling_finetune.py
Lines 185 to 191 in 30dcb63
| q = q * self.scale | |
| attn = (q @ k.transpose(-2, -1)) | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, -1) |
waldo/deepcheat/VideoMAEv2/models/modeling_finetune.py
Lines 124 to 135 in 30dcb63
| attn = ( | |
| F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) | |
| # torch.log(torch.tensor(1. / 0.01)) = 4.6052 | |
| logit_scale = torch.clamp(self.scale, max=4.6052).exp() | |
| attn = attn * logit_scale | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, -1) |
could be simplified to:
x = F.scaled_dot_product_attention(
q,
k,
v,
scale=self.scale,
dropout_p=(self.attn_drop.p if self.training else 0.0)
).transpose(1, 2).reshape(B, N, -1) and
x = F.scaled_dot_product_attention(
F.normalize(q, dim=-1),
F.normalize(k, dim=-1),
v,
scale=torch.clamp(self.scale, max=4.6052).exp(),
dropout_p=(self.attn_drop.p if self.training else 0.0)
).transpose(1, 2).reshape(B, N, -1) (Be aware, I didn't test the proposed snippets)