From 55557db8b1dd097269f39ce77b6369a66bc9f32b Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Sat, 8 Nov 2025 01:04:04 +0000 Subject: [PATCH] temp remove profiler (de)activation logging --- src/MaxText/configs/base.yml | 1 + src/MaxText/metric_logger.py | 116 +++++++++++++++++++++++++---------- 2 files changed, 84 insertions(+), 33 deletions(-) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index fe85db62e..a18ef20c3 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -624,6 +624,7 @@ upload_all_profiler_results: False skip_first_n_steps_for_profiler: 1 # Profile for a small number of steps to avoid a large profile file size. profiler_steps: 5 +hide_profiler_step_metric: False profile_cleanly: True # If set to true, adds a block_until_ready on train state which aligns the profile for each step. profile_periodically_period: -1 # If set to a positive integer, profile every profile_periodically_period steps. # This is useful to debug scenarios where performance is changing. diff --git a/src/MaxText/metric_logger.py b/src/MaxText/metric_logger.py index c3999a455..f64b3ce52 100644 --- a/src/MaxText/metric_logger.py +++ b/src/MaxText/metric_logger.py @@ -104,43 +104,93 @@ def write_metrics(self, metrics, step, is_training=True): def log_metrics(self, metrics, step, is_training): """Logs metrics via max_logging.""" if is_training: - loss = metrics["scalar"]["learning/loss"] - # Do not show flops and tokens during batch size rampup - if step >= self.config.rampup_end_step: - log_message = ( - f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " - f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " - f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, " - f"total_weights: {metrics['scalar']['learning/total_weights']}, " - f"loss: {loss:.3f}" - ) - else: - log_message = ( - "[Rampup Batch Size Phase]: " - f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " - f"total_weights: {metrics['scalar']['learning/total_weights']}, " - f"loss: {loss:.3f}" - ) - - if self.config.mtp_num_layers > 0: - mtp_loss = metrics["scalar"].get("learning/mtp_loss", 0.0) - main_model_loss = loss - mtp_loss - log_message += f", main_model_loss: {main_model_loss:.3f}, mtp_loss: {mtp_loss:.3f}" - + self._log_training_metrics(metrics, step) + else: + self._log_eval_metrics(metrics, step) + + def _log_training_metrics(self, metrics, step): + """Handles training-specific metric logging.""" + # Skip logging if in profiler activation/deactivation steps + # TODO(b/456828037): Switch to subprocess profiling to avoid timing artifacts at boundary steps. + scalars = metrics["scalar"] + loss = scalars["learning/loss"] + is_rampup = step < self.config.rampup_end_step + is_metric_hidden_step = self.config.hide_profiler_step_metric and self._is_profiler_boundary_step(step) + + # Start building the log parts + log_parts = [] + if is_rampup: + log_parts.append("[Rampup Batch Size Phase]") + + if is_metric_hidden_step: + log_parts.append( + f"completed profiler activation/deactivation step: {step}", + ) else: - log_message = ( - f"eval metrics after step: {step}," - f" loss={metrics['scalar']['eval/avg_loss']:.3f}," - f" total_weights={metrics['scalar']['eval/total_weights']}" + log_parts.extend( + [ + f"completed step: {step}", + f"seconds: {scalars['perf/step_time_seconds']:.3f}", + ] ) - if self.config.mtp_num_layers > 0: - log_message += ( - f", avg_mtp_loss={metrics['scalar']['eval/avg_mtp_loss']:.3f}," - f" avg_mtp_acceptance_rate={metrics['scalar']['eval/avg_mtp_acceptance_rate_percent']:.2f}%" - ) + # Add performance metrics only if strictly NOT in rampup phase + # TODO(b/452468482): Enable performance metric (TFLOPs, Tokens/s) tracking during batch size rampup. + if not is_rampup and not is_metric_hidden_step: + log_parts.extend( + [ + f"TFLOP/s/device: {scalars['perf/per_device_tflops_per_sec']:.3f}", + f"Tokens/s/device: {scalars['perf/per_device_tokens_per_sec']:.3f}", + ] + ) + + log_parts.extend( + [ + f"total_weights: {scalars['learning/total_weights']}", + f"loss: {loss:.3f}", + ] + ) + + if self.config.mtp_num_layers > 0: + mtp_loss = scalars.get("learning/mtp_loss", 0.0) + log_parts.append(f"main_model_loss: {loss - mtp_loss:.3f}") + log_parts.append(f"mtp_loss: {mtp_loss:.3f}") + + max_logging.log(", ".join(log_parts)) + + def _log_eval_metrics(self, metrics, step): + """Handles evaluation-specific metric logging.""" + scalars = metrics["scalar"] + log_parts = [ + f"eval metrics after step: {step}", + f"loss={scalars['eval/avg_loss']:.3f}", + f"total_weights={scalars['eval/total_weights']}", + ] + + if self.config.mtp_num_layers > 0: + log_parts.extend( + [ + f"avg_mtp_loss={scalars['eval/avg_mtp_loss']:.3f}", + f"avg_mtp_acceptance_rate={scalars['eval/avg_mtp_acceptance_rate_percent']:.2f}%", + ] + ) - max_logging.log(log_message) + max_logging.log(", ".join(log_parts)) + + def _is_profiler_boundary_step(self, step): + """Determines if the current step is a profiler start/stop boundary that should be hidden.""" + if len(self.config.profiler) == 0: + return False + skip_steps = self.config.skip_first_n_steps_for_profiler + profiler_steps = self.config.profiler_steps + # Steps immediately before/at start, and at/immediately after end of profiling + boundary_steps = { + skip_steps, + skip_steps + 1, + skip_steps + profiler_steps, + skip_steps + profiler_steps + 1, + } + return step in boundary_steps def write_metrics_locally(self, metrics, step): """Writes metrics locally for testing."""