Skip to content

Commit

Permalink
format + add pool for parallel read creation
Browse files Browse the repository at this point in the history
  • Loading branch information
sadikneipp committed Nov 27, 2024
1 parent a3707dd commit 03f5ef6
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 273 deletions.
215 changes: 105 additions & 110 deletions pathwaysutils/persistence/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,57 +26,55 @@


def base64_utf8_stringify(bs: bytes) -> str:
"""Converts bytes to a base64-encoded utf-8 string.
"""Converts bytes to a base64-encoded utf-8 string.
Args:
bs: The bytes to convert.
Args:
bs: The bytes to convert.
Returns:
The base64-encoded utf-8 string.
"""
return base64.b64encode(bs).decode("utf-8")
Returns:
The base64-encoded utf-8 string.
"""
return base64.b64encode(bs).decode("utf-8")


def string_to_base64(text: str) -> str:
"""Encodes a string to base64 format.
"""Encodes a string to base64 format.
Args:
text: The string to encode.
Args:
text: The string to encode.
Returns:
The base64-encoded string.
"""
return base64_utf8_stringify(text.encode("utf-8"))
Returns:
The base64-encoded string.
"""
return base64_utf8_stringify(text.encode("utf-8"))


def get_hlo_sharding_string(
sharding: jax.sharding.Sharding,
num_dimensions: int,
) -> str:
"""Serializes the sharding to an hlo-sharding, encodes it to base64 and returns the base-64 as an utf-8 string."""
return base64_utf8_stringify(
# pylint:disable=protected-access
sharding._to_xla_hlo_sharding(
num_dimensions
) # pytype: disable=attribute-error
# pylint:enable=protected-access
.to_proto().SerializeToString()
)
"""Serializes the sharding to an hlo-sharding, encodes it to base64 and returns the base-64 as an utf-8 string."""
return base64_utf8_stringify(
# pylint:disable=protected-access
sharding._to_xla_hlo_sharding(num_dimensions) # pytype: disable=attribute-error
# pylint:enable=protected-access
.to_proto().SerializeToString()
)


def get_shape_string(
dtype: np.dtype,
shape: Sequence[int],
) -> str:
"""Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string."""
return base64_utf8_stringify(
xc.Shape.array_shape(
xc.PrimitiveType(xc.dtype_to_etype(dtype)),
shape,
)
.with_major_to_minor_layout_if_absent()
.to_serialized_proto()
)
"""Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string."""
return base64_utf8_stringify(
xc.Shape.array_shape(
xc.PrimitiveType(xc.dtype_to_etype(dtype)),
shape,
)
.with_major_to_minor_layout_if_absent()
.to_serialized_proto()
)


def get_write_request(
Expand All @@ -85,36 +83,36 @@ def get_write_request(
jax_array: jax.Array,
timeout: datetime.timedelta,
) -> str:
"""Returns a string representation of the plugin program which writes the given jax_array to the given location."""
sharding = jax_array.sharding
assert isinstance(sharding, jax.sharding.Sharding), sharding

timeout_seconds, timeout_fractional_seconds = divmod(
timeout.total_seconds(), 1
)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
return json.dumps({
"persistenceWriteRequest": {
"b64_location": string_to_base64(location_path),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
jax_array.sharding, len(jax_array.shape)
),
"shape": jax_array.shape,
"devices": {
"device_ids": [
# pylint:disable=protected-access
device.id
for device in sharding._device_assignment
# pylint:enable=protected-access
],
},
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
})
"""Returns a string representation of the plugin program which writes the given jax_array to the given location."""
sharding = jax_array.sharding
assert isinstance(sharding, jax.sharding.Sharding), sharding

timeout_seconds, timeout_fractional_seconds = divmod(timeout.total_seconds(), 1)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
return json.dumps(
{
"persistenceWriteRequest": {
"b64_location": string_to_base64(location_path),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
jax_array.sharding, len(jax_array.shape)
),
"shape": jax_array.shape,
"devices": {
"device_ids": [
# pylint:disable=protected-access
device.id
for device in sharding._device_assignment
# pylint:enable=protected-access
],
},
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
}
)


def get_read_request(
Expand All @@ -126,31 +124,29 @@ def get_read_request(
devices: Sequence[jax.Device],
timeout: datetime.timedelta,
) -> str:
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
if not isinstance(devices, np.ndarray):
devices = np.array(devices)

timeout_seconds, timeout_fractional_seconds = divmod(
timeout.total_seconds(), 1
)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
return json.dumps({
"persistenceReadRequest": {
"b64_location": string_to_base64(location_path),
"b64_shape_proto_string": get_shape_string(dtype, shape),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
sharding, len(shape)
),
"devices": {
"device_ids": [device.id for device in devices.flatten()]
},
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
})
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
if not isinstance(devices, np.ndarray):
devices = np.array(devices)

timeout_seconds, timeout_fractional_seconds = divmod(timeout.total_seconds(), 1)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
return json.dumps(
{
"persistenceReadRequest": {
"b64_location": string_to_base64(location_path),
"b64_shape_proto_string": get_shape_string(dtype, shape),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
sharding, len(shape)
),
"devices": {"device_ids": [device.id for device in devices.flatten()]},
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
}
)


def write_one_array(
Expand All @@ -159,14 +155,14 @@ def write_one_array(
value: jax.Array,
timeout: datetime.timedelta,
):
"""Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future."""
write_request = get_write_request(location, name, value, timeout)
write_executable = plugin_executable.PluginExecutable(write_request)
_, write_future = write_executable.call([value])
return write_future
"""Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future."""
write_request = get_write_request(location, name, value, timeout)
write_executable = plugin_executable.PluginExecutable(write_request)
_, write_future = write_executable.call([value])
return write_future


def read_one_array(
async def read_one_array(
location: str,
name: str,
dtype: np.dtype,
Expand All @@ -175,20 +171,19 @@ def read_one_array(
devices: Union[Sequence[jax.Device], np.ndarray],
timeout: datetime.timedelta,
):
"""Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""
read_request = get_read_request(
location,
name,
dtype,
shape,
shardings,
devices,
timeout,
)
read_executable = plugin_executable.PluginExecutable(read_request)
out_aval = core.ShapedArray(shape, dtype)
read_array, read_future = read_executable.call(
out_shardings=[shardings], out_avals=[out_aval]
)
# read_future.result()
return (read_array, read_future)
"""Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""
read_request = get_read_request(
location,
name,
dtype,
shape,
shardings,
devices,
timeout,
)
read_executable = plugin_executable.PluginExecutable(read_request)
out_aval = core.ShapedArray(shape, dtype)
read_array, read_future = read_executable.call(
out_shardings=[shardings], out_avals=[out_aval]
)
return (read_array, read_future)
Loading

0 comments on commit 03f5ef6

Please sign in to comment.