Skip to content

Commit 696fced

Browse files
authored
[fp8] fix missing fp8_comm flag in mixtral (#6057)
1 parent a35a078 commit 696fced

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

colossalai/shardformer/modeling/mixtral.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
all_to_all_uneven,
3232
)
3333
from colossalai.pipeline.stage_manager import PipelineStageManager
34+
from colossalai.quantization.fp8 import all_reduce_fp8
3435
from colossalai.shardformer.layer._operation import (
3536
all_to_all_comm,
3637
gather_forward_split_backward,
@@ -142,7 +143,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
142143
for i in range(1, self.ep_size):
143144
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
144145
activate_experts = (activate_experts > 0).float()
145-
dist.all_reduce(activate_experts, group=self.moe_dp_group)
146+
147+
if self.fp8_communication:
148+
all_reduce_fp8(activate_experts, group=self.moe_dp_group)
149+
else:
150+
dist.all_reduce(activate_experts, group=self.moe_dp_group)
146151

147152
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
148153
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()

colossalai/shardformer/policies/mixtral.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
178178
"ep_group": self.shard_config.ep_group,
179179
"tp_group": self.shard_config.tensor_parallel_process_group,
180180
"moe_dp_group": self.shard_config.moe_dp_group,
181+
"fp8_communication": self.shard_config.fp8_communication,
181182
},
182183
)
183184
],

0 commit comments

Comments
 (0)