Skip to content

Commit

Permalink
[feat][output head]
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed Jan 29, 2025
1 parent cb0f69a commit e34ecb9
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions mamba_r1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mamba_ssm import Mamba
from safetensors.torch import load_model
from transformers import AutoTokenizer

from zeta.nn import OutputHead

class RotaryEmbedding(nn.Module):
def __init__(
Expand Down Expand Up @@ -330,11 +330,14 @@ def __init__(
expand: int = 2,
num_experts: int = 8,
expert_dim: int = 2048,
vocab_size: int = 32000,
max_seq_len: int = 2048,
):
super().__init__()
self.dim = dim
self.depth = depth
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len

# Embeddings
self.rotary_emb = RotaryEmbedding(
Expand Down Expand Up @@ -401,7 +404,7 @@ def forward(
elif return_loss:
return torch.nn.CrossEntropyLoss()(x, correct_answer)
else:
return x
return OutputHead(self.dim, vocab_size=self.vocab_size)(x)

def generate(
self,
Expand Down

0 comments on commit e34ecb9

Please sign in to comment.