Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ jobs:
- name: Test standalone_dataloader.py
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/standalone_dataloader.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=100 enable_checkpointing=false'
'python3 MaxText/standalone_dataloader.py MaxText/configs/base.yml run_name=standalone_dataloader_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=100 enable_checkpointing=false'
- name: Test standalone_checkpointer.py
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/standalone_checkpointer.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=200 checkpoint_period=50 enable_checkpointing=True async_checkpointing=False'
'python3 MaxText/standalone_checkpointer.py MaxText/configs/base.yml run_name=standalone_checkpointer_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=200 checkpoint_period=50 enable_checkpointing=True async_checkpointing=False'
- name: Test int8_training
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
Expand Down
3 changes: 3 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ base_output_directory: ""
# Jax cache directory
jax_cache_dir: "~/jax_cache"

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' and 'cpu'

# Parallelism
mesh_axes: ['data', 'fsdp', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
Expand Down
54 changes: 50 additions & 4 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import checkpointing
import common_types
import functools
import time
import socket

import max_logging

Expand Down Expand Up @@ -166,15 +168,59 @@ def upload_blob(destination_gcs_name, source_file_name):
blob = bucket.blob(prefix_name)
blob.upload_from_filename(source_file_name)

def initialize_jax_distributed_system():
def maybe_initialize_jax_distributed_system(raw_keys):
""" The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of
indirection in MaxText to avoid breaking the call sites unnecessarily.

Currently jax.distributed.initialize() fully works as expected!

For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments.
"""
max_logging.log("Attempting to initialize the jax distributed system...")
jax.distributed.initialize()
max_logging.log("Jax distributed system initialized!")
if (raw_keys["enable_checkpointing"] and raw_keys["async_checkpointing"]
and raw_keys["compile_topology_num_slices"]==-1):
max_logging.log("Attempting to initialize the jax distributed system...")
jax.distributed.initialize()
max_logging.log("Jax distributed system initialized!")
elif is_cpu_backend(raw_keys):
max_logging.log("Attempting to initialize the jax distributed system for CPU backend...")
initialize_jax_for_cpu()
max_logging.log("Jax distributed system initialized on CPUs!")



def initialize_jax_for_cpu():
"""Jax distributed initialize for CPUs. Includes retries until the coordinator is ready.
"""
if os.environ.get("JAX_COORDINATOR_ADDRESS") is not None:
coordinator_address = os.environ.get("JAX_COORDINATOR_ADDRESS")
coordinator_found = False
lookup_attempt = 1
max_coordinator_lookups = 50
while not coordinator_found and lookup_attempt <= max_coordinator_lookups:
try:
ip_address = socket.gethostbyname(coordinator_address)
coordinator_found = True
except socket.gaierror:
print(f"Failed to recognize coordinator address {coordinator_address} on attempt {lookup_attempt}, retrying...")
lookup_attempt += 1
time.sleep(5)

ip_address = socket.gethostbyname(coordinator_address)
coordinator_address = ip_address + ":1234" # JAX coordinator port used in XPK
# Env variables to be set in XPK or otherwise
job_index = int(os.environ.get("JOB_INDEX"))
job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX"))
processes_in_job = int(os.environ.get("PROCESSES_IN_JOB"))
pid = job_index * processes_in_job + job_completion_index
max_logging.log(f" Jax process id is {pid} ")
# Explicit initialize is needed only for CPUs
jax.distributed.initialize(coordinator_address=coordinator_address,
process_id=pid,
num_processes=int(os.environ.get("JAX_PROCESS_COUNT")))

def is_cpu_backend(raw_keys):
"""Determine whether Maxtext is intended to run on a CPU backend."""
return raw_keys["hardware"] == 'cpu'

def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_type):
"""Evaluates unspecified DCN/ICI parallelism values"""
Expand Down
5 changes: 2 additions & 3 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _load_kwargs(self, argv: list[str], **kwargs):
return args_dict

