Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ ks.kernel()

Go to [microsoft.github.io/skala](https://microsoft.github.io/skala) for a more detailed installation guide and further examples of how to use the Skala functional with PySCF, GPU4PySCF and ASE and in [Azure AI Foundry](https://ai.azure.com/catalog/models/Skala).

## Security: loading `.fun` files

Skala model files (`.fun`) use TorchScript serialization, which can execute arbitrary code when loaded. **Never load `.fun` files from untrusted sources.**

When loading the official Skala model via `load_functional("skala")`, file integrity is automatically verified against pinned SHA-256 hashes before deserialization. If you load `.fun` files directly with `TracedFunctional.load()`, pass the `expected_hash` parameter to enable verification:

```python
TracedFunctional.load("model.fun", expected_hash="<sha256-hex-digest>")
```

## Project information

See the following files for more information about contributing, reporting issues, and the code of conduct:
Expand Down
4 changes: 3 additions & 1 deletion examples/cpp/cpp_integration/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from huggingface_hub import hf_hub_download

from skala.functional._hashes import KNOWN_HASHES
from skala.functional.load import TracedFunctional

GRID_SIZE = "grid_size"
Expand Down Expand Up @@ -48,7 +49,8 @@ def download_model(huggingface_repo_id: str, filename: str, output_path: str) ->

print(f"Downloaded the {filename} functional to {output_path}")

fun = TracedFunctional.load(output_path)
expected_hash = KNOWN_HASHES.get((huggingface_repo_id, filename))
fun = TracedFunctional.load(output_path, expected_hash=expected_hash)

print("\nExpected inputs:")
for feature in fun.features:
Expand Down
13 changes: 11 additions & 2 deletions src/skala/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
(LDA, PBE, TPSS) and the Skala neural functional.
"""

import logging
import os

import torch
from huggingface_hub import hf_hub_download

from skala.functional._hashes import KNOWN_HASHES
from skala.functional.base import ExcFunctionalBase
from skala.functional.load import TracedFunctional
from skala.functional.traditional import LDA, PBE, SPW92, TPSS
Expand Down Expand Up @@ -67,15 +69,22 @@ def load_functional(name: str, device: torch.device | None = None) -> ExcFunctio
if func_name == "skala":
env_path = os.environ.get("SKALA_LOCAL_MODEL_PATH")
if env_path is not None:
logging.getLogger(__name__).warning(
"Loading model from SKALA_LOCAL_MODEL_PATH; "
"SHA-256 hash verification is disabled."
)
path = env_path
expected_hash = None
else:
device_type = (
torch.get_default_device().type if device is None else device.type
)
repo_id = "microsoft/skala"
filename = "skala-1.0.fun" if device_type == "cpu" else "skala-1.0-cuda.fun"
path = hf_hub_download(repo_id="microsoft/skala", filename=filename)
path = hf_hub_download(repo_id=repo_id, filename=filename)
expected_hash = KNOWN_HASHES.get((repo_id, filename))
with open(path, "rb") as fd:
return TracedFunctional.load(fd, device=device)
return TracedFunctional.load(fd, device=device, expected_hash=expected_hash)
elif func_name == "lda":
func = LDA()
elif func_name == "spw92":
Expand Down
19 changes: 19 additions & 0 deletions src/skala/functional/_hashes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-License-Identifier: MIT

"""
SHA-256 hash digests for known traced functional files.

These hashes are used to verify file integrity before loading with
``torch.jit.load``, which deserializes arbitrary code and is therefore
security-sensitive.
"""

# Maps (repo_id, filename) -> expected SHA-256 hex digest.
KNOWN_HASHES: dict[tuple[str, str], str] = {
("microsoft/skala", "skala-1.0.fun"): (
"08d94436995937eb57c451af7c92e2c7f9e1bff6b7da029a3887e9f9dd4581c0"
),
("microsoft/skala", "skala-1.0-cuda.fun"): (
"0b38e13237cec771fed331664aace42f8c0db8f15caca6a5c563085e61e2b1fd"
),
}
29 changes: 29 additions & 0 deletions src/skala/functional/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Tools to load functionals from serialized torchscript checkpoints.
"""

import hashlib
import io
import json
import os
from collections.abc import Iterable, Mapping
Expand Down Expand Up @@ -77,7 +79,17 @@ def load(
cls,
fp: str | bytes | os.PathLike[str] | IO[bytes],
device: torch.device | None = None,
*,
expected_hash: str | None = None,
) -> "TracedFunctional":
"""Load a traced functional from a file.

Args:
fp: File path or readable binary stream.
device: Target device for the loaded model.
expected_hash: If provided, the SHA-256 hex digest that the file
content must match. A ``ValueError`` is raised on mismatch.
"""
extra_files = {
"metadata": b"",
"features": b"",
Expand All @@ -88,6 +100,23 @@ def load(
if device is None:
device = torch.get_default_device()

if expected_hash is not None:
# Read the whole file into memory so we can hash it before
# passing it to the unsafe torch.jit.load deserializer.
if isinstance(fp, (str, bytes, os.PathLike)):
with open(fp, "rb") as f:
data = f.read()
else:
data = fp.read()
actual_hash = hashlib.sha256(data).hexdigest()
if actual_hash != expected_hash:
raise ValueError(
f"Hash mismatch for functional file: "
f"expected {expected_hash}, got {actual_hash}. "
f"The file may have been tampered with."
)
fp = io.BytesIO(data)

traced_model = torch.jit.load(fp, _extra_files=extra_files, map_location=device)

_metadata = json.loads(extra_files["metadata"].decode("utf-8"))
Expand Down
99 changes: 99 additions & 0 deletions tests/test_hash_pinning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-License-Identifier: MIT

"""Tests for hash-pinning on TracedFunctional.load()."""

import hashlib
import io
import json
import os
import tempfile

import pytest
import torch

from skala.functional.load import TracedFunctional


def _make_dummy_fun_bytes() -> bytes:
"""Create a minimal TorchScript archive that TracedFunctional.load can open."""

class Dummy(torch.nn.Module):
features: list[str] = []

def get_exc_density(self, data: dict[str, torch.Tensor]) -> torch.Tensor:
return torch.tensor(0.0)

def get_exc(self, data: dict[str, torch.Tensor]) -> torch.Tensor:
return torch.tensor(0.0)

scripted = torch.jit.script(Dummy())
extra_files = {
"metadata": json.dumps({}).encode(),
"features": json.dumps([]).encode(),
"expected_d3_settings": json.dumps(None).encode(),
"protocol_version": json.dumps(2).encode(),
}
buf = io.BytesIO()
torch.jit.save(scripted, buf, _extra_files=extra_files)
return buf.getvalue()


@pytest.fixture(scope="module")
def dummy_fun_bytes() -> bytes:
return _make_dummy_fun_bytes()


def test_load_with_correct_hash(dummy_fun_bytes: bytes) -> None:
"""Loading succeeds when the expected hash matches."""
correct_hash = hashlib.sha256(dummy_fun_bytes).hexdigest()
func = TracedFunctional.load(
io.BytesIO(dummy_fun_bytes),
expected_hash=correct_hash,
)
assert isinstance(func, TracedFunctional)


def test_load_with_wrong_hash(dummy_fun_bytes: bytes) -> None:
"""Loading raises ValueError when the hash does not match."""
wrong_hash = "0" * 64
with pytest.raises(ValueError, match="Hash mismatch"):
TracedFunctional.load(
io.BytesIO(dummy_fun_bytes),
expected_hash=wrong_hash,
)


def test_load_without_hash(dummy_fun_bytes: bytes) -> None:
"""Loading without expected_hash still works (opt-out)."""
func = TracedFunctional.load(
io.BytesIO(dummy_fun_bytes),
)
assert isinstance(func, TracedFunctional)


def test_load_from_path_with_correct_hash(dummy_fun_bytes: bytes) -> None:
"""Hash verification works when loading from a file path."""
correct_hash = hashlib.sha256(dummy_fun_bytes).hexdigest()
with tempfile.NamedTemporaryFile(suffix=".fun", delete=False) as f:
f.write(dummy_fun_bytes)
f.flush()
path = f.name
try:
func = TracedFunctional.load(path, expected_hash=correct_hash)
assert isinstance(func, TracedFunctional)
finally:
os.unlink(path)


def test_load_from_path_with_wrong_hash(dummy_fun_bytes: bytes) -> None:
"""Hash verification raises ValueError for file path loading too."""
wrong_hash = "0" * 64
with tempfile.NamedTemporaryFile(suffix=".fun", delete=False) as f:
f.write(dummy_fun_bytes)
f.flush()
path = f.name
try:
with pytest.raises(ValueError, match="Hash mismatch"):
TracedFunctional.load(path, expected_hash=wrong_hash)
finally:
os.unlink(path)
Loading