Skip to content

Commit

Permalink
prevent serialization of symlinks by default (#251)
Browse files Browse the repository at this point in the history
* add serialization test for symlink model file

Signed-off-by: Spencer Schrock <sschrock@google.com>

* prevent serialization of symlinks by default

This can be changed via the allow_symlink argument in the various
serialization initializers.

Signed-off-by: Spencer Schrock <sschrock@google.com>

* convert symlink file fixture to symlink directory fixture

Signed-off-by: Spencer Schrock <sschrock@google.com>

* Address style and documentation feedback

Signed-off-by: Spencer Schrock <sschrock@google.com>

* add `Raises` section to serialize docstrings

Signed-off-by: Spencer Schrock <sschrock@google.com>

---------

Signed-off-by: Spencer Schrock <sschrock@google.com>
  • Loading branch information
spencerschrock authored Jul 24, 2024
1 parent c3c4110 commit 9798149
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 18 deletions.
12 changes: 12 additions & 0 deletions model_signing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Test fixtures to share between tests. Not part of the public API."""

import os
import pathlib
import pytest

from model_signing import test_support
Expand Down Expand Up @@ -102,3 +104,13 @@ def deep_model_folder(tmp_path_factory):
file.write_text(f"This is file f{i}.")

return model_root

@pytest.fixture
def symlink_model_folder(tmp_path_factory: pytest.TempPathFactory) -> pathlib.Path:
"""A model folder with a symlink to an external file."""
external_file = tmp_path_factory.mktemp("external") / "file"
external_file.write_bytes(test_support.KNOWN_MODEL_TEXT)
model_dir = tmp_path_factory.mktemp("model")
symlink_file = model_dir / "symlink_file"
os.symlink(external_file.absolute(), symlink_file.absolute())
return model_dir
49 changes: 42 additions & 7 deletions model_signing/serialization/serialize_by_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
from model_signing.serialization import serialization


def check_file_or_directory(path: pathlib.Path) -> None:
def check_file_or_directory(
path: pathlib.Path,
allow_symlinks: bool = False,
) -> None:
"""Checks that the given path is either a file or a directory.
There is no support for sockets, pipes, or any other operating system
Expand All @@ -38,10 +41,19 @@ def check_file_or_directory(path: pathlib.Path) -> None:
Args:
path: The path to check.
allow_symlinks: Controls whether symbolic links are included. If a
symlink is present but the flag is `False` (default) the
serialization would raise an error.
Raises:
ValueError: The path is neither a file or a directory.
ValueError: The path is neither a file or a directory, or the path
is a symlink and `allow_symlinks` is false.
"""
if not allow_symlinks and path.is_symlink():
raise ValueError(
f"Cannot use '{path}' because it is a symlink. This"
" behavior can be changed with `allow_symlinks`."
)
if not (path.is_file() or path.is_dir()):
raise ValueError(
f"Cannot use '{path}' as file or directory. It could be a"
Expand Down Expand Up @@ -87,6 +99,7 @@ def __init__(
self,
file_hasher_factory: Callable[[pathlib.Path], file.FileHasher],
max_workers: int | None = None,
allow_symlinks: bool = False,
):
"""Initializes an instance to serialize a model with this serializer.
Expand All @@ -95,15 +108,23 @@ def __init__(
hash individual files.
max_workers: Maximum number of workers to use in parallel. Default
is to defer to the `concurrent.futures` library.
allow_symlinks: Controls whether symbolic links are included. If a
symlink is present but the flag is `False` (default) the
serialization would raise an error.
"""
self._hasher_factory = file_hasher_factory
self._max_workers = max_workers
self._allow_symlinks = allow_symlinks

@override
def serialize(self, model_path: pathlib.Path) -> manifest.Manifest:
# TODO: github.com/sigstore/model-transparency/issues/196 - Add checks
# to exclude symlinks if desired.
check_file_or_directory(model_path)
"""Serializes the model given by the `model_path` argument.
Raises:
ValueError: The model contains a symbolic link, but the serializer
was not initialized with `allow_symlinks=True`.
"""
check_file_or_directory(model_path, allow_symlinks=self._allow_symlinks)

paths = []
if model_path.is_file():
Expand All @@ -114,7 +135,9 @@ def serialize(self, model_path: pathlib.Path) -> manifest.Manifest:
# with `pathlib.Path.walk` for a clearer interface, and some speed
# improvement.
for path in model_path.glob("**/*"):
check_file_or_directory(path)
check_file_or_directory(
path, allow_symlinks=self._allow_symlinks
)
if path.is_file():
paths.append(path)

Expand Down Expand Up @@ -170,6 +193,10 @@ def serialize(self, model_path: pathlib.Path) -> manifest.FileLevelManifest:
The only reason for the override is to change the return type, to be
more restrictive. This is to signal that the only manifests that can be
returned are `manifest.FileLevelManifest` instances.
Raises:
ValueError: The model contains a symbolic link, but the serializer
was not initialized with `allow_symlinks=True`.
"""
return cast(manifest.FileLevelManifest, super().serialize(model_path))

Expand Down Expand Up @@ -276,6 +303,7 @@ def __init__(
self,
file_hasher: file.SimpleFileHasher,
merge_hasher_factory: Callable[[], hashing.StreamingHashEngine],
allow_symlinks: bool = False,
):
"""Initializes an instance to serialize a model with this serializer.
Expand All @@ -284,13 +312,16 @@ def __init__(
merge_hasher_factory: A callable that returns a
`hashing.StreamingHashEngine` instance used to merge individual
file digests to compute an aggregate digest.
allow_symlinks: Controls whether symbolic links are included. If a
symlink is present but the flag is `False` (default) the
serialization would raise an error.
"""

def _factory(path: pathlib.Path) -> file.FileHasher:
file_hasher.set_file(path)
return file_hasher

super().__init__(_factory, max_workers=1)
super().__init__(_factory, max_workers=1, allow_symlinks=allow_symlinks)
self._merge_hasher_factory = merge_hasher_factory

@override
Expand All @@ -300,6 +331,10 @@ def serialize(self, model_path: pathlib.Path) -> manifest.DigestManifest:
The only reason for the override is to change the return type, to be
more restrictive. This is to signal that the only manifests that can be
returned are `manifest.DigestManifest` instances.
Raises:
ValueError: The model contains a symbolic link, but the serializer
was not initialized with `allow_symlinks=True`.
"""
return cast(manifest.DigestManifest, super().serialize(model_path))

Expand Down
35 changes: 30 additions & 5 deletions model_signing/serialization/serialize_by_file_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
[pathlib.Path, int, int], file.ShardedFileHasher
],
max_workers: int | None = None,
allow_symlinks: bool = False,
):
"""Initializes an instance to serialize a model with this serializer.
Expand All @@ -104,9 +105,13 @@ def __init__(
the shard.
max_workers: Maximum number of workers to use in parallel. Default
is to defer to the `concurrent.futures` library.
allow_symlinks: Controls whether symbolic links are included. If a
symlink is present but the flag is `False` (default) the
serialization would raise an error.
"""
self._hasher_factory = sharded_hasher_factory
self._max_workers = max_workers
self._allow_symlinks = allow_symlinks

# Precompute some private values only once by using a mock file hasher.
# None of the arguments used to build the hasher are used.
Expand All @@ -115,9 +120,15 @@ def __init__(

@override
def serialize(self, model_path: pathlib.Path) -> manifest.Manifest:
# TODO: github.com/sigstore/model-transparency/issues/196 - Add checks
# to exclude symlinks if desired.
serialize_by_file.check_file_or_directory(model_path)
"""Serializes the model given by the `model_path` argument.
Raises:
ValueError: The model contains a symbolic link, but the serializer
was not initialized with `allow_symlinks=True`.
"""
serialize_by_file.check_file_or_directory(
model_path, allow_symlinks=self._allow_symlinks
)

shards = []
if model_path.is_file():
Expand All @@ -128,7 +139,9 @@ def serialize(self, model_path: pathlib.Path) -> manifest.Manifest:
# with `pathlib.Path.walk` for a clearer interface, and some speed
# improvement.
for path in model_path.glob("**/*"):
serialize_by_file.check_file_or_directory(path)
serialize_by_file.check_file_or_directory(
path, allow_symlinks=self._allow_symlinks
)
if path.is_file():
shards.extend(self._get_shards(path))

Expand Down Expand Up @@ -207,6 +220,10 @@ def serialize(
The only reason for the override is to change the return type, to be
more restrictive. This is to signal that the only manifests that can be
returned are `manifest.FileLevelManifest` instances.
Raises:
ValueError: The model contains a symbolic link, but the serializer
was not initialized with `allow_symlinks=True`.
"""
return cast(manifest.ShardLevelManifest, super().serialize(model_path))

Expand All @@ -230,6 +247,7 @@ def __init__(
],
merge_hasher: hashing.StreamingHashEngine,
max_workers: int | None = None,
allow_symlinks: bool = False,
):
"""Initializes an instance to serialize a model with this serializer.
Expand All @@ -243,8 +261,11 @@ def __init__(
individual file shard digests to compute an aggregate digest.
max_workers: Maximum number of workers to use in parallel. Default
is to defer to the `concurent.futures` library.
allow_symlinks: Controls whether symbolic links are included. If a
symlink is present but the flag is `False` (default) the
serialization would raise an error.
"""
super().__init__(file_hasher_factory, max_workers)
super().__init__(file_hasher_factory, max_workers, allow_symlinks)
self._merge_hasher = merge_hasher

@override
Expand All @@ -254,6 +275,10 @@ def serialize(self, model_path: pathlib.Path) -> manifest.DigestManifest:
The only reason for the override is to change the return type, to be
more restrictive. This is to signal that the only manifests that can be
returned are `manifest.FileLevelManifest` instances.
Raises:
ValueError: The model contains a symbolic link, but the serializer
was not initialized with `allow_symlinks=True`.
"""
return cast(manifest.DigestManifest, super().serialize(model_path))

Expand Down
23 changes: 19 additions & 4 deletions model_signing/serialization/serialize_by_file_shard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_known_models(self, request, model_fixture_name):

# Compute model manifest (act)
serializer = serialize_by_file_shard.DigestSerializer(
self._hasher_factory, memory.SHA256()
self._hasher_factory, memory.SHA256(), allow_symlinks=True
)
manifest = serializer.serialize(model)

Expand All @@ -84,7 +84,9 @@ def test_known_models_small_shards(self, request, model_fixture_name):

# Compute model manifest (act)
serializer = serialize_by_file_shard.DigestSerializer(
self._hasher_factory_small_shards, memory.SHA256()
self._hasher_factory_small_shards,
memory.SHA256(),
allow_symlinks=True,
)
manifest = serializer.serialize(model)

Expand Down Expand Up @@ -299,6 +301,13 @@ def test_shard_size_changes_digests(self, sample_model_folder):

assert manifest1.digest.digest_value != manifest2.digest.digest_value

def test_symlinks_disallowed_by_default(self, symlink_model_folder):
serializer = serialize_by_file_shard.DigestSerializer(
self._hasher_factory, memory.SHA256()
)
with pytest.raises(ValueError):
_ = serializer.serialize(symlink_model_folder)


def _extract_shard_items_from_manifest(
manifest: manifest.ShardLevelManifest,
Expand Down Expand Up @@ -357,7 +366,7 @@ def test_known_models(self, request, model_fixture_name):

# Compute model manifest (act)
serializer = serialize_by_file_shard.ManifestSerializer(
self._hasher_factory
self._hasher_factory, allow_symlinks=True
)
manifest_file = serializer.serialize(model)
items = _extract_shard_items_from_manifest(manifest_file)
Expand Down Expand Up @@ -388,7 +397,7 @@ def test_known_models_small_shards(self, request, model_fixture_name):

# Compute model manifest (act)
serializer = serialize_by_file_shard.ManifestSerializer(
self._hasher_factory_small_shards
self._hasher_factory_small_shards, allow_symlinks=True
)
manifest_file = serializer.serialize(model)
items = _extract_shard_items_from_manifest(manifest_file)
Expand Down Expand Up @@ -654,6 +663,12 @@ def test_max_workers_does_not_change_digest(self, sample_model_folder):
assert manifest1 == manifest2
assert manifest1 == manifest3

def test_symlinks_disallowed_by_default(self, symlink_model_folder):
serializer = serialize_by_file_shard.ManifestSerializer(
self._hasher_factory
)
with pytest.raises(ValueError):
_ = serializer.serialize(symlink_model_folder)

def test_shard_to_string(self):
"""Ensure the shard's `__str__` method behaves as assumed."""
Expand Down
21 changes: 19 additions & 2 deletions model_signing/serialization/serialize_by_file_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_known_models(self, request, model_fixture_name):
test_support.UNUSED_PATH, memory.SHA256()
)
serializer = serialize_by_file.DigestSerializer(
file_hasher, memory.SHA256
file_hasher, memory.SHA256, allow_symlinks=True
)
manifest = serializer.serialize(model)

Expand Down Expand Up @@ -275,6 +275,16 @@ def test_model_with_empty_folder_hashes_differently_than_with_empty_file(

assert folder_manifest != file_manifest

def test_symlinks_disallowed_by_default(self, symlink_model_folder):
file_hasher = file.SimpleFileHasher(
test_support.UNUSED_PATH, memory.SHA256()
)
serializer = serialize_by_file.DigestSerializer(
file_hasher, memory.SHA256
)
with pytest.raises(ValueError):
_ = serializer.serialize(symlink_model_folder)


class TestManifestSerializer:

Expand All @@ -292,7 +302,9 @@ def test_known_models(self, request, model_fixture_name):
model = request.getfixturevalue(model_fixture_name)

# Compute model manifest (act)
serializer = serialize_by_file.ManifestSerializer(self._hasher_factory)
serializer = serialize_by_file.ManifestSerializer(
self._hasher_factory, allow_symlinks=True
)
manifest = serializer.serialize(model)
items = test_support.extract_items_from_manifest(manifest)

Expand Down Expand Up @@ -536,6 +548,11 @@ def test_max_workers_does_not_change_digest(self, sample_model_folder):
assert manifest1 == manifest2
assert manifest1 == manifest3

def test_symlinks_disallowed_by_default(self, symlink_model_folder):
serializer = serialize_by_file.ManifestSerializer(self._hasher_factory)
with pytest.raises(ValueError):
_ = serializer.serialize(symlink_model_folder)


class TestUtilities:

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
8372365be7578241d18db47ec83b735bb450a10a1b4298d9b7b0d8bf543b7271
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
symlink_file:3aab065c7181a173b5dd9e9d32a9f79923440b413be1e1ffcdba26a7365f719b
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
013322ae2977a76252119aa6a4d71044599c7c2d890f7ed96215b52308ee7142
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
de9d3fc1608836778e17540ac0332b42bf01730d4697767571b89460fae92fc3
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.:0:22:3aab065c7181a173b5dd9e9d32a9f79923440b413be1e1ffcdba26a7365f719b
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
symlink_file:0:22:3aab065c7181a173b5dd9e9d32a9f79923440b413be1e1ffcdba26a7365f719b
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
symlink_file:0:8:a37010c994067764d86540bf479d93b4d0c3bb3955de7b61f951caf2fd0301b0
symlink_file:8:16:bd762002a3528a27fb9a8822f822b949d3c9ab7e860af33039c9aa70ebbbe682
symlink_file:16:22:a791e1e893ea4260c77475725101fb4cc6ad85f6340f21f10b239184e318cd21
1 change: 1 addition & 0 deletions model_signing/test_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"empty_model_file",
"empty_model_folder",
"model_folder_with_empty_file",
"symlink_model_folder",
]


Expand Down

0 comments on commit 9798149

Please sign in to comment.