Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Un-fused softmax for short-query(decode) attention #707

Open
yzh119 opened this issue Dec 30, 2024 · 1 comment
Open

[RFC] Un-fused softmax for short-query(decode) attention #707

yzh119 opened this issue Dec 30, 2024 · 1 comment

Comments

@yzh119
Copy link
Collaborator

yzh119 commented Dec 30, 2024

Motivation

FlashAttention saves memory access to softmax scores, which is critical for training or prefilling stage in inference, but not necessarily for decoding.

Proposal

Consider the case we don't fuse the two matrix multiplication, but use two kernels instead:

  1. The first kernel computes the first GEMM $Q\cdot K$ and the first stage (compute $M, D$) of online softmax (check this note).
  2. The second kernel computes the second stage of online softmax (divide $P$ by $D$) and computes the second GEMM $P\cdot V$.

Analysis

The I/O of each matrix in unfused kernel are:

$$ size(QO) = 2L_q\cdot H_{qo}\cdot D $$ $$ size(KV) = 2L_{kv}\cdot H_{kv}\cdot D $$ $$ size(softmax) = L_q\cdot L_{kv} \cdot H_{qo} $$

The ratio of $softmax$ (we double the I/O of softmax because we have to write it in the first kernel and read it in the second kernel) in the entire I/O is:

$$ \frac{1}{1 + \frac{H_{kv}}{H_{qo}}\frac{D}{L_{qo}} + \frac{D}{L_{kv}}} $$

For the decoding case, $L_q$ is $1$ so that ratio of softmax I/O can't exceed $1/(1+\frac{H_{kv}}{H_{qo}}D)$ and $D$ is usually a large value such as $128$, $\frac{H_{kv}}{H_{qo}}$ is the reciprocal of GQA group ratio (usually 4/8), so flashattention-style fusion is actually not very useful here, saving 1/16 I/O at most.

Benefits

But what are the benefits of un-fused kernels?

  1. Less register pressure and we can use larger tiles for both stages.
  2. We can use hardware instructions for in-place reduction. More explicitly, there are PTX instructions such as red.global (for early GPUs such as ampere) and cp.reduce.async (for Hopper) for inplace reduction: dst @ global = RedOp(dst @ global, src @ smem), the cp.reduce.async can leverage TMA and thus further reduce the epilogue time.
    • NOTE: we can also leverage cp.reduce.async for FlashAttention-style attention aggregation, but we need to know a prior rowmax $M$ beforehead (such as Flash-Decoding++) so that we don't need to multiply dst data on global memory with a renormalization term (there is no NVGPU instruction that supports dst @ global = RedOp(alpha * (dst @ global), src @ smem)).

Required Changes

What do we need to support in FlashInfer:

  1. Implement gather-GEMM for both kernels.
  2. Use RaggedTensor for softmax storage with proper padding so that we can easily read/write them in both kernels.

Concerns

  • Maintenance overhead in the future
  • If future models keep increasing GQA group ratio, the softmax I/O ratio will gradually dominate and it doesn't make sense not to fuse them.

I’m not convinced it’s worth implementing unfused softmax attention kernels, so I opened this issue to spark a discussion. Any feedback is appreciated!
cc @tqchen @JamesLim-sy @merrymercy @spectrometerHBH

@spectrometerHBH
Copy link

spectrometerHBH commented Dec 30, 2024

Having $L_q$ to be 8~32 seems more common in serving settings.

Regardless, in the $L_q=1$ case,

  1. the QK kernel is mem-bound anyway.
  2. I feel the PV kernel mainly benefits from the shift from mem-bound to compute-bound.
    $L_q << L_{kv}$ so the register pressure and cache of Q plays marginal effect here imo.

I asked ChatGPT the arithmetic intensity of fused attention, it says its $\frac{L_q L_{kv}}{L_q + L_{kv}}$. I suppose it to be true. So when $L_q=1$, it is less than 1 and hence mem-bound in this case.

Having 1 mem-bound kernel to be split into 1 mem-bound kernel + 1 compute-bound kernel with marginal total IO increase makes sense to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants