-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
clean up and add ckpt tests #179
Conversation
0fa14d0
to
ca3b1c8
Compare
65e7b74
to
6755478
Compare
6755478
to
dbddd9f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Why are we killing GPU memory monitor though?
memory_profiler: MemoryProfilerConfig | None = None | ||
|
||
sequence_packing: bool = True | ||
attn_fn: Literal["flash", "sdpa"] | None = None | ||
|
||
math_attn: bool = False # slow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about putting this as an option in attn_fn
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attn_fn: Literal["flash", "sdpa"] | None = None
@model_validator(mode="after")
def validate_attn_fn(self):
if self.attn_fn is not None:
warnings.warn("attn_fn argument is deprecated")
return self
hmm attn_fn is not used anymore. I just kept it to avoid conflict with old code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pr to remove attn_fn #180
@@ -164,6 +200,7 @@ def train(config: Config): | |||
config.type_model, | |||
vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE, | |||
seq_length=config.data.seq_length, | |||
math_attn=config.train.math_attn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about passing attn_fn instead? Would also allow sdpa to be specified
the profiler is enough I think I would kill it unless you had use case where you need it. I personally never used it even tho I added it haha |
yea never used it either haha |
what this pr do: