Skip to content

Commit 2bac65a

Browse files
committed
Cleanup MPI processes on test failures
1 parent fede35b commit 2bac65a

File tree

1 file changed

+116
-91
lines changed

1 file changed

+116
-91
lines changed

tests/comm/test_mnnvl_a2a.py

Lines changed: 116 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,13 @@ def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k):
562562
)
563563
except Exception as e:
564564
traceback.print_exc()
565+
comm.allgather(e)
565566
raise e
566567

568+
exceptions = comm.allgather(None)
569+
if any(exceptions):
570+
raise filter(lambda x: x is not None, exceptions)[0]
571+
567572
# Gather results from all ranks
568573
all_results = comm.allgather(result)
569574

@@ -638,111 +643,131 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
638643
num_experts_per_rank = 8
639644
workspace_size_per_rank = 512 * 1024 * 1024
640645

641-
mapping = Mapping(
642-
rank=rank,
643-
moe_ep_size=world_size,
644-
tp_size=world_size,
645-
world_size=world_size,
646-
)
646+
try:
647+
mapping = Mapping(
648+
rank=rank,
649+
moe_ep_size=world_size,
650+
tp_size=world_size,
651+
world_size=world_size,
652+
)
647653

648-
local_num_tokens = all_num_tokens[rank]
649-
max_num_tokens = max(all_num_tokens)
654+
local_num_tokens = all_num_tokens[rank]
655+
max_num_tokens = max(all_num_tokens)
650656

651-
# Generate inputs
652-
token_selected_experts = generate_token_selected_experts(
653-
local_num_tokens, ep_size, num_experts_per_rank, top_k
654-
)
657+
# Generate inputs
658+
token_selected_experts = generate_token_selected_experts(
659+
local_num_tokens, ep_size, num_experts_per_rank, top_k
660+
)
655661

656-
payloads, expert_id_payload_index = make_bfloat16_payloads(
657-
local_num_tokens, hidden_size, top_k, rank, token_selected_experts
658-
)
662+
payloads, expert_id_payload_index = make_bfloat16_payloads(
663+
local_num_tokens, hidden_size, top_k, rank, token_selected_experts
664+
)
659665

660-
hidden_states = payloads[0]
661-
token_final_scales = payloads[2]
666+
hidden_states = payloads[0]
667+
token_final_scales = payloads[2]
668+
669+
# Compute reference (single-GPU MoE)
670+
all_experts = torch.cat(
671+
[
672+
create_experts(
673+
num_experts_per_rank, hidden_size, r, "cuda", dtype=torch.bfloat16
674+
)
675+
for r in range(ep_size)
676+
],
677+
dim=0,
678+
)
662679

663-
# Compute reference (single-GPU MoE)
664-
all_experts = torch.cat(
665-
[
666-
create_experts(
667-
num_experts_per_rank, hidden_size, r, "cuda", dtype=torch.bfloat16
668-
)
669-
for r in range(ep_size)
670-
],
671-
dim=0,
672-
)
680+
rank_experts = create_experts(
681+
num_experts_per_rank, hidden_size, rank, "cuda", dtype=torch.bfloat16
682+
)
673683

674-
rank_experts = create_experts(
675-
num_experts_per_rank, hidden_size, rank, "cuda", dtype=torch.bfloat16
676-
)
684+
reference_output = fake_moe(
685+
hidden_states,
686+
token_selected_experts,
687+
token_final_scales,
688+
all_experts,
689+
is_ep=False,
690+
)
677691

678-
reference_output = fake_moe(
679-
hidden_states,
680-
token_selected_experts,
681-
token_final_scales,
682-
all_experts,
683-
is_ep=False,
684-
)
692+
torch.cuda.synchronize()
685693

686-
torch.cuda.synchronize()
694+
# Initialize MoeAlltoAll
695+
MoeAlltoAll._WORKSPACE = None
696+
moe_a2a = MoeAlltoAll(
697+
mapping=mapping,
698+
max_num_tokens=max_num_tokens,
699+
top_k=top_k,
700+
num_experts=ep_size * num_experts_per_rank,
701+
workspace_size_per_rank=workspace_size_per_rank,
702+
)
687703

688-
# Initialize MoeAlltoAll
689-
MoeAlltoAll._WORKSPACE = None
690-
moe_a2a = MoeAlltoAll(
691-
mapping=mapping,
692-
max_num_tokens=max_num_tokens,
693-
top_k=top_k,
694-
num_experts=ep_size * num_experts_per_rank,
695-
workspace_size_per_rank=workspace_size_per_rank,
696-
)
704+
# Dispatch
705+
recv_tensors = moe_a2a.dispatch(
706+
token_selected_experts=token_selected_experts,
707+
input_payloads=payloads,
708+
runtime_max_tokens_per_rank=max_num_tokens,
709+
)
697710

