-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update gemma_backbone.py for sharding config. #1491
Conversation
PTAL again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM, though we still might want to check with other folks to help decide between the two.
Ack, I will leave the PR here and feel free to merge it when ready. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good!
* Update gemma_backbone.py for sharding config. * Update unit test and fix format. * Update sharding spec for gemma based on gemma training.
This is trying to address the #1464.
The new setting is based on the Gemma training script internally.
Here is some perf benchmark on TPU v3-8:
(Smaller value are better)
===================
base line (current setting):
generate: 1342 ms per 100 token
finetune with lora: 125ms/step
=====================
This PR setting
generate: 1245 ms per 100 token
finetune with lora: 64ms/step