Skip to content

Commit

Permalink
[vllm] Add support for FP8 in Triton FA kernel (#301)
Browse files Browse the repository at this point in the history
* [vllm] Add support for FP8 in Triton FA kernel

Adding support for FP8 (E4M3) in Triton FA kernel, including per-tensor
scaling factors.

Test:
1. Patched rocm_flash_attn.py to call FA kernel with scaling factors
   (https://gist.github.com/ilia-cher/216762889331cefeb158634a651b2fac)
2. Run the benchmark:
   python3 benchmark_latency.py --model \
      /data/models/Llama-3.1-8B-Instruct-FP8-KV \
      --input-len 8192 \
      --output-len 1 \
      --batch-size 32 \
      --enforce-eager \
      --num-iters 10 \
      --num-iters-warmup 2 \
      --enable-chunked-prefill False \
      --dtype float16
Before:
Avg latency: 6.418297152221203 seconds
10% percentile latency: 6.380122036673129 seconds
25% percentile latency: 6.390297698322684 seconds
50% percentile latency: 6.404989298898727 seconds
75% percentile latency: 6.421127524343319 seconds
90% percentile latency: 6.4394324975088235 seconds
99% percentile latency: 6.562963163470849 seconds

After:
Avg latency: 5.162057781498879 seconds
10% percentile latency: 5.1219399653142315 seconds
25% percentile latency: 5.135780334530864 seconds
50% percentile latency: 5.151887209853157 seconds
75% percentile latency: 5.158517300733365 seconds
90% percentile latency: 5.184290232090279 seconds
99% percentile latency: 5.314461483638734 seconds

3. (Sanity) check using
   https://gist.github.com/ilia-cher/951a3d011a8bafa7c5180fbc3a151a57

4. (follow up in scaling factors loading PR) P3L perplexity check

* (linter)
  • Loading branch information
ilia-cher authored Dec 4, 2024
1 parent 18ef0a0 commit 97fd542
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 12 deletions.
1 change: 1 addition & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def forward(
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
None,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
Expand Down
91 changes: 79 additions & 12 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def _attn_fwd_inner(
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
PADDED_HEAD: tl.constexpr,
USE_FP8: tl.constexpr,
qk_scale,
p_descale,
):
# loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N):
Expand Down Expand Up @@ -145,6 +148,8 @@ def _attn_fwd_inner(
qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ----
qk += tl.dot(q, k)
if USE_FP8:
qk *= qk_scale
if bias_ptr is not None:
bias = load_fn(bias_ptr, False, MASK_STEPS
and (n_extra_tokens != 0), "zero")
Expand Down Expand Up @@ -196,7 +201,12 @@ def _attn_fwd_inner(
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij

if USE_FP8:
p *= p_descale

acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)

V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
if bias_ptr is not None:
Expand Down Expand Up @@ -292,18 +302,20 @@ def _attn_fwd_inner(
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
# triton.Config(
# {
# "BLOCK_M": 16,
# "BLOCK_N": 16,
# "waves_per_eu": 1,
# "PRE_LOAD_V": False,
# },
# num_stages=1,
# num_warps=4,
# ),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'],
)
@triton.jit
def attn_fwd(
Expand All @@ -312,6 +324,12 @@ def attn_fwd(
V,
bias,
sm_scale,
q_scale,
k_scale,
v_scale,
p_scale,
p_descale,
o_descale,
L,
Out,
stride_qz: tl.int64,
Expand Down Expand Up @@ -354,6 +372,7 @@ def attn_fwd(
BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
USE_FP8: tl.constexpr,
):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
Expand Down Expand Up @@ -507,7 +526,12 @@ def attn_fwd(
qk_scale = sm_scale * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q = load_fn(Q_block_ptr, True, padded_head, "zero")
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
if not USE_FP8:
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
acc_scale = 1.0
else:
qk_scale *= q_scale * k_scale
acc_scale = p_scale * v_scale

# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
Expand Down Expand Up @@ -562,6 +586,9 @@ def attn_fwd(
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
padded_head,
USE_FP8,
qk_scale,
p_descale,
)
block_min = block_max
block_max = n_blocks * BLOCK_N
Expand Down Expand Up @@ -608,8 +635,14 @@ def attn_fwd(
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
padded_head,
USE_FP8,
qk_scale,
p_descale,
)
# epilogue

if USE_FP8:
acc *= acc_scale
acc = acc / l_i[:, None]
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
Expand All @@ -620,6 +653,8 @@ def attn_fwd(
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
if USE_FP8:
acc *= o_descale
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
Expand Down Expand Up @@ -710,7 +745,29 @@ def forward(
causal=False,
sm_scale=1.0,
bias=None,
fp8_scales=None,
):
if fp8_scales is not None:
use_fp8 = True
(q_scale, k_scale, v_scale, p_scale, o_scale) = fp8_scales
float8 = torch.float8_e4m3fnuz

def check_and_convert(t, scale):
if t.dtype != float8:
finfo = torch.finfo(float8)
descale = 1.0 / scale
ts = (t * descale).clamp(min=finfo.min, max=finfo.max)
return ts.to(float8)
else:
return t

q = check_and_convert(q, q_scale)
k = check_and_convert(k, k_scale)
v = check_and_convert(v, v_scale)
else:
use_fp8 = False
q_scale = k_scale = v_scale = p_scale = o_scale = 1.0

if o is None:
o = torch.empty_like(q, dtype=v.dtype)

Expand Down Expand Up @@ -773,12 +830,21 @@ def forward(
else:
bias_strides = (0, 0, 0, 0)

p_descale = 1.0 / p_scale
o_descale = 1.0 / o_scale

attn_fwd[grid](
q,
k,
v,
bias,
sm_scale,
q_scale,
k_scale,
v_scale,
p_scale,
p_descale,
o_descale,
None,
o,
*q_strides,
Expand All @@ -803,6 +869,7 @@ def forward(
BIAS_TYPE=0 if bias is None else 1,
ENABLE_DROPOUT=False,
RETURN_ENCODED_SOFTMAX=False,
USE_FP8=use_fp8,
)

ctx.grid = grid
Expand Down

0 comments on commit 97fd542

Please sign in to comment.