Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ devices anyway.
`load_checkpointables()` each with their own dedicated loading logic
- Refactor v0 Pytree validation and metadata resolution and add `OrbaxV0Layout`
tests
- Refactor logic for handler resolution and loading checkpointables for
`OrbaxLayout` and `OrbaxV0Layout`, adding additional fallback capabilities for
non-standard checkpoint formats.

### Fixed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
given checkpointable will be used.
"""

from typing import Type
from typing import Sequence, Type

from orbax.checkpoint.experimental.v1._src.handlers import json_handler
from orbax.checkpoint.experimental.v1._src.handlers import proto_handler
Expand All @@ -34,23 +34,45 @@
def _try_register_handler(
handler_type: Type[handler_types.CheckpointableHandler],
name: str | None = None,
recognized_handler_typestrs: Sequence[str] | None = None,
):
"""Tries to register handler globally with name and recognized typestrs."""
try:
registration.global_registry().add(handler_type, name)
registration.global_registry().add(
handler_type,
name,
recognized_handler_typestrs=recognized_handler_typestrs,
)
except registration.AlreadyExistsError:
pass


_try_register_handler(proto_handler.ProtoHandler)
_try_register_handler(json_handler.JsonHandler)
_try_register_handler(
proto_handler.ProtoHandler,
recognized_handler_typestrs=[
'orbax.checkpoint._src.handlers.proto_checkpoint_handler.ProtoCheckpointHandler',
],
)
_try_register_handler(
json_handler.JsonHandler,
recognized_handler_typestrs=[
'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler',
],
)
_try_register_handler(
stateful_checkpointable_handler.StatefulCheckpointableHandler
)
_try_register_handler(
json_handler.MetricsHandler,
checkpoint_layout.METRICS_CHECKPOINTABLE_KEY,
)
_try_register_handler(pytree_handler.PyTreeHandler)
_try_register_handler(
pytree_handler.PyTreeHandler,
recognized_handler_typestrs=[
'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler',
'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler',
],
)
_try_register_handler(
pytree_handler.PyTreeHandler, checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY
)
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ def add_all(
) -> CheckpointableHandlerRegistry:
"""Adds all entries from `other_registry` to `registry`."""
for handler, checkpointable in other_registry.get_all_entries():
registry.add(handler, checkpointable)
registry.add(
handler,
checkpointable,
recognized_handler_typestrs=other_registry.get_recognized_handler_typestrs(
handler
),
)
return registry


Expand All @@ -87,6 +93,7 @@ def add(
self,
handler_type: Type[CheckpointableHandler],
checkpointable: str | None = None,
recognized_handler_typestrs: Sequence[str] | None = None,
) -> CheckpointableHandlerRegistry:
"""Adds an entry to the registry."""
...
Expand All @@ -110,6 +117,13 @@ def get_all_entries(
) -> Sequence[RegistryEntry]:
...

def get_recognized_handler_typestrs(
self,
handler_type: Type[CheckpointableHandler],
) -> Sequence[str]:
"""Returns the recognized handler typestrs for a given handler_type."""
...


class AlreadyExistsError(ValueError):
"""Raised when an entry already exists in the registry."""
Expand All @@ -126,6 +140,9 @@ def __init__(
self, other_registry: CheckpointableHandlerRegistry | None = None
):
self._registry: list[RegistryEntry] = []
self._recognized_handler_typestrs: dict[
Type[CheckpointableHandler], Sequence[str]
] = {}

# Initialize the registry with entries from other registry.
if other_registry:
Expand All @@ -135,6 +152,7 @@ def add(
self,
handler_type: Type[CheckpointableHandler],
checkpointable: str | None = None,
recognized_handler_typestrs: Sequence[str] | None = None,
) -> CheckpointableHandlerRegistry:
"""Adds an entry to the registry.

