Skip to content

Commit

Permalink
option for GLU or normal MLP
Browse files Browse the repository at this point in the history
  • Loading branch information
haeggee committed Aug 5, 2024
1 parent bcb94cc commit 91acdc0
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
1 change: 1 addition & 0 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class GPT3MoEConfig:
num_experts_per_tok: int = 1
moe_loss_weight: float = 0.01
moe_z_loss_weight: float = 0.001
moe_glu: bool = False


def as_gpt3(self) -> GPT3Config:
Expand Down
59 changes: 45 additions & 14 deletions src/nanotron/models/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,31 @@ def __init__(
self.blocking = 128

if self.experts_per_rank == 1:
self.mlp = MLP(
config=config,
parallel_config=parallel_config,
tp_pg=parallel_context.tp_pg,
)
if config.moe_glu:
self.mlp = GLU(
config=config,
parallel_config=parallel_config,
tp_pg=parallel_context.tp_pg,
)
else:
self.mlp = MLP(
config=config,
parallel_config=parallel_config,
tp_pg=parallel_context.tp_pg,
)
else:
self.mlp = SparseGLU(
config=config,
parallel_config=parallel_config,
parallel_context=parallel_context,
)
if config.moe_glu:
self.mlp = SparseGLU(
config=config,
parallel_config=parallel_config,
parallel_context=parallel_context,
)
else:
self.mlp = SparseMLP(
config=config,
parallel_config=parallel_config,
parallel_context=parallel_context,
)

max_column_index = (self.config.intermediate_size * self.num_experts) // self.blocking
self.transpose_sort_end_bit = max(int(np.ceil(np.log2(max_column_index))), 1)
Expand Down Expand Up @@ -630,6 +644,26 @@ def __init__(
expert_parallel_size=self.expert_pg_size,
)

# TODO @nouamane: jit
self.act = ACT2FN[config.hidden_act]

def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim]
merged_states = self.w1(hidden_states)
hidden_states = self.w2(self.act(merged_states))
return hidden_states

class GLU(MLP):
def __init__(
self,
config: Config,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
):
super().__init__(config, parallel_config, tp_pg)
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
self.w3 = ExpertParallel(
TensorParallelColumnLinear(
config.hidden_size,
Expand All @@ -641,15 +675,12 @@ def __init__(
),
expert_parallel_size=self.expert_pg_size,
)
# TODO @nouamane: jit
self.act = ACT2FN[config.hidden_act]

def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim]
def forward(self, x, topo):
merged_states = self.w1(hidden_states)
hidden_states = self.w2(self.act(merged_states) * self.w3(hidden_states))
return hidden_states


def inclusive_cumsum(x, dim):
scalar = ops.inclusive_cumsum(x, dim)
return scalar.view(1) if not len(scalar.size()) else scalar
Expand Down

0 comments on commit 91acdc0

Please sign in to comment.