Skip to content

Commit

Permalink
[Bugfix] Revert MoE Triton Config Default (vllm-project#12629)
Browse files Browse the repository at this point in the history
SUMMARY:
* previous PR for pulling in block configs also changed defaults
(https://github.com/vllm-project/vllm/pull/11589/files) for FP8
* this broke L4 MoE since there was not enough SHM for the default
configuration
* this reverts the non-block example to the default

Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
  • Loading branch information
robertgshaw2-redhat authored and kerthcet committed Feb 21, 2025
1 parent b0bc1cc commit bc628be
Showing 1 changed file with 11 additions and 30 deletions.
41 changes: 11 additions & 30 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,36 +660,17 @@ def get_default_config(
is_marlin: bool,
block_shape: Optional[List[int]] = None,
) -> Dict[str, int]:
if dtype == "fp8_w8a8":
if block_shape is None:
config = {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
}
if M <= E:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
}
else:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_shape[0],
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_shape[0],
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
else:
config = {
"BLOCK_SIZE_M": 64,
Expand Down

0 comments on commit bc628be

Please sign in to comment.