698-
# Dispatch
699-
recv_tensors = moe_a2a.dispatch(
700-
token_selected_experts=token_selected_experts,
701-
input_payloads=payloads,
702-
runtime_max_tokens_per_rank=max_num_tokens,
703-
)
711+
# Unpack received tensors
712+
hidden_states_recv = recv_tensors[0] # [ep_size, max_tokens, hidden_size]
713+
token_selected_experts_recv = recv_tensors[1] # [ep_size, max_tokens, top_k]
714+
token_final_scales_recv = recv_tensors[2] # [ep_size, max_tokens, top_k]
704715

705-
# Unpack received tensors
706-
hidden_states_recv = recv_tensors[0] # [ep_size, max_tokens, hidden_size]
707-
token_selected_experts_recv = recv_tensors[1] # [ep_size, max_tokens, top_k]
708-
token_final_scales_recv = recv_tensors[2] # [ep_size, max_tokens, top_k]
716+
# Get workspace-backed tensor for output
717+
moe_output = moe_a2a.get_combine_payload_tensor_in_workspace(
718+
runtime_max_tokens_per_rank=max_num_tokens,
719+
hidden_size=hidden_size,
720+
dtype=torch.bfloat16,
721+
)
722+
moe_output.zero_()
723+
724+
# Process each rank's tokens with local experts
725+
moe_output.copy_(
726+
fake_moe(
727+
hidden_states_recv.view(
728+
ep_size * max_num_tokens, hidden_states_recv.shape[-1]
729+
),
730+
token_selected_experts_recv.view(
731+
ep_size * max_num_tokens, token_selected_experts_recv.shape[-1]
732+
),
733+
token_final_scales_recv.view(
734+
ep_size * max_num_tokens, token_final_scales_recv.shape[-1]
735+
),
736+
rank_experts, # experts for current rank
737+
is_ep=True,
738+
ep_rank=rank,
739+
num_experts_per_rank=num_experts_per_rank,
740+
).view(ep_size, max_num_tokens, hidden_size)
741+
)
742+
except Exception as e:
743+
traceback.print_exc()
744+
comm.allgather(e)
745+
raise e
709746

710-
# Get workspace-backed tensor for output
711-
moe_output = moe_a2a.get_combine_payload_tensor_in_workspace(
712-
runtime_max_tokens_per_rank=max_num_tokens,
713-
hidden_size=hidden_size,
714-
dtype=torch.bfloat16,
715-
)
716-
moe_output.zero_()
717-
718-
# Process each rank's tokens with local experts
719-
moe_output.copy_(
720-
fake_moe(
721-
hidden_states_recv.view(
722-
ep_size * max_num_tokens, hidden_states_recv.shape[-1]
723-
),
724-
token_selected_experts_recv.view(
725-
ep_size * max_num_tokens, token_selected_experts_recv.shape[-1]
726-
),
727-
token_final_scales_recv.view(
728-
ep_size * max_num_tokens, token_final_scales_recv.shape[-1]
729-
),
730-
rank_experts, # experts for current rank
731-
is_ep=True,
732-
ep_rank=rank,
733-
num_experts_per_rank=num_experts_per_rank,
734-
).view(ep_size, max_num_tokens, hidden_size)
735-
)
747+
exceptions = comm.allgather(None)
748+
if any(exceptions):
749+
raise filter(lambda x: x is not None, exceptions)[0]
736750

737-
# Combine
738-
combined_output = moe_a2a.combine(
739-
payload=moe_output,
740-
runtime_max_tokens_per_rank=max_num_tokens,
741-
payload_in_workspace=True,
742-
)
751+
try:
752+
# Combine
753+
combined_output = moe_a2a.combine(
754+
payload=moe_output,
755+
runtime_max_tokens_per_rank=max_num_tokens,
756+
payload_in_workspace=True,
757+
)
758+
759+
# Verify against reference
760+
torch.testing.assert_close(
761+
combined_output, reference_output, rtol=1e-2, atol=1e-2
762+
)
763+
except Exception as e:
764+
traceback.print_exc()
765+
comm.allgather(e)
766+
raise e
743767

744-
# Verify against reference
745-
torch.testing.assert_close(combined_output, reference_output, rtol=1e-2, atol=1e-2)
768+
exceptions = comm.allgather(None)
769+
if any(exceptions):
770+
raise filter(lambda x: x is not None, exceptions)[0]
746771

747772

748773
if __name__ == "__main__":

0 commit comments

Comments
 (0)