Expand All @@ -143,6 +161,8 @@ def add(
checkpointable: The checkpointable name. If not-None, the registered
handler will be scoped to that specific name. Otherwise, the handler
will be available for any checkpointable name.
recognized_handler_typestrs: A sequence of alternate typestrs that are
recognized and mapped to this handler.

Returns:
The registry itself.
Expand Down Expand Up @@ -170,6 +190,10 @@ def add(
f'Handler type {handler_type} already exists in the registry.'
)
self._registry.append((handler_type, checkpointable))
if recognized_handler_typestrs is not None:
self._recognized_handler_typestrs[handler_type] = (
recognized_handler_typestrs
)
return self

def get(
Expand Down Expand Up @@ -220,6 +244,13 @@ def get_all_entries(
"""Returns all entries in the registry."""
return self._registry

def get_recognized_handler_typestrs(
self,
handler_type: Type[CheckpointableHandler],
) -> Sequence[str]:
"""Returns the recognized handler typestrs for a given handler_type."""
return self._recognized_handler_typestrs.get(handler_type, [])

def __repr__(self):
return f'_DefaultCheckpointableHandlerRegistry({self.get_all_entries()})'

Expand All @@ -237,6 +268,7 @@ def add(
self,
handler_type: Type[CheckpointableHandler],
checkpointable: str | None = None,
recognized_handler_typestrs: Sequence[str] | None = None,
) -> CheckpointableHandlerRegistry:
raise NotImplementedError('Adding not implemented for read-only registry.')

Expand All @@ -257,6 +289,12 @@ def get_all_entries(
) -> Sequence[RegistryEntry]:
return self._registry.get_all_entries()

def get_recognized_handler_typestrs(
self,
handler_type: Type[CheckpointableHandler],
) -> Sequence[str]:
return self._registry.get_recognized_handler_typestrs(handler_type)

def __repr__(self):
return f'ReadOnlyCheckpointableHandlerRegistry({self.get_all_entries()})'

Expand Down Expand Up @@ -303,6 +341,8 @@ def local_registry(

def register_handler(
cls: CheckpointableHandlerType,
*,
recognized_handler_typestrs: Sequence[str] | None = None,
) -> CheckpointableHandlerType:
"""Registers a :py:class:`~.v1.handlers.CheckpointableHandler` globally.

Expand All @@ -322,11 +362,15 @@ class FooHandler(ocp.handlers.CheckpointableHandler[Foo, AbstractFoo]):

Args:
cls: The handler class.
recognized_handler_typestrs: A sequence of alternate handler typestrs that
are recognized and mapped to this handler.

Returns:
The handler class.
"""
_GLOBAL_REGISTRY.add(cls)
_GLOBAL_REGISTRY.add(
cls, recognized_handler_typestrs=recognized_handler_typestrs
)
return cls


Expand Down Expand Up @@ -392,6 +436,16 @@ def _get_possible_handlers(
return possible_handlers


def get_registered_handler_by_name(
registry: CheckpointableHandlerRegistry,
name: str,
) -> CheckpointableHandler | None:
"""Returns the handler for the given name if registered."""
if registry.has(name):
return _construct_handler_instance(name, registry.get(name))
return None


def resolve_handler_for_save(
registry: CheckpointableHandlerRegistry,
checkpointable: Any,
Expand Down Expand Up @@ -435,7 +489,7 @@ def is_handleable_fn(handler: CheckpointableHandler, ckpt: Any) -> bool:
registry, is_handleable_fn, checkpointable, name
)

# Prefer the first handler in the absence of any other information.
# Prefer the last handler in the absence of any other information.
return possible_handlers[-1]


Expand All @@ -444,7 +498,7 @@ def resolve_handler_for_load(
abstract_checkpointable: Any | None,
*,
name: str,
handler_typestr: str,
handler_typestr: str | None = None,
) -> CheckpointableHandler:
"""Resolves a :py:class:`~.v1.handlers.CheckpointableHandler` for loading.

Expand All @@ -471,7 +525,9 @@ def resolve_handler_for_load(
abstract_checkpointable: An abstract checkpointable to resolve.
name: The name of the checkpointable.
handler_typestr: A :py:class:`~.v1.handlers.CheckpointableHandler` typestr
to guide resolution.
to guide resolution. We allow a None value for handler_typestr as its
possible to find the last registered handler given a specified
abstract_checkpointable.

Returns:
A :py:class:`~.v1.handlers.CheckpointableHandler` instance.
Expand All @@ -492,15 +548,34 @@ def is_handleable_fn(
handler_types.typestr(type(handler)) for handler in possible_handlers
]

try:
idx = possible_handler_typestrs.index(handler_typestr)
return possible_handlers[idx]
except ValueError:
if handler_typestr:
if handler_typestr in possible_handler_typestrs:
idx = possible_handler_typestrs.index(handler_typestr)
return possible_handlers[idx]

# Check if handler_typestr is recognized by any possible handler.
# Check backwards to prioritize most recently added handlers.
for i in reversed(range(len(possible_handlers))):
if handler_typestr in registry.get_recognized_handler_typestrs(
type(possible_handlers[i])
):
return possible_handlers[i]

# 3. If neither worked, log the warning and fall through.
logging.warning(
'No handler found for typestr %s. The checkpointable may be restored'
' with different handler logic than was used for saving.',
'No handler found for typestr %s (or its converted form). The '
'checkpointable may be restored with different handler logic '
'than was used for saving.',
handler_typestr,
)

# Prefer the first handler in the absence of any other information.
return possible_handlers[-1]
if abstract_checkpointable:
# Prefer the last handler in the absence of any other information.
return possible_handlers[-1]

raise NoEntryError(
f'No entry for checkpointable={name} in the registry, using'
f' handler_typestr={handler_typestr} and'
f' abstract_checkpointable={abstract_checkpointable}. Registry contents:'
f' {registry.get_all_entries()}'
)
Loading
Loading