Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ class Config:
multiple_of: int = 128
eps: float = 1e-6
flash: bool = True
num_experts=8,
num_expert_per_tok=2
63 changes: 60 additions & 3 deletions model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,61 @@ def forward(self,x):

return self.dropout(self.w2(F.silu(self.w1(x))*self.w3(x)))

class MoEFFN(nn.Module):
def __init__(self,d_model, hidden_dim, num_experts, num_expert_per_tok,multiple_of, dropout):
super().__init__()
self.d_model = d_model
self.num_experts = num_experts
self.num_expert_per_tok = num_expert_per_tok # Number of experts to route each token to

if hidden_dim is None:
hidden_dim = 4 * d_model
hidden_dim = int(2/3 * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

# Gate network to route tokens to experts
self.gate = nn.Linear(d_model, num_experts, bias=False)

# Create multiple expert FFNs
self.experts = nn.ModuleList([
FFN(d_model, hidden_dim, multiple_of, dropout)
for _ in range(num_experts)
])
def forward(self, x):
B, T, C = x.shape

# Get routing probabilities
route_logits = self.gate(x) # [B, T, num_experts]

# Select top-k experts for each token
route_probs = F.softmax(route_logits, dim=-1)
scores, indices = torch.topk(route_probs, k=self.num_expert_per_tok, dim=-1)

# Normalize the routing probabilities
scores = scores / scores.sum(dim=-1, keepdim=True)

# Initialize output tensor
final_output = torch.zeros_like(x)

# Process tokens through selected experts
for expert_idx in range(self.num_experts):
# Find which positions need this expert
expert_mask = (indices == expert_idx).any(dim=-1) # [B, T]
if not expert_mask.any():
continue

# Get the scores for this expert where it was selected
expert_scores = torch.zeros_like(route_probs[..., 0]) # [B, T]
expert_scores[indices == expert_idx] = scores[indices == expert_idx]

# Process tokens through this expert
expert_output = self.experts[expert_idx](x) # [B, T, C]

# Add weighted output to final result
final_output += expert_output * expert_scores.unsqueeze(-1)

return final_output

class RMSNorm(nn.Module):
def __init__(self,d_model,norm_eps=1e-6):
super().__init__()
Expand All @@ -196,11 +251,13 @@ class TransformerBlock(nn.Module):
def __init__(self,layer_id, config):
super().__init__()
self.attn = Attention(config)
self.ffn = FFN(
self.moe = MoEFFN(
config.d_model,
config.hidden_dim,
config.multiple_of,
config.dropout
num_experts=config.num_experts,
num_expert_per_tok=config.num_expert_per_tok,
multiple_of=config.multiple_of,
dropout=config.dropout
)
self.layer_id = layer_id
self.attn_norm = RMSNorm(config.d_model, config.eps)
Expand Down