From 9364572dd1f28d01650772535f88e1fd6982845a Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sat, 14 Dec 2024 03:58:51 +0900 Subject: [PATCH] nvfp8 with Hopper+ check Signed-off-by: Masaki Kozuki --- thunder/executors/nvfuserex_impl.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 4e5a204f8c..2b0589367b 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -101,6 +101,17 @@ } +_lcfp8_to_nvfp8_map: dict[dtypes.dtype, DataType] = { + dtypes.float8_e5m2: DataType.Float8_e5m2, + dtypes.float8_e5m2_: DataType.Float8_e5m2, + dtypes.float8_e4m3fn: DataType.Float8_e4m3fn, + dtypes.float8_e4m3fn_: DataType.Float8_e4m3fn, +} + + +_lcdtype_to_nvdtype_map.update(_lcfp8_to_nvfp8_map) + + def lcdtype_to_nvdtype(lcdtype: type | dtypes.dtype) -> DataType: return _lcdtype_to_nvdtype_map[lcdtype] @@ -144,7 +155,14 @@ def is_supported_devicetype(devicetype: DeviceType) -> bool: return devicetype is DeviceType.CUDA -_low_precision_floats = (dtypes.float16, dtypes.float16_, dtypes.bfloat16, dtypes.bfloat16_) +_low_precision_floats = (dtypes.float16, dtypes.float16_, dtypes.bfloat16, dtypes.bfloat16_) + tuple( + _lcfp8_to_nvfp8_map.keys() +) + + +def device_supports_fp8() -> bool: + cuda_major, _ = torch.cuda.get_device_capability() + return cuda_major > 8 def is_supported_dtype(dtype: type | dtypes.dtype, *, allow_low_precision_floats: bool = True) -> bool: @@ -154,7 +172,7 @@ def is_supported_dtype(dtype: type | dtypes.dtype, *, allow_low_precision_floats if dtype in _low_precision_floats: return False - return dtype in _lcdtype_to_nvdtype_map + return dtype in _lcdtype_to_nvdtype_map and (device_supports_fp8() if dtype in _lcfp8_to_nvfp8_map else True) def is_supported_tensor(a: TensorProxy, *, allow_low_precision_floats: bool = True) -> bool: