-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FlexAttention example (for mmu_vit
mask)
#8
Comments
Great! Thanks for your efforts! We will try it in these days:) |
Hi! It seems that the length and location of image sequence is fixed in the example. Is it possible to allow the images to be at any place in the overall sequence with any resolution? |
@wusize Good question! In this case, the mask I copied has a fixed prefix. But, one cool aspect of FlexAttention is that it can access "captured" tensors. For example, if you have an image of an arbitrary size, you could write
Supporting an arbitrary amount of images is a bit more nontrivial, since it's harder to "know" if you're within the range of any of the images. However, with a bit of precomputation, this is also fairly straightforward. Basically, for every query token, we precompute the "beginning" and "end" of any image it might belong to. So, for example, if a text token is at position 568, then
In my benchmarking, this is about 6x faster (doing |
@Chillee Hi, Horace. Thank you for the suggestions. I have already implemented the attention mask required in Show-o using flexattention. It will be updated in our repository soon. |
We recently released FlexAttention, which automatically generates fused flashattention kernels for a diverse range of attention variants.
For example, the
mmu_vit
mask can be implemented like soAnd if you benchmark it, we see that FlexAttention is 9x faster than passing a mask to
F.scaled_dot_product_attention
(which uses xformers attention).I believe the other masks can also be implemented (somewhat) straightforwardly.
Code + Benchmark
The text was updated successfully, but these errors were encountered: