From df70e158822ac72c07494de3625082227145f507 Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Fri, 20 Feb 2026 17:46:23 -0800 Subject: [PATCH] Improve SerializationContext and DeserializationContext class docstrings PiperOrigin-RevId: 873143586 --- .../v1/_src/serialization/types.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py index 0890f72f0..88c0c0e19 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/types.py @@ -110,6 +110,38 @@ def name(self) -> str: @typing.final @dataclasses.dataclass(frozen=True, kw_only=True) class SerializationContext: + """A container for the execution context passed to :py:class:`LeafHandler`. + + This class aggregates global resources—such as the destination path and + concurrency limits—enabling the implementation of serialization support for + custom leaf objects. + + Example Usage: + SerializationContext is accessed within `LeafHandler.serialize` to determine + where and how to write data:: + + class MyCustomHandler(LeafHandler): + async def serialize( + self, + params: Sequence[SerializationParam], + context: SerializationContext + ): + # Use the context to determine the save location + save_path = context.parent_dir / "data.bin" + + # Use the context's limiter to manage I/O concurrency + if context.byte_limiter: + async with context.byte_limiter: + await self._write_to_disk(save_path, params) + + Attributes: + parent_dir: The base directory where the checkpoint or leaf data should be + saved. + ts_context: An optional :py:class:`tensorstore.Context` object used to + configure storage backends and shared resources. + byte_limiter: An optional rate limiter used to throttle I/O operations. + """ + parent_dir: path_types.PathAwaitingCreation ts_context: ts.Context | None = None byte_limiter: limits.LimitInFlightBytes | None = None @@ -129,6 +161,38 @@ def name(self) -> str: @typing.final @dataclasses.dataclass(frozen=True, kw_only=True) class DeserializationContext: + """A container for the execution context passed to :py:class:`LeafHandler`. + + This class aggregates global resources—such as the source path and + format-specific checkpoint handles—enabling the implementation of + deserialization support for custom leaf objects. + + Example Usage: + DeserializationContext is accessed within `LeafHandler.deserialize` to + determine the source location and read data:: + + class MyCustomHandler(LeafHandler): + async def deserialize( + self, + params: Sequence[SerializationParam], + context: DeserializationContext + ): + # Use the context to determine the source location. + load_path = context.parent_dir / "data.bin" + + # Use the context's limiter to manage I/O concurrency. + if context.byte_limiter: + async with context.byte_limiter: + return await self._read_from_disk(load_path, params) + Attributes: + parent_dir: The base directory where the checkpoint or leaf data is located. + ocdbt_checkpoint: A boolean indicating if the source is an OCDBT checkpoint. + zarr3_checkpoint: A boolean indicating if the source is a Zarr3 checkpoint. + ts_context: A TensorStore context object used to configure storage backends + and shared resources. + byte_limiter: An optional rate limiter used to throttle I/O operations. + """ + parent_dir: path_types.Path ocdbt_checkpoint: bool zarr3_checkpoint: bool