Skip to content

Commit

Permalink
fix the merge
Browse files Browse the repository at this point in the history
  • Loading branch information
wangbluo committed Aug 21, 2024
1 parent 2d362ac commit eb5ba40
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
4 changes: 3 additions & 1 deletion colossalai/shardformer/modeling/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
dist.all_to_all_single(
output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication
output_split_sizes,
input_split_sizes,
group=self.ep_group,
)

with torch.no_grad():
Expand Down
5 changes: 5 additions & 0 deletions colossalai/shardformer/modeling/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def setup_process_groups(
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group
self.fp8_communication = fp8_communication

if self.num_experts % self.ep_size != 0:
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")

self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
Expand Down
3 changes: 2 additions & 1 deletion colossalai/shardformer/policies/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
Expand Down Expand Up @@ -319,7 +320,7 @@ def module_policy(self):
setattr(self.shard_config, "causal_lm", True)

if self.shard_config.enable_tensor_parallelism:
# add a new item for causal lm
# add a new item for casual lm
new_item = {
Qwen2ForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
Expand Down

0 comments on commit eb5ba40

Please sign in to comment.