Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- #v1 Add `use_load_and_broadcast` option.
- Add Multi-tiered checkpointing support for Pathways

### Removed

Expand Down
6 changes: 4 additions & 2 deletions checkpoint/orbax/checkpoint/_src/multihost/dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,11 @@ def _cp_wrapper(inp: PyTree) -> PyTree:
input_arrays, abstract=True
)
cpu_result_specs = self._transform_pytree_shardings(result_specs)
_cp_wrapper.specialize(out_specs_fn=lambda _: cpu_result_specs)
specialized_wrapper = _cp_wrapper.specialize(
out_specs_fn=lambda _: cpu_result_specs
)

result = _cp_wrapper(self.to_colocated_python(input_arrays))
result = specialized_wrapper(self.to_colocated_python(input_arrays))
return self._to_final_specs(result, result_specs)


Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,34 @@ def read_process_metadata(directory: epath.Path):
return distributed_to_device_ids, device_ids


async def save_process_metadata(
def write_process_metadata(
directory: epath.Path,
global_mesh: jax.sharding.Mesh,
device_ids: List[int],
distributed_to_device_ids: List[List[int]],
):
"""Saves process metadata to local storage. Runs on every process."""
"""Synchronously writes process metadata to local storage."""
metadata_folder = process_metadata_folder(directory)
metadata_folder.mkdir(parents=True, exist_ok=True)
logging.info('Saving process index metadata at %s', metadata_folder)

(metadata_folder / _GLOBAL_PROCESS_METADATA_FILE_NAME).write_text(
json.dumps(distributed_to_device_ids)
)
(metadata_folder / _MESH_METADATA_FILE_NAME).write_text(
json.dumps([int(id) for id in global_mesh.device_ids.flatten()])
json.dumps(device_ids)
)


async def save_process_metadata(
directory: epath.Path,
global_mesh: jax.sharding.Mesh,
distributed_to_device_ids: List[List[int]],
):
"""Saves process metadata to local storage. Runs on every process."""
device_ids = [int(id) for id in global_mesh.device_ids.flatten()]
write_process_metadata(directory, device_ids, distributed_to_device_ids)


def consistent_restore_mesh_from_metadata(
global_mesh: jax.sharding.Mesh,
current_distributed_to_device_ids: List[List[int]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
# limitations under the License.

"""Initialization for multi-tier checkpointing."""

import os
import time
from typing import List, Optional

from absl import logging
from etils import epath
import jax
from jax.experimental import colocated_python
import numpy as np
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.multihost import multislice


_REPLICATOR_FILE = 'replicator.yaml'
_TEMP_REPLICATOR_FILE_NAME = _REPLICATOR_FILE + '.tmp'
_JAX_INIT_INFO_FILE = 'jax-init-info.txt'
Expand Down Expand Up @@ -70,6 +74,108 @@ def _create_replicator_file(
os.rename(temp_file, replicator_file)


def _initialize_mtc_colocated(
local_checkpoint_directory: epath.Path,
backup_interval_minutes: int,
num_slices: int,
run_name: str,
data_parallelism: int,
timeout_seconds: int,
):
"""Initializes multi-tier checkpointing with Colocated Python.

Args:
local_checkpoint_directory: The local checkpoint directory.
backup_interval_minutes: The backup interval in minutes.
num_slices: The number of slices.
run_name: The run name.
data_parallelism: The data parallelism.
timeout_seconds: The timeout in seconds.
"""
# 1. Obtain CPU devices for all remote hosts
cpu_devices = colocated_python.colocated_cpu_devices(jax.devices())

# Ensure one CPU device per process (worker node).
unique_cpu_devices = list(
{dev.process_index: dev for dev in cpu_devices}.values()
)
num_nodes = len(unique_cpu_devices)
nodes_per_slice = max(1, num_nodes // num_slices)

# 2. Pre-calculate the node_rank and peer_ranks for EVERY node
all_node_ranks = np.arange(num_nodes, dtype=np.int32)
all_peer_ranks = []
for nr in range(num_nodes):
my_in_pipeline_index = nr % nodes_per_slice
peers = [
i * nodes_per_slice + my_in_pipeline_index
for i in range(num_slices)
if (i * nodes_per_slice + my_in_pipeline_index) != nr
]
all_peer_ranks.append(peers)

# Handle single-slice edge case where peers list is empty
if not all_peer_ranks[0]:
all_peer_ranks = np.zeros((num_nodes, 0), dtype=np.int32)
else:
all_peer_ranks = np.array(all_peer_ranks, dtype=np.int32)

# 3. Create a 1D Mesh over the remote hosts and shard the configuration arrays
cpu_mesh = jax.sharding.Mesh(np.array(unique_cpu_devices), ('d',))
sharding = jax.sharding.NamedSharding(
cpu_mesh, jax.sharding.PartitionSpec('d')
)

# JAX distributes these arrays across the workers natively
sharded_node_ranks = jax.device_put(all_node_ranks, sharding)
sharded_peer_ranks = jax.device_put(all_peer_ranks, sharding)

# 4. Define the SPMD closure that runs on each remote worker
def _setup(local_nr_arr, local_pr_arr):
loc_dir = epath.Path(local_checkpoint_directory)

# JAX sharding slices the arrays into chunks of shape (1,) and (1, P).
# We must index at [0] to extract the pure scalar and the flat list!
node_rank = int(np.asarray(local_nr_arr)[0])
peer_ranks = np.asarray(local_pr_arr)[0].tolist()

_wait_for_replicator_file_to_disappear(
loc_dir, timeout_seconds=timeout_seconds
)

_create_replicator_file(
loc_dir,
run_name=run_name,
num_nodes=num_nodes,
data_parallelism=data_parallelism,
node_rank=node_rank,
peer_ranks=peer_ranks,
backup_interval_minutes=backup_interval_minutes,
)

_wait_for_replicator_file_to_disappear(
loc_dir, timeout_seconds=timeout_seconds
)
_block_and_process_restore_dir(loc_dir, timeout_seconds=timeout_seconds)

# Return array to satisfy SPMD device matching
return local_nr_arr

# 5. Wrap and dispatch using native JAX SPMD!
wrapped_setup_fn = colocated_python.colocated_python(_setup)
wrapped_setup_fn = wrapped_setup_fn.specialize(out_specs_fn=lambda x, y: x)

# Triggers concurrent execution across all nodes without a thread pool
jax.block_until_ready(
wrapped_setup_fn(sharded_node_ranks, sharded_peer_ranks)
)

logging.info(
'Successfully initialized multi-tier checkpointing on all remote hosts '
'via Colocated Python.'
)


def _initialize_jax_from_mtc(
local_checkpoint_directory: epath.Path,
jax_initialization_timeout_seconds: int = 900,
Expand Down Expand Up @@ -107,6 +213,7 @@ def initialize_multi_tier_checkpointing(
data_parallelism: Optional[int] = None,
jax_initialization_timeout_seconds: int = 900,
use_mtc_process_ids: bool = True,
use_colocated_python: bool = False,
):
"""Initializes multi-tier checkpointing.

Expand All @@ -116,12 +223,34 @@ def initialize_multi_tier_checkpointing(
minutes.
num_slices: The number of slices.
run_name: The name of the run.
data_parallelism: Number of identical pipelines in job, should be
equal to ICI data parallelism * DCN data parallelism. If not provided, it
will be inferred from the number of slices.
data_parallelism: Number of identical pipelines in job, should be equal to
ICI data parallelism * DCN data parallelism. If not provided, it will be
inferred from the number of slices.
jax_initialization_timeout_seconds: The timeout for JAX initialization.
use_mtc_process_ids: Use the MTC rank server to calculate process ids.
use_colocated_python: Whether to use Colocated Python for initialization.
"""
run_name = run_name if run_name else os.environ.get('JOBSET_NAME')
num_slices = num_slices or multislice.slice_count()
data_parallelism = data_parallelism or num_slices
if not run_name:
raise ValueError(
'Run name is not set and JOBSET_NAME is not set in the environment.'
)

if use_colocated_python:
logging.info('Initializing multi-tier checkpointing via Colocated Python.')
_initialize_mtc_colocated(
local_checkpoint_directory=local_checkpoint_directory,
backup_interval_minutes=backup_interval_minutes,
num_slices=num_slices,
run_name=run_name,
data_parallelism=data_parallelism,
timeout_seconds=jax_initialization_timeout_seconds,
)
return

# Standard Multi-Controller Path
if use_mtc_process_ids:
process_id = _initialize_jax_from_mtc(
local_checkpoint_directory, jax_initialization_timeout_seconds
Expand All @@ -135,14 +264,9 @@ def initialize_multi_tier_checkpointing(
multihost.initialize_runtime_to_distributed_ids()
multihost.initialize_distributed_to_device_ids()
_wait_for_replicator_file_to_disappear(local_checkpoint_directory)
num_slices = (
num_slices
or multislice.slice_count()
)
num_nodes = jax.process_count()
nodes_per_slice = num_nodes // num_slices
node_rank = jax._src.distributed.global_state.process_id # pylint: disable=protected-access
data_parallelism = data_parallelism or num_slices
my_process_index = jax.process_index()
process_index_to_node_rank = (
multihost.runtime_to_distributed_ids()
Expand Down Expand Up @@ -173,11 +297,7 @@ def initialize_multi_tier_checkpointing(
peer_process_rank = process_index_to_node_rank[peer_process_index]
peer_ranks.append(peer_process_rank)
logging.info('Peers for NodeRank %s: %s', node_rank, peer_ranks)
run_name = run_name if run_name else os.environ.get('JOBSET_NAME')
if not run_name:
raise ValueError(
'Run name is not set and JOBSET_NAME is not set in the environment.'
)

_create_replicator_file(
local_checkpoint_directory,
run_name=run_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,43 @@ def test_initialize_multi_tier_checkpointing_run_name_not_set(
num_slices=1,
run_name="",
)
mock_jax_distributed_initialize.assert_called_once_with(
process_id=0,
coordinator_address="coordinator_address",
initialization_timeout=900,

mock_jax_distributed_initialize.assert_not_called()
mock_initialize_runtime_to_distributed_ids.assert_not_called()
mock_initialize_distributed_to_device_ids.assert_not_called()
self.assertEqual(mock_wait_for_replicator_file_to_disappear.call_count, 0)

@mock.patch.object(initialization, "_initialize_mtc_colocated", autospec=True)
@mock.patch.object(jax.distributed, "initialize", autospec=True)
def test_initialize_multi_tier_checkpointing_colocated_success(
self,
mock_jax_distributed_initialize,
mock_init_mtc_colocated,
):
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = epath.Path(tmp_dir)

initialization.initialize_multi_tier_checkpointing(
tmp_dir_path,
num_slices=1,
run_name="test-colocated-run",
data_parallelism=1,
use_colocated_python=True,
backup_interval_minutes=15,
)
mock_initialize_runtime_to_distributed_ids.assert_called_once()
mock_initialize_distributed_to_device_ids.assert_called_once()
self.assertEqual(mock_wait_for_replicator_file_to_disappear.call_count, 1)

# Verify colocated Python path is taken
mock_init_mtc_colocated.assert_called_once_with(
local_checkpoint_directory=tmp_dir_path,
backup_interval_minutes=15,
num_slices=1,
run_name="test-colocated-run",
data_parallelism=1,
timeout_seconds=900,
)

# Verify standard multi-controller JAX init is bypassed
mock_jax_distributed_initialize.assert_not_called()

@mock.patch.object(
initialization, "_wait_for_replicator_file_to_disappear", autospec=True
Expand Down
Loading
Loading