diff --git a/mamba_r1/model.py b/mamba_r1/model.py index 7ff2cdb..27ee7ea 100644 --- a/mamba_r1/model.py +++ b/mamba_r1/model.py @@ -11,7 +11,7 @@ from mamba_ssm import Mamba from safetensors.torch import load_model from transformers import AutoTokenizer - +from zeta.nn import OutputHead class RotaryEmbedding(nn.Module): def __init__( @@ -330,11 +330,14 @@ def __init__( expand: int = 2, num_experts: int = 8, expert_dim: int = 2048, + vocab_size: int = 32000, max_seq_len: int = 2048, ): super().__init__() self.dim = dim self.depth = depth + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len # Embeddings self.rotary_emb = RotaryEmbedding( @@ -401,7 +404,7 @@ def forward( elif return_loss: return torch.nn.CrossEntropyLoss()(x, correct_answer) else: - return x + return OutputHead(self.dim, vocab_size=self.vocab_size)(x) def generate( self,