You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: MaxText/configs/base.yml
+7Lines changed: 7 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -342,6 +342,11 @@ grain_worker_count: 1
342
342
steps: 150_001# If set to -1 then will inherit value from learning_rate_schedule_steps
343
343
log_period: 100# Flushes Tensorboard
344
344
345
+
jax_distributed_initialization_timeout: 300# This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py
346
+
# Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers
347
+
# only to the jax coordination service.
348
+
jax_debug_log_modules: ""# Set this to "jax" to enable jax verbose logging such as for the jax coordination service initialization.
349
+
345
350
# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
346
351
# Learning rate schedule has either two or three parts:
347
352
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
@@ -477,6 +482,8 @@ prometheus_port: 0
477
482
enable_jax_profiler: False
478
483
jax_profiler_port: 9999
479
484
485
+
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
0 commit comments