Skip to content

Commit 80d050a

Browse files
authored
align model seq len with data seq len (#26)
1 parent ac78db1 commit 80d050a

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

src/zeroband/models/llama/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
}
6262

6363

64-
def get_model(name_model: str, type_model: str, vocab_size: int) -> tuple[Transformer, ModelArgs]:
64+
def get_model(name_model: str, type_model: str, vocab_size: int, seq_length: int) -> tuple[Transformer, ModelArgs]:
6565
"""get the transformer model"""
6666

6767
if type_model == "llama2":
@@ -72,4 +72,5 @@ def get_model(name_model: str, type_model: str, vocab_size: int) -> tuple[Transf
7272
raise ValueError(f"Model type {type_model} not supported")
7373

7474
config.vocab_size = vocab_size
75+
config.max_seq_len = seq_length
7576
return Transformer(config), config

src/zeroband/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def train(config: Config):
116116
vocab_size=tokenizer.vocab_size
117117
if config.name_model != "debugmodel" or not config.data.fake
118118
else TEST_VOCAB_SIZE,
119+
seq_length=config.data.seq_length,
119120
)
120121

121122
if config.train.log_model_hash:

0 commit comments

Comments
 (0)