diff --git a/dependencies/requirements/base_requirements/requirements.txt b/dependencies/requirements/base_requirements/requirements.txt index bc3e68345..df36e06ee 100644 --- a/dependencies/requirements/base_requirements/requirements.txt +++ b/dependencies/requirements/base_requirements/requirements.txt @@ -4,6 +4,7 @@ array-record cloud-accelerator-diagnostics cloud-tpu-diagnostics datasets +drjax flax gcsfs google-api-python-client diff --git a/dependencies/requirements/generated_requirements/cuda12-requirements.txt b/dependencies/requirements/generated_requirements/cuda12-requirements.txt index cb04b19ad..ae72525d3 100644 --- a/dependencies/requirements/generated_requirements/cuda12-requirements.txt +++ b/dependencies/requirements/generated_requirements/cuda12-requirements.txt @@ -4,7 +4,7 @@ absl-py>=2.3.1 aiofiles>=25.1.0 aiohappyeyeballs>=2.6.1 -aiohttp>=3.13.1 +aiohttp>=3.13.2 aiosignal>=1.4.0 annotated-doc>=0.0.3 annotated-types>=0.7.0 @@ -40,13 +40,14 @@ dill>=0.4.0 distlib>=0.4.0 dm-tree>=0.1.9 docstring-parser>=0.17.0 +drjax>=0.1.4 editdistance>=0.8.1 einops>=0.8.1 einshape>=1.0 etils>=1.13.0 evaluate>=0.4.6 execnet>=2.1.1 -fastapi>=0.120.1 +fastapi>=0.120.2 filelock>=3.20.0 flatbuffers>=25.9.23 flax>=0.12.0 @@ -55,11 +56,11 @@ frozenlist>=1.8.0 fsspec>=2025.9.0 gast>=0.6.0 gcsfs>=2025.9.0 -google-api-core>=2.28.0 +google-api-core>=2.28.1 google-api-python-client>=2.185.0 google-auth-httplib2>=0.2.0 google-auth-oauthlib>=1.2.2 -google-auth>=2.41.1 +google-auth>=2.42.0 google-benchmark>=1.9.4 google-cloud-aiplatform>=1.122.0 google-cloud-appengine-logging>=1.7.0 @@ -195,7 +196,7 @@ python-dateutil>=2.9.0.post0 pytype>=2024.10.11 pytz>=2025.2 pyyaml>=6.0.3 -qwix>=0.1.1 +qwix>=0.1.2 regex>=2025.10.23 requests-oauthlib>=2.0.0 requests>=2.32.5 @@ -214,7 +215,7 @@ simplejson>=3.20.2 six>=1.17.0 sniffio>=1.3.1 sortedcontainers>=2.4.0 -starlette>=0.48.0 +starlette>=0.49.1 sympy>=1.14.0 tabulate>=0.9.0 tenacity>=9.1.2 @@ -248,7 +249,7 @@ tzdata>=2025.2 uritemplate>=4.2.0 urllib3>=2.5.0 uvicorn>=0.38.0 -virtualenv>=20.35.3 +virtualenv>=20.35.4 wadler-lindig>=0.1.7 websockets>=15.0.1 werkzeug>=3.1.3 diff --git a/dependencies/requirements/generated_requirements/tpu-requirements.txt b/dependencies/requirements/generated_requirements/tpu-requirements.txt index 8e6d948c8..9cc7eaeb8 100644 --- a/dependencies/requirements/generated_requirements/tpu-requirements.txt +++ b/dependencies/requirements/generated_requirements/tpu-requirements.txt @@ -4,14 +4,14 @@ absl-py>=2.3.1 aiofiles>=25.1.0 aiohappyeyeballs>=2.6.1 -aiohttp>=3.13.1 +aiohttp>=3.13.2 aiosignal>=1.4.0 annotated-doc>=0.0.3 annotated-types>=0.7.0 antlr4-python3-runtime>=4.9.3 anyio>=4.11.0 aqtp>=0.9.0 -array-record>=0.8.1 +array-record>=0.8.2 astroid>=4.0.1 astunparse>=1.6.3 attrs>=25.4.0 @@ -40,13 +40,14 @@ dill>=0.4.0 distlib>=0.4.0 dm-tree>=0.1.9 docstring-parser>=0.17.0 +drjax>=0.1.4 editdistance>=0.8.1 einops>=0.8.1 einshape>=1.0 etils>=1.13.0 evaluate>=0.4.6 execnet>=2.1.1 -fastapi>=0.120.0 +fastapi>=0.120.2 filelock>=3.20.0 flatbuffers>=25.9.23 flax>=0.12.0 @@ -55,11 +56,11 @@ frozenlist>=1.8.0 fsspec>=2025.9.0 gast>=0.6.0 gcsfs>=2025.9.0 -google-api-core>=2.27.0 +google-api-core>=2.28.1 google-api-python-client>=2.185.0 google-auth-httplib2>=0.2.0 google-auth-oauthlib>=1.2.2 -google-auth>=2.41.1 +google-auth>=2.42.0 google-benchmark>=1.9.4 google-cloud-aiplatform>=1.122.0 google-cloud-appengine-logging>=1.7.0 @@ -85,7 +86,7 @@ gviz-api>=1.10.0 h11>=0.16.0 h5py>=3.15.1 hf-transfer>=0.1.9 -hf-xet>=1.1.10 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' +hf-xet>=1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' httpcore>=1.0.9 httplib2>=0.31.0 httpx>=0.28.1 @@ -184,11 +185,11 @@ pyproject-hooks>=1.2.0 pytest-xdist>=3.8.0 pytest>=8.4.2 python-dateutil>=2.9.0.post0 -python-dotenv>=1.1.1 +python-dotenv>=1.2.1 pytype>=2024.10.11 pytz>=2025.2 pyyaml>=6.0.3 -qwix>=0.1.1 +qwix>=0.1.2 regex>=2025.10.23 requests-oauthlib>=2.0.0 requests>=2.32.5 @@ -207,7 +208,7 @@ simplejson>=3.20.2 six>=1.17.0 sniffio>=1.3.1 sortedcontainers>=2.4.0 -starlette>=0.48.0 +starlette>=0.49.1 sympy>=1.14.0 tabulate>=0.9.0 tenacity>=9.1.2 @@ -238,7 +239,7 @@ tzdata>=2025.2 uritemplate>=4.2.0 urllib3>=2.5.0 uvicorn>=0.38.0 -virtualenv>=20.35.3 +virtualenv>=20.35.4 wadler-lindig>=0.1.7 websockets>=15.0.1 werkzeug>=3.1.3 diff --git a/dependencies/requirements/requirements.txt b/dependencies/requirements/requirements.txt index 36471cf55..6420642bb 100644 --- a/dependencies/requirements/requirements.txt +++ b/dependencies/requirements/requirements.txt @@ -4,6 +4,7 @@ array-record cloud-accelerator-diagnostics cloud-tpu-diagnostics datasets +drjax>=0.1.4 flax gcsfs google-api-python-client diff --git a/dependencies/requirements/requirements_with_jax_ai_image.txt b/dependencies/requirements/requirements_with_jax_ai_image.txt index 993f0e6d8..ca07bb324 100644 --- a/dependencies/requirements/requirements_with_jax_ai_image.txt +++ b/dependencies/requirements/requirements_with_jax_ai_image.txt @@ -1,6 +1,7 @@ # Requirements for Building the MaxText Docker Image # These requirements are additional to the dependencies present in the JAX AI base image. datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72184f89e7814cf784360.zip +drjax>=0.1.4 flax>=0.11.0 google-api-python-client google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 9ee7f39ba..bb7f6b708 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -359,7 +359,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' # Parallelism shard_mode: "auto" # can be either auto or explicit -mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] +mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], @@ -435,6 +435,7 @@ logical_axis_rules: [ ['paged_kv_head_dim_size', []], ['dense_layers', []], ['moe_layers', []], + ['diloco', 'diloco'], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] @@ -447,6 +448,7 @@ sharding_tolerance: 0.02 # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. +dcn_diloco_parallelism: 1 dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 dcn_fsdp_transpose_parallelism: 1 @@ -459,6 +461,7 @@ dcn_tensor_sequence_parallelism: 1 # never recommended dcn_pipeline_parallelism: 1 dcn_expert_parallelism: 1 dcn_autoregressive_parallelism: 1 # never recommended +ici_diloco_parallelism: 1 ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_fsdp_transpose_parallelism: 1 @@ -644,6 +647,12 @@ enable_data_shuffling: True data_shuffle_seed: 0 init_weights_seed: 0 +# DiLoCo params. +enable_diloco: False +diloco_sync_period: 36 +diloco_outer_lr: 0.3 +diloco_outer_momentum: 0.9 + # You may disable clipping by setting gradient_clipping_threshold to zero. gradient_clipping_threshold: 1.0 diff --git a/src/MaxText/diloco.py b/src/MaxText/diloco.py new file mode 100644 index 000000000..7d137e466 --- /dev/null +++ b/src/MaxText/diloco.py @@ -0,0 +1,201 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An implementation of Distributed Low-Communication (DiLoCo) training. + +This module contains implementations of: + +- DiLoCo: Distributed Low-Communication Training of Language Models + https://arxiv.org/abs/2311.08105 +- Streaming DiLoCo with overlapping communication: Towards a Distributed Free Lunch + https://arxiv.org/abs/2501.18512 +""" + +from collections.abc import Sequence +from typing import Any, Callable + +import drjax +from flax import struct +from flax.training import train_state +import jax +import jax.numpy as jnp +from jaxtyping import Array, Int32, Key, PyTree, UInt32 +import optax + +from MaxText import pyconfig + +Batch = Any +Params = PyTree +Metrics = PyTree +OptState = optax.OptState +InnerOptStates = optax.OptState +PRNGKey = Key[Array, ""] | UInt32[Array, "2"] +Step = Int32[Array, ""] + + +class DiLoCoTrainState(struct.PyTreeNode): + """The state of the DiLoCo training process. + + Attributes: + inner_state: A `flax.training.train_state.TrainState` of the state for each + step of the inner optimization. All arrays are expected to have a leading + dimension with size of the number of diloco replicas so that training + steps can be mapped over this dimension. + outer_params: A PyTree of the global model weights. These will mimic a + sub-PyTree in `inner_state`, which rank-1 shape. + outer_opt_state: The state for the outer Nesterov momentum optimizer. + step: The step counter of the training process. + """ + + inner_state: train_state.TrainState + outer_params: Params + outer_opt_state: OptState + step: Step + + +def reshape_first_axis_with_diloco(num_diloco_replicas: int, pytree: PyTree) -> PyTree: + """Reshapes the first dimension of each array in the PyTree to include a DiLoCo axis. + + This function takes a a batch of data represented as a PyTree + and reshapes the leading dimension of each array within it. The purpose is + to introduce a new 'diloco' axis, which is used for distributing data + across DiLoCo replicas. + + Args: + num_diloco_replicas: The number of DiLoCo replicas. This determines the + size of the new leading dimension. + pytree: The input PyTree, where each array is expected to have a batch + dimension as its first axis. + + Returns: + A new PyTree with the same structure as the input, but with each array's + first dimension reshaped to `(num_diloco_replicas, original_batch_dim // num_diloco_replicas, ...)`. + The sharding specification is also updated to include the 'diloco' axis. + """ + + def extend_pspec(pspec: jax.sharding.PartitionSpec | Sequence[str | Sequence[str]] = ()) -> jax.sharding.PartitionSpec: + if tuple(*pspec)[0] == "diloco": + # pull out diloco axis if already present + return jax.sharding.PartitionSpec("diloco", (*pspec[0][1:],), (*pspec[1:],)) + return jax.sharding.PartitionSpec("diloco", *pspec) + + def reshape_for_diloco(arr): + batch_dim, *example_shape = arr.shape + diloco_shape = (num_diloco_replicas, batch_dim // num_diloco_replicas, *example_shape) + s = arr.sharding + s = jax.sharding.NamedSharding(mesh=s.mesh, spec=extend_pspec(s.spec)) + return jax.lax.with_sharding_constraint(jnp.reshape(arr, shape=diloco_shape), s) + + return jax.tree.map(reshape_for_diloco, pytree) + + +def build_diloco_state( + config: "pyconfig.HyperParameters", + initialize_state: Callable[[], train_state.TrainState], +) -> tuple[DiLoCoTrainState, PyTree]: + """Given a non-DiLoCo train state, construct a DiLoCo training state.""" + outer_optimizer = optax.sgd( + config.diloco_outer_lr, + momentum=config.diloco_outer_momentum, + nesterov=True, + ) + + @drjax.program(placements={"diloco": config.num_diloco_replicas}) + def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: + state = initialize_state() + # Inner state must be broadcast across clients. + inner_state = drjax.broadcast(state) + # Outer state retains a single copy of the model parameters and optimizer state. + outer_params = state.params + outer_opt_state = outer_optimizer.init(outer_params) + outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state) + return ( + DiLoCoTrainState( + inner_state=inner_state, outer_params=outer_params, outer_opt_state=outer_opt_state, step=state.step + ), + outer_opt_state_sharding, + ) + + return init_diloco_state() + + +def build_diloco_train_step( + config: pyconfig.HyperParameters, + train_step: Callable[[train_state.TrainState, Batch, PRNGKey], tuple[train_state.TrainState, Metrics]], +) -> Callable[[DiLoCoTrainState, Batch, PRNGKey], tuple[DiLoCoTrainState, Metrics]]: + """Convert a local state and train step into DiLoCo-compatible versions. + + This is an implementation of the original (non-streaming) DiLoCo algorithm + which syncs all model parameters across the replicas every + `config.diloco_sync_period` steps, treating the difference accumulated over + non-sync steps as a pseudo gradient and applying SGD with Nesterov momentum on + the "global" model. + + Args: + config: The config used to set up training. + train_step: A local train step. This will be executed independently within + each replica. + """ + outer_optimizer = optax.sgd( + config.diloco_outer_lr, + momentum=config.diloco_outer_momentum, + nesterov=True, + ) + + def synchronize(state): + # Calculate the delta between the current replica's state and the global + # state (since last synchronization). + broadcast_outer_params = drjax.broadcast(state.outer_params) + model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params) + # Treat the average delta as the outer optimizer's gradient and apply to + # the global (outer) model params. + averaged_pseudo_grad = drjax.reduce_mean(model_delta) + updates, new_opt_state = outer_optimizer.update(averaged_pseudo_grad, state.outer_opt_state, state.outer_params) + new_outer_params = optax.apply_updates(state.outer_params, updates) + # Replace inner model params with the new global model params. + # NOTE: inner optimizer state is retained despite the change in parameters, + # see section 6.1 in https://arxiv.org/pdf/2311.08105. + new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state) + return state.replace( + outer_params=new_outer_params, + outer_opt_state=new_opt_state, + inner_state=new_inner_state, + ) + + def typed_reduce_mean(in_tree): + total = drjax.reduce_sum(in_tree) + avg = jax.tree.map(lambda x: (x / config.num_diloco_replicas).astype(x.dtype), total) + return avg + + @drjax.program(placements={"diloco": config.num_diloco_replicas}) + def diloco_train_step(state, batch, prng): + # Broadcast the RNG across replicas. + broadcast_rng = drjax.broadcast(prng) + inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng)) + avg_metrics = typed_reduce_mean(metrics) + state = state.replace( + inner_state=inner_state, + step=inner_state.step[0], + ) + # Either synchronize the model, or no-op, depending on whether the current + # step falls on the synchronization period. + state = jax.lax.cond( + inner_state.step[0] % config.diloco_sync_period == 0, + synchronize, + lambda x: x, # no-op + state, + ) + return state, avg_metrics + + return diloco_train_step diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index 6cb8b1e10..70388ee8c 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -125,7 +125,14 @@ def get_reorder_callable(cp_size, shard_mode): def get_shaped_batch(config): """Return the shape of the batch - this is what eval_shape would return for the output of create_data_iterator, but eval_shape doesn't work, see b/306901078.""" - batch_shape = (config.global_batch_size_to_load, config.max_target_length) + if config.enable_diloco: + batch_shape = ( + config.num_diloco_replicas, + config.global_batch_size_to_load // config.num_diloco_replicas, + config.max_target_length, + ) + else: + batch_shape = (config.global_batch_size_to_load, config.max_target_length) shaped_batch = {} shaped_batch["inputs"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) shaped_batch["inputs_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index dc68080c7..2c65e3e2a 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -912,6 +912,7 @@ def update_model_vars(base_config_path, raw_keys, config_name: str, keys_from_en def create_parallelisms_list(raw_keys): ici_parallelism = [ + raw_keys["ici_diloco_parallelism"], raw_keys["ici_data_parallelism"], raw_keys["ici_pipeline_parallelism"], raw_keys["ici_fsdp_parallelism"], @@ -926,6 +927,7 @@ def create_parallelisms_list(raw_keys): raw_keys["ici_autoregressive_parallelism"], ] dcn_parallelism = [ + raw_keys["dcn_diloco_parallelism"], raw_keys["dcn_data_parallelism"], raw_keys["dcn_pipeline_parallelism"], raw_keys["dcn_fsdp_parallelism"], @@ -941,6 +943,7 @@ def create_parallelisms_list(raw_keys): ] raw_keys["ici_parallelism"] = ici_parallelism raw_keys["dcn_parallelism"] = dcn_parallelism + raw_keys["num_diloco_replicas"] = int(raw_keys["ici_diloco_parallelism"] * raw_keys["dcn_diloco_parallelism"]) return raw_keys diff --git a/src/MaxText/sharding.py b/src/MaxText/sharding.py index 8a5e7d338..7ba8b894d 100644 --- a/src/MaxText/sharding.py +++ b/src/MaxText/sharding.py @@ -30,7 +30,14 @@ def get_input_data_sharding(config, mesh): """Get the input data sharding for the model""" - return nn.logical_to_mesh_sharding(P(*config.input_data_sharding_logical_axes), mesh, config.logical_axis_rules) + data_sharding = nn.logical_to_mesh_sharding( + P(*config.input_data_sharding_logical_axes), mesh, config.logical_axis_rules + ) + if config.enable_diloco: + data_sharding = jax.tree_util.tree_map( + lambda s: jax.sharding.NamedSharding(s.mesh, P("diloco", *s.spec)), data_sharding + ) + return data_sharding def maybe_shard_with_name(inputs, named_sharding, shard_mode): diff --git a/src/MaxText/train.py b/src/MaxText/train.py index 48550a1d9..84bbf5bc9 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -64,6 +64,7 @@ maybe_record_goodput, ) from MaxText.vertex_tensorboard import VertexTensorboardManager +from MaxText import diloco # Placeholder: internal from MaxText.gradient_accumulation import gradient_accumulation_loss_and_grad @@ -377,9 +378,50 @@ def train_loop(config, recorder, state=None): params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) - p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator, params_shardings - ) + if config.enable_diloco: + train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) + with mesh: + state, outer_opt_state_sharding = diloco.build_diloco_state(config, lambda: state) + + # create state_mesh_shardings for the DilocoState + def add_diloco_to_sharding(pytree): + """ + Recursively traverses a PyTree and prepends 'diloco' to the PartitionSpec + of any NamedSharding object that doesn't have an empty PartitionSpec. + """ + + def map_fn(leaf): + if isinstance(leaf, jax.sharding.NamedSharding): + new_spec = jax.sharding.PartitionSpec("diloco", *leaf.spec) + return jax.sharding.NamedSharding(mesh=leaf.mesh, spec=new_spec) + return leaf + + return jax.tree_util.tree_map(map_fn, pytree) + + inner_state_shardings = add_diloco_to_sharding(state_mesh_shardings) + diloco_state_shardings = diloco.DiLoCoTrainState( + inner_state_shardings, + state_mesh_shardings.params, + outer_opt_state_sharding, + jax.sharding.NamedSharding(mesh=state_mesh_shardings.step.mesh, spec=jax.sharding.PartitionSpec()), + ) + + diloco_train_step = diloco.build_diloco_train_step(config, train_step_partial) + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( + config, + model, + mesh, + state, + diloco_state_shardings, + diloco_train_step, + eval_step, + eval_data_iterator, + params_shardings, + ) + else: + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( + config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator, params_shardings + ) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): shaped_batch = maxtext_utils.get_shaped_batch(config) @@ -395,7 +437,10 @@ def train_loop(config, recorder, state=None): metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard - metric_logger.write_setup_info_to_tensorboard(state.params) + if config.enable_diloco: + metric_logger.write_setup_info_to_tensorboard(state.outer_params) + else: + metric_logger.write_setup_info_to_tensorboard(state.params) try: last_step_completion = datetime.datetime.now() @@ -404,6 +449,8 @@ def train_loop(config, recorder, state=None): with jax.profiler.StepTraceAnnotation("train", step_num=step): example_batch = data_loader.load_next_batch() + if config.enable_diloco: + example_batch = diloco.reshape_first_axis_with_diloco(config.num_diloco_replicas, example_batch) # Reshard data from loaded sharding to performant activation sharding example_batch = sharding.maybe_shard_with_name( example_batch, diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index 4d6ff3d3b..0fca08f33 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -78,15 +78,22 @@ def create_training_tools(config, model, mesh): def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings): """Returns a JIT-compiled train step function, which is loaded from a file if specified in the config.""" - ( - functional_train, - in_shardings, - out_shardings, - static_argnums, - donate_argnums, - ) = maxtext_utils.get_functional_train_with_signature( - train_step, data_sharding, state_mesh_shardings, model, config, params_shardings - ) + if config.enable_diloco: + functional_train = train_step + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + out_shardings = (state_mesh_shardings, None) # State, metrics + static_argnums = () # We partial out the static argnums of model and config + donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. + else: + ( + functional_train, + in_shardings, + out_shardings, + static_argnums, + donate_argnums, + ) = maxtext_utils.get_functional_train_with_signature( + train_step, data_sharding, state_mesh_shardings, model, config, params_shardings + ) # Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit if config.compiled_trainstep_file != "": diff --git a/tests/diloco_test.py b/tests/diloco_test.py new file mode 100644 index 000000000..bbaa38aab --- /dev/null +++ b/tests/diloco_test.py @@ -0,0 +1,273 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the DiLoCo implementation in diloco.py""" + +from collections.abc import Mapping +import dataclasses +import unittest +from typing import Any + +import chex +from flax.experimental import nnx +from flax.training import train_state +import jax +import jax.numpy as jnp +import jax.sharding +import numpy as np +import optax +import pytest + +from MaxText import diloco +from MaxText import pyconfig + + +class SimpleNNXModel(nnx.Module): + """A simple state for testing a minimal model.""" + + def __init__(self, *, rngs: nnx.Rngs): + self.dense = nnx.Linear( + 2, + 1, + kernel_init=nnx.initializers.constant(jnp.asarray([[2.0], [1.0]])), + bias_init=nnx.initializers.ones_init(), + rngs=rngs, + ) + + def __call__(self, x): + return self.dense(x) + + +@dataclasses.dataclass +class _TestConfig: + """A fake config for test.""" + + keys: Mapping[str, Any] + + +class DiLoCoTest(unittest.TestCase): + + @pytest.mark.tpu_only + def test_diloco_training_simulation_with_mesh(self): + """Runs a simulation of DiLoCo training on a mesh and asserts correctness.""" + num_replicas = 2 + num_steps = 4 + + devices = jax.devices() + if len(devices) < num_replicas: + self.skipTest(f"Test requires {num_replicas} devices, but only {len(devices)} are available.") + + mesh_devices = np.array(devices[:num_replicas]).reshape(1, num_replicas) + mesh = jax.sharding.Mesh(mesh_devices, axis_names=("data", "diloco")) + + test_config = pyconfig.HyperParameters( + config=_TestConfig( + keys={ + "num_diloco_replicas": num_replicas, + "diloco_outer_momentum": 0.9, + "diloco_outer_lr": 1.0, + "diloco_sync_period": num_steps - 1, + } + ) + ) + + with mesh: + tx = optax.sgd(learning_rate=0.1) + rngs = nnx.Rngs(params=jax.random.key(seed=42)) + model = SimpleNNXModel(rngs=rngs) + graphdef, params = nnx.split(model) + + def nnx_apply_fn(params, inputs): + model_replica = nnx.merge(graphdef, params) + return model_replica(inputs) + + # 2. Vmap this new wrapper function + vmapped_apply = jax.vmap(nnx_apply_fn, in_axes=(None, 0)) + + def _test_train_step(state: train_state.TrainState, batch, prng_key: diloco.PRNGKey): + """A simple MSE loss train step to enable numerics testing.""" + del prng_key + + def loss_fn(params, batch): + inputs, labels = batch + logits = vmapped_apply(params, inputs) + residual = logits - labels + sq_residual = jnp.square(residual) + msq_residual = jnp.mean(sq_residual) + return msq_residual + + loss, grad = jax.value_and_grad(loss_fn)(state.params, batch) + return state.apply_gradients(grads=grad), loss + + initial_test_state = train_state.TrainState.create( + apply_fn=vmapped_apply, + params=params, + tx=tx, + ) + + diloco_test_state, _ = diloco.build_diloco_state(test_config, lambda: initial_test_state) + chex.assert_equal(diloco_test_state.step, 0) + chex.assert_trees_all_equal(diloco_test_state.outer_params, initial_test_state.params) + + diloco_train_step = diloco.build_diloco_train_step(test_config, _test_train_step) + inputs = jnp.array( + [ + [[0.0, 1.0], [1.0, 0.0]], # First replica inputs. + [[1.0, 0.0], [0.0, 1.0]], # Second replica inputs. + ] + ) + labels = jnp.array( + [ + [[1.0], [2.0]], # First replica labels. + [[2.0], [3.0]], # Second replica labels. + ] + ) + + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "diloco")) + inputs = jax.device_put(inputs, sharding) + labels = jax.device_put(labels, sharding) + + # Run the first step (no synchronization). + # Replica 0: + # Data: [[0, 1], [1, 0]] + # Labels: [[1], [2]] + # Weights: w = [[2], [1]] + # Bias: b = [1] + # Loss = mean((y - pred)^2) = + # = mean( ([[1], [2]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[1], [2]] - ([[0, 1], [1, 0]] . [[2], [1]] + [1])) ^ 2 ) + # = mean( ([[1], [2]] - [[2], [3]]) ^ 2 ) + # = mean( ([-1, 1]) ^ 2 ) = mean( [1, 1] ) + # = 1.0 + # + # Replica 1: + # Data: [[1, 0], [0, 1]] + # Labels: [[2], [3]] + # Weights: w = [[2], [1]] + # Bias: b = [1] + # Loss = mean((y - pred)^2) = + # = mean( ([[2], [3]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[2], [3]] - ([[1, 0], [0, 1]] . [[2], [1]] + [1])) ^ 2 ) + # = mean( ([[2], [3]] - [[3], [2]]) ^ 2 ) + # = mean( ([-1, 1]) ^ 2 ) = mean( [1, 1] ) + # = 1.0 + diloco_test_state, loss = diloco_train_step(diloco_test_state, (inputs, labels), jax.random.key(seed=42)) + chex.assert_equal(diloco_test_state.step, 1.0) + chex.assert_equal(loss, 1.0) + # Assert no updates to the global model yet (no synchronization) + chex.assert_trees_all_equal(diloco_test_state.outer_params, initial_test_state.params) + + # Run the second step (no synchronization). + # Replica 0: + # Data: [[0, 1], [1, 0]] + # Labels: [[1], [2]] + # Weights: w = [[1.9], [0.9]] + # Bias: b = [0.8] + # Loss = mean((y - pred)^2) = + # = mean( ([[1], [2]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[1], [2]] - ([[0, 1], [1, 0]] . [[1.9], [0.9]] + [0.8])) ^ 2 ) + # = mean( ([[1], [2]] - [[1.7], [2.7]]) ^ 2 ) + # = mean( ([-0.7, 0.7]) ^ 2 ) = mean( [0.49, 0.49] ) + # = 0.49 + # + # Replica 1: + # Data: [[1, 0], [0, 1]] + # Labels: [[2], [3]] + # Weights: w = [[1.9], [1.1]] + # Bias: b = [1] + # Loss = mean((y - pred)^2) = + # = mean( ([[2], [3]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[2], [3]] - ([[1, 0], [0, 1]] . [[1.9], [1.1]] + [1])) ^ 2 ) + # = mean( ([[2], [3]] - [[2.9], [2.1]]) ^ 2 ) + # = mean( ([-0.9, 0.9]) ^ 2 ) = mean( [0.81, 0.81] ) + # = 0.81 + diloco_test_state, loss = diloco_train_step(diloco_test_state, (inputs, labels), jax.random.key(seed=42)) + chex.assert_equal(diloco_test_state.step, 2.0) + chex.assert_trees_all_close(loss, 0.65) + # Assert no updates to the global model yet (no synchronization) + chex.assert_trees_all_equal(diloco_test_state.outer_params, initial_test_state.params) + + # Run the third step, which synchronizes afterwards. + # Replica 0: + # Data: [[0, 1], [1, 0]] + # Labels: [[1], [2]] + # Weights: w = [[1.83], [0.83]] + # Bias: b = [0.66] + # Loss = mean((y - pred)^2) = + # = mean( ([[1], [2]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[1], [2]] - ([[0, 1], [1, 0]] . [[1.83], [0.83]] + [0.66])) ^ 2 ) + # = mean( ([[1], [2]] - [[1.49], [2.49]]) ^ 2 ) + # = mean( ([-0.49, 0.49]) ^ 2 ) = mean( [0.2401, 0.2401] ) + # = 0.2401 + # + # Replica 1: + # Data: [[1, 0], [0, 1]] + # Labels: [[2], [3]] + # Weights: w = [[1.81], [1.19]] + # Bias: b = [1.] + # Loss = mean((y - pred)^2) = + # = mean( ([[2], [3]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[2], [3]] - ([[1, 0], [0, 1]] . [[1.81], [1.19]] + [1])) ^ 2 ) + # = mean( ([[2], [3]] - [[2.81], [2.19]]) ^ 2 ) + # = mean( ([-0.81, 0.81]) ^ 2 ) = mean( [0.6561, 0.6561] ) + # = 0.6561 + # + # After these are averaged, the model differences are computed to create a + # pseudo-gradient update to the outer_params and applied via a momentum + # based outer optimizer. + diloco_test_state, loss = diloco_train_step(diloco_test_state, (inputs, labels), jax.random.key(seed=42)) + chex.assert_equal(diloco_test_state.step, 3.0) + chex.assert_trees_all_close(loss, 0.4481) + # Assert that inner and outer parameters are all equal now that + # synchronization has happened. + chex.assert_trees_all_equal( + diloco_test_state.outer_params, + jax.tree.map(lambda arr: arr[0, ...], diloco_test_state.inner_state.params), + ) + chex.assert_trees_all_equal( + diloco_test_state.outer_params, + jax.tree.map(lambda arr: arr[1, ...], diloco_test_state.inner_state.params), + ) + + # Run the fourth step (no synchronization). + # Replica 0: + # Data: [[0, 1], [1, 0]] + # Labels: [[1], [2]] + # Weights: w = [[1.5345], [1.0494]] + # Bias: b = [0.5839] + # Loss = mean((y - pred)^2) = + # = mean( ([[1], [2]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[1], [2]] - ([[0, 1], [1, 0]] . [[1.5345], [1.0494]]] + [0.5839])) ^ 2 ) + # = mean( ([[1], [2]] - [[1.6333], [2.1184]]) ^ 2 ) + # = mean( ([-0.6333, 0.1184]) ^ 2 ) = mean( [0.4010, 0.0140] ) + # ~ 0.2075 + # + # Replica 1: + # Data: [[1, 0], [0, 1]] + # Labels: [[2], [3]] + # Weights: w = [[1.5345], [1.0494]] + # Bias: b = [0.5839] + # Loss = mean((y - pred)^2) = + # = mean( ([[2], [3]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[2], [3]] - ([[1, 0], [0, 1]] . [[1.5345], [1.0494]] + [0.5839])) ^ 2 ) + # = mean( ([[2], [3]] - [[2.1184], [1.6333]]) ^ 2 ) + # = mean( ([-0.1184, 1.3667]) ^ 2 ) = mean( [0.0140, 1.8678] ) + # ~ 0.94 + step_three_outer_params = diloco_test_state.outer_params + diloco_test_state, loss = diloco_train_step(diloco_test_state, (inputs, labels), jax.random.key(seed=42)) + chex.assert_equal(diloco_test_state.step, 4.0) + chex.assert_trees_all_close(loss, 0.574244) + # Assert no updates to the global model since previous step (no + # synchronization). + chex.assert_trees_all_equal(diloco_test_state.outer_params, step_three_outer_params)