Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ upload_all_profiler_results: False
skip_first_n_steps_for_profiler: 1
# Profile for a small number of steps to avoid a large profile file size.
profiler_steps: 5
hide_profiler_step_metric: False
profile_cleanly: True # If set to true, adds a block_until_ready on train state which aligns the profile for each step.
profile_periodically_period: -1 # If set to a positive integer, profile every profile_periodically_period steps.
# This is useful to debug scenarios where performance is changing.
Expand Down
116 changes: 83 additions & 33 deletions src/MaxText/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,43 +104,93 @@ def write_metrics(self, metrics, step, is_training=True):
def log_metrics(self, metrics, step, is_training):
"""Logs metrics via max_logging."""
if is_training:
loss = metrics["scalar"]["learning/loss"]
# Do not show flops and tokens during batch size rampup
if step >= self.config.rampup_end_step:
log_message = (
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
f"loss: {loss:.3f}"
)
else:
log_message = (
"[Rampup Batch Size Phase]: "
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
f"loss: {loss:.3f}"
)

if self.config.mtp_num_layers > 0:
mtp_loss = metrics["scalar"].get("learning/mtp_loss", 0.0)
main_model_loss = loss - mtp_loss
log_message += f", main_model_loss: {main_model_loss:.3f}, mtp_loss: {mtp_loss:.3f}"

self._log_training_metrics(metrics, step)
else:
self._log_eval_metrics(metrics, step)

def _log_training_metrics(self, metrics, step):
"""Handles training-specific metric logging."""
# Skip logging if in profiler activation/deactivation steps
# TODO(b/456828037): Switch to subprocess profiling to avoid timing artifacts at boundary steps.
scalars = metrics["scalar"]
loss = scalars["learning/loss"]
is_rampup = step < self.config.rampup_end_step
is_metric_hidden_step = self.config.hide_profiler_step_metric and self._is_profiler_boundary_step(step)

# Start building the log parts
log_parts = []
if is_rampup:
log_parts.append("[Rampup Batch Size Phase]")

if is_metric_hidden_step:
log_parts.append(
f"completed profiler activation/deactivation step: {step}",
)
else:
log_message = (
f"eval metrics after step: {step},"
f" loss={metrics['scalar']['eval/avg_loss']:.3f},"
f" total_weights={metrics['scalar']['eval/total_weights']}"
log_parts.extend(
[
f"completed step: {step}",
f"seconds: {scalars['perf/step_time_seconds']:.3f}",
]
)

if self.config.mtp_num_layers > 0:
log_message += (
f", avg_mtp_loss={metrics['scalar']['eval/avg_mtp_loss']:.3f},"
f" avg_mtp_acceptance_rate={metrics['scalar']['eval/avg_mtp_acceptance_rate_percent']:.2f}%"
)
# Add performance metrics only if strictly NOT in rampup phase
# TODO(b/452468482): Enable performance metric (TFLOPs, Tokens/s) tracking during batch size rampup.
if not is_rampup and not is_metric_hidden_step:
log_parts.extend(
[
f"TFLOP/s/device: {scalars['perf/per_device_tflops_per_sec']:.3f}",
f"Tokens/s/device: {scalars['perf/per_device_tokens_per_sec']:.3f}",
]
)

log_parts.extend(
[
f"total_weights: {scalars['learning/total_weights']}",
f"loss: {loss:.3f}",
]
)

if self.config.mtp_num_layers > 0:
mtp_loss = scalars.get("learning/mtp_loss", 0.0)
log_parts.append(f"main_model_loss: {loss - mtp_loss:.3f}")
log_parts.append(f"mtp_loss: {mtp_loss:.3f}")

max_logging.log(", ".join(log_parts))

def _log_eval_metrics(self, metrics, step):
"""Handles evaluation-specific metric logging."""
scalars = metrics["scalar"]
log_parts = [
f"eval metrics after step: {step}",
f"loss={scalars['eval/avg_loss']:.3f}",
f"total_weights={scalars['eval/total_weights']}",
]

if self.config.mtp_num_layers > 0:
log_parts.extend(
[
f"avg_mtp_loss={scalars['eval/avg_mtp_loss']:.3f}",
f"avg_mtp_acceptance_rate={scalars['eval/avg_mtp_acceptance_rate_percent']:.2f}%",
]
)

max_logging.log(log_message)
max_logging.log(", ".join(log_parts))

def _is_profiler_boundary_step(self, step):
"""Determines if the current step is a profiler start/stop boundary that should be hidden."""
if len(self.config.profiler) == 0:
return False
skip_steps = self.config.skip_first_n_steps_for_profiler
profiler_steps = self.config.profiler_steps
# Steps immediately before/at start, and at/immediately after end of profiling
boundary_steps = {
skip_steps,
skip_steps + 1,
skip_steps + profiler_steps,
skip_steps + profiler_steps + 1,
}
return step in boundary_steps

def write_metrics_locally(self, metrics, step):
"""Writes metrics locally for testing."""
Expand Down
Loading