Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions model/orbax/experimental/model/core/python/manifest_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from collections.abc import Mapping, Sequence
from absl import logging
from orbax.experimental.model.core.protos import manifest_pb2
from orbax.experimental.model.core.python import manifest_constants
from orbax.experimental.model.core.python import unstructured_data
from orbax.experimental.model.core.python.device_assignment import DeviceAssignment
from orbax.experimental.model.core.python.function import Function
Expand All @@ -29,6 +28,7 @@
from orbax.experimental.model.core.python.unstructured_data import UnstructuredData
from orbax.experimental.model.core.python.value import ExternalValue


def _build_function(
fn: Function,
path: str,
Expand Down Expand Up @@ -63,7 +63,7 @@ def _build_function(
supp_proto = supp.proto
if supp.ext_name is not None:
filename = unstructured_data.build_filename_from_extension(
name + "_supplemental", supp.ext_name
name + "_" + supp_name + "_supplemental", supp.ext_name
)
supp_proto = unstructured_data.write_inlined_data_to_file(
supp_proto, path, filename
Expand Down
112 changes: 50 additions & 62 deletions model/orbax/experimental/model/core/python/persistence_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,23 @@

"""Saving to disk.

This file contains facilities that can save a `Module` to disk.
This file contains facilities that can save an OBM module to disk or load a
persisted OBM module.
"""

from collections.abc import Mapping, Sequence
import dataclasses
import os
from typing import Optional

from absl import logging
from orbax.experimental.model.core.protos import manifest_pb2
from orbax.experimental.model.core.python import device_assignment
from orbax.experimental.model.core.python import file_utils
from orbax.experimental.model.core.python import manifest_constants
from orbax.experimental.model.core.python import manifest_util
from orbax.experimental.model.core.python import metadata
from orbax.experimental.model.core.python import saveable
from orbax.experimental.model.core.python import unstructured_data
from orbax.experimental.model.core.python.device_assignment import DeviceAssignment
from orbax.experimental.model.core.python.manifest_util import build_manifest_proto
from orbax.experimental.model.core.python.saveable import Saveable
from orbax.experimental.model.core.python.unstructured_data import UnstructuredData


@dataclasses.dataclass
Expand All @@ -45,9 +44,9 @@ class GlobalSupplemental:
mustn't be a `file_location` already).
"""

data: UnstructuredData
data: unstructured_data.UnstructuredData

save_as: str | None
save_as: str | None = None


@dataclasses.dataclass
Expand All @@ -58,81 +57,55 @@ class SaveOptions:
save an FSM SavedModel.

Attributes:
function_aliases: A mapping from user-chosen function alias name to the
function that runs on TPU.
version: The serialization format version. With version >= 2 it generates
manifest.pb whereas with version <= 1 it generates saved_model.pb.
supplemental_info: Optional. An `UnstructuredData` (or a string-map of them)
to be saved in the manifest.pb as the global supplemental info.
supplementals: Optional. A map of UnstructuredData to be saved in the
manifest as the global supplemental info.
visibility: Optional. A mapping from function name to its visibility (e.g.,
`manifest_pb2.PUBLIC`, `manifest_pb2.PRIVATE`). If this parameter is not
provided, all functions will be public. If only a subset of functions are
provided in the mapping, the rest will be public by default.
DeviceAssignmentByCoords
device_assignment_by_coords: Optional. A sequence of DeviceAssignment to be
saved in the manifest.pb.
"""

version: int | None = None
supplemental_info: Mapping[str, GlobalSupplemental] | None = None
supplementals: Mapping[str, GlobalSupplemental] | None = None
visibility: Mapping[str, manifest_pb2.Visibility] | None = None
device_assignment_by_coords: Sequence[DeviceAssignment] | None = None


def _save_single_supplemental(
supplemental_info: GlobalSupplemental,
path: str,
) -> UnstructuredData:
"""Saves a single supplemental to disk."""
data = supplemental_info.data
if supplemental_info.save_as is not None:
data = unstructured_data.write_inlined_data_to_file(
data,
path,
supplemental_info.save_as,
)
return data


def _save_supplementals(
supplemental_info: (
GlobalSupplemental | Mapping[str, GlobalSupplemental] | None
),
path: str,
) -> UnstructuredData | Mapping[str, UnstructuredData] | None:
"""Saves supplementals to disk."""
if supplemental_info is None:
return None
if isinstance(supplemental_info, GlobalSupplemental):
return _save_single_supplemental(supplemental_info, path)
else:
return {
name: _save_single_supplemental(supp, path)
for name, supp in supplemental_info.items()
}
device_assignment_by_coords: (
Sequence[device_assignment.DeviceAssignment] | None
) = None


def save(
m: dict[str, Saveable],
path: str,
options: Optional[SaveOptions] = None,
module: dict[str, saveable.Saveable],
target_dir: str,
options: SaveOptions,
) -> None:
"""Saved the Module to disk."""
assert options is not None
assert options.version is not None
assert options.version >= 2
"""Saves module `module` in the target directory `target_dir`."""
if options.version is None or options.version < 2:
raise ValueError('Version must be >= 2')

logging.info('Save version: %d', options.version)
# Generate and export the Manifest proto.
supplemental_info = _save_supplementals(options.supplemental_info, path)
manifest_proto = build_manifest_proto(
m,
path,
# Generate and export the manifest proto.
supplemental_info = None
if options.supplementals is not None:
supplemental_info = {
name: _save_supplemental(supp, target_dir)
for name, supp in options.supplementals.items()
}

manifest_proto = manifest_util.build_manifest_proto(
module,
target_dir,
supplemental_info=supplemental_info,
names_to_visibilities=options.visibility,
device_assignment_by_coords=options.device_assignment_by_coords,
)

manifest_path = os.path.join(path, manifest_constants.MANIFEST_FILE_PATH)
manifest_path = os.path.join(
target_dir, manifest_constants.MANIFEST_FILE_PATH
)
file_utils.mkdir_p(os.path.dirname(manifest_path))
with file_utils.open_file(manifest_path, 'wb') as f:
f.write(manifest_proto.SerializeToString())
Expand All @@ -146,7 +119,7 @@ def save(
# should be THE LAST file to be written. It is used to validate the export and
# identify an Orbax Model.
model_version.save(
os.path.join(path, manifest_constants.MODEL_VERSION_FILENAME)
os.path.join(target_dir, manifest_constants.MODEL_VERSION_FILENAME)
)


Expand All @@ -169,3 +142,18 @@ def load(saved_state_dir: str) -> manifest_pb2.Manifest:
)
with file_utils.open_file(manifest_path, 'rb') as f:
return manifest_pb2.Manifest.FromString(f.read())


def _save_supplemental(
supplemental: GlobalSupplemental,
target_dir: str,
) -> unstructured_data.UnstructuredData:
"""Returns the supplemental data, either inlined or saved to disk."""
if supplemental.save_as is None:
return supplemental.data

return unstructured_data.write_inlined_data_to_file(
supplemental.data,
target_dir,
supplemental.save_as,
)
Loading
Loading