diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 0dfbbf44..075419dd 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -206,6 +206,7 @@ class GPT3MoEConfig: num_experts_per_tok: int = 1 moe_loss_weight: float = 0.01 moe_z_loss_weight: float = 0.001 + moe_glu: bool = False def as_gpt3(self) -> GPT3Config: diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index 10f00194..6ef9f0f9 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -230,17 +230,31 @@ def __init__( self.blocking = 128 if self.experts_per_rank == 1: - self.mlp = MLP( - config=config, - parallel_config=parallel_config, - tp_pg=parallel_context.tp_pg, - ) + if config.moe_glu: + self.mlp = GLU( + config=config, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg, + ) + else: + self.mlp = MLP( + config=config, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg, + ) else: - self.mlp = SparseGLU( - config=config, - parallel_config=parallel_config, - parallel_context=parallel_context, - ) + if config.moe_glu: + self.mlp = SparseGLU( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) + else: + self.mlp = SparseMLP( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) max_column_index = (self.config.intermediate_size * self.num_experts) // self.blocking self.transpose_sort_end_bit = max(int(np.ceil(np.log2(max_column_index))), 1) @@ -630,6 +644,26 @@ def __init__( expert_parallel_size=self.expert_pg_size, ) + # TODO @nouamane: jit + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim] + merged_states = self.w1(hidden_states) + hidden_states = self.w2(self.act(merged_states)) + return hidden_states + +class GLU(MLP): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__(config, parallel_config, tp_pg) + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) self.w3 = ExpertParallel( TensorParallelColumnLinear( config.hidden_size, @@ -641,15 +675,12 @@ def __init__( ), expert_parallel_size=self.expert_pg_size, ) - # TODO @nouamane: jit - self.act = ACT2FN[config.hidden_act] - def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim] + def forward(self, x, topo): merged_states = self.w1(hidden_states) hidden_states = self.w2(self.act(merged_states) * self.w3(hidden_states)) return hidden_states - def inclusive_cumsum(x, dim): scalar = ops.inclusive_cumsum(x, dim) return scalar.view(1) if not len(scalar.size()) else scalar