Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple printing logs to stdout from tensorboard #1376

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions MaxText/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def write_metrics(self, running_gcs_metrics, metrics, step, is_training=True):
steps_to_write = step

if metrics_to_write:
self.log_metrics(metrics_to_write, steps_to_write, is_training)

if self.config.enable_tensorboard:
self.write_metrics_to_tensorboard(metrics_to_write, steps_to_write, is_training)

Expand All @@ -76,6 +78,22 @@ def write_metrics(self, running_gcs_metrics, metrics, step, is_training=True):
if self.config.gcs_metrics and jax.process_index() == 0:
running_gcs_metrics = self.write_metrics_for_gcs(metrics_to_write, steps_to_write, running_gcs_metrics, is_training)

def _log_metrics(self, metrics, step):
"""Logs metrics via max_logging"""
max_logging.log(
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
f"loss: {metrics['scalar']['learning/loss']:.3f}"
)

def log_metrics(self, metrics, step, is_training):
"""Prints metrics"""
with jax.spmd_mode("allow_all"):
if is_training:
self._log_metrics(metrics, step)
Comment on lines +93 to +95
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to gate the logging with with jax.spmd_mode("allow_all") and if metrics to write? I see a duplicate PR was created that isn't gating with those. Are the extra conditions needed here?


def write_metrics_locally(self, metrics, step):
"""Writes metrics locally for testing"""
with open(self.config.metrics_file, "a", encoding="utf8") as local_metrics_file:
Expand Down Expand Up @@ -115,14 +133,6 @@ def write_metrics_to_tensorboard(self, metrics, step, is_training):
if is_training:
full_log = step % self.config.log_period == 0

max_logging.log(
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
f"total_weights: {metrics['scalar']['learning/total_weights']}, "
f"loss: {metrics['scalar']['learning/loss']:.3f}"
)

if full_log and jax.process_index() == 0:
max_logging.log(f"To see full metrics 'tensorboard --logdir={self.config.tensorboard_dir}'")
self.writer.flush()
Loading