From 390c5d561b55ca317082993abb9e19cf5b284600 Mon Sep 17 00:00:00 2001 From: Rafi Witten Date: Wed, 31 Jan 2024 17:38:06 +0000 Subject: [PATCH] Double buffer metrics In Train to avoid gap between steps --- MaxText/train.py | 50 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/MaxText/train.py b/MaxText/train.py index 37a992d22..f3c2fb103 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -36,7 +36,7 @@ import max_utils import maxtext_utils import max_logging -from maxtext import optimizers +import optimizers import pyconfig from input_pipeline.input_pipeline_interface import create_data_iterator_with_tokenizer @@ -107,9 +107,35 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr): }) metrics['scalar'].update({'learning/current_learning_rate': lr }) +_buffered_step = None +_buffered_metrics = None +def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config): + """Entry point for all metrics writing in Train's Main. + TODO: would be better as a Class in the future (that initialized all state!) -def write_metrics(writer, metrics, step, config): - """Writes metrics to tensorboard""" + To avoid introducing an unnecessary dependency, we "double buffer" -- we hold + onto the last metrics and step and only publish when we receive a new metrics and step. + The logic is that this ensures that Jax is able to queues train_steps and we + don't block when turning "lazy" Jax arrays into real Python numbers. + """ + global _buffered_step, _buffered_metrics + + if _buffered_metrics is not None: + if _buffered_step is None: + raise ValueError(f"When writing metrics, {_buffered_step=} was none") + write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config) + + if config.metrics_file: + max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file) + + if config.gcs_metrics and jax.process_index() == 0: + running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics) + + _buffered_step = step + _buffered_metrics = metrics + +def write_metrics_to_tensorboard(writer, metrics, step, config): + """ Writes metrics to tensorboard""" with jax.spmd_mode('allow_all'): if jax.process_index() == 0: for metric_name in metrics.get("scalar",[]): @@ -303,8 +329,6 @@ def train_loop(config, state=None): static_argnums=static_argnums, donate_argnums=donate_argnums) - last_step_completion = datetime.datetime.now() - local_metrics_file = open(config.metrics_file, 'a', encoding="utf8") if config.metrics_file else None running_gcs_metrics = [] if config.gcs_metrics else None @@ -314,22 +338,22 @@ def train_loop(config, state=None): raise ValueError("Profiling requested but initial profiling step set past training final step") last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1) - nextrng = jax.random.fold_in(init_rng, start_step) - example_batch = load_next_batch(data_iterator, None, config) + example_batch = None + last_step_completion = datetime.datetime.now() + for step in np.arange(start_step, config.steps): if step == first_profiling_step: max_utils.activate_profiler(config) + example_batch = load_next_batch(data_iterator, example_batch, config) + nextrng = jax.jit(jax.random.fold_in)(init_rng, step) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): state, metrics = p_train_step( state, example_batch, nextrng ) - example_batch = load_next_batch(data_iterator, example_batch, config) - nextrng = jax.random.fold_in(init_rng, step+1) new_time = datetime.datetime.now() record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step)) - write_metrics(writer, metrics, step, config) last_step_completion = new_time if checkpoint_manager is not None: @@ -340,11 +364,7 @@ def train_loop(config, state=None): checkpoint_manager.wait_until_finished() sys.exit() - if config.metrics_file: - max_utils.write_metrics_locally(metrics, step, config, local_metrics_file) - - if config.gcs_metrics and jax.process_index() == 0: - running_gcs_metrics = max_utils.write_metrics_for_gcs(metrics, step, config, running_gcs_metrics) + write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config) if step == last_profiling_step: max_utils.deactivate_profiler(config)