@@ -104,43 +104,80 @@ def write_metrics(self, metrics, step, is_training=True):
104104 def log_metrics (self , metrics , step , is_training ):
105105 """Logs metrics via max_logging."""
106106 if is_training :
107- loss = metrics ["scalar" ]["learning/loss" ]
108- # Do not show flops and tokens during batch size rampup
109- if step >= self .config .rampup_end_step :
110- log_message = (
111- f"completed step: { step } , seconds: { metrics ['scalar' ]['perf/step_time_seconds' ]:.3f} , "
112- f"TFLOP/s/device: { metrics ['scalar' ]['perf/per_device_tflops_per_sec' ]:.3f} , "
113- f"Tokens/s/device: { metrics ['scalar' ]['perf/per_device_tokens_per_sec' ]:.3f} , "
114- f"total_weights: { metrics ['scalar' ]['learning/total_weights' ]} , "
115- f"loss: { loss :.3f} "
116- )
117- else :
118- log_message = (
119- "[Rampup Batch Size Phase]: "
120- f"completed step: { step } , seconds: { metrics ['scalar' ]['perf/step_time_seconds' ]:.3f} , "
121- f"total_weights: { metrics ['scalar' ]['learning/total_weights' ]} , "
122- f"loss: { loss :.3f} "
123- )
124-
125- if self .config .mtp_num_layers > 0 :
126- mtp_loss = metrics ["scalar" ].get ("learning/mtp_loss" , 0.0 )
127- main_model_loss = loss - mtp_loss
128- log_message += f", main_model_loss: { main_model_loss :.3f} , mtp_loss: { mtp_loss :.3f} "
129-
107+ self ._log_training_metrics (metrics , step )
130108 else :
131- log_message = (
132- f"eval metrics after step: { step } ,"
133- f" loss={ metrics ['scalar' ]['eval/avg_loss' ]:.3f} ,"
134- f" total_weights={ metrics ['scalar' ]['eval/total_weights' ]} "
135- )
136-
137- if self .config .mtp_num_layers > 0 :
138- log_message += (
139- f", avg_mtp_loss={ metrics ['scalar' ]['eval/avg_mtp_loss' ]:.3f} ,"
140- f" avg_mtp_acceptance_rate={ metrics ['scalar' ]['eval/avg_mtp_acceptance_rate_percent' ]:.2f} %"
141- )
142-
143- max_logging .log (log_message )
109+ self ._log_eval_metrics (metrics , step )
110+
111+ def _log_training_metrics (self , metrics , step ):
112+ """Handles training-specific metric logging."""
113+ # Skip logging if in profiler activation/deactivation steps
114+ # TODO(b/456828037): Switch to subprocess profiling to avoid timing artifacts at boundary steps.
115+ if self ._is_profiler_boundary_step (step ):
116+ return
117+
118+ scalars = metrics ["scalar" ]
119+ loss = scalars ["learning/loss" ]
120+ is_rampup = step < self .config .rampup_end_step
121+
122+ # Start building the log parts
123+ log_parts = []
124+ if is_rampup :
125+ log_parts .append ("[Rampup Batch Size Phase]" )
126+
127+ log_parts .extend ([
128+ f"completed step: { step } " ,
129+ f"seconds: { scalars ['perf/step_time_seconds' ]:.3f} " ,
130+ ])
131+
132+ # Add performance metrics only if strictly NOT in rampup phase
133+ # TODO(b/452468482): Enable performance metric (TFLOPs, Tokens/s) tracking during batch size rampup.
134+ if not is_rampup :
135+ log_parts .extend ([
136+ f"TFLOP/s/device: { scalars ['perf/per_device_tflops_per_sec' ]:.3f} " ,
137+ f"Tokens/s/device: { scalars ['perf/per_device_tokens_per_sec' ]:.3f} " ,
138+ ])
139+
140+ log_parts .extend ([
141+ f"total_weights: { scalars ['learning/total_weights' ]} " ,
142+ f"loss: { loss :.3f} " ,
143+ ])
144+
145+ if self .config .mtp_num_layers > 0 :
146+ mtp_loss = scalars .get ("learning/mtp_loss" , 0.0 )
147+ log_parts .append (f"main_model_loss: { loss - mtp_loss :.3f} " )
148+ log_parts .append (f"mtp_loss: { mtp_loss :.3f} " )
149+
150+ max_logging .log (", " .join (log_parts ))
151+
152+ def _log_eval_metrics (self , metrics , step ):
153+ """Handles evaluation-specific metric logging."""
154+ scalars = metrics ["scalar" ]
155+ log_parts = [
156+ f"eval metrics after step: { step } " ,
157+ f"loss={ scalars ['eval/avg_loss' ]:.3f} " ,
158+ f"total_weights={ scalars ['eval/total_weights' ]} " ,
159+ ]
160+
161+ if self .config .mtp_num_layers > 0 :
162+ log_parts .extend ([
163+ f"avg_mtp_loss={ scalars ['eval/avg_mtp_loss' ]:.3f} " ,
164+ f"avg_mtp_acceptance_rate={ scalars ['eval/avg_mtp_acceptance_rate_percent' ]:.2f} %" ,
165+ ])
166+
167+ max_logging .log (", " .join (log_parts ))
168+
169+ def _is_profiler_boundary_step (self , step ):
170+ """Determines if the current step is a profiler start/stop boundary that should be hidden."""
171+ skip_steps = self .config .skip_first_n_steps_for_profiler
172+ profiler_steps = self .config .profiler_steps
173+ # Steps immediately before/at start, and at/immediately after end of profiling
174+ boundary_steps = {
175+ skip_steps ,
176+ skip_steps + 1 ,
177+ skip_steps + profiler_steps ,
178+ skip_steps + profiler_steps + 1 ,
179+ }
180+ return step in boundary_steps
144181
145182 def write_metrics_locally (self , metrics , step ):
146183 """Writes metrics locally for testing."""
0 commit comments