From f8150598909d74ca6c63da9dfd4198ec272c3e21 Mon Sep 17 00:00:00 2001 From: Wen Chen Date: Mon, 11 Mar 2024 09:12:24 +0000 Subject: [PATCH] fix ut bugs --- msamp/common/utils/device.py | 15 +-------------- msamp/operators/gemm/gemm.py | 4 +++- tests/te/test_te_replacer.py | 6 +++++- 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/msamp/common/utils/device.py b/msamp/common/utils/device.py index eedfb98d..ce308dc2 100644 --- a/msamp/common/utils/device.py +++ b/msamp/common/utils/device.py @@ -14,24 +14,11 @@ class GPUType(Enum): class Device: """Device class for different hardwares.""" - @staticmethod - def is_fp8_supported(): - """Check whether the device support FP8 or not. - - Return: - boolean: return True if the device support FP8 precision. - """ - gpu_name = torch.cuda.get_device_name().lower() - if 'h100' in gpu_name: - return True - - return False - @staticmethod def get_gpu_type(): """Get the GPU type.""" if torch.cuda.device_count() > 0: - device_name = torch.cuda.get_device_name(0) + device_name = torch.cuda.get_device_name(0).upper() if "NVIDIA" in device_name: return GPUType.NVIDIA elif "AMD" in device_name: diff --git a/msamp/operators/gemm/gemm.py b/msamp/operators/gemm/gemm.py index e7834c20..a3e6b246 100644 --- a/msamp/operators/gemm/gemm.py +++ b/msamp/operators/gemm/gemm.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F +import transformer_engine.pytorch as te from msamp.common.dtype import Dtypes from msamp.common.utils import Device @@ -118,7 +119,8 @@ def fp8_gemm( bias = (bias if bias is not None else cls._empty_tensor) # here out is padded, and src_out is the original one. - if Device.is_fp8_supported(): + is_fp8_avaiable, _ = te.fp8.check_fp8_support() + if is_fp8_avaiable: tew.te_gemm( mat_a.value, a_meta.scale_inv, diff --git a/tests/te/test_te_replacer.py b/tests/te/test_te_replacer.py index 9c56e411..4917f2f9 100644 --- a/tests/te/test_te_replacer.py +++ b/tests/te/test_te_replacer.py @@ -9,11 +9,12 @@ import torch import torch.distributed as dist -from torch.testing._internal.common_distributed import MultiProcessTestCase, skip_if_lt_x_gpu, requires_nccl +from torch.testing._internal.common_distributed import MultiProcessTestCase, skip_if_lt_x_gpu, requires_nccl, skip_if_rocm import transformer_engine.pytorch as te from transformer_engine.common.recipe import Format, DelayedScaling from tests.helper import decorator +from msamp.common.utils import Device, GPUType from msamp import deepspeed from msamp.nn import ScalingParameter from msamp.optim import LBAdamW @@ -136,6 +137,9 @@ def world_size(self): @decorator.cuda_test def test_fp8_ddp_with_te(self): """Test FP8DistributedDataParallel with TransformerEngine.""" + if Device.get_gpu_type() == GPUType.AMD: + # This UT is time out when running on MI300x and we have reported to AMD. + return hidden_size = 4096 ffn_hidden_size = 16384 num_attention_heads = 32