Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Double buffer metrics In Train to avoid gap between steps #390

Merged
merged 2 commits into from
Jan 31, 2024

Conversation

rwitten
Copy link
Collaborator

@rwitten rwitten commented Jan 31, 2024

Old gap (4.1 ms): https://screenshot.googleplex.com/3H8J9bAkWGAFfQL
New gap (.092 ms): https://screenshot.googleplex.com/BZzKRkh6dQ875VX

Unfortunately our reporting is a bit sus:

Per train step, total TFLOPs will be 170.71, split as 94.20% learnable weight flops and 5.80% attention flops
completed step: 0, seconds: 5.120, TFLOP/s/device: 33.343, loss: 12.620
To see full metrics 'tensorboard --logdir=gs://runner-maxtext-logs/fake/tensorboard/'
completed step: 1, seconds: 0.811, TFLOP/s/device: 210.381, loss: 12.603
completed step: 2, seconds: 0.609, TFLOP/s/device: 280.094, loss: 10.725
completed step: 3, seconds: 0.996, TFLOP/s/device: 171.314, loss: 9.784
completed step: 4, seconds: 1.207, TFLOP/s/device: 141.395, loss: 9.447
completed step: 5, seconds: 1.202, TFLOP/s/device: 141.962, loss: 9.180
completed step: 6, seconds: 1.204, TFLOP/s/device: 141.783, loss: 9.001
completed step: 7, seconds: 1.205, TFLOP/s/device: 141.693, loss: 8.896
completed step: 8, seconds: 1.205, TFLOP/s/device: 141.641, loss: 8.807
completed step: 9, seconds: 1.206, TFLOP/s/device: 141.548, loss: 8.725

I honestly don't see a fix for this reporting issue.

The logic is that this ensures that Jax is able to queues train_steps and we
don't block when turning "lazy" Jax arrays into real Python numbers.
"""
global _buffered_step, _buffered_metrics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Do these need to be global variables? Where else are they used? Are they global so they are not input/outputs of write_metrics?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a nit! They need to be stateful so they can be "remembered" -- a pure function can't remember them. In the future as I hint we should encapsulate all of the metrics details into a class and that class can take care of this.

@copybara-service copybara-service bot merged commit 7eea549 into main Jan 31, 2024
13 checks passed
@copybara-service copybara-service bot deleted the rwitten_double_buffer_metrics branch January 31, 2024 18:13
@@ -314,22 +338,22 @@ def train_loop(config, state=None):
raise ValueError("Profiling requested but initial profiling step set past training final step")
last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1)

nextrng = jax.random.fold_in(init_rng, start_step)
example_batch = load_next_batch(data_iterator, None, config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A tiny question, does this line change the dataloading? Previously, we load first batch before entering the loop and now inside the loop. I remember there was a PR relevant.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants