Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 92 additions & 15 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,7 @@ def _test_grouped_linear_accuracy(
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute=False,
m_splits_on_device=False,
):
reset_rng_states()
if fp8:
Expand All @@ -1776,7 +1777,7 @@ def _test_grouped_linear_accuracy(
if fp8:
split_size = 16
if recipe.mxfp8() or recipe.nvfp4():
split_size = 32
split_size = 128
m = config.max_seqlen_q // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
Expand All @@ -1787,17 +1788,20 @@ def _test_grouped_linear_accuracy(
m_splits = torch.tensor([config.max_seqlen_q])

with autocast(enabled=fp8, recipe=recipe):
if m_splits_on_device:
m_splits = m_splits.to("cuda")
if isinstance(block, GroupedLinear):
m_splits = m_splits * bs
out = block(inp_hidden_states, m_splits.tolist())
out = block(inp_hidden_states, m_splits)
else:
out = torch.cat(
[
block[i](inp)
for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
]
)
loss = out.sum()
target = torch.rand_like(out, device=out.device, dtype=out.dtype)
loss = (out * target).sum()
loss.backward()
if delay_wgrad_compute:
if isinstance(block, GroupedLinear):
Expand Down Expand Up @@ -1839,6 +1843,8 @@ def test_grouped_linear_accuracy(
delay_wgrad_compute,
parallel_mode=None,
use_cutlass=False,
m_splits_on_device=False,
num_unfuse_wgrad_accumulation=0,
):
fp8 = recipe is not None
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
Expand All @@ -1853,7 +1859,15 @@ def test_grouped_linear_accuracy(
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)

if num_unfuse_wgrad_accumulation > 0 and not m_splits_on_device:
pytest.skip("Partial accumulate is not supported when m_splits_on_device is False")
wgrad_accumulation_mask = None
if fuse_wgrad_accumulation and num_unfuse_wgrad_accumulation > 0 and num_unfuse_wgrad_accumulation < num_gemms:
wgrad_accumulation_mask = torch.ones(num_gemms, dtype=torch.bool)
indices = list(range(num_gemms))
random.shuffle(indices)
for idx in indices[:num_unfuse_wgrad_accumulation]:
wgrad_accumulation_mask[idx] = False
with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear(
num_gemms,
Expand All @@ -1864,9 +1878,12 @@ def test_grouped_linear_accuracy(
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
wgrad_accumulation_mask=wgrad_accumulation_mask,
delay_wgrad_compute=delay_wgrad_compute,
save_original_input=False,
).eval()
if wgrad_accumulation_mask is None:
wgrad_accumulation_mask = torch.full((num_gemms,), fuse_wgrad_accumulation, dtype=torch.bool)
sequential_linear = torch.nn.ModuleList(
[
Linear(
Expand All @@ -1876,9 +1893,9 @@ def test_grouped_linear_accuracy(
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
fuse_wgrad_accumulation=wgrad_accumulation_mask[i],
).eval()
for _ in range(num_gemms)
for i in range(num_gemms)
]
)

Expand All @@ -1888,7 +1905,7 @@ def test_grouped_linear_accuracy(
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
if bias:
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if fuse_wgrad_accumulation:
if wgrad_accumulation_mask[i]:
weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
Expand All @@ -1903,7 +1920,9 @@ def test_grouped_linear_accuracy(
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
m_splits_on_device
)

outputs = _test_grouped_linear_accuracy(
grouped_linear,
num_gemms,
Expand All @@ -1914,7 +1933,9 @@ def test_grouped_linear_accuracy(
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
m_splits_on_device
)


for o, o_ref in zip(outputs, outputs_ref):
if use_cutlass:
Expand All @@ -1924,6 +1945,46 @@ def test_grouped_linear_accuracy(
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (10, 0),
reason="Only enable CUTLASS device grouped gemm on Blackwell",
)


@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", [recipe.MXFP8BlockScaling()])
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
@pytest.mark.parametrize("num_unfuse_wgrad_accumulation", [0, 1, 2])
def test_grouped_linear_accuracy_cutlass_device(
dtype,
num_gemms,
bs,
model,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
num_unfuse_wgrad_accumulation,
delay_wgrad_compute,
):
test_grouped_linear_accuracy(
dtype,
num_gemms,
bs,
model,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
False,
delay_wgrad_compute,
m_splits_on_device=True,
num_unfuse_wgrad_accumulation=num_unfuse_wgrad_accumulation,
)

@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0),
reason="Only enable CUTLASS grouped gemm on Hopper",
Expand Down Expand Up @@ -1968,6 +2029,7 @@ def test_grouped_linear_accuracy_cutlass(
@pytest.mark.parametrize("fuse_wgrad_accumulation", [True])
@pytest.mark.parametrize("bias", [False])
@pytest.mark.parametrize("delay_wgrad_compute", [True])
@pytest.mark.parametrize("m_splits_on_device", all_boolean)
def test_grouped_linear_accuracy_save_original_input(
dtype,
num_gemms,
Expand All @@ -1978,13 +2040,16 @@ def test_grouped_linear_accuracy_save_original_input(
fuse_wgrad_accumulation,
bias,
delay_wgrad_compute,
m_splits_on_device,
parallel_mode=None,
):
fp8 = recipe is not None
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if m_splits_on_device and (not (fp8 and recipe.mxfp8()) or dtype not in [torch.bfloat16]):
pytest.skip("m_splits_on_device is only supported with MXFP8 recipe and bfloat16 dtype")

config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8:
Expand Down Expand Up @@ -2045,6 +2110,7 @@ def test_grouped_linear_accuracy_save_original_input(
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
m_splits_on_device,
)
outputs = _test_grouped_linear_accuracy(
grouped_linear,
Expand All @@ -2056,6 +2122,7 @@ def test_grouped_linear_accuracy_save_original_input(
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
m_splits_on_device,
)

# Shoule be bit-wise match
Expand All @@ -2079,12 +2146,12 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
)


def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False, m_splits_on_device=False):

def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
align_size = 16
if recipe.mxfp8() or recipe.nvfp4():
align_size = 32
align_size = 128
padded_tokens_per_expert = [
(num_tokens + align_size - 1) // align_size * align_size
for num_tokens in tokens_per_expert
Expand Down Expand Up @@ -2158,7 +2225,7 @@ def _generate_random_numbers(n, total_sum):
padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8(
inp_hidden_states, m_splits
)
padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits)
padded_inp_hidden_states = block(padded_inp_hidden_states, torch.tensor(padding_m_splits, device="cuda" if m_splits_on_device else "cpu"))
out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits)
else:
out = block(inp_hidden_states, m_splits)
Expand All @@ -2181,6 +2248,7 @@ def _generate_random_numbers(n, total_sum):
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("m_splits_on_device", all_boolean)
def test_padding_grouped_linear_accuracy(
dtype,
num_gemms,
Expand All @@ -2189,6 +2257,7 @@ def test_padding_grouped_linear_accuracy(
fp8,
recipe,
fp8_model_params,
m_splits_on_device,
parallel_mode=None,
):
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
Expand All @@ -2203,6 +2272,8 @@ def test_padding_grouped_linear_accuracy(
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)
if m_splits_on_device and (not recipe.mxfp8() or dtype not in [torch.bfloat16]):
pytest.skip("m_splits_on_device is only supported with MXFP8 recipe and bfloat16 dtype")

with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
Expand Down Expand Up @@ -2238,10 +2309,10 @@ def test_padding_grouped_linear_accuracy(
)

outputs = _test_padding_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, m_splits_on_device
)
outputs_ref = _test_padding_grouped_linear_accuracy(
ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, m_splits_on_device
)

# Shoule be bit-wise match
Expand All @@ -2256,6 +2327,7 @@ def test_padding_grouped_linear_accuracy(
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", [False])
@pytest.mark.parametrize("m_splits_on_device", all_boolean)
def test_padding_grouped_linear_accuracy_save_original_input(
dtype,
num_gemms,
Expand All @@ -2264,6 +2336,7 @@ def test_padding_grouped_linear_accuracy_save_original_input(
fp8,
recipe,
fp8_model_params,
m_splits_on_device,
parallel_mode=None,
):
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
Expand All @@ -2281,6 +2354,9 @@ def test_padding_grouped_linear_accuracy_save_original_input(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)

if m_splits_on_device and (not recipe.mxfp8() or dtype not in [torch.bfloat16]):
pytest.skip("m_splits_on_device is only supported with MXFP8 recipe and bfloat16 dtype")

with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
Expand Down Expand Up @@ -2315,10 +2391,10 @@ def test_padding_grouped_linear_accuracy_save_original_input(
)

outputs = _test_padding_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, m_splits_on_device
)
outputs_ref = _test_padding_grouped_linear_accuracy(
ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, m_splits_on_device
)

# Shoule be bit-wise match
Expand Down Expand Up @@ -2683,6 +2759,7 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
grad = True
single_output = False

m_splits = torch.tensor(m_splits)
if use_cutlass:
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"

Expand Down Expand Up @@ -2857,7 +2934,7 @@ def test_fp8_grouped_gemm(shape, accumulate):
out,
dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits,
m_splits=torch.tensor(m_splits),
accumulate=accumulate,
)

Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ list(APPEND transformer_engine_cuda_sources

list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
gemm/cutlass_device_grouped_gemm.cu
util/cast.cu
activation/gelu.cu
activation/relu.cu
Expand Down Expand Up @@ -232,6 +233,11 @@ set_property(
APPEND
PROPERTY
COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0")
set_property(
SOURCE gemm/cutlass_device_grouped_gemm.cu
APPEND
PROPERTY
COMPILE_OPTIONS "--generate-code=arch=compute_100a,code=sm_100a;-g0")

# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
Expand Down
Loading