def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, **kwargs) -> list[str]:
''' Update model config from environemnt and command line
''' Update model config from environment and command line
'''
raw_data_from_cmd_line = self._load_kwargs(argv, **kwargs)
updated_keys = []
Expand Down Expand Up @@ -169,8 +169,7 @@ def user_init(raw_keys):
'''Transformations between the config data and configs used at runtime'''

# We initialize the jax distributed system here because it must be done before device backend is initialized.
if raw_keys["enable_checkpointing"] and raw_keys["async_checkpointing"] and raw_keys["compile_topology_num_slices"]==-1:
max_utils.initialize_jax_distributed_system()
max_utils.maybe_initialize_jax_distributed_system(raw_keys)

if raw_keys["run_name"] == "":
raw_keys["run_name"] = os.environ.get("JOBSET_NAME") #using XPK default
Expand Down
10 changes: 3 additions & 7 deletions MaxText/standalone_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@

from layers import models

from jax.experimental.compilation_cache import compilation_cache as cc

Transformer = models.Transformer

def checkpoint_loop(config, state=None):
Expand All @@ -55,7 +53,7 @@ def checkpoint_loop(config, state=None):
unboxed_abstract_state, state_mesh_annotations = max_utils.get_abstract_state(model, tx,
config, init_rng, mesh, is_training=True)
# A barrier to sync all hosts before starting to restore checkpoint
jax.experimental.multihost_utils.sync_global_devices("dummy1")
jax.experimental.multihost_utils.sync_global_devices("Barrier before load")
checkpoint_load_start = datetime.datetime.now()
with nn_partitioning.axis_rules(config.logical_axis_rules):
state, _ = checkpointing.load_state_if_possible(checkpoint_manager,
Expand All @@ -78,7 +76,7 @@ def checkpoint_loop(config, state=None):
if checkpoint_manager is not None:
start_time = datetime.datetime.now()
# A barrier to sync all hosts before starting to save checkpoint
jax.experimental.multihost_utils.sync_global_devices("dummy2")
jax.experimental.multihost_utils.sync_global_devices("Barrier before save")
if checkpoint_manager.save(step, state):
checkpoint_manager.wait_until_finished()
end_time = datetime.datetime.now()
Expand All @@ -89,13 +87,11 @@ def checkpoint_loop(config, state=None):
return state

def main(argv: Sequence[str]) -> None:
jax.config.update('jax_default_prng_impl', 'unsafe_rbg')
jax.config.update('jax_cpu_enable_gloo_collectives', True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
pyconfig.initialize(argv)
config = pyconfig.config
validate_train_config(config)
cc.initialize_cache(os.path.expanduser(config.jax_cache_dir))
print(f"Found {jax.device_count()} devices.")
print(f"Found {jax.process_count()} processes.")
print(f"Found {jax.devices()} devices.")
Expand Down
19 changes: 6 additions & 13 deletions MaxText/standalone_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,14 @@
import numpy as np

import pyconfig
from jax.sharding import Mesh
import max_utils
from train import validate_train_config, get_first_step, load_next_batch, setup_train_loop

from jax.experimental.compilation_cache import compilation_cache as cc


def data_load_loop(config, state=None):
"""Main data loader loop.
Loads batches of data for each training step.
"""
_, _, _, _, _, _, _, data_iterator, state = setup_train_loop(config)

devices_array = max_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)
_, _, _, _, _, mesh, _, data_iterator, state = setup_train_loop(config)
example_batch = None

start = datetime.datetime.now()
Expand All @@ -50,25 +43,25 @@ def data_load_loop(config, state=None):
jax.block_until_ready(example_batch)
first_end = datetime.datetime.now()
time_to_load_first_batch = first_end-start
max_logging.log(f"First step completed in {time_to_load_first_batch} seconds")
if jax.process_index() == 0:
max_logging.log(f"STANDALONE DATALOADER : First step completed in {time_to_load_first_batch} seconds, on host 0")

for _ in np.arange(start_step+1, config.steps):
example_batch = load_next_batch(data_iterator, example_batch, config, mesh)

jax.block_until_ready(example_batch) # wait until the last batch is read
end = datetime.datetime.now()
max_logging.log(f"{config.steps} batches loaded in {end-start} seconds, on host {jax.process_index()}")
if jax.process_index() == 0:
max_logging.log(f"STANDALONE DATALOADER : {config.steps} batches loaded in {end-start} seconds, on host 0")
return state


def main(argv: Sequence[str]) -> None:
jax.config.update('jax_default_prng_impl', 'unsafe_rbg')
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
jax.config.update('jax_cpu_enable_gloo_collectives', True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
config = pyconfig.config
validate_train_config(config)
cc.initialize_cache(os.path.expanduser(config.jax_cache_dir))
max_logging.log(f"Found {jax.device_count()} devices.")
max_logging.log(f"Found {jax.process_count()} processes.")
max_logging.log(f"Found {jax.devices()} devices.")
Expand Down