diff --git a/model/orbax/experimental/model/core/python/manifest_util.py b/model/orbax/experimental/model/core/python/manifest_util.py index 2f9092df6..7392928e6 100644 --- a/model/orbax/experimental/model/core/python/manifest_util.py +++ b/model/orbax/experimental/model/core/python/manifest_util.py @@ -18,7 +18,6 @@ from collections.abc import Mapping, Sequence from absl import logging from orbax.experimental.model.core.protos import manifest_pb2 -from orbax.experimental.model.core.python import manifest_constants from orbax.experimental.model.core.python import unstructured_data from orbax.experimental.model.core.python.device_assignment import DeviceAssignment from orbax.experimental.model.core.python.function import Function @@ -29,6 +28,7 @@ from orbax.experimental.model.core.python.unstructured_data import UnstructuredData from orbax.experimental.model.core.python.value import ExternalValue + def _build_function( fn: Function, path: str, @@ -63,7 +63,7 @@ def _build_function( supp_proto = supp.proto if supp.ext_name is not None: filename = unstructured_data.build_filename_from_extension( - name + "_supplemental", supp.ext_name + name + "_" + supp_name + "_supplemental", supp.ext_name ) supp_proto = unstructured_data.write_inlined_data_to_file( supp_proto, path, filename diff --git a/model/orbax/experimental/model/core/python/persistence_lib.py b/model/orbax/experimental/model/core/python/persistence_lib.py index c77f7ac75..cad970204 100644 --- a/model/orbax/experimental/model/core/python/persistence_lib.py +++ b/model/orbax/experimental/model/core/python/persistence_lib.py @@ -14,24 +14,23 @@ """Saving to disk. -This file contains facilities that can save a `Module` to disk. +This file contains facilities that can save an OBM module to disk or load a +persisted OBM module. """ from collections.abc import Mapping, Sequence import dataclasses import os -from typing import Optional from absl import logging from orbax.experimental.model.core.protos import manifest_pb2 +from orbax.experimental.model.core.python import device_assignment from orbax.experimental.model.core.python import file_utils from orbax.experimental.model.core.python import manifest_constants +from orbax.experimental.model.core.python import manifest_util from orbax.experimental.model.core.python import metadata +from orbax.experimental.model.core.python import saveable from orbax.experimental.model.core.python import unstructured_data -from orbax.experimental.model.core.python.device_assignment import DeviceAssignment -from orbax.experimental.model.core.python.manifest_util import build_manifest_proto -from orbax.experimental.model.core.python.saveable import Saveable -from orbax.experimental.model.core.python.unstructured_data import UnstructuredData @dataclasses.dataclass @@ -45,9 +44,9 @@ class GlobalSupplemental: mustn't be a `file_location` already). """ - data: UnstructuredData + data: unstructured_data.UnstructuredData - save_as: str | None + save_as: str | None = None @dataclasses.dataclass @@ -58,81 +57,55 @@ class SaveOptions: save an FSM SavedModel. Attributes: - function_aliases: A mapping from user-chosen function alias name to the - function that runs on TPU. version: The serialization format version. With version >= 2 it generates manifest.pb whereas with version <= 1 it generates saved_model.pb. - supplemental_info: Optional. An `UnstructuredData` (or a string-map of them) - to be saved in the manifest.pb as the global supplemental info. + supplementals: Optional. A map of UnstructuredData to be saved in the + manifest as the global supplemental info. visibility: Optional. A mapping from function name to its visibility (e.g., `manifest_pb2.PUBLIC`, `manifest_pb2.PRIVATE`). If this parameter is not provided, all functions will be public. If only a subset of functions are provided in the mapping, the rest will be public by default. - DeviceAssignmentByCoords device_assignment_by_coords: Optional. A sequence of DeviceAssignment to be saved in the manifest.pb. """ version: int | None = None - supplemental_info: Mapping[str, GlobalSupplemental] | None = None + supplementals: Mapping[str, GlobalSupplemental] | None = None visibility: Mapping[str, manifest_pb2.Visibility] | None = None - device_assignment_by_coords: Sequence[DeviceAssignment] | None = None - - -def _save_single_supplemental( - supplemental_info: GlobalSupplemental, - path: str, -) -> UnstructuredData: - """Saves a single supplemental to disk.""" - data = supplemental_info.data - if supplemental_info.save_as is not None: - data = unstructured_data.write_inlined_data_to_file( - data, - path, - supplemental_info.save_as, - ) - return data - - -def _save_supplementals( - supplemental_info: ( - GlobalSupplemental | Mapping[str, GlobalSupplemental] | None - ), - path: str, -) -> UnstructuredData | Mapping[str, UnstructuredData] | None: - """Saves supplementals to disk.""" - if supplemental_info is None: - return None - if isinstance(supplemental_info, GlobalSupplemental): - return _save_single_supplemental(supplemental_info, path) - else: - return { - name: _save_single_supplemental(supp, path) - for name, supp in supplemental_info.items() - } + device_assignment_by_coords: ( + Sequence[device_assignment.DeviceAssignment] | None + ) = None def save( - m: dict[str, Saveable], - path: str, - options: Optional[SaveOptions] = None, + module: dict[str, saveable.Saveable], + target_dir: str, + options: SaveOptions, ) -> None: - """Saved the Module to disk.""" - assert options is not None - assert options.version is not None - assert options.version >= 2 + """Saves module `module` in the target directory `target_dir`.""" + if options.version is None or options.version < 2: + raise ValueError('Version must be >= 2') + logging.info('Save version: %d', options.version) - # Generate and export the Manifest proto. - supplemental_info = _save_supplementals(options.supplemental_info, path) - manifest_proto = build_manifest_proto( - m, - path, + # Generate and export the manifest proto. + supplemental_info = None + if options.supplementals is not None: + supplemental_info = { + name: _save_supplemental(supp, target_dir) + for name, supp in options.supplementals.items() + } + + manifest_proto = manifest_util.build_manifest_proto( + module, + target_dir, supplemental_info=supplemental_info, names_to_visibilities=options.visibility, device_assignment_by_coords=options.device_assignment_by_coords, ) - manifest_path = os.path.join(path, manifest_constants.MANIFEST_FILE_PATH) + manifest_path = os.path.join( + target_dir, manifest_constants.MANIFEST_FILE_PATH + ) file_utils.mkdir_p(os.path.dirname(manifest_path)) with file_utils.open_file(manifest_path, 'wb') as f: f.write(manifest_proto.SerializeToString()) @@ -146,7 +119,7 @@ def save( # should be THE LAST file to be written. It is used to validate the export and # identify an Orbax Model. model_version.save( - os.path.join(path, manifest_constants.MODEL_VERSION_FILENAME) + os.path.join(target_dir, manifest_constants.MODEL_VERSION_FILENAME) ) @@ -169,3 +142,18 @@ def load(saved_state_dir: str) -> manifest_pb2.Manifest: ) with file_utils.open_file(manifest_path, 'rb') as f: return manifest_pb2.Manifest.FromString(f.read()) + + +def _save_supplemental( + supplemental: GlobalSupplemental, + target_dir: str, +) -> unstructured_data.UnstructuredData: + """Returns the supplemental data, either inlined or saved to disk.""" + if supplemental.save_as is None: + return supplemental.data + + return unstructured_data.write_inlined_data_to_file( + supplemental.data, + target_dir, + supplemental.save_as, + ) diff --git a/model/orbax/experimental/model/core/python/persistence_lib_test.py b/model/orbax/experimental/model/core/python/persistence_lib_test.py index c95aa4d17..d82e7ecc7 100644 --- a/model/orbax/experimental/model/core/python/persistence_lib_test.py +++ b/model/orbax/experimental/model/core/python/persistence_lib_test.py @@ -12,45 +12,363 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Saving tests.""" +"""Persistence tests (saving/loading the model).""" -from collections.abc import Sequence import os -from typing import Tuple from absl.testing import absltest -import jax -from jax import export as jax_export -import jax.numpy as jnp +from orbax.experimental.model.core.protos import manifest_pb2 +from orbax.experimental.model.core.protos import type_pb2 +from orbax.experimental.model.core.python import device_assignment +from orbax.experimental.model.core.python import file_utils +from orbax.experimental.model.core.python import function +from orbax.experimental.model.core.python import manifest_constants from orbax.experimental.model.core.python import persistence_lib -from orbax.experimental.model.core.python.function import np_dtype_to_shlo_dtype -from orbax.experimental.model.core.python.function import ShloTensorSpec -from orbax.experimental.model.core.python.shlo_function import ShloFunction +from orbax.experimental.model.core.python import serializable_function +from orbax.experimental.model.core.python import shlo_function +from orbax.experimental.model.core.python import unstructured_data +from orbax.experimental.model.core.python import value +from tensorflow.python.util.protobuf import compare -save = persistence_lib.save +class PersistenceLibTest(absltest.TestCase): -def jax_spec_from_aval(x: jax.core.AbstractValue) -> jax.ShapeDtypeStruct: - assert isinstance(x, jax.core.ShapedArray) - return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype) + def test_save_succeeds_with_all_fields(self): + target_dir: str = self.create_tempdir().full_path + func1 = TestShloFunction(b'func1') + func2 = TestShloFunction(b'func2') + func3 = TestSerializableFunction(b'func3') -def jax_spec_to_shlo_spec( - jax_spec: Sequence[jax.ShapeDtypeStruct], -) -> Tuple[ShloTensorSpec, ...]: - return tuple( - ShloTensorSpec(shape=x.shape, dtype=np_dtype_to_shlo_dtype(x.dtype)) - for x in jax_spec - ) + persistence_lib.save( + module={ + 'my_func': func1, + 'my_seq_of_funcs': [func2, func3], + 'my_value': value.ExternalValue( + data=manifest_pb2.UnstructuredData(inlined_string='my_value'), + ), + }, + target_dir=target_dir, + options=persistence_lib.SaveOptions( + version=2, + supplementals={ + 'supp_a': persistence_lib.GlobalSupplemental( + data=manifest_pb2.UnstructuredData(inlined_string='supp_a'), + save_as='supp_a.pb', + ), + 'supp_b': persistence_lib.GlobalSupplemental( + data=manifest_pb2.UnstructuredData(inlined_string='supp_b') + ), + }, + visibility={ + 'my_func': manifest_pb2.PRIVATE, + 'my_seq_of_funcs': manifest_pb2.PUBLIC, + }, + device_assignment_by_coords=[ + device_assignment.DeviceAssignment( + id=0, + coords=[0, 1, 2], + core_on_chip=0, + ), + device_assignment.DeviceAssignment( + id=1, + coords=[1, 2, 3], + core_on_chip=1, + ), + ], + ), + ) + model_version_path = os.path.join( + target_dir, manifest_constants.MODEL_VERSION_FILENAME + ) + self.assertTrue(os.path.exists(model_version_path)) -class SaveTest(absltest.TestCase): - # TODO(wangpeng): We can move relevant parts of test from - # orbax/experimental/model/integration_tests/orbax_model_test.py - # here. - def test_save(self): - pass + actual_manifest = persistence_lib.load(target_dir) + expected_manifest = manifest_pb2.Manifest( + objects={ + 'my_func': manifest_pb2.TopLevelObject( + function=func1.manifest_repr('my_func', manifest_pb2.PRIVATE) + ), + 'my_seq_of_funcs': manifest_pb2.TopLevelObject( + poly_fn=manifest_pb2.PolymorphicFunction( + concrete_functions=[ + func2.manifest_repr( + '__my_seq_of_funcs_0', manifest_pb2.PUBLIC + ), + func3.manifest_repr( + '__my_seq_of_funcs_1', + manifest_pb2.PUBLIC, + ), + ], + ) + ), + 'my_value': manifest_pb2.TopLevelObject( + value=manifest_pb2.Value( + external=manifest_pb2.ExternalValue( + data=manifest_pb2.UnstructuredData( + inlined_string='my_value' + ) + ) + ) + ), + }, + supplemental_info={ + 'supp_a': manifest_pb2.UnstructuredData( + file_system_location=manifest_pb2.FileSystemLocation( + string_path='supp_a.pb' + ), + mime_type='', + version='', + ), + 'supp_b': manifest_pb2.UnstructuredData( + inlined_string='supp_b', + ), + }, + device_assignment_by_coords=manifest_pb2.DeviceAssignmentByCoords( + devices=[ + manifest_pb2.DeviceAssignmentByCoords.Device( + id=0, coords=[0, 1, 2], core_on_chip=0 + ), + manifest_pb2.DeviceAssignmentByCoords.Device( + id=1, coords=[1, 2, 3], core_on_chip=1 + ), + ] + ), + ) + + compare.assertProtoEqual( + self, + expected_manifest, + actual_manifest, + ) + + # Check that the supplemental files are written only for the functions + # that don't have them inlined. + self.assertTrue(os.path.exists(os.path.join(target_dir, 'supp_a.pb'))) + self.assertTrue( + os.path.exists( + os.path.join(target_dir, 'my_func_supp_a_supplemental.pb') + ) + ) + self.assertTrue( + os.path.exists( + os.path.join(target_dir, 'my_func_supp_b_supplemental.pb') + ) + ) + self.assertFalse( + os.path.exists(os.path.join(target_dir, '__my_seq_of_funcs_0.pb')) + ) + self.assertTrue( + os.path.exists(os.path.join(target_dir, '__my_seq_of_funcs_1.pb')) + ) + + def test_load_fails_version_file_missing(self): + target_dir: str = self.create_tempdir().full_path + + with self.assertRaisesRegex(Exception, 'orbax_model_version.txt'): + persistence_lib.load(target_dir) + + def test_load_fails_with_unsupported_version(self): + target_dir: str = self.create_tempdir().full_path + + path = os.path.join(target_dir, 'orbax_model_version.txt') + file_content = """ + version: "42" + manifest_file_path: "test/path" + mime_type: "test_mime_type; application/foo" + """ + with file_utils.open_file(path, 'w') as f: + f.write(file_content) + + with self.assertRaisesRegex( + ValueError, 'Unsupported manifest version "42"' + ): + persistence_lib.load(target_dir) + + def test_load_fails_with_unsupported_mime_type(self): + target_dir: str = self.create_tempdir().full_path + + path = os.path.join(target_dir, 'orbax_model_version.txt') + file_content = """ + version: "0.0.1" + mime_type: "text/plain" + manifest_file_path: "test/path" + """ + with file_utils.open_file(path, 'w') as f: + f.write(file_content) + + with self.assertRaisesRegex( + ValueError, 'Unsupported manifest type "text/plain"' + ): + persistence_lib.load(target_dir) + + +class TestShloFunctionSupplementalInfo( + shlo_function.ShloFunctionSupplementalInfo +): + + def __init__(self, b: bytes): + self._bytes = b + + def serializable_to_proto( + self, + ) -> unstructured_data.UnstructuredDataWithExtName: + return unstructured_data.UnstructuredDataWithExtName( + proto=manifest_pb2.UnstructuredData(inlined_bytes=self._bytes), + ext_name='pb', + ) + + +class TestShloFunction(shlo_function.ShloFunction): + + def __init__(self, b: bytes): + self.bytes = b + self.input_signature = function.ShloTensorSpec((1,), function.ShloDType.f32) + self.output_signature = function.ShloTensorSpec( + ( + 1, + 2, + ), + function.ShloDType.f16, + ) + self.mlir_module_serialized = b + self.calling_convention_version = 42 + self.lowering_platforms = ('a', 'b') + self.module_kept_var_idx = (0, 1) + self.supplemental_info = { + 'supp_a': TestShloFunctionSupplementalInfo(b'supp_a'), + 'supp_b': TestShloFunctionSupplementalInfo(b'supp_b'), + } + self.physical_in_dtypes = (None, function.ShloDType.f32) + self.physical_out_dtypes = (function.ShloDType.f16, None) + + def manifest_repr( + self, key: str, visibility: manifest_pb2.Visibility + ) -> manifest_pb2.Function: + return manifest_pb2.Function( + signature=type_pb2.FunctionSignature( + input=type_pb2.Type( + leaf=type_pb2.LeafType( + tensor_type=type_pb2.TensorType( + shape=type_pb2.Shape( + shape_with_known_rank=type_pb2.ShapeWithKnownRank( + dimension_sizes=[type_pb2.DimensionSize(size=1)] + ) + ), + dtype=type_pb2.DType.f32, + ) + ) + ), + output=type_pb2.Type( + leaf=type_pb2.LeafType( + tensor_type=type_pb2.TensorType( + shape=type_pb2.Shape( + shape_with_known_rank=type_pb2.ShapeWithKnownRank( + dimension_sizes=[ + type_pb2.DimensionSize(size=1), + type_pb2.DimensionSize(size=2), + ] + ) + ), + dtype=type_pb2.DType.f16, + ) + ) + ), + ), + body=manifest_pb2.FunctionBody( + stable_hlo_body=manifest_pb2.StableHloFunctionBody( + stable_hlo=manifest_pb2.UnstructuredData( + inlined_bytes=self.bytes, + mime_type='application/x.mlir-stablehlo', + version='1.0', + ), + calling_convention_version=42, + lowering_platforms=['a', 'b'], + module_kept_var_idx=[0, 1], + supplemental_info={ + 'supp_a': manifest_pb2.UnstructuredData( + file_system_location=manifest_pb2.FileSystemLocation( + string_path=f'{key}_supp_a_supplemental.pb' + ), + mime_type='', + version='', + ), + 'supp_b': manifest_pb2.UnstructuredData( + file_system_location=manifest_pb2.FileSystemLocation( + string_path=f'{key}_supp_b_supplemental.pb' + ), + mime_type='', + version='', + ), + }, + ) + ), + visibility=visibility, + ) + + +class TestSerializableFunction(serializable_function.SerializableFunction): + + def __init__(self, b: bytes): + self.input_signature = function.ShloTensorSpec((1,), function.ShloDType.f32) + self.output_signature = function.ShloTensorSpec( + ( + 1, + 2, + ), + function.ShloDType.f16, + ) + self.body = unstructured_data.UnstructuredDataWithExtName( + proto=manifest_pb2.UnstructuredData(inlined_bytes=b), + ext_name='pb', + ) + + def manifest_repr( + self, key: str, visibility: manifest_pb2.Visibility + ) -> manifest_pb2.Function: + return manifest_pb2.Function( + signature=type_pb2.FunctionSignature( + input=type_pb2.Type( + leaf=type_pb2.LeafType( + tensor_type=type_pb2.TensorType( + shape=type_pb2.Shape( + shape_with_known_rank=type_pb2.ShapeWithKnownRank( + dimension_sizes=[type_pb2.DimensionSize(size=1)] + ) + ), + dtype=type_pb2.DType.f32, + ) + ) + ), + output=type_pb2.Type( + leaf=type_pb2.LeafType( + tensor_type=type_pb2.TensorType( + shape=type_pb2.Shape( + shape_with_known_rank=type_pb2.ShapeWithKnownRank( + dimension_sizes=[ + type_pb2.DimensionSize(size=1), + type_pb2.DimensionSize(size=2), + ] + ) + ), + dtype=type_pb2.DType.f16, + ) + ) + ), + ), + body=manifest_pb2.FunctionBody( + other=manifest_pb2.UnstructuredData( + file_system_location=manifest_pb2.FileSystemLocation( + string_path=f'{key}.pb', + ), + mime_type='', + version='', + ) + ), + visibility=visibility, + ) if __name__ == '__main__': diff --git a/model/orbax/experimental/model/jax2obm/main_lib_test.py b/model/orbax/experimental/model/jax2obm/main_lib_test.py index 516d3bd67..3a46d9322 100644 --- a/model/orbax/experimental/model/jax2obm/main_lib_test.py +++ b/model/orbax/experimental/model/jax2obm/main_lib_test.py @@ -135,7 +135,7 @@ def _generated_sharded_params(): save_dir_path, obm.SaveOptions( version=2, - supplemental_info={ + supplementals={ simple_orchestration.TEST_ORCHESTRATION_SUPPLEMENTAL_NAME: ( obm.GlobalSupplemental( simple_orchestration.create( @@ -298,7 +298,9 @@ def _generated_sharded_params(): """ ) - jax_supplemental_filename = f'{model_function_name}_supplemental.pb' + jax_supplemental_filename = ( + f'{model_function_name}_jax_specific_info_supplemental.pb' + ) expected_manifest_proto_text = ( """ @@ -552,7 +554,9 @@ def get_mesh(): ) manifest_proto = obm.load(save_dir_path) - jax_supplemental_filename = f'{model_function_name}_supplemental.pb' + jax_supplemental_filename = ( + f'{model_function_name}_jax_specific_info_supplemental.pb' + ) expected_manifest_proto_text = ( """ diff --git a/model/orbax/experimental/model/jax2obm/obm_to_jax_test.py b/model/orbax/experimental/model/jax2obm/obm_to_jax_test.py index 420609ae2..5f5a1d1b7 100644 --- a/model/orbax/experimental/model/jax2obm/obm_to_jax_test.py +++ b/model/orbax/experimental/model/jax2obm/obm_to_jax_test.py @@ -294,7 +294,7 @@ def jax_model_fn(params, x): save_dir_path, obm.SaveOptions( version=2, - supplemental_info={ + supplementals={ simple_orchestration.TEST_ORCHESTRATION_SUPPLEMENTAL_NAME: ( obm.GlobalSupplemental( simple_orchestration.create( @@ -443,7 +443,7 @@ def get_mesh(): save_dir_path, obm.SaveOptions( version=2, - supplemental_info={ + supplementals={ simple_orchestration.TEST_ORCHESTRATION_SUPPLEMENTAL_NAME: ( obm.GlobalSupplemental( simple_orchestration.create( diff --git a/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py b/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py index 835921b0e..9850dc387 100644 --- a/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py +++ b/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py @@ -302,7 +302,7 @@ def tf_fn(a): save_dir_path, obm.SaveOptions( version=2, - supplemental_info={ + supplementals={ TF_SAVED_MODEL_SUPPLEMENTAL_NAME: obm.GlobalSupplemental( tf_global_supplemental, None ),