From c78de1154acfc76c84d6648ffb4106c0106c4708 Mon Sep 17 00:00:00 2001 From: RoshaniN Date: Wed, 31 Jan 2024 00:39:36 +0000 Subject: [PATCH] Adding CPU support for standalone dataloader and checkpointer. --- .github/workflows/UnitTests.yml | 4 +-- MaxText/configs/base.yml | 3 ++ MaxText/max_utils.py | 54 +++++++++++++++++++++++++++--- MaxText/pyconfig.py | 5 ++- MaxText/standalone_checkpointer.py | 10 ++---- MaxText/standalone_dataloader.py | 19 ++++------- 6 files changed, 66 insertions(+), 29 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2afa2f9644..041a0c49d4 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -89,11 +89,11 @@ jobs: - name: Test standalone_dataloader.py run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ - 'python3 MaxText/standalone_dataloader.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=100 enable_checkpointing=false' + 'python3 MaxText/standalone_dataloader.py MaxText/configs/base.yml run_name=standalone_dataloader_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=100 enable_checkpointing=false' - name: Test standalone_checkpointer.py run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ - 'python3 MaxText/standalone_checkpointer.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=200 checkpoint_period=50 enable_checkpointing=True async_checkpointing=False' + 'python3 MaxText/standalone_checkpointer.py MaxText/configs/base.yml run_name=standalone_checkpointer_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=200 checkpoint_period=50 enable_checkpointing=True async_checkpointing=False' - name: Test int8_training run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 083b89dbd7..94870f8c0d 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -95,6 +95,9 @@ base_output_directory: "" # Jax cache directory jax_cache_dir: "~/jax_cache" +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' and 'cpu' + # Parallelism mesh_axes: ['data', 'fsdp', 'sequence', 'tensor', 'autoregressive'] logical_axis_rules: [ diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 537c833f8b..0bc3139e24 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -19,6 +19,8 @@ import checkpointing import common_types import functools +import time +import socket import max_logging @@ -166,15 +168,59 @@ def upload_blob(destination_gcs_name, source_file_name): blob = bucket.blob(prefix_name) blob.upload_from_filename(source_file_name) -def initialize_jax_distributed_system(): +def maybe_initialize_jax_distributed_system(raw_keys): """ The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of indirection in MaxText to avoid breaking the call sites unnecessarily. Currently jax.distributed.initialize() fully works as expected! + + For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. """ - max_logging.log("Attempting to initialize the jax distributed system...") - jax.distributed.initialize() - max_logging.log("Jax distributed system initialized!") + if (raw_keys["enable_checkpointing"] and raw_keys["async_checkpointing"] + and raw_keys["compile_topology_num_slices"]==-1): + max_logging.log("Attempting to initialize the jax distributed system...") + jax.distributed.initialize() + max_logging.log("Jax distributed system initialized!") + elif is_cpu_backend(raw_keys): + max_logging.log("Attempting to initialize the jax distributed system for CPU backend...") + initialize_jax_for_cpu() + max_logging.log("Jax distributed system initialized on CPUs!") + + + +def initialize_jax_for_cpu(): + """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready. + """ + if os.environ.get("JAX_COORDINATOR_ADDRESS") is not None: + coordinator_address = os.environ.get("JAX_COORDINATOR_ADDRESS") + coordinator_found = False + lookup_attempt = 1 + max_coordinator_lookups = 50 + while not coordinator_found and lookup_attempt <= max_coordinator_lookups: + try: + ip_address = socket.gethostbyname(coordinator_address) + coordinator_found = True + except socket.gaierror: + print(f"Failed to recognize coordinator address {coordinator_address} on attempt {lookup_attempt}, retrying...") + lookup_attempt += 1 + time.sleep(5) + + ip_address = socket.gethostbyname(coordinator_address) + coordinator_address = ip_address + ":1234" # JAX coordinator port used in XPK + # Env variables to be set in XPK or otherwise + job_index = int(os.environ.get("JOB_INDEX")) + job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) + processes_in_job = int(os.environ.get("PROCESSES_IN_JOB")) + pid = job_index * processes_in_job + job_completion_index + max_logging.log(f" Jax process id is {pid} ") + # Explicit initialize is needed only for CPUs + jax.distributed.initialize(coordinator_address=coordinator_address, + process_id=pid, + num_processes=int(os.environ.get("JAX_PROCESS_COUNT"))) + +def is_cpu_backend(raw_keys): + """Determine whether Maxtext is intended to run on a CPU backend.""" + return raw_keys["hardware"] == 'cpu' def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_type): """Evaluates unspecified DCN/ICI parallelism values""" diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 34357f5ea5..547fda1876 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -97,7 +97,7 @@ def _load_kwargs(self, argv: list[str], **kwargs): return args_dict def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, **kwargs) -> list[str]: - ''' Update model config from environemnt and command line + ''' Update model config from environment and command line ''' raw_data_from_cmd_line = self._load_kwargs(argv, **kwargs) updated_keys = [] @@ -169,8 +169,7 @@ def user_init(raw_keys): '''Transformations between the config data and configs used at runtime''' # We initialize the jax distributed system here because it must be done before device backend is initialized. - if raw_keys["enable_checkpointing"] and raw_keys["async_checkpointing"] and raw_keys["compile_topology_num_slices"]==-1: - max_utils.initialize_jax_distributed_system() + max_utils.maybe_initialize_jax_distributed_system(raw_keys) if raw_keys["run_name"] == "": raw_keys["run_name"] = os.environ.get("JOBSET_NAME") #using XPK default diff --git a/MaxText/standalone_checkpointer.py b/MaxText/standalone_checkpointer.py index e5d445768b..8398094991 100644 --- a/MaxText/standalone_checkpointer.py +++ b/MaxText/standalone_checkpointer.py @@ -37,8 +37,6 @@ from layers import models -from jax.experimental.compilation_cache import compilation_cache as cc - Transformer = models.Transformer def checkpoint_loop(config, state=None): @@ -55,7 +53,7 @@ def checkpoint_loop(config, state=None): unboxed_abstract_state, state_mesh_annotations = max_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) # A barrier to sync all hosts before starting to restore checkpoint - jax.experimental.multihost_utils.sync_global_devices("dummy1") + jax.experimental.multihost_utils.sync_global_devices("Barrier before load") checkpoint_load_start = datetime.datetime.now() with nn_partitioning.axis_rules(config.logical_axis_rules): state, _ = checkpointing.load_state_if_possible(checkpoint_manager, @@ -78,7 +76,7 @@ def checkpoint_loop(config, state=None): if checkpoint_manager is not None: start_time = datetime.datetime.now() # A barrier to sync all hosts before starting to save checkpoint - jax.experimental.multihost_utils.sync_global_devices("dummy2") + jax.experimental.multihost_utils.sync_global_devices("Barrier before save") if checkpoint_manager.save(step, state): checkpoint_manager.wait_until_finished() end_time = datetime.datetime.now() @@ -89,13 +87,11 @@ def checkpoint_loop(config, state=None): return state def main(argv: Sequence[str]) -> None: - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') + jax.config.update('jax_cpu_enable_gloo_collectives', True) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" pyconfig.initialize(argv) config = pyconfig.config validate_train_config(config) - cc.initialize_cache(os.path.expanduser(config.jax_cache_dir)) print(f"Found {jax.device_count()} devices.") print(f"Found {jax.process_count()} processes.") print(f"Found {jax.devices()} devices.") diff --git a/MaxText/standalone_dataloader.py b/MaxText/standalone_dataloader.py index fedc606c1b..bb3ac34aa2 100644 --- a/MaxText/standalone_dataloader.py +++ b/MaxText/standalone_dataloader.py @@ -27,21 +27,14 @@ import numpy as np import pyconfig -from jax.sharding import Mesh -import max_utils from train import validate_train_config, get_first_step, load_next_batch, setup_train_loop -from jax.experimental.compilation_cache import compilation_cache as cc - def data_load_loop(config, state=None): """Main data loader loop. Loads batches of data for each training step. """ - _, _, _, _, _, _, _, data_iterator, state = setup_train_loop(config) - - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) + _, _, _, _, _, mesh, _, data_iterator, state = setup_train_loop(config) example_batch = None start = datetime.datetime.now() @@ -50,25 +43,25 @@ def data_load_loop(config, state=None): jax.block_until_ready(example_batch) first_end = datetime.datetime.now() time_to_load_first_batch = first_end-start - max_logging.log(f"First step completed in {time_to_load_first_batch} seconds") + if jax.process_index() == 0: + max_logging.log(f"STANDALONE DATALOADER : First step completed in {time_to_load_first_batch} seconds, on host 0") for _ in np.arange(start_step+1, config.steps): example_batch = load_next_batch(data_iterator, example_batch, config, mesh) jax.block_until_ready(example_batch) # wait until the last batch is read end = datetime.datetime.now() - max_logging.log(f"{config.steps} batches loaded in {end-start} seconds, on host {jax.process_index()}") + if jax.process_index() == 0: + max_logging.log(f"STANDALONE DATALOADER : {config.steps} batches loaded in {end-start} seconds, on host 0") return state def main(argv: Sequence[str]) -> None: - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') - os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + jax.config.update('jax_cpu_enable_gloo_collectives', True) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(argv) config = pyconfig.config validate_train_config(config) - cc.initialize_cache(os.path.expanduser(config.jax_cache_dir)) max_logging.log(f"Found {jax.device_count()} devices.") max_logging.log(f"Found {jax.process_count()} processes.") max_logging.log(f"Found {jax.devices()} devices.")