@@ -562,8 +562,13 @@ def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k):
562562 )
563563 except Exception as e :
564564 traceback .print_exc ()
565+ comm .allgather (e )
565566 raise e
566567
568+ exceptions = comm .allgather (None )
569+ if any (exceptions ):
570+ raise filter (lambda x : x is not None , exceptions )[0 ]
571+
567572 # Gather results from all ranks
568573 all_results = comm .allgather (result )
569574
@@ -638,111 +643,131 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
638643 num_experts_per_rank = 8
639644 workspace_size_per_rank = 512 * 1024 * 1024
640645
641- mapping = Mapping (
642- rank = rank ,
643- moe_ep_size = world_size ,
644- tp_size = world_size ,
645- world_size = world_size ,
646- )
646+ try :
647+ mapping = Mapping (
648+ rank = rank ,
649+ moe_ep_size = world_size ,
650+ tp_size = world_size ,
651+ world_size = world_size ,
652+ )
647653
648- local_num_tokens = all_num_tokens [rank ]
649- max_num_tokens = max (all_num_tokens )
654+ local_num_tokens = all_num_tokens [rank ]
655+ max_num_tokens = max (all_num_tokens )
650656
651- # Generate inputs
652- token_selected_experts = generate_token_selected_experts (
653- local_num_tokens , ep_size , num_experts_per_rank , top_k
654- )
657+ # Generate inputs
658+ token_selected_experts = generate_token_selected_experts (
659+ local_num_tokens , ep_size , num_experts_per_rank , top_k
660+ )
655661
656- payloads , expert_id_payload_index = make_bfloat16_payloads (
657- local_num_tokens , hidden_size , top_k , rank , token_selected_experts
658- )
662+ payloads , expert_id_payload_index = make_bfloat16_payloads (
663+ local_num_tokens , hidden_size , top_k , rank , token_selected_experts
664+ )
659665
660- hidden_states = payloads [0 ]
661- token_final_scales = payloads [2 ]
666+ hidden_states = payloads [0 ]
667+ token_final_scales = payloads [2 ]
668+
669+ # Compute reference (single-GPU MoE)
670+ all_experts = torch .cat (
671+ [
672+ create_experts (
673+ num_experts_per_rank , hidden_size , r , "cuda" , dtype = torch .bfloat16
674+ )
675+ for r in range (ep_size )
676+ ],
677+ dim = 0 ,
678+ )
662679
663- # Compute reference (single-GPU MoE)
664- all_experts = torch .cat (
665- [
666- create_experts (
667- num_experts_per_rank , hidden_size , r , "cuda" , dtype = torch .bfloat16
668- )
669- for r in range (ep_size )
670- ],
671- dim = 0 ,
672- )
680+ rank_experts = create_experts (
681+ num_experts_per_rank , hidden_size , rank , "cuda" , dtype = torch .bfloat16
682+ )
673683
674- rank_experts = create_experts (
675- num_experts_per_rank , hidden_size , rank , "cuda" , dtype = torch .bfloat16
676- )
684+ reference_output = fake_moe (
685+ hidden_states ,
686+ token_selected_experts ,
687+ token_final_scales ,
688+ all_experts ,
689+ is_ep = False ,
690+ )
677691
678- reference_output = fake_moe (
679- hidden_states ,
680- token_selected_experts ,
681- token_final_scales ,
682- all_experts ,
683- is_ep = False ,
684- )
692+ torch .cuda .synchronize ()
685693
686- torch .cuda .synchronize ()
694+ # Initialize MoeAlltoAll
695+ MoeAlltoAll ._WORKSPACE = None
696+ moe_a2a = MoeAlltoAll (
697+ mapping = mapping ,
698+ max_num_tokens = max_num_tokens ,
699+ top_k = top_k ,
700+ num_experts = ep_size * num_experts_per_rank ,
701+ workspace_size_per_rank = workspace_size_per_rank ,
702+ )
687703
688- # Initialize MoeAlltoAll
689- MoeAlltoAll ._WORKSPACE = None
690- moe_a2a = MoeAlltoAll (
691- mapping = mapping ,
692- max_num_tokens = max_num_tokens ,
693- top_k = top_k ,
694- num_experts = ep_size * num_experts_per_rank ,
695- workspace_size_per_rank = workspace_size_per_rank ,
696- )
704+ # Dispatch
705+ recv_tensors = moe_a2a .dispatch (
706+ token_selected_experts = token_selected_experts ,
707+ input_payloads = payloads ,
708+ runtime_max_tokens_per_rank = max_num_tokens ,
709+ )
697710
698- # Dispatch
699- recv_tensors = moe_a2a .dispatch (
700- token_selected_experts = token_selected_experts ,
701- input_payloads = payloads ,
702- runtime_max_tokens_per_rank = max_num_tokens ,
703- )
711+ # Unpack received tensors
712+ hidden_states_recv = recv_tensors [0 ] # [ep_size, max_tokens, hidden_size]
713+ token_selected_experts_recv = recv_tensors [1 ] # [ep_size, max_tokens, top_k]
714+ token_final_scales_recv = recv_tensors [2 ] # [ep_size, max_tokens, top_k]
704715
705- # Unpack received tensors
706- hidden_states_recv = recv_tensors [0 ] # [ep_size, max_tokens, hidden_size]
707- token_selected_experts_recv = recv_tensors [1 ] # [ep_size, max_tokens, top_k]
708- token_final_scales_recv = recv_tensors [2 ] # [ep_size, max_tokens, top_k]
716+ # Get workspace-backed tensor for output
717+ moe_output = moe_a2a .get_combine_payload_tensor_in_workspace (
718+ runtime_max_tokens_per_rank = max_num_tokens ,
719+ hidden_size = hidden_size ,
720+ dtype = torch .bfloat16 ,
721+ )
722+ moe_output .zero_ ()
723+
724+ # Process each rank's tokens with local experts
725+ moe_output .copy_ (
726+ fake_moe (
727+ hidden_states_recv .view (
728+ ep_size * max_num_tokens , hidden_states_recv .shape [- 1 ]
729+ ),
730+ token_selected_experts_recv .view (
731+ ep_size * max_num_tokens , token_selected_experts_recv .shape [- 1 ]
732+ ),
733+ token_final_scales_recv .view (
734+ ep_size * max_num_tokens , token_final_scales_recv .shape [- 1 ]
735+ ),
736+ rank_experts , # experts for current rank
737+ is_ep = True ,
738+ ep_rank = rank ,
739+ num_experts_per_rank = num_experts_per_rank ,
740+ ).view (ep_size , max_num_tokens , hidden_size )
741+ )
742+ except Exception as e :
743+ traceback .print_exc ()
744+ comm .allgather (e )
745+ raise e
709746
710- # Get workspace-backed tensor for output
711- moe_output = moe_a2a .get_combine_payload_tensor_in_workspace (
712- runtime_max_tokens_per_rank = max_num_tokens ,
713- hidden_size = hidden_size ,
714- dtype = torch .bfloat16 ,
715- )
716- moe_output .zero_ ()
717-
718- # Process each rank's tokens with local experts
719- moe_output .copy_ (
720- fake_moe (
721- hidden_states_recv .view (
722- ep_size * max_num_tokens , hidden_states_recv .shape [- 1 ]
723- ),
724- token_selected_experts_recv .view (
725- ep_size * max_num_tokens , token_selected_experts_recv .shape [- 1 ]
726- ),
727- token_final_scales_recv .view (
728- ep_size * max_num_tokens , token_final_scales_recv .shape [- 1 ]
729- ),
730- rank_experts , # experts for current rank
731- is_ep = True ,
732- ep_rank = rank ,
733- num_experts_per_rank = num_experts_per_rank ,
734- ).view (ep_size , max_num_tokens , hidden_size )
735- )
747+ exceptions = comm .allgather (None )
748+ if any (exceptions ):
749+ raise filter (lambda x : x is not None , exceptions )[0 ]
736750
737- # Combine
738- combined_output = moe_a2a .combine (
739- payload = moe_output ,
740- runtime_max_tokens_per_rank = max_num_tokens ,
741- payload_in_workspace = True ,
742- )
751+ try :
752+ # Combine
753+ combined_output = moe_a2a .combine (
754+ payload = moe_output ,
755+ runtime_max_tokens_per_rank = max_num_tokens ,
756+ payload_in_workspace = True ,
757+ )
758+
759+ # Verify against reference
760+ torch .testing .assert_close (
761+ combined_output , reference_output , rtol = 1e-2 , atol = 1e-2
762+ )
763+ except Exception as e :
764+ traceback .print_exc ()
765+ comm .allgather (e )
766+ raise e
743767
744- # Verify against reference
745- torch .testing .assert_close (combined_output , reference_output , rtol = 1e-2 , atol = 1e-2 )
768+ exceptions = comm .allgather (None )
769+ if any (exceptions ):
770+ raise filter (lambda x : x is not None , exceptions )[0 ]
746771
747772
748773if __name__ == "__main__" :
0 commit comments