Skip to content

Commit

Permalink
Refactor DFSSerializer to remove code duplication. (#241)
Browse files Browse the repository at this point in the history
* Refactor DFSSerializer to remove code duplication.

There is a duplication in the directory traversal between
`DFSSerializer` and `FilesSerializer`. Since the later supports parallel
hashing, let's prepare to use only that.

We make `FilesSerializer` be an abstract parent class that performs
the directory traversal. We introduce `ManifestSerializer` for the old
`FilesSerializer` class that was creating a manifest out of the model.
We rename `DFSSerializer` to `DigestSerializer` for consistency.

Since `FilesSerializer` (the directory traversal) only considers files
now, we need to add a `_FileDigestTree` class to transform the list of
files and their hashes (the `FileManifestItem`) to a directory traversal
tree, so we can build the digest for `DigestSerializer` in a bottom-up
fashion, like before. We could have just included only the files,
instead of the directory, but that would require changing a lot of
expected constants in the tests. So, we add this transformation now, we
plan to migrate tests to goldens and then maybe change the hashing to
only include the files when rolling up to a single digest.

We still had to update one test: since the hashes are computed only for
files, we no longer differentiate between a model with an empty
directory and a model where that empty directory is completely removed.
This is a corner case and it is ok to do this.

In fact, ignoring empty directories is part of the optimization hinted
at in #197.

Signed-off-by: Mihai Maruseac <mihaimaruseac@google.com>

* Fix Windows

Signed-off-by: Mihai Maruseac <mihaimaruseac@google.com>

* Document `__init__`

Signed-off-by: Mihai Maruseac <mihaimaruseac@google.com>

* Fix typos, add a type signature

Signed-off-by: Mihai Maruseac <mihaimaruseac@google.com>

---------

Signed-off-by: Mihai Maruseac <mihaimaruseac@google.com>
  • Loading branch information
mihaimaruseac authored Jul 22, 2024
1 parent bfc60f1 commit 99f4d8a
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 113 deletions.
246 changes: 174 additions & 72 deletions model_signing/serialization/serialize_by_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

"""Model serializers that operated at file level granularity."""

import abc
import base64
import concurrent.futures
import pathlib
from typing import Callable, Iterable
from typing import Callable, Iterable, cast
from typing_extensions import override

from model_signing.hashing import file
Expand Down Expand Up @@ -65,84 +66,21 @@ def _build_header(
bytes. Each argument is separated by dots and the last byte is also a
dot (so the file digest can be appended unambiguously).
"""
# Note: This will get replaced in subsequent change, right now we're just
# moving existing code around.
encoded_type = entry_type.encode("utf-8")
# Prevent confusion if name has a "." inside by encoding to base64.
encoded_name = base64.b64encode(entry_name.encode("utf-8"))
# Note: empty string at the end, to terminate header with a "."
return b".".join([encoded_type, encoded_name, b""])


class DFSSerializer(serialization.Serializer):
"""Serializer for a model that performs a traversal of the model directory.
This serializer produces a single hash for the entire model. If the model is
a file, the hash is the digest of the file. If the model is a directory, we
perform a depth-first traversal of the directory, hash each individual files
and aggregate the hashes together.
"""

def __init__(
self,
file_hasher: file.SimpleFileHasher,
merge_hasher_factory: Callable[[], hashing.StreamingHashEngine],
):
"""Initializes an instance to serialize a model with this serializer.
Args:
hasher: The hash engine used to hash the individual files.
merge_hasher_factory: A callable that returns a
`hashing.StreamingHashEngine` instance used to merge individual
file digests to compute an aggregate digest.
"""
self._file_hasher = file_hasher
self._merge_hasher_factory = merge_hasher_factory

@override
def serialize(self, model_path: pathlib.Path) -> manifest.DigestManifest:
# TODO: github.com/sigstore/model-transparency/issues/196 - Add checks
# to exclude symlinks if desired.
check_file_or_directory(model_path)

if model_path.is_file():
self._file_hasher.set_file(model_path)
return manifest.DigestManifest(self._file_hasher.compute())

return manifest.DigestManifest(self._dfs(model_path))

def _dfs(self, directory: pathlib.Path) -> hashing.Digest:
# TODO: github.com/sigstore/model-transparency/issues/196 - Add support
# for excluded files.
children = sorted([x for x in directory.iterdir()])

hasher = self._merge_hasher_factory()
for child in children:
check_file_or_directory(child)

if child.is_file():
header = _build_header(entry_name=child.name, entry_type="file")
hasher.update(header)
self._file_hasher.set_file(child)
digest = self._file_hasher.compute()
hasher.update(digest.digest_value)
else:
header = _build_header(entry_name=child.name, entry_type="dir")
hasher.update(header)
digest = self._dfs(child)
hasher.update(digest.digest_value)

return hasher.compute()


class FilesSerializer(serialization.Serializer):
"""Model serializers that produces an itemized manifest, at file level.
"""Generic file serializer.
Traverses the model directory and creates digests for every file found,
possibly in parallel.
Since the manifest lists each item individually, this will also enable
support for incremental updates (to be added later).
Subclasses can then create a manifest with these digests, either listing
them item by item, or combining everything into a single digest.
"""

def __init__(
Expand All @@ -162,7 +100,7 @@ def __init__(
self._max_workers = max_workers

@override
def serialize(self, model_path: pathlib.Path) -> manifest.FileLevelManifest:
def serialize(self, model_path: pathlib.Path) -> manifest.Manifest:
# TODO: github.com/sigstore/model-transparency/issues/196 - Add checks
# to exclude symlinks if desired.
check_file_or_directory(model_path)
Expand Down Expand Up @@ -210,12 +148,176 @@ def _compute_hash(
digest = self._hasher_factory(path).compute()
return manifest.FileManifestItem(path=relative_path, digest=digest)

@abc.abstractmethod
def _build_manifest(
self, items: Iterable[manifest.FileManifestItem]
) -> manifest.FileLevelManifest:
"""Builds an itemized manifest from a given list of items.
) -> manifest.Manifest:
"""Builds the manifest representing the serialization of the model."""
pass


class ManifestSerializer(FilesSerializer):
"""Model serializer that produces an itemized manifest, at file level.
Since the manifest lists each item individually, this will also enable
support for incremental updates (to be added later).
"""

Every subclass needs to implement this method to determine the format of
the manifest.
@override
def serialize(self, model_path: pathlib.Path) -> manifest.FileLevelManifest:
"""Serializes the model given by the `model_path` argument.
The only reason for the override is to change the return type, to be
more restrictive. This is to signal that the only manifests that can be
returned are `manifest.FileLevelManifest` instances.
"""
return cast(manifest.FileLevelManifest, super().serialize(model_path))

@override
def _build_manifest(
self, items: Iterable[manifest.FileManifestItem]
) -> manifest.FileLevelManifest:
return manifest.FileLevelManifest(items)


class _FileDigestTree:
"""A tree of files with their digests.
Every leaf in the tree is a file, paired with its digest. Every intermediate
node represents a directory. We need to pair every directory with a digest,
in a bottom-up fashion.
"""

def __init__(
self, path: pathlib.PurePath, digest: hashing.Digest | None = None
):
"""Builds a node in the digest tree.
Don't call this from outside of the class. Instead, use `build_tree`.
Args:
path: Path included in the node.
digest: Optional hash of the path. Files must have a digest,
directories never have one.
"""
self._path = path
self._digest = digest
self._children: list[_FileDigestTree] = []

@classmethod
def build_tree(
cls, items: Iterable[manifest.FileManifestItem]
) -> "_FileDigestTree":
"""Builds a tree out of the sequence of manifest items."""
path_to_node: dict[pathlib.PurePath, _FileDigestTree] = {}

for file_item in items:
file = file_item.path
node = cls(file, file_item.digest)
for parent in file.parents:
if parent in path_to_node:
parent_node = path_to_node[parent]
parent_node._children.append(node)
break # everything else already exists

parent_node = cls(parent) # no digest for directories
parent_node._children.append(node)
path_to_node[parent] = parent_node
node = parent_node

# Handle empty model
if not path_to_node:
return cls(pathlib.PurePosixPath())

return path_to_node[pathlib.PurePosixPath()]

def get_digest(
self,
hasher_factory: Callable[[], hashing.StreamingHashEngine],
) -> hashing.Digest:
"""Returns the digest of this tree of files.
Args:
hasher_factory: A callable that returns a
`hashing.StreamingHashEngine` instance used to merge individual
digests to compute an aggregate digest.
"""
hasher = hasher_factory()

for child in sorted(self._children, key=lambda c: c._path):
name = child._path.name
if child._digest is not None:
header = _build_header(entry_name=name, entry_type="file")
hasher.update(header)
hasher.update(child._digest.digest_value)
else:
header = _build_header(entry_name=name, entry_type="dir")
hasher.update(header)
digest = child.get_digest(hasher_factory)
hasher.update(digest.digest_value)

return hasher.compute()


class DigestSerializer(FilesSerializer):
"""Serializer for a model that performs a traversal of the model directory.
This serializer produces a single hash for the entire model. If the model is
a file, the hash is the digest of the file. If the model is a directory, we
perform a depth-first traversal of the directory, hash each individual files
and aggregate the hashes together.
Currently, this has a different initialization than `FilesSerializer`, but
this will likely change in a subsequent change. Similarly, currently, this
only supports one single worker, but this will change in the future.
"""

def __init__(
self,
file_hasher: file.SimpleFileHasher,
merge_hasher_factory: Callable[[], hashing.StreamingHashEngine],
):
"""Initializes an instance to serialize a model with this serializer.
Args:
hasher: The hash engine used to hash the individual files.
merge_hasher_factory: A callable that returns a
`hashing.StreamingHashEngine` instance used to merge individual
file digests to compute an aggregate digest.
"""

def _factory(path: pathlib.Path) -> file.FileHasher:
file_hasher.set_file(path)
return file_hasher

super().__init__(_factory, max_workers=1)
self._merge_hasher_factory = merge_hasher_factory

@override
def serialize(self, model_path: pathlib.Path) -> manifest.DigestManifest:
"""Serializes the model given by the `model_path` argument.
The only reason for the override is to change the return type, to be
more restrictive. This is to signal that the only manifests that can be
returned are `manifest.DigestManifest` instances.
"""
return cast(manifest.DigestManifest, super().serialize(model_path))

@override
def _build_manifest(
self, items: Iterable[manifest.FileManifestItem]
) -> manifest.DigestManifest:
# Note: we do several computations here to try and match the old
# behavior but these would be simplified in the future. Since we are
# defining the hashing behavior, we can freely change this.

# If the model is just one file, return the hash of the file.
# A model is a file if we have one item only and its path is empty.
items = list(items)
if len(items) == 1 and not items[0].path.name:
return manifest.DigestManifest(items[0].digest)

# Otherwise, build a tree of files and compute the digests.
tree = _FileDigestTree.build_tree(items)
digest = tree.get_digest(self._merge_hasher_factory)
return manifest.DigestManifest(digest)
Loading

0 comments on commit 99f4d8a

Please sign in to comment.