Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions checkpoint/orbax/checkpoint/_src/multihost/multihost.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Orbax utils related to multihost_utils functionality."""

import functools
import threading
import time
from typing import List, Optional, Protocol, Set
Expand Down Expand Up @@ -441,3 +442,32 @@ def process_index_from_device(device: jax.Device) -> int:
def unique_processes_from_devices(device_array: np.ndarray) -> Set[int]:
get_pids_from_devices = np.vectorize(process_index_from_device)
return set(get_pids_from_devices(device_array).flat)


def global_max(values: list[int]) -> list[int]:
"""Computes the global max of a list of integers across all hosts."""
host_mesh = jax.sharding.Mesh(
np.asarray(jax.devices()).reshape(
process_count(), jax.local_device_count()
),
['host', 'dev'],
)
sharding = jax.sharding.NamedSharding(
host_mesh, jax.sharding.PartitionSpec('host', None)
)
local_array = np.array([values], dtype=np.int32)
# Create the global array, which is sharded across hosts.
global_array = jax.make_array_from_process_local_data(sharding, local_array)

@jax.jit
@functools.partial(
jax.shard_map,
mesh=host_mesh,
in_specs=jax.sharding.PartitionSpec('host', None),
out_specs=jax.sharding.PartitionSpec(),
)
def reduce_max_fn(x):
return jax.lax.pmax(x, axis_name='host')

max_values_array = reduce_max_fn(global_array).squeeze(axis=0)
return list(np.asarray(max_values_array).astype(int))
53 changes: 53 additions & 0 deletions checkpoint/orbax/checkpoint/_src/multihost/pathways.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pathways-specific multihost utilities."""

import functools
import jax
import numpy as np
from .learning.deepmind.jax.ocean.remote_python import rp


@functools.lru_cache(maxsize=1)
def worker_count() -> int:
"""Gets the number of Pathways workers."""
fully_replicated_sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), 'x'),
jax.sharding.PartitionSpec(),
)

@rp.stateless_fn
def _get_worker_count(_) -> jax.Array:
wc = np.asarray(jax.process_count(), dtype=np.int32)
return jax.make_array_from_callback(
(),
fully_replicated_sharding,
lambda _: wc,
dtype=np.int32,
)

dummy_input = jax.device_put(
np.asarray(0, dtype=np.int32),
device=fully_replicated_sharding,
)
_get_worker_count.register_shape_fn(
lambda _: jax.ShapeDtypeStruct(
(), dtype=np.int32, sharding=fully_replicated_sharding
)
)
result = _get_worker_count(rp.to_remote_python(dummy_input))
jax.block_until_ready(result)
result = rp.from_remote_python(result)
return result.item()
6 changes: 4 additions & 2 deletions checkpoint/orbax/checkpoint/_src/path/deleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
import threading
import time
from typing import Optional, Protocol, Sequence

from absl import logging
from etils import epath
import jax
from orbax.checkpoint import utils
from orbax.checkpoint._src.logging import event_tracking
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import gcs_utils
from orbax.checkpoint._src.path import step as step_lib


PurePosixPath = pathlib.PurePosixPath

