You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
FlashAttention saves memory access to softmax scores, which is critical for training or prefilling stage in inference, but not necessarily for decoding.
Proposal
Consider the case we don't fuse the two matrix multiplication, but use two kernels instead:
The first kernel computes the first GEMM $Q\cdot K$ and the first stage (compute $M, D$) of online softmax (check this note).
The second kernel computes the second stage of online softmax (divide $P$ by $D$) and computes the second GEMM $P\cdot V$.
Analysis
The I/O of each matrix in unfused kernel are:
$$ size(QO) = 2L_q\cdot H_{qo}\cdot D $$$$ size(KV) = 2L_{kv}\cdot H_{kv}\cdot D $$$$ size(softmax) = L_q\cdot L_{kv} \cdot H_{qo} $$
The ratio of $softmax$ (we double the I/O of softmax because we have to write it in the first kernel and read it in the second kernel) in the entire I/O is:
For the decoding case, $L_q$ is $1$ so that ratio of softmax I/O can't exceed $1/(1+\frac{H_{kv}}{H_{qo}}D)$ and $D$ is usually a large value such as $128$, $\frac{H_{kv}}{H_{qo}}$ is the reciprocal of GQA group ratio (usually 4/8), so flashattention-style fusion is actually not very useful here, saving 1/16 I/O at most.
Benefits
But what are the benefits of un-fused kernels?
Less register pressure and we can use larger tiles for both stages.
We can use hardware instructions for in-place reduction. More explicitly, there are PTX instructions such as red.global (for early GPUs such as ampere) and cp.reduce.async (for Hopper) for inplace reduction: dst @ global = RedOp(dst @ global, src @ smem), the cp.reduce.async can leverage TMA and thus further reduce the epilogue time.
NOTE: we can also leverage cp.reduce.async for FlashAttention-style attention aggregation, but we need to know a prior rowmax $M$ beforehead (such as Flash-Decoding++) so that we don't need to multiply dst data on global memory with a renormalization term (there is no NVGPU instruction that supports dst @ global = RedOp(alpha * (dst @ global), src @ smem)).
Required Changes
What do we need to support in FlashInfer:
Implement gather-GEMM for both kernels.
Use RaggedTensor for softmax storage with proper padding so that we can easily read/write them in both kernels.
Concerns
Maintenance overhead in the future
If future models keep increasing GQA group ratio, the softmax I/O ratio will gradually dominate and it doesn't make sense not to fuse them.
I’m not convinced it’s worth implementing unfused softmax attention kernels, so I opened this issue to spark a discussion. Any feedback is appreciated!
cc @tqchen@JamesLim-sy@merrymercy@spectrometerHBH
The text was updated successfully, but these errors were encountered:
Having $L_q$ to be 8~32 seems more common in serving settings.
Regardless, in the $L_q=1$ case,
the QK kernel is mem-bound anyway.
I feel the PV kernel mainly benefits from the shift from mem-bound to compute-bound. $L_q << L_{kv}$ so the register pressure and cache of Q plays marginal effect here imo.
I asked ChatGPT the arithmetic intensity of fused attention, it says its $\frac{L_q L_{kv}}{L_q + L_{kv}}$. I suppose it to be true. So when $L_q=1$, it is less than 1 and hence mem-bound in this case.
Having 1 mem-bound kernel to be split into 1 mem-bound kernel + 1 compute-bound kernel with marginal total IO increase makes sense to me.
Motivation
FlashAttention saves memory access to softmax scores, which is critical for training or prefilling stage in inference, but not necessarily for decoding.
Proposal
Consider the case we don't fuse the two matrix multiplication, but use two kernels instead:
Analysis
The I/O of each matrix in unfused kernel are:
The ratio of$softmax$ (we double the I/O of softmax because we have to write it in the first kernel and read it in the second kernel) in the entire I/O is:
For the decoding case,$L_q$ is $1$ so that ratio of softmax I/O can't exceed $1/(1+\frac{H_{kv}}{H_{qo}}D)$ and $D$ is usually a large value such as $128$ , $\frac{H_{kv}}{H_{qo}}$ is the reciprocal of GQA group ratio (usually 4/8), so flashattention-style fusion is actually not very useful here, saving 1/16 I/O at most.
Benefits
But what are the benefits of un-fused kernels?
red.global
(for early GPUs such as ampere) andcp.reduce.async
(for Hopper) for inplace reduction:dst @ global = RedOp(dst @ global, src @ smem)
, thecp.reduce.async
can leverage TMA and thus further reduce the epilogue time.cp.reduce.async
for FlashAttention-style attention aggregation, but we need to know a prior rowmaxdst @ global = RedOp(alpha * (dst @ global), src @ smem)
).Required Changes
What do we need to support in FlashInfer:
Concerns
I’m not convinced it’s worth implementing unfused softmax attention kernels, so I opened this issue to spark a discussion. Any feedback is appreciated!
cc @tqchen @JamesLim-sy @merrymercy @spectrometerHBH
The text was updated successfully, but these errors were encountered: