diff --git a/pathwaysutils/persistence/helper.py b/pathwaysutils/persistence/helper.py index 794deaf..8648fb2 100644 --- a/pathwaysutils/persistence/helper.py +++ b/pathwaysutils/persistence/helper.py @@ -26,55 +26,55 @@ def base64_utf8_stringify(bs: bytes) -> str: - """Converts bytes to a base64-encoded utf-8 string. + """Converts bytes to a base64-encoded utf-8 string. - Args: - bs: The bytes to convert. + Args: + bs: The bytes to convert. - Returns: - The base64-encoded utf-8 string. - """ - return base64.b64encode(bs).decode("utf-8") + Returns: + The base64-encoded utf-8 string. + """ + return base64.b64encode(bs).decode("utf-8") def string_to_base64(text: str) -> str: - """Encodes a string to base64 format. + """Encodes a string to base64 format. - Args: - text: The string to encode. + Args: + text: The string to encode. - Returns: - The base64-encoded string. - """ - return base64_utf8_stringify(text.encode("utf-8")) + Returns: + The base64-encoded string. + """ + return base64_utf8_stringify(text.encode("utf-8")) def get_hlo_sharding_string( sharding: jax.sharding.Sharding, num_dimensions: int, ) -> str: - """Serializes the sharding to an hlo-sharding, encodes it to base64 and returns the base-64 as an utf-8 string.""" - return base64_utf8_stringify( - # pylint:disable=protected-access - sharding._to_xla_hlo_sharding(num_dimensions) # pytype: disable=attribute-error - # pylint:enable=protected-access - .to_proto().SerializeToString() - ) + """Serializes the sharding to an hlo-sharding, encodes it to base64 and returns the base-64 as an utf-8 string.""" + return base64_utf8_stringify( + # pylint:disable=protected-access + sharding._to_xla_hlo_sharding(num_dimensions) # pytype: disable=attribute-error + # pylint:enable=protected-access + .to_proto().SerializeToString() + ) def get_shape_string( dtype: np.dtype, shape: Sequence[int], ) -> str: - """Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string.""" - return base64_utf8_stringify( - xc.Shape.array_shape( - xc.PrimitiveType(xc.dtype_to_etype(dtype)), - shape, - ) - .with_major_to_minor_layout_if_absent() - .to_serialized_proto() - ) + """Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string.""" + return base64_utf8_stringify( + xc.Shape.array_shape( + xc.PrimitiveType(xc.dtype_to_etype(dtype)), + shape, + ) + .with_major_to_minor_layout_if_absent() + .to_serialized_proto() + ) def get_write_request( @@ -83,36 +83,36 @@ def get_write_request( jax_array: jax.Array, timeout: datetime.timedelta, ) -> str: - """Returns a string representation of the plugin program which writes the given jax_array to the given location.""" - sharding = jax_array.sharding - assert isinstance(sharding, jax.sharding.Sharding), sharding - - timeout_seconds, timeout_fractional_seconds = divmod( - timeout.total_seconds(), 1 - ) - timeout_nanoseconds = timeout_fractional_seconds * 1e9 - return json.dumps({ - "persistenceWriteRequest": { - "b64_location": string_to_base64(location_path), - "b64_name": string_to_base64(name), - "b64_hlo_sharding_string": get_hlo_sharding_string( - jax_array.sharding, len(jax_array.shape) - ), - "shape": jax_array.shape, - "devices": { - "device_ids": [ - # pylint:disable=protected-access - device.id - for device in sharding._device_assignment - # pylint:enable=protected-access - ], - }, - "timeout": { - "seconds": int(timeout_seconds), - "nanos": int(timeout_nanoseconds), - }, - } - }) + """Returns a string representation of the plugin program which writes the given jax_array to the given location.""" + sharding = jax_array.sharding + assert isinstance(sharding, jax.sharding.Sharding), sharding + + timeout_seconds, timeout_fractional_seconds = divmod(timeout.total_seconds(), 1) + timeout_nanoseconds = timeout_fractional_seconds * 1e9 + return json.dumps( + { + "persistenceWriteRequest": { + "b64_location": string_to_base64(location_path), + "b64_name": string_to_base64(name), + "b64_hlo_sharding_string": get_hlo_sharding_string( + jax_array.sharding, len(jax_array.shape) + ), + "shape": jax_array.shape, + "devices": { + "device_ids": [ + # pylint:disable=protected-access + device.id + for device in sharding._device_assignment + # pylint:enable=protected-access + ], + }, + "timeout": { + "seconds": int(timeout_seconds), + "nanos": int(timeout_nanoseconds), + }, + } + } + ) def get_read_request( @@ -124,31 +124,29 @@ def get_read_request( devices: Sequence[jax.Device], timeout: datetime.timedelta, ) -> str: - """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding.""" - if not isinstance(devices, np.ndarray): - devices = np.array(devices) - - timeout_seconds, timeout_fractional_seconds = divmod( - timeout.total_seconds(), 1 - ) - timeout_nanoseconds = timeout_fractional_seconds * 1e9 - return json.dumps({ - "persistenceReadRequest": { - "b64_location": string_to_base64(location_path), - "b64_shape_proto_string": get_shape_string(dtype, shape), - "b64_name": string_to_base64(name), - "b64_hlo_sharding_string": get_hlo_sharding_string( - sharding, len(shape) - ), - "devices": { - "device_ids": [device.id for device in devices.flatten()] - }, - "timeout": { - "seconds": int(timeout_seconds), - "nanos": int(timeout_nanoseconds), - }, - } - }) + """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding.""" + if not isinstance(devices, np.ndarray): + devices = np.array(devices) + + timeout_seconds, timeout_fractional_seconds = divmod(timeout.total_seconds(), 1) + timeout_nanoseconds = timeout_fractional_seconds * 1e9 + return json.dumps( + { + "persistenceReadRequest": { + "b64_location": string_to_base64(location_path), + "b64_shape_proto_string": get_shape_string(dtype, shape), + "b64_name": string_to_base64(name), + "b64_hlo_sharding_string": get_hlo_sharding_string( + sharding, len(shape) + ), + "devices": {"device_ids": [device.id for device in devices.flatten()]}, + "timeout": { + "seconds": int(timeout_seconds), + "nanos": int(timeout_nanoseconds), + }, + } + } + ) def write_one_array( @@ -157,11 +155,11 @@ def write_one_array( value: jax.Array, timeout: datetime.timedelta, ): - """Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future.""" - write_request = get_write_request(location, name, value, timeout) - write_executable = plugin_executable.PluginExecutable(write_request) - _, write_future = write_executable.call([value]) - return write_future + """Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future.""" + write_request = get_write_request(location, name, value, timeout) + write_executable = plugin_executable.PluginExecutable(write_request) + _, write_future = write_executable.call([value]) + return write_future def read_one_array( @@ -173,20 +171,20 @@ def read_one_array( devices: Union[Sequence[jax.Device], np.ndarray], timeout: datetime.timedelta, ): - """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result.""" - read_request = get_read_request( - location, - name, - dtype, - shape, - shardings, - devices, - timeout, - ) - read_executable = plugin_executable.PluginExecutable(read_request) - out_aval = core.ShapedArray(shape, dtype) - read_array, read_future = read_executable.call( - out_shardings=[shardings], out_avals=[out_aval] - ) - read_future.result() - return read_array[0] + """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result.""" + read_request = get_read_request( + location, + name, + dtype, + shape, + shardings, + devices, + timeout, + ) + read_executable = plugin_executable.PluginExecutable(read_request) + out_aval = core.ShapedArray(shape, dtype) + read_array, read_future = read_executable.call( + out_shardings=[shardings], out_avals=[out_aval] + ) + # read_future.result() + return (read_array, read_future) diff --git a/pathwaysutils/persistence/pathways_orbax_handler.py b/pathwaysutils/persistence/pathways_orbax_handler.py index 43b20db..b301c34 100644 --- a/pathwaysutils/persistence/pathways_orbax_handler.py +++ b/pathwaysutils/persistence/pathways_orbax_handler.py @@ -14,6 +14,7 @@ """TypeHandlers supporting Pathways backend.""" import collections +import concurrent.futures import datetime import functools import typing @@ -34,159 +35,170 @@ def extract_parent_dir_and_name( infos: Sequence[ParamInfo], ) -> tuple[Sequence[str], Sequence[str]]: - """Extracts names and locations from ParamInfos.""" - parent_dirs = [str(info.parent_dir) for info in infos] - names = [str(info.name) for info in infos] - return parent_dirs, names + """Extracts names and locations from ParamInfos.""" + parent_dirs = [str(info.parent_dir) for info in infos] + names = [str(info.name) for info in infos] + return parent_dirs, names class CloudPathwaysArrayHandler(type_handlers.ArrayHandler): - """A TypeHandler for array types when using Pathways.""" - - def __init__( - self, - read_timeout: Optional[datetime.timedelta] = None, - use_ocdbt: bool = False, - ): - """Constructor. - - Args: - read_timeout: Duration indicating the timeout for reading arrays - use_ocdbt: allows using Tensorstore OCDBT driver. - """ - self._read_timeout = read_timeout - - if use_ocdbt: - raise ValueError('OCDBT not supported for Pathways.') - super().__init__() - - async def serialize( - self, - values: Sequence[jax.Array], - infos: Sequence[ParamInfo], - args: Optional[Sequence[SaveArgs]] = None, - ) -> Sequence[future.Future]: - """Uses Pathways Persistence API to serialize a jax array.""" - type_handlers.check_input_arguments(values, infos, args) - - if any([arg.dtype is not None for arg in args]): - raise ValueError('Casting during save not supported for Pathways.') - - locations, names = extract_parent_dir_and_name(infos) - f = functools.partial( - helper.write_one_array, timeout=self._read_timeout - ) - return list(map(f, locations, names, values)) - - async def deserialize( - self, - infos: Sequence[ParamInfo], - args: Optional[Sequence[RestoreArgs]] = None, - ) -> Sequence[jax.Array]: - """Uses Pathways Persistence API to deserialize a jax array.""" - if args is None: - raise ValueError('Must provide ArrayRestoreArgs to restore as jax.Array.') - type_handlers.check_input_arguments(infos, args) - - global_meshes = [] - mesh_axes = [] - global_shapes = [] - dtypes = [] - shardings = [] - - should_open_metadata = False - for arg in args: - if not isinstance(arg, ArrayRestoreArgs): - raise ValueError( - 'To restore jax.Array, provide ArrayRestoreArgs; found' - f' {type(arg).__name__}' - ) - arg = typing.cast(ArrayRestoreArgs, arg) - if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None): - raise ValueError( - 'Sharding of jax.Array cannot be None. Provide `mesh`' - ' and `mesh_axes` OR `sharding`.' - ) - if arg.sharding is None: - global_meshes.append(arg.mesh) - mesh_axes.append(arg.mesh_axes) - shardings.append( - jax.sharding.NamedSharding(mesh=arg.mesh, spec=arg.mesh_axes) - ) - else: - if not isinstance(arg.sharding, jax.sharding.NamedSharding): - raise ValueError('Pathways only supports jax.sharding.NamedSharding.') - sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding) - global_meshes.append(sharding.mesh) - mesh_axes.append(sharding.spec) - shardings.append(sharding) - if arg.global_shape is None or arg.dtype is None: - logging.warning( - 'Shape or dtype not provided for restoration. Provide these' - ' properties for improved performance.' - ) - should_open_metadata = True - global_shapes.append(arg.global_shape) - dtypes.append(arg.dtype) - - if should_open_metadata: - metadatas = await self.metadata(infos) - global_shapes = [ - m.shape if s is None else s for m, s in zip(metadatas, global_shapes) - ] - dtypes = [m.dtype if d is None else d for m, d in zip(metadatas, dtypes)] - - # Group inputs by global_mesh so that we can perform batched Array - # construction for each global_mesh. - inputs_by_global_mesh = collections.defaultdict(list) - for i, global_mesh in enumerate(global_meshes): - inputs_by_global_mesh[global_mesh].append(i) - - results = [None] * len(infos) - - for global_mesh, idxs in inputs_by_global_mesh.items(): - grouped_infos = [infos[idx] for idx in idxs] - grouped_global_shapes = [global_shapes[idx] for idx in idxs] - grouped_dtypes = [dtypes[idx] for idx in idxs] - grouped_shardings = [shardings[idx] for idx in idxs] - locations, names = extract_parent_dir_and_name(grouped_infos) - f = functools.partial( - helper.read_one_array, - devices=global_mesh.devices, - timeout=self._read_timeout, - ) - grouped_arrays = [ - f( - location=location, - name=name, - dtype=dtype, - shape=shape, - shardings=sharding, - ) - for location, name, dtype, shape, sharding in zip( - locations, - names, - grouped_dtypes, - grouped_global_shapes, - grouped_shardings, - ) - ] - for idx, arr in zip(idxs, grouped_arrays): - results[idx] = arr - return results # pytype: disable=bad-return-type + """A TypeHandler for array types when using Pathways.""" + + def __init__( + self, + read_timeout: Optional[datetime.timedelta] = None, + use_ocdbt: bool = False, + ): + """Constructor. + + Args: + read_timeout: Duration indicating the timeout for reading arrays + use_ocdbt: allows using Tensorstore OCDBT driver. + """ + self._read_timeout = read_timeout + + if use_ocdbt: + raise ValueError("OCDBT not supported for Pathways.") + super().__init__() + + async def serialize( + self, + values: Sequence[jax.Array], + infos: Sequence[ParamInfo], + args: Optional[Sequence[SaveArgs]] = None, + ) -> Sequence[future.Future]: + """Uses Pathways Persistence API to serialize a jax array.""" + type_handlers.check_input_arguments(values, infos, args) + + if any([arg.dtype is not None for arg in args]): + raise ValueError("Casting during save not supported for Pathways.") + + locations, names = extract_parent_dir_and_name(infos) + f = functools.partial(helper.write_one_array, timeout=self._read_timeout) + return list(map(f, locations, names, values)) + + async def deserialize( + self, + infos: Sequence[ParamInfo], + args: Optional[Sequence[RestoreArgs]] = None, + ) -> Sequence[jax.Array]: + """Uses Pathways Persistence API to deserialize a jax array.""" + if args is None: + raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.") + type_handlers.check_input_arguments(infos, args) + + global_meshes = [] + mesh_axes = [] + global_shapes = [] + dtypes = [] + shardings = [] + + should_open_metadata = False + for arg in args: + if not isinstance(arg, ArrayRestoreArgs): + raise ValueError( + "To restore jax.Array, provide ArrayRestoreArgs; found" + f" {type(arg).__name__}" + ) + arg = typing.cast(ArrayRestoreArgs, arg) + if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None): + raise ValueError( + "Sharding of jax.Array cannot be None. Provide `mesh`" + " and `mesh_axes` OR `sharding`." + ) + if arg.sharding is None: + global_meshes.append(arg.mesh) + mesh_axes.append(arg.mesh_axes) + shardings.append( + jax.sharding.NamedSharding(mesh=arg.mesh, spec=arg.mesh_axes) + ) + else: + if not isinstance(arg.sharding, jax.sharding.NamedSharding): + raise ValueError( + "Pathways only supports jax.sharding.NamedSharding." + ) + sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding) + global_meshes.append(sharding.mesh) + mesh_axes.append(sharding.spec) + shardings.append(sharding) + if arg.global_shape is None or arg.dtype is None: + logging.warning( + "Shape or dtype not provided for restoration. Provide these" + " properties for improved performance." + ) + should_open_metadata = True + global_shapes.append(arg.global_shape) + dtypes.append(arg.dtype) + + if should_open_metadata: + metadatas = await self.metadata(infos) + global_shapes = [ + m.shape if s is None else s for m, s in zip(metadatas, global_shapes) + ] + dtypes = [m.dtype if d is None else d for m, d in zip(metadatas, dtypes)] + + # Group inputs by global_mesh so that we can perform batched Array + # construction for each global_mesh. + inputs_by_global_mesh = collections.defaultdict(list) + for i, global_mesh in enumerate(global_meshes): + inputs_by_global_mesh[global_mesh].append(i) + + results = [None] * len(infos) + + for global_mesh, idxs in inputs_by_global_mesh.items(): + grouped_infos = [infos[idx] for idx in idxs] + grouped_global_shapes = [global_shapes[idx] for idx in idxs] + grouped_dtypes = [dtypes[idx] for idx in idxs] + grouped_shardings = [shardings[idx] for idx in idxs] + locations, names = extract_parent_dir_and_name(grouped_infos) + f = functools.partial( + helper.read_one_array, + devices=global_mesh.devices, + timeout=self._read_timeout, + ) + grouped_arrays_and_futures = [ + f( + location=location, + name=name, + dtype=dtype, + shape=shape, + shardings=sharding, + ) + for location, name, dtype, shape, sharding in zip( + locations, + names, + grouped_dtypes, + grouped_global_shapes, + grouped_shardings, + ) + ] + + arrays = [ + array_and_future[0] for array_and_future in grouped_arrays_and_futures + ] + futures = [ + array_and_future[1] for array_and_future in grouped_arrays_and_futures + ] + + _ = concurrent.futures.wait( + futures, return_when=concurrent.futures.ALL_COMPLETED + ) + grouped_arrays = [array[0] for array in arrays] + + for idx, arr in zip(idxs, grouped_arrays): + results[idx] = arr + return results # pytype: disable=bad-return-type def register_pathways_handlers( read_timeout: Optional[datetime.timedelta] = None, ): - """Function that must be called before saving or restoring with Pathways.""" - logging.debug( - 'Registering CloudPathwaysArrayHandler (Pathways Persistence API).' - ) - type_handlers.register_type_handler( - jax.Array, - CloudPathwaysArrayHandler( - read_timeout=read_timeout, - ), - override=True, - ) + """Function that must be called before saving or restoring with Pathways.""" + logging.debug("Registering CloudPathwaysArrayHandler (Pathways Persistence API).") + type_handlers.register_type_handler( + jax.Array, + CloudPathwaysArrayHandler( + read_timeout=read_timeout, + ), + override=True, + )