Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor sharded serialization to remove code duplication #245

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 96 additions & 170 deletions model_signing/serialization/serialize_by_file_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

"""Model serializers that operated at file shard level of granularity."""

import abc
import base64
import concurrent.futures
import pathlib
from typing import Callable, Iterable, TypeAlias
from typing import Callable, Iterable, cast
from typing_extensions import override

from model_signing.hashing import file
Expand All @@ -27,21 +28,16 @@
from model_signing.serialization import serialize_by_file


_ShardSignTask: TypeAlias = tuple[pathlib.PurePath, str, int, int]


def _build_header(
*,
entry_name: str,
entry_type: str,
name: str,
start: int,
end: int,
) -> bytes:
"""Builds a header to encode a path with given name and type.
"""Builds a header to encode a path with given name and shard range.

Args:
entry_name: The name of the entry to build the header for.
entry_type: The type of the entry (file or directory).
start: Offset for the start of the path shard.
end: Offset for the end of the path shard.

Expand All @@ -50,14 +46,11 @@ def _build_header(
bytes. Each argument is separated by dots and the last byte is also a
dot (so the file digest can be appended unambiguously).
"""
# Note: This will get replaced in subsequent change, right now we're just
# moving existing code around.
encoded_type = entry_type.encode("utf-8")
# Prevent confusion if name has a "." inside by encoding to base64.
encoded_name = base64.b64encode(entry_name.encode("utf-8"))
encoded_name = base64.b64encode(name.encode("utf-8"))
encoded_range = f"{start}-{end}".encode("utf-8")
# Note: empty string at the end, to terminate header with a "."
return b".".join([encoded_type, encoded_name, encoded_range, b""])
return b".".join([encoded_name, encoded_range, b""])


def _endpoints(step: int, end: int) -> Iterable[int]:
Expand All @@ -83,164 +76,15 @@ def _endpoints(step: int, end: int) -> Iterable[int]:
yield end


class ShardedDFSSerializer(serialization.Serializer):
"""DFSSerializer that uses a sharded hash engine to exploit parallelism."""

def __init__(
self,
file_hasher_factory: Callable[
[pathlib.Path, int, int], file.ShardedFileHasher
],
merge_hasher: hashing.StreamingHashEngine,
max_workers: int | None = None,
):
"""Initializes an instance to serialize a model with this serializer.

Args:
hasher_factory: A callable to build the hash engine used to hash
every shard of the files in the model. Because each shard is
processed in parallel, every thread needs to call the factory to
start hashing. The arguments are the file, and the endpoints of
the shard.
merge_hasher: A `hashing.StreamingHashEngine` instance used to merge
individual file digests to compute an aggregate digest.
max_workers: Maximum number of workers to use in parallel. Default
is to defer to the `concurent.futures` library.
"""
self._file_hasher_factory = file_hasher_factory
self._merge_hasher = merge_hasher
self._max_workers = max_workers

# Precompute some private values only once by using a mock file hasher.
# None of the arguments used to build the hasher are used.
hasher = file_hasher_factory(pathlib.Path(), 0, 1)
self._shard_size = hasher.shard_size

@override
def serialize(self, model_path: pathlib.Path) -> manifest.DigestManifest:
# Note: This function currently uses `pathlib.Path.glob` so the DFS
# expansion relies on the `glob` implementation performing a DFS. We
# will be truthful again when switching to `pathlib.Path.walk`, after
# Python 3.12 is the minimum version we support.

# TODO: github.com/sigstore/model-transparency/issues/196 - Add checks
# to exclude symlinks if desired.
serialize_by_file.check_file_or_directory(model_path)

if model_path.is_file():
entries = [model_path]
else:
# TODO: github.com/sigstore/model-transparency/issues/200 - When
# Python3.12 is the minimum supported version, this can be replaced
# with `pathlib.Path.walk` for a clearer interface, and some speed
# improvement.
entries = sorted(model_path.glob("**/*"))

tasks = self._convert_paths_to_tasks(entries, model_path)

digest_len = self._merge_hasher.digest_size
digests_buffer = bytearray(len(tasks) * digest_len)

with concurrent.futures.ThreadPoolExecutor(
max_workers=self._max_workers
) as tpe:
futures_dict = {
tpe.submit(self._perform_hash_task, model_path, task): i
for i, task in enumerate(tasks)
}
for future in concurrent.futures.as_completed(futures_dict):
i = futures_dict[future]
task_digest = future.result()

task_path, task_type, task_start, task_end = tasks[i]
header = _build_header(
entry_name=task_path.name,
entry_type=task_type,
start=task_start,
end=task_end,
)
self._merge_hasher.reset(header)
self._merge_hasher.update(task_digest)
digest = self._merge_hasher.compute().digest_value

start = i * digest_len
end = start + digest_len
digests_buffer[start:end] = digest

self._merge_hasher.reset(digests_buffer)
return manifest.DigestManifest(self._merge_hasher.compute())

def _convert_paths_to_tasks(
self, paths: Iterable[pathlib.Path], root_path: pathlib.Path
) -> list[_ShardSignTask]:
"""Returns the tasks that would hash shards of files in parallel.

Every file in `paths` is replaced by a set of tasks. Each task computes
the digest over a shard of the file. Directories result in a single
task, just to compute a digest over a header.

To differentiate between (empty) files and directories with the same
name, every task needs to also include a header. The header needs to
include relative path to the model root, as we want to obtain the same
digest if the model is moved.

We don't construct an enum for the type of the entry, because these will
never escape this class.

Note that the path component of the tasks is a `pathlib.PurePath`, so
operations on it cannot touch the filesystem.
"""
# TODO: github.com/sigstore/model-transparency/issues/196 - Add support
# for excluded files.

tasks = []
for path in paths:
serialize_by_file.check_file_or_directory(path)
relative_path = path.relative_to(root_path)

if path.is_file():
path_size = path.stat().st_size
start = 0
for end in _endpoints(self._shard_size, path_size):
tasks.append((relative_path, "file", start, end))
start = end
else:
tasks.append((relative_path, "dir", 0, 0))

return tasks

def _perform_hash_task(
self, model_path: pathlib.Path, task: _ShardSignTask
) -> bytes:
"""Produces the hash of the file shard included in `task`."""
task_path, task_type, task_start, task_end = task

# TODO: github.com/sigstore/model-transparency/issues/197 - Directories
# don't need to use the file hasher. Rather than starting a process
# just for them, we should filter these ahead of time, and only use
# threading for file shards. For now, just return an empty result.
if task_type == "dir":
return b""

# TODO: github.com/sigstore/model-transparency/issues/197 - Similarly,
# empty files should be hashed outside of a parallel task, to not waste
# resources.
if task_start == task_end:
return b""

full_path = model_path.joinpath(task_path)
hasher = self._file_hasher_factory(full_path, task_start, task_end)
return hasher.compute().digest_value


class ShardedFilesSerializer(serialization.Serializer):
"""Model serializers that produces an itemized manifest, at shard level.
"""Generic file shard serializer.

Traverses the model directory and creates digests for every file found,
sharding the file in equal shards and computing the digests in parallel.

Since the manifest lists each item individually, this will also enable
support for incremental updates (to be added later).
Subclasses can then create a manifest with these digests, either listing
them item by item, combining them into file digests, or combining all of
them into a single digest.
"""

def __init__(
Expand Down Expand Up @@ -270,9 +114,7 @@ def __init__(
self._shard_size = hasher.shard_size

@override
def serialize(
self, model_path: pathlib.Path
) -> manifest.ShardLevelManifest:
def serialize(self, model_path: pathlib.Path) -> manifest.Manifest:
# TODO: github.com/sigstore/model-transparency/issues/196 - Add checks
# to exclude symlinks if desired.
serialize_by_file.check_file_or_directory(model_path)
Expand Down Expand Up @@ -337,12 +179,96 @@ def _compute_hash(
path=relative_path, digest=digest, start=start, end=end
)

@abc.abstractmethod
def _build_manifest(
self, items: Iterable[manifest.ShardedFileManifestItem]
) -> manifest.ShardLevelManifest:
) -> manifest.Manifest:
"""Builds an itemized manifest from a given list of items.

Every subclass needs to implement this method to determine the format of
the manifest.
"""
pass


class ManifestSerializer(ShardedFilesSerializer):
"""Model serializers that produces an itemized manifest, at shard level.

Since the manifest lists each item individually, this will also enable
support for incremental updates (to be added later).
"""

@override
def serialize(
self, model_path: pathlib.Path
) -> manifest.ShardLevelManifest:
"""Serializes the model given by the `model_path` argument.

The only reason for the override is to change the return type, to be
more restrictive. This is to signal that the only manifests that can be
returned are `manifest.FileLevelManifest` instances.
"""
return cast(manifest.ShardLevelManifest, super().serialize(model_path))

@override
def _build_manifest(
self, items: Iterable[manifest.ShardedFileManifestItem]
) -> manifest.ShardLevelManifest:
return manifest.ShardLevelManifest(items)


class DigestSerializer(ShardedFilesSerializer):
"""Serializer for a model that performs a traversal of the model directory.

This serializer produces a single hash for the entire model.
"""

def __init__(
self,
file_hasher_factory: Callable[
[pathlib.Path, int, int], file.ShardedFileHasher
],
merge_hasher: hashing.StreamingHashEngine,
max_workers: int | None = None,
):
"""Initializes an instance to serialize a model with this serializer.

Args:
hasher_factory: A callable to build the hash engine used to hash
every shard of the files in the model. Because each shard is
processed in parallel, every thread needs to call the factory to
start hashing. The arguments are the file, and the endpoints of
the shard.
merge_hasher: A `hashing.StreamingHashEngine` instance used to merge
individual file shard digests to compute an aggregate digest.
max_workers: Maximum number of workers to use in parallel. Default
is to defer to the `concurent.futures` library.
"""
super().__init__(file_hasher_factory, max_workers)
self._merge_hasher = merge_hasher

@override
def serialize(self, model_path: pathlib.Path) -> manifest.DigestManifest:
"""Serializes the model given by the `model_path` argument.

The only reason for the override is to change the return type, to be
more restrictive. This is to signal that the only manifests that can be
returned are `manifest.FileLevelManifest` instances.
"""
return cast(manifest.DigestManifest, super().serialize(model_path))

@override
def _build_manifest(
self, items: Iterable[manifest.ShardedFileManifestItem]
) -> manifest.DigestManifest:
self._merge_hasher.reset()

for item in sorted(items, key=lambda i: (i.path, i.start, i.end)):
header = _build_header(
name=item.path.name, start=item.start, end=item.end
)
self._merge_hasher.update(header)
self._merge_hasher.update(item.digest.digest_value)

digest = self._merge_hasher.compute()
return manifest.DigestManifest(digest)
Loading
Loading