Skip to content

Commit

Permalink
Merge pull request #390 from google:rwitten_double_buffer_metrics
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 603085925
  • Loading branch information
maxtex authors committed Jan 31, 2024
2 parents 8ac9c4d + d1d16d6 commit 7eea549
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,34 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
})
metrics['scalar'].update({'learning/current_learning_rate': lr })

_buffered_step = None
_buffered_metrics = None
def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config):
"""Entry point for all metrics writing in Train's Main.
TODO: would be better as a Class in the future (that initialized all state!)
To avoid introducing an unnecessary dependency, we "double buffer" -- we hold
onto the last metrics and step and only publish when we receive a new metrics and step.
The logic is that this ensures that Jax is able to queues train_steps and we
don't block when turning "lazy" Jax arrays into real Python numbers.
"""
global _buffered_step, _buffered_metrics

if _buffered_metrics is not None:
if _buffered_step is None:
raise ValueError(f"When writing metrics, {_buffered_step=} was none")
write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)

if config.metrics_file:
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)

def write_metrics(writer, metrics, step, config):
if config.gcs_metrics and jax.process_index() == 0:
running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics)

_buffered_step = step
_buffered_metrics = metrics

def write_metrics_to_tensorboard(writer, metrics, step, config):
""" Writes metrics to tensorboard"""
with jax.spmd_mode('allow_all'):
if jax.process_index() == 0:
Expand Down Expand Up @@ -303,8 +329,6 @@ def train_loop(config, state=None):
static_argnums=static_argnums,
donate_argnums=donate_argnums)

last_step_completion = datetime.datetime.now()

local_metrics_file = open(config.metrics_file, 'a', encoding="utf8") if config.metrics_file else None
running_gcs_metrics = [] if config.gcs_metrics else None

Expand All @@ -314,22 +338,22 @@ def train_loop(config, state=None):
raise ValueError("Profiling requested but initial profiling step set past training final step")
last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1)

nextrng = jax.random.fold_in(init_rng, start_step)
example_batch = load_next_batch(data_iterator, None, config)
example_batch = None
last_step_completion = datetime.datetime.now()

for step in np.arange(start_step, config.steps):
if step == first_profiling_step:
max_utils.activate_profiler(config)

example_batch = load_next_batch(data_iterator, example_batch, config)
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(
state, example_batch, nextrng
)

example_batch = load_next_batch(data_iterator, example_batch, config)
nextrng = jax.random.fold_in(init_rng, step+1)
new_time = datetime.datetime.now()
record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step))
write_metrics(writer, metrics, step, config)
last_step_completion = new_time

if checkpoint_manager is not None:
Expand All @@ -340,11 +364,7 @@ def train_loop(config, state=None):
checkpoint_manager.wait_until_finished()
sys.exit()

if config.metrics_file:
max_utils.write_metrics_locally(metrics, step, config, local_metrics_file)

if config.gcs_metrics and jax.process_index() == 0:
running_gcs_metrics = max_utils.write_metrics_for_gcs(metrics, step, config, running_gcs_metrics)
write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config)

if step == last_profiling_step:
max_utils.deactivate_profiler(config)
Expand Down

0 comments on commit 7eea549

Please sign in to comment.