_THREADED_DELETE_DURATION = (
Expand Down Expand Up @@ -155,7 +157,7 @@ def delete(self, step: int) -> None:
"""
start = time.time()
try:
if not utils.is_primary_host(self._primary_host):
if not multihost.is_primary_host(self._primary_host):
logging.info(
'Not primary host(%s), skipping deletion of step %d.',
self._primary_host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
"""

import asyncio
import collections
import dataclasses
import functools
import time
Expand All @@ -35,7 +34,6 @@
from etils import epath
from etils import epy
import jax
from jax.experimental import multihost_utils
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import abstract_checkpoint_manager
Expand All @@ -55,6 +53,7 @@
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint.experimental.emergency import local_checkpoint_data_debugging
from orbax.checkpoint.experimental.emergency import mesh_consistency
from orbax.checkpoint.experimental.emergency import path as emergency_path_utils
from orbax.checkpoint.experimental.emergency import process_metadata_checkpoint_handler
from typing_extensions import override
from typing_extensions import Self # for Python version < 3.11
Expand Down Expand Up @@ -310,80 +309,6 @@ class CheckpointManagerOptions:
single_host_load_and_broadcast: bool = True


def _common_values_per_slice(
per_process_values: Dict[int, Set[int]],
global_mesh: jax.sharding.Mesh,
*,
replica_axis_index: int,
) -> Dict[int, Set[int]]:
"""Obtains values shared in common across all processes in each slice.

Args:
per_process_values: A mapping of process index to a list of values local to
that process.
global_mesh: The global mesh.
replica_axis_index: The index of the replica axis in the global mesh.

Returns:
A mapping of slice index to a set of values shared in common across all
processes in that slice. A value appearing in one process but not another
in the same slice will not appear in the output.
"""
total_num_replicas = multislice.replica_count(
global_mesh, replica_axis_index=replica_axis_index
)
num_processes_per_replica = (
global_mesh.devices.size // total_num_replicas // jax.local_device_count()
)
per_replica_values = collections.defaultdict(list)
for pid, values in per_process_values.items():
replica_id = multislice.process_replica_id(
pid, global_mesh, replica_axis_index=replica_axis_index
)
per_replica_values[replica_id].extend(values)

for replica_id, values in per_replica_values.items():
counter = collections.Counter(values)
common_values = [
k for k in counter if counter[k] == num_processes_per_replica
]
# Here `len(common_values)`` will be less than or equal to `len(values)`
# because a value can only appear in `common_values` if it occurs
# `num_processes_per_slice` times in `values`.
if len(common_values) > len(values):
raise AssertionError(
f' len(common_values) ({common_values}) exceeded length of input'
f' values ({values}).'
)
per_replica_values[replica_id] = common_values

return {k: set(v) for k, v in per_replica_values.items()}


def _global_max(values: list[int]) -> list[int]:
"""Computes the global max of a list of values across all hosts."""
host_mesh = jax.sharding.Mesh(
np.asarray(jax.devices()).reshape(
multihost.process_count(), jax.local_device_count()
),
['host', 'dev'],
)
sharding = jax.sharding.NamedSharding(host_mesh, P('host', None))
local_array = np.array([values], dtype=np.int32)
# Create the global array, which is sharded across hosts.
global_array = jax.make_array_from_process_local_data(sharding, local_array)

@jax.jit
@functools.partial(
jax.shard_map, mesh=host_mesh, in_specs=P('host', None), out_specs=P()
)
def reduce_max_fn(x):
return jax.lax.pmax(x, axis_name='host')

max_values_array = reduce_max_fn(global_array).squeeze(axis=0)
return list(np.asarray(max_values_array).astype(int))


class _LocalCheckpointManager(checkpoint_manager.CheckpointManager):
"""A checkpoint manager that checkpoints to local storage."""

Expand Down Expand Up @@ -825,7 +750,7 @@ def should_save(self, step: int) -> bool:
should_save = self._persistent_checkpoint_manager.should_save(step)
else:
should_save = self._local_checkpoint_manager.should_save(step)
return bool(_global_max([int(should_save)])[0])
return bool(multihost.global_max([int(should_save)])[0])

def delete(self, step: int):
"""Deletes a step checkpoint."""
Expand Down Expand Up @@ -901,7 +826,8 @@ def save(

start = time.time()
saved = tuple(
bool(e) for e in _global_max([int(persistent_saved), int(local_saved)])
bool(e)
for e in multihost.global_max([int(persistent_saved), int(local_saved)])
)
persistent_saved, local_saved = saved
logging.info('Broadcast `saved` bool in %f seconds.', time.time() - start)
Expand All @@ -914,42 +840,12 @@ def save(
return persistent_saved or local_saved

def _get_per_slice_local_steps(self) -> Dict[int, Set[int]]:
"""Gets the set of steps present in each slice from all hosts."""
local_steps = set(step_lib.checkpoint_steps(self._local_directory))
logging.info(
'Found steps: %s in local host storage: %s.',
local_steps,
return emergency_path_utils.get_per_replica_local_steps(
self._local_directory,
)

num_local_steps = len(local_steps)
max_num_local_steps = _global_max([num_local_steps])[0]
# Pad the local steps so all hosts have an array of the same length.
padded_local_steps = list(local_steps) + [-1] * (
max_num_local_steps - num_local_steps
)
local_steps_per_process_array = np.array(
[multihost.process_index()] + padded_local_steps, dtype=np.int32
)

# Use all_gather to collect the arrays from every host.
global_steps_per_process = multihost_utils.process_allgather(
local_steps_per_process_array, tiled=False
)

# The rest of the logic works on the gathered NumPy array.
per_process_steps = {}
for process_and_steps in global_steps_per_process:
per_process_steps[process_and_steps[0]] = set(
s for s in process_and_steps[1:] if s != -1
)
per_slice_steps = _common_values_per_slice(
per_process_steps,
self._global_mesh,
step_name_format=self._options.step_name_format,
global_mesh=self._global_mesh,
replica_axis_index=self._replica_axis_index,
)
logging.vlog(1, 'per_slice_steps=%s', per_slice_steps)
return per_slice_steps

def _find_slice_with_complete_local_checkpoint(self, step: int) -> int:
"""Return the slice id which has the step."""
Expand Down
Loading
Loading