Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlexAttention example (for mmu_vit mask) #8

Open
Chillee opened this issue Aug 24, 2024 · 4 comments
Open

FlexAttention example (for mmu_vit mask) #8

Chillee opened this issue Aug 24, 2024 · 4 comments

Comments

@Chillee
Copy link

Chillee commented Aug 24, 2024

We recently released FlexAttention, which automatically generates fused flashattention kernels for a diverse range of attention variants.

For example, the mmu_vit mask can be implemented like so

def mmu_vit_mask(b, h, q_idx, kv_idx, system_prompt_len=0):
    causal_mask = (q_idx >= kv_idx)
    bidirectional_mask = (kv_idx <= system_prompt_len + 1 + 576)
    return causal_mask | bidirectional_mask

And if you benchmark it, we see that FlexAttention is 9x faster than passing a mask to F.scaled_dot_product_attention (which uses xformers attention).

image

I believe the other masks can also be implemented (somewhat) straightforwardly.

Code + Benchmark
import torch
torch.set_default_device('cuda')
from triton.testing import do_bench
import torch._inductor.config as config
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, create_mask

flex_attention = torch.compile(flex_attention, dynamic=False)

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import vmap
from triton.testing import do_bench

def create_attention_mask_for_mmu_vit(
        sequence,
        return_inverse_mask=False,
        system_prompt_len=0
):
    N, L = sequence.shape
    causal_mask = torch.tril(torch.ones((N, 1, L, L), dtype=torch.bool)).to(sequence.device)
    index = 1 + system_prompt_len + 1 + 576

    causal_mask[:, :, :, :index] = 1
    if return_inverse_mask:
        inverted_mask = 1.0 - causal_mask.type(torch.int64)
        inverted_mask = inverted_mask.masked_fill(
            inverted_mask.to(torch.bool), torch.iinfo(torch.int64).min
        )
        return inverted_mask.to(dtype=torch.bool)
    else:
        return causal_mask

def mmu_vit_mask(b, h, q_idx, kv_idx, system_prompt_len=0):
    return (q_idx >= kv_idx) | (kv_idx <= system_prompt_len + 1 + 576)

def bench_masks(name, xformer_mask_fn, flex_mask_fn):
    B = 4
    L = 4096
    H = 8
    D = 64
    sequence = torch.randn(B, L)
    mask = xformer_mask_fn(sequence)
    block_mask = create_block_mask(flex_mask_fn, B=None, H=None, Q_LEN=L, KV_LEN=L, _compile=True)
    flex_full_mask = create_mask(flex_mask_fn, B, None, L, L)
    assert torch.allclose(mask, flex_full_mask)

    q, k, v = [torch.randn(B, H, L, D, dtype=torch.float16) for _ in range(3)]
    xformer_attn = lambda: F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
    flex_attn = lambda: flex_attention(q, k, v, block_mask=block_mask)
    print(name)
    print(block_mask)
    print("xformer: ", do_bench(xformer_attn))
    print("flex: ", do_bench(flex_attn))
    assert (xformer_attn() - flex_attn()).abs().max() < 1e-2
    print()

bench_masks("mmu_vit_mask", create_attention_mask_for_mmu_vit, mmu_vit_mask)
@Sierkinhane
Copy link
Collaborator

Great! Thanks for your efforts! We will try it in these days:)

@wusize
Copy link

wusize commented Aug 24, 2024

Hi! It seems that the length and location of image sequence is fixed in the example. Is it possible to allow the images to be at any place in the overall sequence with any resolution?

@Chillee
Copy link
Author

Chillee commented Aug 24, 2024

@wusize Good question! In this case, the mask I copied has a fixed prefix. But, one cool aspect of FlexAttention is that it can access "captured" tensors.

For example, if you have an image of an arbitrary size, you could write

image_size: Tensor[B]
def mmu_vit_dynamic_mask(b, h, q_idx, kv_idx, system_prompt_len=0):
    causal_mask = (q_idx >= kv_idx)
    bidirectional_mask = (kv_idx <= system_prompt_len + 1 + image_size[b])
    return causal_mask | bidirectional_mask

Supporting an arbitrary amount of images is a bit more nontrivial, since it's harder to "know" if you're within the range of any of the images. However, with a bit of precomputation, this is also fairly straightforward. Basically, for every query token, we precompute the "beginning" and "end" of any image it might belong to.

So, for example, if a text token is at position 568, then bidirectional_starts[568] = 568 and bidirectional_ends[568] = 568. However, if there's an image from 512 to 763, then bidirectional_starts[568] = 512 and bidirectional_ends[568] = 763.

from torch.nn.attention.flex_attention import create_block_mask, flex_attention, create_mask

S = 2048
image_begin_ends = [(0, 128), (512, 1024), (1530, 1890)]
bidirectional_starts = torch.arange(S, device='cuda')
bidirectional_ends = torch.arange(S, device='cuda')
for image_begin, image_end in image_begin_ends:
    bidirectional_starts[image_begin:image_end] = image_begin
    bidirectional_ends[image_begin:image_end] = image_end

def images_mask(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    bidirectional_mask = (kv_idx < bidirectional_ends[q_idx]) & (kv_idx > bidirectional_starts[q_idx])
    return causal_mask | bidirectional_mask

block_mask = create_block_mask(images_mask, B=None, H=None, Q_LEN=S, KV_LEN=S, _compile=True)
print(block_mask)

In my benchmarking, this is about 6x faster (doing torch.compile(mode="max-autotune-no-cudagraphs") seems to help perf about 30% in this case), and of course, also doesn't fully materialize the mask.

@Sierkinhane
Copy link
Collaborator

Sierkinhane commented Aug 31, 2024

@Chillee Hi, Horace. Thank you for the suggestions. I have already implemented the attention mask required in Show-o using flexattention. It will be updated in our repository soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants