Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use master weights for bfloat16 FusedAdam when master_weights=True #1731

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

cbcase
Copy link
Contributor

@cbcase cbcase commented Sep 22, 2023

As mentioned in #1728, the FusedAdam optimizer ignores master_weights=True for bfloat16 parameters. This PR fixes that oversight. I have confirmed that the behavior now matches a "by hand" implementation of master weights (hand-copying) along with vanilla torch.optim.AdamW on the fp32 copy.

@cbcase
Copy link
Contributor Author

cbcase commented Oct 16, 2023

Ping @minitu, looks like you added this support originally -- could you take a look? Thanks

@minitu
Copy link
Contributor

minitu commented Oct 17, 2023

LGTM, we only looked at adding master weights for FP16 AMP at the time of the original PR.
@crcrpar Could you review this as well?

Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good but could you add a test case of bfloat16 model with fp32 weights to

def testGradScalerCapturableMaster(self):
?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants