Skip to content

Commit 375a26b

Browse files
authored
Enable SplitK and fix autotuner for trtllm fp4 fused moe (#1548)
Enable splitK for trtllm-gen fused moe. Make autotuner for trtllm-gen fp4 fused moe more robust. Add autotuner for trtllm-gen fused moe test. <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 3c1e8d7 commit 375a26b

File tree

5 files changed

+41
-24
lines changed

5 files changed

+41
-24
lines changed

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,6 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
104104
tileSize == mOptions.tileSize &&
105105
options.mUseShuffledMatrixA == mOptions.useShuffledMatrixA &&
106106
options.mLayoutA == mOptions.weightLayout) {
107-
// FIXME: Disable split-k for swiglu for now.
108-
if (static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType) ==
109-
batchedGemm::gemmGatedAct::ActType::SwiGlu &&
110-
options.mClusterDimZ != 1) {
111-
continue;
112-
}
113-
114107
if (options.mFusedAct) {
115108
if (options.mActType != static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType)) {
116109
continue;

flashinfer/artifacts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
111111
class ArtifactPath:
112112
TRTLLM_GEN_FMHA: str = "037e528e719ec3456a7d7d654f26b805e44c63b1/fmha/trtllm-gen/"
113113
TRTLLM_GEN_BMM: str = (
114-
"037e528e719ec3456a7d7d654f26b805e44c63b1/batched_gemm-8704aa4-ba3b00d/"
114+
"e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802/"
115115
)
116116
TRTLLM_GEN_GEMM: str = (
117117
"037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e/"
@@ -125,7 +125,7 @@ class MetaInfoHash:
125125
"0ff77215b86997665cf75973e13cd2932f551d46b4e008f851d32d47e1d9560f"
126126
)
127127
TRTLLM_GEN_BMM: str = (
128-
"34bdfe7acfd49f5fb8b48e06d56e6a5ad88b951c730552f228fc5f614f7632a8"
128+
"c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34"
129129
)
130130
DEEPGEMM: str = "69aa277b7f3663ed929e73f9c57301792b8c594dac15a465b44a5d151b6a1d50"
131131
TRTLLM_GEN_GEMM: str = (

flashinfer/fused_moe/core.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
TunableRunner,
3030
TuningConfig,
3131
)
32+
from ..jit.core import logger
3233
from ..jit import JitSpec
3334
from ..jit import env as jit_env
3435
from ..jit import (
@@ -1104,9 +1105,14 @@ def get_valid_tactics(
11041105
num_tokens,
11051106
)
11061107
if instance_key not in MoERunner.valid_tactics_dict:
1107-
MoERunner.valid_tactics_dict[instance_key] = (
1108-
moe_op.trtllm_get_valid_moe_configs(*instance_key)
1109-
)
1108+
try:
1109+
valid_tactics = moe_op.trtllm_get_valid_moe_configs(*instance_key)
1110+
except Exception as e:
1111+
logger.debug(
1112+
f"[Autotuner]: Failed to get valid tactics for {instance_key}. Error occurred: {e}"
1113+
)
1114+
return []
1115+
MoERunner.valid_tactics_dict[instance_key] = valid_tactics
11101116
return MoERunner.valid_tactics_dict[instance_key]
11111117

11121118
def forward(

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,22 +1085,34 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in
10851085
")");
10861086
}
10871087

1088+
// Number of iterations in K dimension after padding.
1089+
// Note the perCtaK in each CTA in the splitK group are padded to the same number of iterations.
1090+
// E.g., K = 512, TileK = 128, numSlicesForSplitK = 3. Then the padded K is
1091+
//
1092+
// ceil(512 / (128*3)) * (128*3) = 768
1093+
//
1094+
int const paddedK = divUpMul(options.mK, options.mTileK * options.mNumSlicesForSplitK);
1095+
int const perCtaK = paddedK / options.mNumSlicesForSplitK;
1096+
// However, number of iterations is clamped to multiples of tileK within individual CTAs
1097+
// E.g., K = 448, TileK = 64, numSlicesForSplitK = 4.
1098+
//
1099+
// paddedK = 512
1100+
// perCtaK = 128
1101+
// clampedPerCtaK for CTA 0, 1, 2 = 128
1102+
// clampedPerCtaK for CTA 3 = 64
1103+
int const paddingForK = paddedK - options.mK;
1104+
int const clampedAndPaddedPerCtaK = divUpMul(perCtaK - paddingForK, options.mTileK);
10881105
if (options.mUseUnrollLoop2xForMma) {
1089-
// Number of iterations in K dimension after padding.
1090-
// Note the perCtaK in each CTA in the splitK group are padded to the same number of iterations.
1091-
// E.g., K = 512, TileK = 128, numSlicesForSplitK = 3. Then the padded K is
1092-
//
1093-
// ceil(512 / (128*3)) * (128*3) = 768
1106+
// Check that the padded K and clamped padded K (K rounded to next multiple of tileK) is a
1107+
// multiple of 2*TileK when UnrollLoop2x is enabled. This is to avoid deadlock when mma runs
1108+
// even-numbered loop while the other warps run odd-numbered loop.
10941109
//
1095-
int paddedK = divUpMul(options.mK, options.mTileK * options.mNumSlicesForSplitK);
1096-
// Check that the padded K (K rounded to next multiple of tileK) is a multiple of 2*TileK when
1097-
// UnrollLoop2x is enabled. This is to avoid deadlock when mma runs even-numbered loop while the
1098-
// other warps run odd-numbered loop.
1099-
//
1100-
bool notSupported = (paddedK / options.mNumSlicesForSplitK) % (options.mTileK * 2) != 0;
1110+
bool notSupported = (perCtaK % (options.mTileK * 2) != 0) ||
1111+
(clampedAndPaddedPerCtaK % (options.mTileK * 2) != 0);
11011112
if (notSupported) {
11021113
TLLM_LOG_WARNING("Size K / splitK must be a multiple of TileK * 2. Found TileK=",
11031114
options.mTileK, " and K=", options.mK, " (paddedK=", paddedK,
1115+
" clampedAndPaddedPerCtaK=", clampedAndPaddedPerCtaK,
11041116
") and numSlicesForSplitK=", options.mNumSlicesForSplitK,
11051117
". Disabling unrollLoop2xForMma.");
11061118
if (updateOptions) {
@@ -1110,6 +1122,11 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in
11101122
}
11111123
}
11121124
}
1125+
if (options.mNumSlicesForSplitK > 1) {
1126+
TLLM_CHECK_ERROR(
1127+
perCtaK * (options.mNumSlicesForSplitK - 1) < options.mK,
1128+
"K must be greater than perCtaK * (numSlicesForSplitK - 1) to ensure each CTA has work");
1129+
}
11131130

11141131
if (!isBlackwell && options.mTileScheduler == TileScheduler::Persistent) {
11151132
// TODO(anchengc): will be supported in upcoming MRs.

tests/test_trtllm_gen_fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
reorder_rows_for_gated_act_gemm,
3434
shuffle_matrix_a,
3535
)
36+
from flashinfer.autotuner import autotune
3637
from flashinfer.fp4_quantization import block_scale_interleave
3738
from flashinfer.fused_moe import (
3839
WeightLayout,
@@ -105,7 +106,7 @@ def capture(self, hidden_states_sample, **runtime_args):
105106
self.input_tensor = hidden_states_sample.clone()
106107

107108
# Warmup
108-
with torch.cuda.stream(torch_stream):
109+
with torch.cuda.stream(torch_stream), autotune(True):
109110
for _ in range(1):
110111
self._run_moe_computation(runtime_args)
111112

0 commit comments

Comments
 (0)