Skip to content

Commit db22fce

Browse files
committed
Fix coderabbit nits
1 parent d375afe commit db22fce

File tree

4 files changed

+12
-13
lines changed

4 files changed

+12
-13
lines changed

csrc/trtllm_moe_alltoall.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ Tuple<Array<int64_t>, Array<int64_t>, int64_t> moeA2ADispatchOp(
188188

189189
TVM_FFI_ICHECK(totalBytesNeeded % elementSize == 0)
190190
<< "Misaligned payload buffer " << i << " with element size " << elementSize
191-
<< ". Consider putting ordering payloads by minimum element size";
191+
<< ". Consider reordering payloads by largest to smallest element size";
192192
}
193193

194194
auto* workspaceBase = static_cast<uint8_t*>(workspace.data_ptr());

flashinfer/comm/trtllm_moe_alltoall.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
supporting multiple payloads per collective operation.
66
"""
77

8-
# TODO Review
9-
108
from dataclasses import dataclass
119
from types import SimpleNamespace
1210
from typing import Optional
@@ -648,9 +646,9 @@ def get_combine_payload_tensor_in_workspace(
648646

649647
__all__ = [
650648
"MoeAlltoAll",
651-
"moe_a2a_initialize",
652-
"moe_a2a_dispatch",
653649
"moe_a2a_combine",
654-
"moe_a2a_sanitize_expert_ids",
650+
"moe_a2a_dispatch",
655651
"moe_a2a_get_workspace_size_per_rank",
652+
"moe_a2a_initialize",
653+
"moe_a2a_sanitize_expert_ids",
656654
]

tests/comm/test_mnnvl_moe_alltoall.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,18 @@ def safe_run(func, *args, **kwargs):
3838
comm = MPI.COMM_WORLD
3939
try:
4040
func(*args, **kwargs)
41-
except MPIExit as e:
42-
raise e
43-
except Exception as e:
41+
except MPIExit:
42+
raise
43+
except Exception:
4444
traceback.print_exc()
4545
comm.allgather(True)
46-
raise e
46+
raise
4747

4848

4949
@pytest.fixture(autouse=True)
5050
def setup_test():
5151
torch.manual_seed(0x1234)
52+
yield
5253

5354

5455
def compute_target_rank_id(expert_id, num_experts_per_rank):
@@ -154,7 +155,7 @@ def fake_moe(
154155
# Process each token
155156
for token_idx in range(num_tokens):
156157
results = []
157-
# For each expert selected for this token/
158+
# For each expert selected for this token
158159
for k in range(top_k):
159160
expert_id = token_selected_experts[token_idx, k].item()
160161
if is_ep:

tests/comm/test_trtllm_moe_alltoall.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
@pytest.fixture(autouse=True, scope="session")
2626
def setup_test_environment():
27-
"""Set up test environment and warm up JIT compilation."""
27+
"""Set up torch seed for deterministic tests."""
2828
torch.manual_seed(0xD5)
2929
yield
3030

@@ -410,7 +410,7 @@ def fake_moe(
410410
# Process each token
411411
for token_idx in range(num_tokens):
412412
results = []
413-
# For each expert selected for this token/
413+
# For each expert selected for this token
414414
for k in range(top_k):
415415
expert_id = token_selected_experts[token_idx, k].item()
416416
if is_ep and not (

0 commit comments

Comments
 (0)