Skip to content

Commit

Permalink
fix ut bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
wenchenvincent committed Mar 11, 2024
1 parent 1efec4a commit f815059
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 16 deletions.
15 changes: 1 addition & 14 deletions msamp/common/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,11 @@ class GPUType(Enum):

class Device:
"""Device class for different hardwares."""
@staticmethod
def is_fp8_supported():
"""Check whether the device support FP8 or not.
Return:
boolean: return True if the device support FP8 precision.
"""
gpu_name = torch.cuda.get_device_name().lower()
if 'h100' in gpu_name:
return True

return False

@staticmethod
def get_gpu_type():
"""Get the GPU type."""
if torch.cuda.device_count() > 0:
device_name = torch.cuda.get_device_name(0)
device_name = torch.cuda.get_device_name(0).upper()
if "NVIDIA" in device_name:
return GPUType.NVIDIA
elif "AMD" in device_name:
Expand Down
4 changes: 3 additions & 1 deletion msamp/operators/gemm/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.nn.functional as F
import transformer_engine.pytorch as te

from msamp.common.dtype import Dtypes
from msamp.common.utils import Device
Expand Down Expand Up @@ -118,7 +119,8 @@ def fp8_gemm(
bias = (bias if bias is not None else cls._empty_tensor)

# here out is padded, and src_out is the original one.
if Device.is_fp8_supported():
is_fp8_avaiable, _ = te.fp8.check_fp8_support()
if is_fp8_avaiable:
tew.te_gemm(
mat_a.value,
a_meta.scale_inv,
Expand Down
6 changes: 5 additions & 1 deletion tests/te/test_te_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

import torch
import torch.distributed as dist
from torch.testing._internal.common_distributed import MultiProcessTestCase, skip_if_lt_x_gpu, requires_nccl
from torch.testing._internal.common_distributed import MultiProcessTestCase, skip_if_lt_x_gpu, requires_nccl, skip_if_rocm
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

from tests.helper import decorator
from msamp.common.utils import Device, GPUType
from msamp import deepspeed
from msamp.nn import ScalingParameter
from msamp.optim import LBAdamW
Expand Down Expand Up @@ -136,6 +137,9 @@ def world_size(self):
@decorator.cuda_test
def test_fp8_ddp_with_te(self):
"""Test FP8DistributedDataParallel with TransformerEngine."""
if Device.get_gpu_type() == GPUType.AMD:
# This UT is time out when running on MI300x and we have reported to AMD.
return
hidden_size = 4096
ffn_hidden_size = 16384
num_attention_heads = 32
Expand Down

0 comments on commit f815059

Please sign in to comment.