From f559b96af8773ddee9ed930c0b44f7c0dec75734 Mon Sep 17 00:00:00 2001 From: Mihirsinh Chauhan <112346682+MihirsinhChauhan@users.noreply.github.com> Date: Sun, 23 Feb 2025 17:22:31 +0530 Subject: [PATCH] MOE implementation --- model/config.py | 2 ++ model/model.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/model/config.py b/model/config.py index 4bc7750..0f1173e 100644 --- a/model/config.py +++ b/model/config.py @@ -14,3 +14,5 @@ class Config: multiple_of: int = 128 eps: float = 1e-6 flash: bool = True + num_experts=8, + num_expert_per_tok=2 \ No newline at end of file diff --git a/model/model.py b/model/model.py index 1170cef..391ae34 100644 --- a/model/model.py +++ b/model/model.py @@ -177,6 +177,61 @@ def forward(self,x): return self.dropout(self.w2(F.silu(self.w1(x))*self.w3(x))) +class MoEFFN(nn.Module): + def __init__(self,d_model, hidden_dim, num_experts, num_expert_per_tok,multiple_of, dropout): + super().__init__() + self.d_model = d_model + self.num_experts = num_experts + self.num_expert_per_tok = num_expert_per_tok # Number of experts to route each token to + + if hidden_dim is None: + hidden_dim = 4 * d_model + hidden_dim = int(2/3 * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + # Gate network to route tokens to experts + self.gate = nn.Linear(d_model, num_experts, bias=False) + + # Create multiple expert FFNs + self.experts = nn.ModuleList([ + FFN(d_model, hidden_dim, multiple_of, dropout) + for _ in range(num_experts) + ]) + def forward(self, x): + B, T, C = x.shape + + # Get routing probabilities + route_logits = self.gate(x) # [B, T, num_experts] + + # Select top-k experts for each token + route_probs = F.softmax(route_logits, dim=-1) + scores, indices = torch.topk(route_probs, k=self.num_expert_per_tok, dim=-1) + + # Normalize the routing probabilities + scores = scores / scores.sum(dim=-1, keepdim=True) + + # Initialize output tensor + final_output = torch.zeros_like(x) + + # Process tokens through selected experts + for expert_idx in range(self.num_experts): + # Find which positions need this expert + expert_mask = (indices == expert_idx).any(dim=-1) # [B, T] + if not expert_mask.any(): + continue + + # Get the scores for this expert where it was selected + expert_scores = torch.zeros_like(route_probs[..., 0]) # [B, T] + expert_scores[indices == expert_idx] = scores[indices == expert_idx] + + # Process tokens through this expert + expert_output = self.experts[expert_idx](x) # [B, T, C] + + # Add weighted output to final result + final_output += expert_output * expert_scores.unsqueeze(-1) + + return final_output + class RMSNorm(nn.Module): def __init__(self,d_model,norm_eps=1e-6): super().__init__() @@ -196,11 +251,13 @@ class TransformerBlock(nn.Module): def __init__(self,layer_id, config): super().__init__() self.attn = Attention(config) - self.ffn = FFN( + self.moe = MoEFFN( config.d_model, config.hidden_dim, - config.multiple_of, - config.dropout + num_experts=config.num_experts, + num_expert_per_tok=config.num_expert_per_tok, + multiple_of=config.multiple_of, + dropout=config.dropout ) self.layer_id = layer_id self.attn_norm = RMSNorm(config.d_model, config.eps)