diff --git a/MaxText/metric_logger.py b/MaxText/metric_logger.py index 9be1e5eb7..1106dd4f7 100644 --- a/MaxText/metric_logger.py +++ b/MaxText/metric_logger.py @@ -67,6 +67,8 @@ def write_metrics(self, running_gcs_metrics, metrics, step, is_training=True): steps_to_write = step if metrics_to_write: + self.log_metrics(metrics_to_write, steps_to_write, is_training) + if self.config.enable_tensorboard: self.write_metrics_to_tensorboard(metrics_to_write, steps_to_write, is_training) @@ -76,6 +78,22 @@ def write_metrics(self, running_gcs_metrics, metrics, step, is_training=True): if self.config.gcs_metrics and jax.process_index() == 0: running_gcs_metrics = self.write_metrics_for_gcs(metrics_to_write, steps_to_write, running_gcs_metrics, is_training) + def _log_metrics(self, metrics, step): + """Logs metrics via max_logging""" + max_logging.log( + f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " + f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " + f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, " + f"total_weights: {metrics['scalar']['learning/total_weights']}, " + f"loss: {metrics['scalar']['learning/loss']:.3f}" + ) + + def log_metrics(self, metrics, step, is_training): + """Prints metrics""" + with jax.spmd_mode("allow_all"): + if is_training: + self._log_metrics(metrics, step) + def write_metrics_locally(self, metrics, step): """Writes metrics locally for testing""" with open(self.config.metrics_file, "a", encoding="utf8") as local_metrics_file: @@ -115,14 +133,6 @@ def write_metrics_to_tensorboard(self, metrics, step, is_training): if is_training: full_log = step % self.config.log_period == 0 - max_logging.log( - f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " - f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " - f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, " - f"total_weights: {metrics['scalar']['learning/total_weights']}, " - f"loss: {metrics['scalar']['learning/loss']:.3f}" - ) - if full_log and jax.process_index() == 0: max_logging.log(f"To see full metrics 'tensorboard --logdir={self.config.tensorboard_dir}'") self.writer.flush()