From b2c36e0ebc8d3323a2bd068920a0c9b9c4b31505 Mon Sep 17 00:00:00 2001 From: Thijs Vogels Date: Mon, 16 Feb 2026 11:03:35 +0000 Subject: [PATCH 1/2] Add SHA-256 hash pinning for .fun file loading torch.jit.load deserializes arbitrary code via pickle, making it vulnerable to code execution if a .fun file is tampered with. This change verifies file integrity against pinned SHA-256 digests before calling torch.jit.load. - Add _hashes.py with known digests for microsoft/skala model files - Add expected_hash parameter to TracedFunctional.load() - Pass pinned hash from load_functional() for HF-downloaded models - Skip verification for user-supplied SKALA_LOCAL_MODEL_PATH - Add unit tests for hash verification (match, mismatch, opt-out) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../cpp/cpp_integration/download_model.py | 4 +- src/skala/functional/__init__.py | 13 ++- src/skala/functional/_hashes.py | 19 ++++ src/skala/functional/load.py | 29 ++++++ tests/test_hash_pinning.py | 99 +++++++++++++++++++ 5 files changed, 161 insertions(+), 3 deletions(-) create mode 100644 src/skala/functional/_hashes.py create mode 100644 tests/test_hash_pinning.py diff --git a/examples/cpp/cpp_integration/download_model.py b/examples/cpp/cpp_integration/download_model.py index f4b317b..158c40b 100755 --- a/examples/cpp/cpp_integration/download_model.py +++ b/examples/cpp/cpp_integration/download_model.py @@ -11,6 +11,7 @@ from huggingface_hub import hf_hub_download +from skala.functional._hashes import KNOWN_HASHES from skala.functional.load import TracedFunctional GRID_SIZE = "grid_size" @@ -48,7 +49,8 @@ def download_model(huggingface_repo_id: str, filename: str, output_path: str) -> print(f"Downloaded the {filename} functional to {output_path}") - fun = TracedFunctional.load(output_path) + expected_hash = KNOWN_HASHES.get((huggingface_repo_id, filename)) + fun = TracedFunctional.load(output_path, expected_hash=expected_hash) print("\nExpected inputs:") for feature in fun.features: diff --git a/src/skala/functional/__init__.py b/src/skala/functional/__init__.py index f77cdd5..c4cc1a2 100644 --- a/src/skala/functional/__init__.py +++ b/src/skala/functional/__init__.py @@ -8,11 +8,13 @@ (LDA, PBE, TPSS) and the Skala neural functional. """ +import logging import os import torch from huggingface_hub import hf_hub_download +from skala.functional._hashes import KNOWN_HASHES from skala.functional.base import ExcFunctionalBase from skala.functional.load import TracedFunctional from skala.functional.traditional import LDA, PBE, SPW92, TPSS @@ -67,15 +69,22 @@ def load_functional(name: str, device: torch.device | None = None) -> ExcFunctio if func_name == "skala": env_path = os.environ.get("SKALA_LOCAL_MODEL_PATH") if env_path is not None: + logging.getLogger(__name__).warning( + "Loading model from SKALA_LOCAL_MODEL_PATH; " + "SHA-256 hash verification is disabled." + ) path = env_path + expected_hash = None else: device_type = ( torch.get_default_device().type if device is None else device.type ) + repo_id = "microsoft/skala" filename = "skala-1.0.fun" if device_type == "cpu" else "skala-1.0-cuda.fun" - path = hf_hub_download(repo_id="microsoft/skala", filename=filename) + path = hf_hub_download(repo_id=repo_id, filename=filename) + expected_hash = KNOWN_HASHES.get((repo_id, filename)) with open(path, "rb") as fd: - return TracedFunctional.load(fd, device=device) + return TracedFunctional.load(fd, device=device, expected_hash=expected_hash) elif func_name == "lda": func = LDA() elif func_name == "spw92": diff --git a/src/skala/functional/_hashes.py b/src/skala/functional/_hashes.py new file mode 100644 index 0000000..113b99f --- /dev/null +++ b/src/skala/functional/_hashes.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: MIT + +""" +SHA-256 hash digests for known traced functional files. + +These hashes are used to verify file integrity before loading with +``torch.jit.load``, which deserializes arbitrary code and is therefore +security-sensitive. +""" + +# Maps (repo_id, filename) -> expected SHA-256 hex digest. +KNOWN_HASHES: dict[tuple[str, str], str] = { + ("microsoft/skala", "skala-1.0.fun"): ( + "08d94436995937eb57c451af7c92e2c7f9e1bff6b7da029a3887e9f9dd4581c0" + ), + ("microsoft/skala", "skala-1.0-cuda.fun"): ( + "0b38e13237cec771fed331664aace42f8c0db8f15caca6a5c563085e61e2b1fd" + ), +} diff --git a/src/skala/functional/load.py b/src/skala/functional/load.py index ab9e241..ed29714 100644 --- a/src/skala/functional/load.py +++ b/src/skala/functional/load.py @@ -4,6 +4,8 @@ Tools to load functionals from serialized torchscript checkpoints. """ +import hashlib +import io import json import os from collections.abc import Iterable, Mapping @@ -77,7 +79,17 @@ def load( cls, fp: str | bytes | os.PathLike[str] | IO[bytes], device: torch.device | None = None, + *, + expected_hash: str | None = None, ) -> "TracedFunctional": + """Load a traced functional from a file. + + Args: + fp: File path or readable binary stream. + device: Target device for the loaded model. + expected_hash: If provided, the SHA-256 hex digest that the file + content must match. A ``ValueError`` is raised on mismatch. + """ extra_files = { "metadata": b"", "features": b"", @@ -88,6 +100,23 @@ def load( if device is None: device = torch.get_default_device() + if expected_hash is not None: + # Read the whole file into memory so we can hash it before + # passing it to the unsafe torch.jit.load deserializer. + if isinstance(fp, (str, bytes, os.PathLike)): + with open(fp, "rb") as f: + data = f.read() + else: + data = fp.read() + actual_hash = hashlib.sha256(data).hexdigest() + if actual_hash != expected_hash: + raise ValueError( + f"Hash mismatch for functional file: " + f"expected {expected_hash}, got {actual_hash}. " + f"The file may have been tampered with." + ) + fp = io.BytesIO(data) + traced_model = torch.jit.load(fp, _extra_files=extra_files, map_location=device) _metadata = json.loads(extra_files["metadata"].decode("utf-8")) diff --git a/tests/test_hash_pinning.py b/tests/test_hash_pinning.py new file mode 100644 index 0000000..9cc48e5 --- /dev/null +++ b/tests/test_hash_pinning.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: MIT + +"""Tests for hash-pinning on TracedFunctional.load().""" + +import hashlib +import io +import json +import os +import tempfile + +import pytest +import torch + +from skala.functional.load import TracedFunctional + + +def _make_dummy_fun_bytes() -> bytes: + """Create a minimal TorchScript archive that TracedFunctional.load can open.""" + + class Dummy(torch.nn.Module): + features: list[str] = [] + + def get_exc_density(self, data: dict[str, torch.Tensor]) -> torch.Tensor: + return torch.tensor(0.0) + + def get_exc(self, data: dict[str, torch.Tensor]) -> torch.Tensor: + return torch.tensor(0.0) + + scripted = torch.jit.script(Dummy()) + extra_files = { + "metadata": json.dumps({}).encode(), + "features": json.dumps([]).encode(), + "expected_d3_settings": json.dumps(None).encode(), + "protocol_version": json.dumps(2).encode(), + } + buf = io.BytesIO() + torch.jit.save(scripted, buf, _extra_files=extra_files) + return buf.getvalue() + + +@pytest.fixture(scope="module") +def dummy_fun_bytes() -> bytes: + return _make_dummy_fun_bytes() + + +def test_load_with_correct_hash(dummy_fun_bytes: bytes) -> None: + """Loading succeeds when the expected hash matches.""" + correct_hash = hashlib.sha256(dummy_fun_bytes).hexdigest() + func = TracedFunctional.load( + io.BytesIO(dummy_fun_bytes), + expected_hash=correct_hash, + ) + assert isinstance(func, TracedFunctional) + + +def test_load_with_wrong_hash(dummy_fun_bytes: bytes) -> None: + """Loading raises ValueError when the hash does not match.""" + wrong_hash = "0" * 64 + with pytest.raises(ValueError, match="Hash mismatch"): + TracedFunctional.load( + io.BytesIO(dummy_fun_bytes), + expected_hash=wrong_hash, + ) + + +def test_load_without_hash(dummy_fun_bytes: bytes) -> None: + """Loading without expected_hash still works (opt-out).""" + func = TracedFunctional.load( + io.BytesIO(dummy_fun_bytes), + ) + assert isinstance(func, TracedFunctional) + + +def test_load_from_path_with_correct_hash(dummy_fun_bytes: bytes) -> None: + """Hash verification works when loading from a file path.""" + correct_hash = hashlib.sha256(dummy_fun_bytes).hexdigest() + with tempfile.NamedTemporaryFile(suffix=".fun", delete=False) as f: + f.write(dummy_fun_bytes) + f.flush() + path = f.name + try: + func = TracedFunctional.load(path, expected_hash=correct_hash) + assert isinstance(func, TracedFunctional) + finally: + os.unlink(path) + + +def test_load_from_path_with_wrong_hash(dummy_fun_bytes: bytes) -> None: + """Hash verification raises ValueError for file path loading too.""" + wrong_hash = "0" * 64 + with tempfile.NamedTemporaryFile(suffix=".fun", delete=False) as f: + f.write(dummy_fun_bytes) + f.flush() + path = f.name + try: + with pytest.raises(ValueError, match="Hash mismatch"): + TracedFunctional.load(path, expected_hash=wrong_hash) + finally: + os.unlink(path) From 48702acbab9430fdb86b6ccf779d8694ae83798b Mon Sep 17 00:00:00 2001 From: Thijs Vogels Date: Mon, 16 Feb 2026 12:29:02 +0000 Subject: [PATCH 2/2] Add security note about loading .fun files from untrusted sources Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index de86f33..1631b8c 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,16 @@ ks.kernel() Go to [microsoft.github.io/skala](https://microsoft.github.io/skala) for a more detailed installation guide and further examples of how to use the Skala functional with PySCF, GPU4PySCF and ASE and in [Azure AI Foundry](https://ai.azure.com/catalog/models/Skala). +## Security: loading `.fun` files + +Skala model files (`.fun`) use TorchScript serialization, which can execute arbitrary code when loaded. **Never load `.fun` files from untrusted sources.** + +When loading the official Skala model via `load_functional("skala")`, file integrity is automatically verified against pinned SHA-256 hashes before deserialization. If you load `.fun` files directly with `TracedFunctional.load()`, pass the `expected_hash` parameter to enable verification: + +```python +TracedFunctional.load("model.fun", expected_hash="") +``` + ## Project information See the following files for more information about contributing, reporting issues, and the code of conduct: