Skip to content

Commit a404265

Browse files
authored
Fixing the fact that the context manager wasn't properly cleaning up after itself. (#166)
* Adding failing test for windows. * This should fail. * Fixing Windows issue by actually dropping the rust resources. * Making the rust module private.
1 parent 613a34f commit a404265

File tree

7 files changed

+37
-7
lines changed

7 files changed

+37
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
11
__version__ = "0.2.9"
22

33
# Re-export this
4-
from .safetensors_rust import safe_open # noqa: F401
4+
from ._safetensors_rust import safe_open as rust_open, serialize, serialize_file, deserialize, SafetensorError
5+
6+
7+
class safe_open:
8+
def __init__(self, *args, **kwargs):
9+
self.args = args
10+
self.kwargs = kwargs
11+
12+
def __getattr__(self, __name: str):
13+
return getattr(self.f, __name)
14+
15+
def __enter__(self):
16+
self.f = rust_open(*self.args, **self.kwargs)
17+
return self
18+
19+
def __exit__(self, type, value, traceback):
20+
del self.f

bindings/python/py_src/safetensors/numpy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from .safetensors_rust import deserialize, safe_open, serialize, serialize_file
7+
from safetensors import deserialize, safe_open, serialize, serialize_file
88

99

1010
def save(tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]] = None) -> bytes:

bindings/python/py_src/safetensors/torch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from .safetensors_rust import deserialize, safe_open, serialize, serialize_file
8+
from safetensors import deserialize, safe_open, serialize, serialize_file
99

1010

1111
def save(tensors: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes:

bindings/python/setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def deps_list(*pkgs):
5959
author_email="",
6060
url="https://github.com/huggingface/safetensors",
6161
license="Apache License 2.0",
62-
rust_extensions=[RustExtension("safetensors.safetensors_rust", binding=Binding.PyO3, debug=False)],
62+
rust_extensions=[RustExtension("safetensors._safetensors_rust", binding=Binding.PyO3, debug=False)],
6363
extras_require=extras,
6464
classifiers=[
6565
"Development Status :: 5 - Production/Stable",

bindings/python/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,7 @@ pyo3::create_exception!(
10211021

10221022
/// A Python module implemented in Rust.
10231023
#[pymodule]
1024-
fn safetensors_rust(py: Python, m: &PyModule) -> PyResult<()> {
1024+
fn _safetensors_rust(py: Python, m: &PyModule) -> PyResult<()> {
10251025
m.add_function(wrap_pyfunction!(serialize, m)?)?;
10261026
m.add_function(wrap_pyfunction!(serialize_file, m)?)?;
10271027
m.add_function(wrap_pyfunction!(deserialize, m)?)?;

bindings/python/tests/test_pt_comparison.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from safetensors.safetensors_rust import safe_open
5+
from safetensors import safe_open
66
from safetensors.torch import load, load_file, save, save_file
77

88

bindings/python/tests/test_simple.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from safetensors.numpy import load, load_file, save, save_file
9-
from safetensors.safetensors_rust import SafetensorError, serialize
9+
from safetensors import safe_open, SafetensorError, serialize
1010
from safetensors.torch import load_file as load_file_pt
1111
from safetensors.torch import save_file as save_file_pt
1212

@@ -59,6 +59,20 @@ def test_accept_path(self):
5959
os.remove(Path("./out.safetensors"))
6060

6161

62+
class WindowsTestCase(unittest.TestCase):
63+
def test_get_correctly_dropped(self):
64+
tensors = {
65+
"a": torch.zeros((2, 2)),
66+
"b": torch.zeros((2, 3), dtype=torch.uint8),
67+
}
68+
save_file_pt(tensors, "./out.safetensors")
69+
with safe_open("./out.safetensors", framework="pt") as f:
70+
pass
71+
72+
with open("./out.safetensors", "w") as g:
73+
g.write("something")
74+
75+
6276
class ReadmeTestCase(unittest.TestCase):
6377
def assertTensorEqual(self, tensors1, tensors2, equality_fn):
6478
self.assertEqual(tensors1.keys(), tensors2.keys(), "tensor keys don't match")

0 commit comments

Comments
 (0)