diff --git a/checkpoint/orbax/checkpoint/_src/multihost/multihost.py b/checkpoint/orbax/checkpoint/_src/multihost/multihost.py index fd6bcf5df..b81cb3e05 100644 --- a/checkpoint/orbax/checkpoint/_src/multihost/multihost.py +++ b/checkpoint/orbax/checkpoint/_src/multihost/multihost.py @@ -14,6 +14,7 @@ """Orbax utils related to multihost_utils functionality.""" +import functools import threading import time from typing import List, Optional, Protocol, Set @@ -441,3 +442,32 @@ def process_index_from_device(device: jax.Device) -> int: def unique_processes_from_devices(device_array: np.ndarray) -> Set[int]: get_pids_from_devices = np.vectorize(process_index_from_device) return set(get_pids_from_devices(device_array).flat) + + +def global_max(values: list[int]) -> list[int]: + """Computes the global max of a list of integers across all hosts.""" + host_mesh = jax.sharding.Mesh( + np.asarray(jax.devices()).reshape( + process_count(), jax.local_device_count() + ), + ['host', 'dev'], + ) + sharding = jax.sharding.NamedSharding( + host_mesh, jax.sharding.PartitionSpec('host', None) + ) + local_array = np.array([values], dtype=np.int32) + # Create the global array, which is sharded across hosts. + global_array = jax.make_array_from_process_local_data(sharding, local_array) + + @jax.jit + @functools.partial( + jax.shard_map, + mesh=host_mesh, + in_specs=jax.sharding.PartitionSpec('host', None), + out_specs=jax.sharding.PartitionSpec(), + ) + def reduce_max_fn(x): + return jax.lax.pmax(x, axis_name='host') + + max_values_array = reduce_max_fn(global_array).squeeze(axis=0) + return list(np.asarray(max_values_array).astype(int)) diff --git a/checkpoint/orbax/checkpoint/_src/multihost/pathways.py b/checkpoint/orbax/checkpoint/_src/multihost/pathways.py new file mode 100644 index 000000000..ab5a67c53 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/multihost/pathways.py @@ -0,0 +1,53 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pathways-specific multihost utilities.""" + +import functools +import jax +import numpy as np +from .learning.deepmind.jax.ocean.remote_python import rp + + +@functools.lru_cache(maxsize=1) +def worker_count() -> int: + """Gets the number of Pathways workers.""" + fully_replicated_sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), 'x'), + jax.sharding.PartitionSpec(), + ) + + @rp.stateless_fn + def _get_worker_count(_) -> jax.Array: + wc = np.asarray(jax.process_count(), dtype=np.int32) + return jax.make_array_from_callback( + (), + fully_replicated_sharding, + lambda _: wc, + dtype=np.int32, + ) + + dummy_input = jax.device_put( + np.asarray(0, dtype=np.int32), + device=fully_replicated_sharding, + ) + _get_worker_count.register_shape_fn( + lambda _: jax.ShapeDtypeStruct( + (), dtype=np.int32, sharding=fully_replicated_sharding + ) + ) + result = _get_worker_count(rp.to_remote_python(dummy_input)) + jax.block_until_ready(result) + result = rp.from_remote_python(result) + return result.item() diff --git a/checkpoint/orbax/checkpoint/_src/path/deleter.py b/checkpoint/orbax/checkpoint/_src/path/deleter.py index 4c8566fba..b88a1dc58 100644 --- a/checkpoint/orbax/checkpoint/_src/path/deleter.py +++ b/checkpoint/orbax/checkpoint/_src/path/deleter.py @@ -21,14 +21,16 @@ import threading import time from typing import Optional, Protocol, Sequence + from absl import logging from etils import epath import jax -from orbax.checkpoint import utils from orbax.checkpoint._src.logging import event_tracking +from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import gcs_utils from orbax.checkpoint._src.path import step as step_lib + PurePosixPath = pathlib.PurePosixPath _THREADED_DELETE_DURATION = ( @@ -155,7 +157,7 @@ def delete(self, step: int) -> None: """ start = time.time() try: - if not utils.is_primary_host(self._primary_host): + if not multihost.is_primary_host(self._primary_host): logging.info( 'Not primary host(%s), skipping deletion of step %d.', self._primary_host, diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py index f885a9dee..1461c3dc4 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py @@ -26,7 +26,6 @@ """ import asyncio -import collections import dataclasses import functools import time @@ -35,7 +34,6 @@ from etils import epath from etils import epy import jax -from jax.experimental import multihost_utils import jax.numpy as jnp import numpy as np from orbax.checkpoint import abstract_checkpoint_manager @@ -55,6 +53,7 @@ from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint.experimental.emergency import local_checkpoint_data_debugging from orbax.checkpoint.experimental.emergency import mesh_consistency +from orbax.checkpoint.experimental.emergency import path as emergency_path_utils from orbax.checkpoint.experimental.emergency import process_metadata_checkpoint_handler from typing_extensions import override from typing_extensions import Self # for Python version < 3.11 @@ -310,80 +309,6 @@ class CheckpointManagerOptions: single_host_load_and_broadcast: bool = True -def _common_values_per_slice( - per_process_values: Dict[int, Set[int]], - global_mesh: jax.sharding.Mesh, - *, - replica_axis_index: int, -) -> Dict[int, Set[int]]: - """Obtains values shared in common across all processes in each slice. - - Args: - per_process_values: A mapping of process index to a list of values local to - that process. - global_mesh: The global mesh. - replica_axis_index: The index of the replica axis in the global mesh. - - Returns: - A mapping of slice index to a set of values shared in common across all - processes in that slice. A value appearing in one process but not another - in the same slice will not appear in the output. - """ - total_num_replicas = multislice.replica_count( - global_mesh, replica_axis_index=replica_axis_index - ) - num_processes_per_replica = ( - global_mesh.devices.size // total_num_replicas // jax.local_device_count() - ) - per_replica_values = collections.defaultdict(list) - for pid, values in per_process_values.items(): - replica_id = multislice.process_replica_id( - pid, global_mesh, replica_axis_index=replica_axis_index - ) - per_replica_values[replica_id].extend(values) - - for replica_id, values in per_replica_values.items(): - counter = collections.Counter(values) - common_values = [ - k for k in counter if counter[k] == num_processes_per_replica - ] - # Here `len(common_values)`` will be less than or equal to `len(values)` - # because a value can only appear in `common_values` if it occurs - # `num_processes_per_slice` times in `values`. - if len(common_values) > len(values): - raise AssertionError( - f' len(common_values) ({common_values}) exceeded length of input' - f' values ({values}).' - ) - per_replica_values[replica_id] = common_values - - return {k: set(v) for k, v in per_replica_values.items()} - - -def _global_max(values: list[int]) -> list[int]: - """Computes the global max of a list of values across all hosts.""" - host_mesh = jax.sharding.Mesh( - np.asarray(jax.devices()).reshape( - multihost.process_count(), jax.local_device_count() - ), - ['host', 'dev'], - ) - sharding = jax.sharding.NamedSharding(host_mesh, P('host', None)) - local_array = np.array([values], dtype=np.int32) - # Create the global array, which is sharded across hosts. - global_array = jax.make_array_from_process_local_data(sharding, local_array) - - @jax.jit - @functools.partial( - jax.shard_map, mesh=host_mesh, in_specs=P('host', None), out_specs=P() - ) - def reduce_max_fn(x): - return jax.lax.pmax(x, axis_name='host') - - max_values_array = reduce_max_fn(global_array).squeeze(axis=0) - return list(np.asarray(max_values_array).astype(int)) - - class _LocalCheckpointManager(checkpoint_manager.CheckpointManager): """A checkpoint manager that checkpoints to local storage.""" @@ -825,7 +750,7 @@ def should_save(self, step: int) -> bool: should_save = self._persistent_checkpoint_manager.should_save(step) else: should_save = self._local_checkpoint_manager.should_save(step) - return bool(_global_max([int(should_save)])[0]) + return bool(multihost.global_max([int(should_save)])[0]) def delete(self, step: int): """Deletes a step checkpoint.""" @@ -901,7 +826,8 @@ def save( start = time.time() saved = tuple( - bool(e) for e in _global_max([int(persistent_saved), int(local_saved)]) + bool(e) + for e in multihost.global_max([int(persistent_saved), int(local_saved)]) ) persistent_saved, local_saved = saved logging.info('Broadcast `saved` bool in %f seconds.', time.time() - start) @@ -914,42 +840,12 @@ def save( return persistent_saved or local_saved def _get_per_slice_local_steps(self) -> Dict[int, Set[int]]: - """Gets the set of steps present in each slice from all hosts.""" - local_steps = set(step_lib.checkpoint_steps(self._local_directory)) - logging.info( - 'Found steps: %s in local host storage: %s.', - local_steps, + return emergency_path_utils.get_per_replica_local_steps( self._local_directory, - ) - - num_local_steps = len(local_steps) - max_num_local_steps = _global_max([num_local_steps])[0] - # Pad the local steps so all hosts have an array of the same length. - padded_local_steps = list(local_steps) + [-1] * ( - max_num_local_steps - num_local_steps - ) - local_steps_per_process_array = np.array( - [multihost.process_index()] + padded_local_steps, dtype=np.int32 - ) - - # Use all_gather to collect the arrays from every host. - global_steps_per_process = multihost_utils.process_allgather( - local_steps_per_process_array, tiled=False - ) - - # The rest of the logic works on the gathered NumPy array. - per_process_steps = {} - for process_and_steps in global_steps_per_process: - per_process_steps[process_and_steps[0]] = set( - s for s in process_and_steps[1:] if s != -1 - ) - per_slice_steps = _common_values_per_slice( - per_process_steps, - self._global_mesh, + step_name_format=self._options.step_name_format, + global_mesh=self._global_mesh, replica_axis_index=self._replica_axis_index, ) - logging.vlog(1, 'per_slice_steps=%s', per_slice_steps) - return per_slice_steps def _find_slice_with_complete_local_checkpoint(self, step: int) -> int: """Return the slice id which has the step.""" diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/path.py b/checkpoint/orbax/checkpoint/experimental/emergency/path.py new file mode 100644 index 000000000..9cb8b5c5d --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/emergency/path.py @@ -0,0 +1,120 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Path utilities for emergency checkpointing.""" + +import collections +from absl import logging +from etils import epath +import jax +from jax.experimental import multihost_utils +import numpy as np +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.multihost import multislice +from orbax.checkpoint._src.path import step as step_lib + + +def _common_values_per_replica( + per_process_values: dict[int, set[int]], + *, + global_mesh: jax.sharding.Mesh, + replica_axis_index: int, +) -> dict[int, set[int]]: + """Obtains values shared in common across all processes in each replica. + + Args: + per_process_values: A mapping of process index to a list of values local to + that process. + global_mesh: The global mesh. + replica_axis_index: The index of the replica axis in the global mesh. + + Returns: + A mapping of slice index to a set of values shared in common across all + processes in that slice. A value appearing in one process but not another + in the same slice will not appear in the output. + """ + total_num_replicas = multislice.replica_count( + global_mesh, replica_axis_index=replica_axis_index + ) + num_processes_per_replica = ( + global_mesh.devices.size // total_num_replicas // jax.local_device_count() + ) + per_replica_values = collections.defaultdict(list) + for pid, values in per_process_values.items(): + replica_id = multislice.process_replica_id( + pid, global_mesh, replica_axis_index=replica_axis_index + ) + per_replica_values[replica_id].extend(values) + + for replica_id, values in per_replica_values.items(): + counter = collections.Counter(values) + common_values = [ + k for k in counter if counter[k] == num_processes_per_replica + ] + # Here `len(common_values)`` will be less than or equal to `len(values)` + # because a value can only appear in `common_values` if it occurs + # `num_processes_per_slice` times in `values`. + if len(common_values) > len(values): + raise AssertionError( + f' len(common_values) ({common_values}) exceeded length of input' + f' values ({values}).' + ) + per_replica_values[replica_id] = common_values + + return {k: set(v) for k, v in per_replica_values.items()} + + +def get_per_replica_local_steps( + local_directory: epath.Path, + *, + step_name_format: step_lib.NameFormat[step_lib.Metadata], + global_mesh: jax.sharding.Mesh, + replica_axis_index: int, +) -> dict[int, set[int]]: + """Gets the set of steps present in each replica from all hosts.""" + local_steps = set(m.step for m in step_name_format.find_all(local_directory)) + logging.info( + 'Found steps: %s in local host storage: %s.', + local_steps, + local_directory, + ) + + num_local_steps = len(local_steps) + max_num_local_steps = multihost.global_max([num_local_steps])[0] + # Pad the local steps so all hosts have an array of the same length. + padded_local_steps = list(local_steps) + [-1] * ( + max_num_local_steps - num_local_steps + ) + local_steps_per_process_array = np.array( + [multihost.process_index()] + padded_local_steps, dtype=np.int32 + ) + + # Use all_gather to collect the arrays from every host. + global_steps_per_process = multihost_utils.process_allgather( + local_steps_per_process_array, tiled=False + ) + + # The rest of the logic works on the gathered NumPy array. + per_process_steps = {} + for process_and_steps in global_steps_per_process: + per_process_steps[process_and_steps[0]] = set( + s for s in process_and_steps[1:] if s != -1 + ) + per_slice_steps = _common_values_per_replica( + per_process_steps, + global_mesh=global_mesh, + replica_axis_index=replica_axis_index, + ) + logging.vlog(1, 'per_replica_steps=%s', per_slice_steps) + return per_slice_steps diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/test_utils/test_base.py b/checkpoint/orbax/checkpoint/experimental/emergency/test_utils/test_base.py index f943442f1..676c81dbc 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/test_utils/test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/test_utils/test_base.py @@ -38,6 +38,7 @@ from orbax.checkpoint._src.multihost import multislice from orbax.checkpoint.experimental.emergency import checkpoint_manager from orbax.checkpoint.experimental.emergency import mesh_consistency +from orbax.checkpoint.experimental.emergency import path as emergency_path_utils from orbax.checkpoint.experimental.emergency import process_metadata_checkpoint_handler from orbax.checkpoint.experimental.emergency.test_utils import dataset_iterator_checkpoint_handler @@ -597,8 +598,8 @@ def test_common_steps_per_slice(self, process_steps, expectation): per_process_steps = { pid: steps for pid, steps in enumerate(process_steps) } - result = checkpoint_manager._common_values_per_slice( # pylint: disable=protected-access - per_process_steps, self.global_mesh, replica_axis_index=0 + result = emergency_path_utils._common_values_per_replica( # pylint: disable=protected-access + per_process_steps, global_mesh=self.global_mesh, replica_axis_index=0 ) self.assertSameElements(result, expectation) @@ -752,7 +753,7 @@ def test_global_max(self, inputs, expectation): local_host_inputs = [local_host_inputs] expectation = [expectation] self.assertEqual( - checkpoint_manager._global_max(local_host_inputs), + multihost.global_max(local_host_inputs), expectation, ) diff --git a/checkpoint/orbax/checkpoint/testing/local_path.py b/checkpoint/orbax/checkpoint/testing/local_path.py index fe056b063..3d3467580 100644 --- a/checkpoint/orbax/checkpoint/testing/local_path.py +++ b/checkpoint/orbax/checkpoint/testing/local_path.py @@ -16,15 +16,16 @@ from __future__ import annotations -import typing -from typing import Iterator, Protocol +import os +import pathlib +from typing import Iterator from etils import epath from orbax.checkpoint._src.multihost import multihost -@typing.runtime_checkable -class LocalPath(Protocol): +@epath.register_path_cls +class LocalPath(pathlib.PurePosixPath): """A Path implementation for testing process-local paths. In the future, this class may more completely provide all functions and @@ -45,103 +46,9 @@ class LocalPath(Protocol): process index must be appended when path operations are performed. """ - @property - def path(self) -> epath.Path: - ... - - def exists(self) -> bool: - """Returns True if self exists.""" - ... - - def is_dir(self) -> bool: - """Returns True if self is a dir.""" - ... - - def is_file(self) -> bool: - """Returns True if self is a file.""" - ... - - def iterdir(self) -> Iterator[LocalPath]: - """Iterates over the directory.""" - ... - - def glob(self, pattern: str) -> Iterator[LocalPath]: - """Yields all matching files (of any kind).""" - ... - - def read_bytes(self) -> bytes: - """Reads contents of self as bytes.""" - ... - - def read_text(self, encoding: str | None = None) -> str: - """Reads contents of self as a string.""" - ... - - def mkdir( - self, - mode: int | None = None, - parents: bool = False, - exist_ok: bool = False, - ) -> None: - """Create a new directory at this given path.""" - ... - - def rmdir(self) -> None: - """Remove the empty directory at this given path.""" - ... - - def rmtree(self, missing_ok: bool = False) -> None: - """Remove the directory, including all sub-files.""" - ... - - def unlink(self, missing_ok: bool = False) -> None: - """Remove this file or symbolic link.""" - ... - - def write_bytes(self, data: bytes) -> int: - """Writes content as bytes.""" - ... - - def write_text( - self, - data: str, - encoding: str | None = None, - errors: str | None = None, - ) -> int: - """Writes content as str.""" - ... - - ### PosixPath methods ### - - def as_posix(self) -> str: - ... - - def __truediv__(self, key: epath.PathLike) -> epath.Path: - ... - - @property - def name(self) -> str: - ... - - @property - def parent(self) -> epath.Path: - ... - - -LocalPathLike = LocalPath | str - - -def _resolve_local_path_like(path: LocalPathLike) -> epath.Path: - if isinstance(path, LocalPath): - return typing.cast(LocalPath, path).path - return epath.Path(path) - - -class _LocalPathImpl(LocalPath): - """etils.epath.Path implementation for multiprocess tests.""" - - def __init__(self, path: epath.PathLike): - self._path = epath.Path(path) + def __init__(self, *parts: epath.PathLike): + super().__init__(*parts) + self._path = epath.Path('/'.join(os.fspath(p) for p in parts)) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.path})' @@ -168,11 +75,11 @@ def is_file(self) -> bool: def iterdir(self) -> Iterator[LocalPath]: """Iterates over the directory.""" - return (_LocalPathImpl(p) for p in self.path.iterdir()) + return (LocalPath(p) for p in self.path.iterdir()) def glob(self, pattern: str) -> Iterator[LocalPath]: """Yields all matching files (of any kind).""" - return (_LocalPathImpl(p) for p in self.path.glob(pattern)) + return (LocalPath(p) for p in self.path.glob(pattern)) def read_bytes(self) -> bytes: """Reads contents of self as bytes.""" @@ -230,8 +137,3 @@ def name(self) -> str: @property def parent(self) -> epath.Path: return self.path.parent - - -def local_path(path: epath.PathLike) -> LocalPath: - """Returns a LocalPath for the given path.""" - return _LocalPathImpl(path)