Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up and add ckpt tests #179

Merged
merged 3 commits into from
Dec 23, 2024
Merged

clean up and add ckpt tests #179

merged 3 commits into from
Dec 23, 2024

Conversation

samsja
Copy link
Collaborator

@samsja samsja commented Dec 19, 2024

what this pr do:

  • clean up model/opt hash logging
  • add ckpt test using the hash logging
  • remove useless memory arguments

@samsja samsja force-pushed the refactor-test-and-hash branch 3 times, most recently from 0fa14d0 to ca3b1c8 Compare December 19, 2024 06:20
@samsja samsja force-pushed the refactor-test-and-hash branch 5 times, most recently from 65e7b74 to 6755478 Compare December 19, 2024 09:42
@samsja samsja force-pushed the refactor-test-and-hash branch from 6755478 to dbddd9f Compare December 19, 2024 09:51
@samsja samsja changed the title use shared function for log model hash clean up and add ckpt tets Dec 19, 2024
@samsja samsja changed the title clean up and add ckpt tets clean up and add ckpt tests Dec 19, 2024
Copy link
Member

@Jackmin801 Jackmin801 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Why are we killing GPU memory monitor though?

memory_profiler: MemoryProfilerConfig | None = None

sequence_packing: bool = True
attn_fn: Literal["flash", "sdpa"] | None = None

math_attn: bool = False # slow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about putting this as an option in attn_fn instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    attn_fn: Literal["flash", "sdpa"] | None = None

    @model_validator(mode="after")
    def validate_attn_fn(self):
        if self.attn_fn is not None:
            warnings.warn("attn_fn argument is deprecated")

        return self

hmm attn_fn is not used anymore. I just kept it to avoid conflict with old code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pr to remove attn_fn #180

@@ -164,6 +200,7 @@ def train(config: Config):
config.type_model,
vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE,
seq_length=config.data.seq_length,
math_attn=config.train.math_attn,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about passing attn_fn instead? Would also allow sdpa to be specified

@samsja
Copy link
Collaborator Author

samsja commented Dec 20, 2024

Nice! Why are we killing GPU memory monitor though?

the profiler is enough I think

I would kill it unless you had use case where you need it. I personally never used it even tho I added it haha

@Jackmin801
Copy link
Member

yea never used it either haha

@samsja samsja merged commit 4715633 into main Dec 23, 2024
2 checks passed
@samsja samsja deleted the refactor-test-and-hash branch December 23, 2024 04:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants