Skip to content

Commit 6d30a0e

Browse files
committed
temp remove profiler (de)activation logging
1 parent bfdb7ed commit 6d30a0e

File tree

1 file changed

+73
-36
lines changed

1 file changed

+73
-36
lines changed

src/MaxText/metric_logger.py

Lines changed: 73 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -104,43 +104,80 @@ def write_metrics(self, metrics, step, is_training=True):
104104
def log_metrics(self, metrics, step, is_training):
105105
"""Logs metrics via max_logging."""
106106
if is_training:
107-
loss = metrics["scalar"]["learning/loss"]
108-
# Do not show flops and tokens during batch size rampup
109-
if step >= self.config.rampup_end_step:
110-
log_message = (
111-
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
112-
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
113-
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
114-
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
115-
f"loss: {loss:.3f}"
116-
)
117-
else:
118-
log_message = (
119-
"[Rampup Batch Size Phase]: "
120-
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
121-
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
122-
f"loss: {loss:.3f}"
123-
)
124-
125-
if self.config.mtp_num_layers > 0:
126-
mtp_loss = metrics["scalar"].get("learning/mtp_loss", 0.0)
127-
main_model_loss = loss - mtp_loss
128-
log_message += f", main_model_loss: {main_model_loss:.3f}, mtp_loss: {mtp_loss:.3f}"
129-
107+
self._log_training_metrics(metrics, step)
130108
else:
131-
log_message = (
132-
f"eval metrics after step: {step},"
133-
f" loss={metrics['scalar']['eval/avg_loss']:.3f},"
134-
f" total_weights={metrics['scalar']['eval/total_weights']}"
135-
)
136-
137-
if self.config.mtp_num_layers > 0:
138-
log_message += (
139-
f", avg_mtp_loss={metrics['scalar']['eval/avg_mtp_loss']:.3f},"
140-
f" avg_mtp_acceptance_rate={metrics['scalar']['eval/avg_mtp_acceptance_rate_percent']:.2f}%"
141-
)
142-
143-
max_logging.log(log_message)
109+
self._log_eval_metrics(metrics, step)
110+
111+
def _log_training_metrics(self, metrics, step):
112+
"""Handles training-specific metric logging."""
113+
# Skip logging if in profiler activation/deactivation steps
114+
# TODO(b/456828037): Switch to subprocess profiling to avoid timing artifacts at boundary steps.
115+
if self._is_profiler_boundary_step(step):
116+
return
117+
118+
scalars = metrics["scalar"]
119+
loss = scalars["learning/loss"]
120+
is_rampup = step < self.config.rampup_end_step
121+
122+
# Start building the log parts
123+
log_parts = []
124+
if is_rampup:
125+
log_parts.append("[Rampup Batch Size Phase]")
126+
127+
log_parts.extend([
128+
f"completed step: {step}",
129+
f"seconds: {scalars['perf/step_time_seconds']:.3f}",
130+
])
131+
132+
# Add performance metrics only if strictly NOT in rampup phase
133+
# TODO(b/452468482): Enable performance metric (TFLOPs, Tokens/s) tracking during batch size rampup.
134+
if not is_rampup:
135+
log_parts.extend([
136+
f"TFLOP/s/device: {scalars['perf/per_device_tflops_per_sec']:.3f}",
137+
f"Tokens/s/device: {scalars['perf/per_device_tokens_per_sec']:.3f}",
138+
])
139+
140+
log_parts.extend([
141+
f"total_weights: {scalars['learning/total_weights']}",
142+
f"loss: {loss:.3f}",
143+
])
144+
145+
if self.config.mtp_num_layers > 0:
146+
mtp_loss = scalars.get("learning/mtp_loss", 0.0)
147+
log_parts.append(f"main_model_loss: {loss - mtp_loss:.3f}")
148+
log_parts.append(f"mtp_loss: {mtp_loss:.3f}")
149+
150+
max_logging.log(", ".join(log_parts))
151+
152+
def _log_eval_metrics(self, metrics, step):
153+
"""Handles evaluation-specific metric logging."""
154+
scalars = metrics["scalar"]
155+
log_parts = [
156+
f"eval metrics after step: {step}",
157+
f"loss={scalars['eval/avg_loss']:.3f}",
158+
f"total_weights={scalars['eval/total_weights']}",
159+
]
160+
161+
if self.config.mtp_num_layers > 0:
162+
log_parts.extend([
163+
f"avg_mtp_loss={scalars['eval/avg_mtp_loss']:.3f}",
164+
f"avg_mtp_acceptance_rate={scalars['eval/avg_mtp_acceptance_rate_percent']:.2f}%",
165+
])
166+
167+
max_logging.log(", ".join(log_parts))
168+
169+
def _is_profiler_boundary_step(self, step):
170+
"""Determines if the current step is a profiler start/stop boundary that should be hidden."""
171+
skip_steps = self.config.skip_first_n_steps_for_profiler
172+
profiler_steps = self.config.profiler_steps
173+
# Steps immediately before/at start, and at/immediately after end of profiling
174+
boundary_steps = {
175+
skip_steps,
176+
skip_steps + 1,
177+
skip_steps + profiler_steps,
178+
skip_steps + profiler_steps + 1,
179+
}
180+
return step in boundary_steps
144181

145182
def write_metrics_locally(self, metrics, step):
146183
"""Writes metrics locally for testing."""

0 commit comments

Comments
 (0)