Skip to content

Commit b0e19da

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Eliminate usages of ParamInfo.path in favor of parent_dir / name.
PiperOrigin-RevId: 823187713
1 parent e4d8241 commit b0e19da

19 files changed

+131
-79
lines changed

checkpoint/orbax/checkpoint/_src/handlers/array_checkpoint_handler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,8 @@ def restore(
153153
# checkpoints lacking PYTREE_METADATA_FILE is no longer needed.
154154
restore_args = args.restore_args or type_handlers.RestoreArgs()
155155

156-
checkpoint_path = directory / self._checkpoint_name
157156
info = type_handlers.ParamInfo(
158157
name=self._checkpoint_name,
159-
path=checkpoint_path,
160158
parent_dir=directory,
161159
skip_deserialize=False,
162160
is_ocdbt_checkpoint=type_handlers.is_ocdbt_checkpoint(directory),

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,6 @@ def _param_info(keypath, name, value):
474474
return ParamInfo(
475475
name=name,
476476
keypath=keypath,
477-
path=(directory / name),
478477
parent_dir=directory,
479478
skip_deserialize=skip_deserialize,
480479
is_ocdbt_checkpoint=use_ocdbt,

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ def _get_param_info(
268268
skip_deserialize = meta_or_value.skip_deserialize
269269
return ParamInfo(
270270
name=name,
271-
path=directory / name,
272271
parent_dir=directory,
273272
skip_deserialize=skip_deserialize,
274273
is_ocdbt_checkpoint=is_ocdbt_checkpoint,
@@ -283,7 +282,9 @@ def _get_param_info(
283282
if partial_restore:
284283
for key, meta in flat_structure.items():
285284
if key not in flat_item:
286-
flat_param_infos[key] = ParamInfo(skip_deserialize=True)
285+
flat_param_infos[key] = ParamInfo(
286+
name='', parent_dir=directory, skip_deserialize=True
287+
)
287288
flat_input_restore_args[key] = RestoreArgs()
288289
else:
289290
flat_param_infos[key] = _get_param_info(flat_param_names[key], meta)
@@ -322,7 +323,9 @@ def _get_param_info(
322323
# Specified `use_fallback`, but key was also present in the
323324
# checkpoint. This means we should skip loading, since it will be
324325
# overridden with a new value.
325-
flat_param_infos[input_key] = ParamInfo(skip_deserialize=True)
326+
flat_param_infos[input_key] = ParamInfo(
327+
name='', parent_dir=directory, skip_deserialize=True
328+
)
326329
flat_input_restore_args[input_key] = RestoreArgs()
327330
else:
328331
# Specified `use_fallback`, but `transforms_default_to_original`
@@ -343,12 +346,16 @@ def _get_param_info(
343346
else:
344347
# Take the value from the user-provided `item`, ignoring any value
345348
# in the checkpoint.
346-
flat_param_infos[input_key] = ParamInfo(skip_deserialize=True)
349+
flat_param_infos[input_key] = ParamInfo(
350+
name='', parent_dir=directory, skip_deserialize=True
351+
)
347352
flat_input_restore_args[input_key] = RestoreArgs()
348353
else:
349354
# No match, restoration not required since it will be dropped from the
350355
# output.
351-
flat_param_infos[input_key] = ParamInfo(skip_deserialize=True)
356+
flat_param_infos[input_key] = ParamInfo(
357+
name='', parent_dir=directory, skip_deserialize=True
358+
)
352359
flat_input_restore_args[input_key] = RestoreArgs()
353360

354361
restore_args = tree_utils.from_flat_dict(

checkpoint/orbax/checkpoint/_src/metadata/tree.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,6 @@ def as_custom_metadata(
485485
param_name = '.'.join(keypath)
486486
flat_param_infos[keypath] = types.ParamInfo(
487487
name=param_name,
488-
path=directory / param_name,
489488
parent_dir=directory,
490489
skip_deserialize=value_meta.skip_deserialize,
491490
is_ocdbt_checkpoint=use_ocdbt,

checkpoint/orbax/checkpoint/_src/metadata/tree_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from absl.testing import absltest
1818
from absl.testing import parameterized
1919
import chex
20+
from etils import epath
2021
import jax
2122
import numpy as np
2223
from orbax.checkpoint._src.metadata import tree as tree_metadata_lib
@@ -43,6 +44,8 @@ def _to_param_infos(
4344
return jax.tree.map(
4445
# Other properties are not relevant.
4546
lambda x: types.ParamInfo(
47+
name='',
48+
parent_dir=epath.Path(''),
4649
value_typestr=type_handler_registry.get_param_typestr(
4750
x,
4851
type_handler_registry.GLOBAL_TYPE_HANDLER_REGISTRY,

checkpoint/orbax/checkpoint/_src/multihost/pathways.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,27 @@
1616

1717
import functools
1818
import jax
19-
import numpy as np
20-
from .learning.deepmind.jax.ocean.remote_python import rp
2119

2220

2321
@functools.lru_cache(maxsize=1)
24-
def worker_count() -> int:
25-
"""Gets the number of Pathways workers."""
26-
fully_replicated_sharding = jax.sharding.NamedSharding(
27-
jax.sharding.Mesh(jax.devices(), 'x'),
28-
jax.sharding.PartitionSpec(),
29-
)
22+
def worker_count(global_mesh: jax.sharding.Mesh | None) -> int:
23+
"""Gets the number of Pathways workers.
3024
31-
@rp.stateless_fn
32-
def _get_worker_count(_) -> jax.Array:
33-
wc = np.asarray(jax.process_count(), dtype=np.int32)
34-
return jax.make_array_from_callback(
35-
(),
36-
fully_replicated_sharding,
37-
lambda _: wc,
38-
dtype=np.int32,
39-
)
25+
Args:
26+
global_mesh: The global mesh of active devices. If None is provided,
27+
`jax.devices()` will be used.
4028
41-
dummy_input = jax.device_put(
42-
np.asarray(0, dtype=np.int32),
43-
device=fully_replicated_sharding,
44-
)
45-
_get_worker_count.register_shape_fn(
46-
lambda _: jax.ShapeDtypeStruct(
47-
(), dtype=np.int32, sharding=fully_replicated_sharding
48-
)
49-
)
50-
result = _get_worker_count(rp.to_remote_python(dummy_input))
51-
jax.block_until_ready(result)
52-
result = rp.from_remote_python(result)
53-
return result.item()
29+
Returns:
30+
The number of Pathways workers in the mesh.
31+
"""
32+
global_mesh = global_mesh or jax.sharding.Mesh(jax.devices(), 'x')
33+
devices = global_mesh.devices.flatten()
34+
workers = set()
35+
for d in devices:
36+
attrs = []
37+
if hasattr(d, 'virtual_task_index'):
38+
attrs.append(d.virtual_task_index)
39+
if hasattr(d, 'slice_index'):
40+
attrs.append(d.slice_index)
41+
workers.add(tuple(attrs))
42+
return len(workers)

checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ async def deserialize(
656656
)
657657
if not info.is_ocdbt_checkpoint:
658658
await ts_utils.assert_parameter_files_exist(
659-
info.path,
659+
info.parent_dir / info.name,
660660
self._metadata_key,
661661
info.use_zarr3,
662662
)
@@ -887,8 +887,8 @@ async def deserialize(
887887
for info, arg in zip(infos, args):
888888
arg = cast(SingleReplicaArrayRestoreArgs, arg)
889889
if not info.is_ocdbt_checkpoint:
890-
await ts_utils.assert_parameter_files_exist( # pylint: disable=protected-access
891-
info.path, self._metadata_key
890+
await ts_utils.assert_parameter_files_exist(
891+
info.parent_dir / info.name, self._metadata_key
892892
)
893893
if not isinstance(arg, SingleReplicaArrayRestoreArgs):
894894
raise ValueError(

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,8 @@ def _get_json_tspec(
582582
raise_array_data_missing_error: bool = True,
583583
) -> dict[str, Any]:
584584
"""Gets Tensorstore spec in JSON format."""
585-
if info.path is None:
586-
raise ValueError('Must construct serialization path.')
585+
if info.name is None or info.parent_dir is None:
586+
raise ValueError('Must provide info.name and info.parent_dir.')
587587
parent_dir = info.parent_dir
588588
assert parent_dir is not None
589589
directory = parent_dir.as_posix()
@@ -690,8 +690,8 @@ def build_array_write_spec(
690690
ext_metadata: dict[str, Any] | None = None,
691691
) -> ArrayWriteSpec:
692692
"""Gets ArrayWriteSpec for writing."""
693-
if info.path is None:
694-
raise ValueError('Must construct serialization path.')
693+
if info.name is None or info.parent_dir is None:
694+
raise ValueError('Must provide info.name and info.parent_dir.')
695695
parent_dir = info.parent_dir
696696
assert parent_dir is not None
697697
directory = parent_dir.as_posix()

checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from typing import Any, Dict, Optional, Sequence, Tuple, TypeAlias, Union
2323

2424
from absl import logging
25-
from etils import epath
2625
import jax
2726
import numpy as np
2827
from orbax.checkpoint._src.futures import future
@@ -198,7 +197,7 @@ async def deserialize(
198197
for info, arg in zip(infos, args):
199198
if not info.is_ocdbt_checkpoint:
200199
await ts_utils.assert_parameter_files_exist(
201-
info.path, self._metadata_key, info.use_zarr3
200+
info.parent_dir / info.name, self._metadata_key, info.use_zarr3
202201
)
203202
# Use OCDBT flag from the existing checkpoint.
204203
use_ocdbt = info.is_ocdbt_checkpoint
@@ -306,8 +305,8 @@ def _get_json_tspec(
306305
info: types.ParamInfo,
307306
) -> Dict[str, Any]:
308307
"""Gets Tensorstore spec in JSON format."""
309-
if info.path is None:
310-
raise ValueError('Must construct serialization path.')
308+
if info.parent_dir is None:
309+
raise ValueError('Must provide info.parent_dir.')
311310
directory = (info.parent_dir / self._filename).as_posix()
312311
kvstore_tspec = ts_utils.build_kvstore_tspec(directory, use_ocdbt=False)
313312
tspec = {
@@ -380,11 +379,9 @@ async def deserialize(
380379
"""See superclass documentation."""
381380
del args
382381
types.check_input_arguments(infos)
383-
directory = epath.Path(infos[0].path).parent
384382
open_futures = []
385383

386384
for info in infos:
387-
info.path = epath.Path(directory / self._filename)
388385
tspec = self._get_json_tspec(info)
389386
open_future = ts.open(
390387
tspec, open=True, read=True, context=self._ts_context

checkpoint/orbax/checkpoint/_src/serialization/types.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def check_input_arguments(*args):
6060
raise ValueError('Found input args with mismatched lengths.')
6161

6262

63-
@dataclasses.dataclass
63+
@dataclasses.dataclass(kw_only=True)
6464
class ParamInfo:
6565
"""Information describing a parameter in a PyTree.
6666
@@ -70,9 +70,6 @@ class ParamInfo:
7070
7171
name:
7272
Name of the parameter.
73-
path:
74-
A path providing a location where file(s) should be saved. The path is
75-
assumed to be a directory.
7673
parent_dir:
7774
A path providing location where all files under the same checkpoint should
7875
be saved under. All `ParamInfo` provided to a given TypeHandler should have
@@ -115,10 +112,9 @@ class ParamInfo:
115112
is_prioritized_key_fn: See `IsPrioritizedKeyFn` definition.
116113
"""
117114

118-
name: Optional[str] = None
115+
name: str
116+
parent_dir: epath.Path
119117
keypath: Optional[Tuple[Any, ...]] = None
120-
path: Optional[epath.Path] = None
121-
parent_dir: Optional[epath.Path] = None
122118
skip_deserialize: Optional[bool] = None
123119
byte_limiter: Optional[limits.ByteLimiter] = None
124120
device_host_byte_limiter: Optional[limits.ByteLimiter] = None

0 commit comments

Comments
 (0)