-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Double buffer metrics In Train to avoid gap between steps #390
Conversation
The logic is that this ensures that Jax is able to queues train_steps and we | ||
don't block when turning "lazy" Jax arrays into real Python numbers. | ||
""" | ||
global _buffered_step, _buffered_metrics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Do these need to be global variables? Where else are they used? Are they global so they are not input/outputs of write_metrics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a nit! They need to be stateful so they can be "remembered" -- a pure function can't remember them. In the future as I hint we should encapsulate all of the metrics details into a class and that class can take care of this.
@@ -314,22 +338,22 @@ def train_loop(config, state=None): | |||
raise ValueError("Profiling requested but initial profiling step set past training final step") | |||
last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1) | |||
|
|||
nextrng = jax.random.fold_in(init_rng, start_step) | |||
example_batch = load_next_batch(data_iterator, None, config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A tiny question, does this line change the dataloading? Previously, we load first batch before entering the loop and now inside the loop. I remember there was a PR relevant.
Old gap (4.1 ms): https://screenshot.googleplex.com/3H8J9bAkWGAFfQL
New gap (.092 ms): https://screenshot.googleplex.com/BZzKRkh6dQ875VX
Unfortunately our reporting is a bit sus:
I honestly don't see a fix for this reporting issue.