diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 12641996188f..d859e17cd5c3 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -25,6 +25,7 @@ import triton import triton.language as tl +from vllm.platforms import current_platform from vllm.utils import is_navi torch_dtype: tl.constexpr = torch.float16 @@ -391,7 +392,7 @@ def get_autotune_configs(): autotune_configs, autotune_keys = get_autotune_configs() -float8_info = torch.finfo(torch.float8_e4m3fnuz) +float8_info = torch.finfo(current_platform.fp8_dtype()) @triton.autotune( @@ -834,7 +835,7 @@ def forward( if fp8_scales is not None: use_fp8 = True (q_scale, k_scale, v_scale, p_scale, o_scale) = fp8_scales - float8 = torch.float8_e4m3fnuz + float8 = current_platform.fp8_dtype() def check_and_convert(t, scale): if t.dtype != float8: