Skip to content

Commit

Permalink
combine model serialization logic for files and directories (#257)
Browse files Browse the repository at this point in the history
* combine model serialization logic for files and directories

Identical logic was being performed whether model_path pointed to a file
or a directory due to two pathlib.Path.glob behaviors:

1. The path must be a directory for the generator to yield items
2. The glob does not include the provided directory itself.

By iterating over the provided model_path first, and then continuing to
the glob, all of the serialization logic can be done in the loop. This
will help future changes which will add more serialization logic in the
form of ignore paths.

Signed-off-by: Spencer Schrock <sschrock@google.com>

* avoid using an inner function to generate the root iterator

Signed-off-by: Spencer Schrock <sschrock@google.com>

---------

Signed-off-by: Spencer Schrock <sschrock@google.com>
  • Loading branch information
spencerschrock authored Jul 25, 2024
1 parent 9798149 commit 8e54020
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 32 deletions.
28 changes: 13 additions & 15 deletions model_signing/serialization/serialize_by_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import abc
import base64
import concurrent.futures
import itertools
import pathlib
from typing import Callable, Iterable, cast
from typing_extensions import override
Expand Down Expand Up @@ -124,22 +125,19 @@ def serialize(self, model_path: pathlib.Path) -> manifest.Manifest:
ValueError: The model contains a symbolic link, but the serializer
was not initialized with `allow_symlinks=True`.
"""
check_file_or_directory(model_path, allow_symlinks=self._allow_symlinks)

paths = []
if model_path.is_file():
paths.append(model_path)
else:
# TODO: github.com/sigstore/model-transparency/issues/200 - When
# Python3.12 is the minimum supported version, this can be replaced
# with `pathlib.Path.walk` for a clearer interface, and some speed
# improvement.
for path in model_path.glob("**/*"):
check_file_or_directory(
path, allow_symlinks=self._allow_symlinks
)
if path.is_file():
paths.append(path)
# TODO: github.com/sigstore/model-transparency/issues/200 - When
# Python3.12 is the minimum supported version, the glob can be replaced
# with `pathlib.Path.walk` for a clearer interface, and some speed
# improvement.
for path in itertools.chain(
iter([model_path]), model_path.glob("**/*")
):
check_file_or_directory(
path, allow_symlinks=self._allow_symlinks
)
if path.is_file():
paths.append(path)

manifest_items = []
with concurrent.futures.ThreadPoolExecutor(
Expand Down
30 changes: 13 additions & 17 deletions model_signing/serialization/serialize_by_file_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import abc
import base64
import concurrent.futures
import itertools
import pathlib
from typing import Callable, Iterable, cast
from typing_extensions import override
Expand Down Expand Up @@ -126,24 +127,19 @@ def serialize(self, model_path: pathlib.Path) -> manifest.Manifest:
ValueError: The model contains a symbolic link, but the serializer
was not initialized with `allow_symlinks=True`.
"""
serialize_by_file.check_file_or_directory(
model_path, allow_symlinks=self._allow_symlinks
)

shards = []
if model_path.is_file():
shards.extend(self._get_shards(model_path))
else:
# TODO: github.com/sigstore/model-transparency/issues/200 - When
# Python3.12 is the minimum supported version, this can be replaced
# with `pathlib.Path.walk` for a clearer interface, and some speed
# improvement.
for path in model_path.glob("**/*"):
serialize_by_file.check_file_or_directory(
path, allow_symlinks=self._allow_symlinks
)
if path.is_file():
shards.extend(self._get_shards(path))
# TODO: github.com/sigstore/model-transparency/issues/200 - When
# Python3.12 is the minimum supported version, the glob can be replaced
# with `pathlib.Path.walk` for a clearer interface, and some speed
# improvement.
for path in itertools.chain(
iter([model_path]), model_path.glob("**/*")
):
serialize_by_file.check_file_or_directory(
path, allow_symlinks=self._allow_symlinks
)
if path.is_file():
shards.extend(self._get_shards(path))

manifest_items = []
with concurrent.futures.ThreadPoolExecutor(
Expand Down

0 comments on commit 8e54020

Please sign in to comment.