Skip to content

Commit 9005fe1

Browse files
committed
Better error messages on outdated context manager.
1 parent a895411 commit 9005fe1

File tree

3 files changed

+112
-40
lines changed

3 files changed

+112
-40
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,4 @@
11
__version__ = "0.2.9"
22

33
# Re-export this
4-
from ._safetensors_rust import safe_open as rust_open, serialize, serialize_file, deserialize, SafetensorError
5-
6-
7-
class safe_open:
8-
def __init__(self, *args, **kwargs):
9-
self.args = args
10-
self.kwargs = kwargs
11-
12-
def __getattr__(self, __name: str):
13-
return getattr(self.f, __name)
14-
15-
def __enter__(self):
16-
self.f = rust_open(*self.args, **self.kwargs)
17-
return self
18-
19-
def __exit__(self, type, value, traceback):
20-
del self.f
4+
from ._safetensors_rust import safe_open, serialize, serialize_file, deserialize, SafetensorError # noqa: F401

bindings/python/src/lib.rs

+108-23
Original file line numberDiff line numberDiff line change
@@ -466,32 +466,15 @@ impl Version {
466466
}
467467
}
468468

469-
/// Opens a safetensors lazily and returns tensors as asked
470-
///
471-
/// Args:
472-
/// filename (`str`, or `os.PathLike`):
473-
/// The filename to open
474-
///
475-
/// framework (`str`):
476-
/// The framework you want you tensors in. Supported values:
477-
/// `pt`, `tf`, `flax`, `numpy`.
478-
///
479-
/// device (`str`, defaults to `"cpu"`):
480-
/// The device on which you want the tensors.
481-
#[pyclass]
482-
#[allow(non_camel_case_types)]
483-
#[pyo3(text_signature = "(self, filename, framework, device=\"cpu\")")]
484-
struct safe_open {
469+
struct Open {
485470
metadata: Metadata,
486471
offset: usize,
487472
framework: Framework,
488473
device: Device,
489474
storage: Arc<Storage>,
490475
}
491476

492-
#[pymethods]
493-
impl safe_open {
494-
#[new]
477+
impl Open {
495478
fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> PyResult<Self> {
496479
let file = File::open(&filename)?;
497480
let device = device.unwrap_or(Device::Cpu);
@@ -714,6 +697,105 @@ impl safe_open {
714697
)))
715698
}
716699
}
700+
}
701+
702+
/// Opens a safetensors lazily and returns tensors as asked
703+
///
704+
/// Args:
705+
/// filename (`str`, or `os.PathLike`):
706+
/// The filename to open
707+
///
708+
/// framework (`str`):
709+
/// The framework you want you tensors in. Supported values:
710+
/// `pt`, `tf`, `flax`, `numpy`.
711+
///
712+
/// device (`str`, defaults to `"cpu"`):
713+
/// The device on which you want the tensors.
714+
#[pyclass]
715+
#[allow(non_camel_case_types)]
716+
#[pyo3(text_signature = "(self, filename, framework, device=\"cpu\")")]
717+
struct safe_open {
718+
inner: Option<Open>,
719+
}
720+
721+
impl safe_open {
722+
fn inner(&self) -> PyResult<&Open> {
723+
let inner = self
724+
.inner
725+
.as_ref()
726+
.ok_or_else(|| SafetensorError::new_err(format!("File is closed",)))?;
727+
Ok(inner)
728+
}
729+
}
730+
731+
#[pymethods]
732+
impl safe_open {
733+
#[new]
734+
fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> PyResult<Self> {
735+
let inner = Some(Open::new(filename, framework, device)?);
736+
Ok(Self { inner })
737+
}
738+
739+
/// Return the special non tensor information in the header
740+
///
741+
/// Returns:
742+
/// (`Dict[str, str]`):
743+
/// The freeform metadata.
744+
pub fn metadata(&self) -> PyResult<Option<BTreeMap<String, String>>> {
745+
Ok(self.inner()?.metadata())
746+
}
747+
748+
/// Returns the names of the tensors in the file.
749+
///
750+
/// Returns:
751+
/// (`List[str]`):
752+
/// The name of the tensors contained in that file
753+
pub fn keys(&self) -> PyResult<Vec<String>> {
754+
self.inner()?.keys()
755+
}
756+
757+
/// Returns a full tensor
758+
///
759+
/// Args:
760+
/// name (`str`):
761+
/// The name of the tensor you want
762+
///
763+
/// Returns:
764+
/// (`Tensor`):
765+
/// The tensor in the framework you opened the file for.
766+
///
767+
/// Example:
768+
/// ```python
769+
/// from safetensors import safe_open
770+
///
771+
/// with safe_open("model.safetensors", framework="pt", device=0) as f:
772+
/// tensor = f.get_tensor("embedding")
773+
///
774+
/// ```
775+
pub fn get_tensor(&self, name: &str) -> PyResult<PyObject> {
776+
self.inner()?.get_tensor(name)
777+
}
778+
779+
/// Returns a full slice view object
780+
///
781+
/// Args:
782+
/// name (`str`):
783+
/// The name of the tensor you want
784+
///
785+
/// Returns:
786+
/// (`PySafeSlice`):
787+
/// A dummy object you can slice into to get a real tensor
788+
/// Example:
789+
/// ```python
790+
/// from safetensors import safe_open
791+
///
792+
/// with safe_open("model.safetensors", framework="pt", device=0) as f:
793+
/// tensor_part = f.get_slice("embedding")[:, ::8]
794+
///
795+
/// ```
796+
pub fn get_slice(&self, name: &str) -> PyResult<PySafeSlice> {
797+
self.inner()?.get_slice(name)
798+
}
717799

718800
pub fn __enter__(slf: Py<Self>) -> Py<Self> {
719801
// SAFETY: This code is extremely important to the GPU fast load.
@@ -726,9 +808,10 @@ impl safe_open {
726808
// of the context manager lifecycle.
727809
Python::with_gil(|py| -> PyResult<()> {
728810
let _self: &safe_open = &slf.borrow(py);
729-
if let (Device::Cuda(_), Framework::Pytorch) = (&_self.device, &_self.framework) {
811+
let inner = _self.inner()?;
812+
if let (Device::Cuda(_), Framework::Pytorch) = (&inner.device, &inner.framework) {
730813
let module = get_module(py, &TORCH_MODULE)?;
731-
let device: PyObject = _self.device.clone().into_py(py);
814+
let device: PyObject = inner.device.clone().into_py(py);
732815
let torch_device = module
733816
.getattr(intern!(py, "cuda"))?
734817
.getattr(intern!(py, "device"))?;
@@ -742,10 +825,11 @@ impl safe_open {
742825
}
743826

744827
pub fn __exit__(&mut self, _exc_type: PyObject, _exc_value: PyObject, _traceback: PyObject) {
745-
if let (Device::Cuda(_), Framework::Pytorch) = (&self.device, &self.framework) {
828+
let inner = self.inner().unwrap();
829+
if let (Device::Cuda(_), Framework::Pytorch) = (&inner.device, &inner.framework) {
746830
Python::with_gil(|py| -> PyResult<()> {
747831
let module = get_module(py, &TORCH_MODULE)?;
748-
let device: PyObject = self.device.clone().into_py(py);
832+
let device: PyObject = inner.device.clone().into_py(py);
749833
let torch_device = module
750834
.getattr(intern!(py, "cuda"))?
751835
.getattr(intern!(py, "device"))?;
@@ -756,6 +840,7 @@ impl safe_open {
756840
})
757841
.ok();
758842
}
843+
self.inner = None;
759844
}
760845
}
761846

bindings/python/tests/test_simple.py

+3
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def test_get_correctly_dropped(self):
6969
with safe_open("./out.safetensors", framework="pt") as f:
7070
pass
7171

72+
with self.assertRaises(SafetensorError):
73+
print(f.keys())
74+
7275
with open("./out.safetensors", "w") as g:
7376
g.write("something")
7477

0 commit comments

Comments
 (0)