Skip to content

Commit 4651cb3

Browse files
author
maxtext authors
committed
Merge pull request #1099 from AI-Hypercomputer:mattdavidow-jdi-telemetry
PiperOrigin-RevId: 707276379
2 parents 9f908a1 + 8e21573 commit 4651cb3

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

MaxText/configs/base.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,11 @@ grain_worker_count: 1
342342
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
343343
log_period: 100 # Flushes Tensorboard
344344

345+
jax_distributed_initialization_timeout: 300 # This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py
346+
# Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers
347+
# only to the jax coordination service.
348+
jax_debug_log_modules: "" # Set this to "jax" to enable jax verbose logging such as for the jax coordination service initialization.
349+
345350
# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
346351
# Learning rate schedule has either two or three parts:
347352
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
@@ -477,6 +482,8 @@ prometheus_port: 0
477482
enable_jax_profiler: False
478483
jax_profiler_port: 9999
479484

485+
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
486+
480487
# Checkpoint Structured logging
481488
enable_checkpoint_cloud_logger: False
482489

MaxText/max_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,11 @@ def maybe_initialize_jax_distributed_system(raw_keys):
224224
return
225225
if is_gpu_backend(raw_keys):
226226
max_logging.log("Attempting to initialize the jax distributed system for GPU backend...")
227-
initialize_jax_for_gpu()
227+
initialize_jax_for_gpu(raw_keys)
228228
max_logging.log("Jax distributed system initialized on GPU!")
229229
elif is_cpu_backend(raw_keys):
230230
max_logging.log("Attempting to initialize the jax distributed system for CPU backend...")
231-
initialize_jax_for_cpu()
231+
initialize_jax_for_cpu(raw_keys)
232232
max_logging.log("Jax distributed system initialized on CPUs!")
233233
elif (
234234
raw_keys["enable_checkpointing"]
@@ -238,13 +238,13 @@ def maybe_initialize_jax_distributed_system(raw_keys):
238238
) or raw_keys["hardware"] == "gpu_multiprocess":
239239
max_logging.log("Attempting to initialize the jax distributed system...")
240240
if not raw_keys["enable_emergency_checkpoint"]:
241-
jax.distributed.initialize()
241+
jax.distributed.initialize(initialization_timeout=raw_keys["jax_distributed_initialization_timeout"])
242242
else:
243243
initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys)
244244
max_logging.log("Jax distributed system initialized!")
245245

246246

247-
def initialize_jax_for_gpu():
247+
def initialize_jax_for_gpu(raw_keys):
248248
"""Jax distributed initialize for GPUs."""
249249
if os.environ.get("JAX_COORDINATOR_IP") is not None:
250250
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP"))
@@ -253,11 +253,12 @@ def initialize_jax_for_gpu():
253253
coordinator_address=f"{coordinator_ip}:{coordinator_port}",
254254
num_processes=int(os.getenv("NNODES")),
255255
process_id=int(os.getenv("NODE_RANK")),
256+
initialization_timeout=raw_keys["jax_distributed_initialization_timeout"],
256257
)
257258
max_logging.log(f"JAX global devices: {jax.devices()}")
258259

259260

260-
def initialize_jax_for_cpu():
261+
def initialize_jax_for_cpu(raw_keys):
261262
"""Jax distributed initialize for CPUs. Includes retries until the coordinator is ready."""
262263
coordinator_ip_address = get_coordinator_ip_address()
263264
coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK
@@ -272,6 +273,7 @@ def initialize_jax_for_cpu():
272273
coordinator_address=coordinator_address,
273274
process_id=pid,
274275
num_processes=int(os.environ.get("JAX_PROCESS_COUNT")),
276+
initialization_timeout=raw_keys["jax_distributed_initialization_timeout"],
275277
)
276278

277279

@@ -288,7 +290,11 @@ def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys):
288290
f"Using {process_id} as the process_id and {coordinator_address} as the"
289291
" coordinator_address to initialize JAX distributed runtime..."
290292
)
291-
jax.distributed.initialize(coordinator_address=coordinator_address, process_id=int(process_id))
293+
jax.distributed.initialize(
294+
coordinator_address=coordinator_address,
295+
process_id=int(process_id),
296+
initialization_timeout=raw_keys["jax_distributed_initialization_timeout"],
297+
)
292298
if raw_keys["use_replicator_service"]:
293299
REPLICATOR_FILE = "replicator.yaml"
294300
TEMP_FILE = REPLICATOR_FILE + ".tmp"
@@ -324,7 +330,7 @@ def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys):
324330
"Initializing JAX distributed runtime without args when emergency checkpointing is"
325331
" enabled. This should not happen and your workload may have unexpected behavior."
326332
)
327-
jax.distributed.initialize()
333+
jax.distributed.initialize(initialization_timeout=raw_keys["jax_distributed_initialization_timeout"])
328334

329335
ocp.multihost.initialize_runtime_to_distributed_ids()
330336
ocp.multihost.initialize_distributed_to_device_ids()

MaxText/pyconfig.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,8 @@ def __init__(self, argv: list[str], **kwargs):
345345
validate_no_keys_overwritten_twice(keys_from_env_and_command_line, keys_from_model)
346346

347347
# We initialize the jax distributed system here because it must be done before device backend is initialized.
348+
if raw_keys["jax_debug_log_modules"]:
349+
jax.config.update("jax_debug_log_modules", raw_keys["jax_debug_log_modules"])
348350
max_utils.maybe_initialize_jax_distributed_system(raw_keys)
349351

350352
if raw_keys["jax_cache_dir"]:
@@ -367,8 +369,10 @@ def __init__(self, argv: list[str], **kwargs):
367369
self.keys = raw_keys
368370
keys = [k for k in raw_keys] # pylint: disable=unnecessary-comprehension
369371
keys.sort()
370-
for k in keys:
371-
max_logging.log(f"Config param {k}: {raw_keys[k]}")
372+
373+
if raw_keys["log_config"]:
374+
for k in keys:
375+
max_logging.log(f"Config param {k}: {raw_keys[k]}")
372376

373377
@staticmethod
374378
def user_init(raw_keys):

0 commit comments

Comments
 (0)