Skip to content

Commit a98326a

Browse files
authored
Better error messages on outdated context manager. (#168)
* Better error messages on outdated context manager. * Clippy. * Remove unwrap. * Fix.
1 parent bcb033f commit a98326a

File tree

3 files changed

+130
-53
lines changed

3 files changed

+130
-53
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,4 @@
11
__version__ = "0.2.9"
22

33
# Re-export this
4-
from ._safetensors_rust import safe_open as rust_open, serialize, serialize_file, deserialize, SafetensorError
5-
6-
7-
class safe_open:
8-
def __init__(self, *args, **kwargs):
9-
self.args = args
10-
self.kwargs = kwargs
11-
12-
def __getattr__(self, __name: str):
13-
return getattr(self.f, __name)
14-
15-
def __enter__(self):
16-
self.f = rust_open(*self.args, **self.kwargs)
17-
return self
18-
19-
def __exit__(self, type, value, traceback):
20-
del self.f
4+
from ._safetensors_rust import safe_open, serialize, serialize_file, deserialize, SafetensorError # noqa: F401

bindings/python/src/lib.rs

+126-36
Original file line numberDiff line numberDiff line change
@@ -466,32 +466,15 @@ impl Version {
466466
}
467467
}
468468

469-
/// Opens a safetensors lazily and returns tensors as asked
470-
///
471-
/// Args:
472-
/// filename (`str`, or `os.PathLike`):
473-
/// The filename to open
474-
///
475-
/// framework (`str`):
476-
/// The framework you want you tensors in. Supported values:
477-
/// `pt`, `tf`, `flax`, `numpy`.
478-
///
479-
/// device (`str`, defaults to `"cpu"`):
480-
/// The device on which you want the tensors.
481-
#[pyclass]
482-
#[allow(non_camel_case_types)]
483-
#[pyo3(text_signature = "(self, filename, framework, device=\"cpu\")")]
484-
struct safe_open {
469+
struct Open {
485470
metadata: Metadata,
486471
offset: usize,
487472
framework: Framework,
488473
device: Device,
489474
storage: Arc<Storage>,
490475
}
491476

492-
#[pymethods]
493-
impl safe_open {
494-
#[new]
477+
impl Open {
495478
fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> PyResult<Self> {
496479
let file = File::open(&filename)?;
497480
let device = device.unwrap_or(Device::Cpu);
@@ -661,7 +644,9 @@ impl safe_open {
661644
let start = (info.data_offsets.0 + self.offset) as isize;
662645
let stop = (info.data_offsets.1 + self.offset) as isize;
663646
let slice = pyslice_new(py, start, stop, 1);
664-
let storage: &PyObject = storage.get(py).unwrap();
647+
let storage: &PyObject = storage
648+
.get(py)
649+
.ok_or_else(|| SafetensorError::new_err("Could not find storage"))?;
665650
let storage: &PyAny = storage.as_ref(py);
666651

667652
let storage_slice = storage
@@ -714,6 +699,105 @@ impl safe_open {
714699
)))
715700
}
716701
}
702+
}
703+
704+
/// Opens a safetensors lazily and returns tensors as asked
705+
///
706+
/// Args:
707+
/// filename (`str`, or `os.PathLike`):
708+
/// The filename to open
709+
///
710+
/// framework (`str`):
711+
/// The framework you want you tensors in. Supported values:
712+
/// `pt`, `tf`, `flax`, `numpy`.
713+
///
714+
/// device (`str`, defaults to `"cpu"`):
715+
/// The device on which you want the tensors.
716+
#[pyclass]
717+
#[allow(non_camel_case_types)]
718+
#[pyo3(text_signature = "(self, filename, framework, device=\"cpu\")")]
719+
struct safe_open {
720+
inner: Option<Open>,
721+
}
722+
723+
impl safe_open {
724+
fn inner(&self) -> PyResult<&Open> {
725+
let inner = self
726+
.inner
727+
.as_ref()
728+
.ok_or_else(|| SafetensorError::new_err("File is closed".to_string()))?;
729+
Ok(inner)
730+
}
731+
}
732+
733+
#[pymethods]
734+
impl safe_open {
735+
#[new]
736+
fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> PyResult<Self> {
737+
let inner = Some(Open::new(filename, framework, device)?);
738+
Ok(Self { inner })
739+
}
740+
741+
/// Return the special non tensor information in the header
742+
///
743+
/// Returns:
744+
/// (`Dict[str, str]`):
745+
/// The freeform metadata.
746+
pub fn metadata(&self) -> PyResult<Option<BTreeMap<String, String>>> {
747+
Ok(self.inner()?.metadata())
748+
}
749+
750+
/// Returns the names of the tensors in the file.
751+
///
752+
/// Returns:
753+
/// (`List[str]`):
754+
/// The name of the tensors contained in that file
755+
pub fn keys(&self) -> PyResult<Vec<String>> {
756+
self.inner()?.keys()
757+
}
758+
759+
/// Returns a full tensor
760+
///
761+
/// Args:
762+
/// name (`str`):
763+
/// The name of the tensor you want
764+
///
765+
/// Returns:
766+
/// (`Tensor`):
767+
/// The tensor in the framework you opened the file for.
768+
///
769+
/// Example:
770+
/// ```python
771+
/// from safetensors import safe_open
772+
///
773+
/// with safe_open("model.safetensors", framework="pt", device=0) as f:
774+
/// tensor = f.get_tensor("embedding")
775+
///
776+
/// ```
777+
pub fn get_tensor(&self, name: &str) -> PyResult<PyObject> {
778+
self.inner()?.get_tensor(name)
779+
}
780+
781+
/// Returns a full slice view object
782+
///
783+
/// Args:
784+
/// name (`str`):
785+
/// The name of the tensor you want
786+
///
787+
/// Returns:
788+
/// (`PySafeSlice`):
789+
/// A dummy object you can slice into to get a real tensor
790+
/// Example:
791+
/// ```python
792+
/// from safetensors import safe_open
793+
///
794+
/// with safe_open("model.safetensors", framework="pt", device=0) as f:
795+
/// tensor_part = f.get_slice("embedding")[:, ::8]
796+
///
797+
/// ```
798+
pub fn get_slice(&self, name: &str) -> PyResult<PySafeSlice> {
799+
self.inner()?.get_slice(name)
800+
}
717801

718802
pub fn __enter__(slf: Py<Self>) -> Py<Self> {
719803
// SAFETY: This code is extremely important to the GPU fast load.
@@ -726,9 +810,10 @@ impl safe_open {
726810
// of the context manager lifecycle.
727811
Python::with_gil(|py| -> PyResult<()> {
728812
let _self: &safe_open = &slf.borrow(py);
729-
if let (Device::Cuda(_), Framework::Pytorch) = (&_self.device, &_self.framework) {
813+
let inner = _self.inner()?;
814+
if let (Device::Cuda(_), Framework::Pytorch) = (&inner.device, &inner.framework) {
730815
let module = get_module(py, &TORCH_MODULE)?;
731-
let device: PyObject = _self.device.clone().into_py(py);
816+
let device: PyObject = inner.device.clone().into_py(py);
732817
let torch_device = module
733818
.getattr(intern!(py, "cuda"))?
734819
.getattr(intern!(py, "device"))?;
@@ -742,20 +827,23 @@ impl safe_open {
742827
}
743828

744829
pub fn __exit__(&mut self, _exc_type: PyObject, _exc_value: PyObject, _traceback: PyObject) {
745-
if let (Device::Cuda(_), Framework::Pytorch) = (&self.device, &self.framework) {
746-
Python::with_gil(|py| -> PyResult<()> {
747-
let module = get_module(py, &TORCH_MODULE)?;
748-
let device: PyObject = self.device.clone().into_py(py);
749-
let torch_device = module
750-
.getattr(intern!(py, "cuda"))?
751-
.getattr(intern!(py, "device"))?;
752-
let none = py.None();
753-
let lock = torch_device.call1((device,))?;
754-
lock.call_method1(intern!(py, "__exit__"), (&none, &none, &none))?;
755-
Ok(())
756-
})
757-
.ok();
830+
if let Some(inner) = &self.inner {
831+
if let (Device::Cuda(_), Framework::Pytorch) = (&inner.device, &inner.framework) {
832+
Python::with_gil(|py| -> PyResult<()> {
833+
let module = get_module(py, &TORCH_MODULE)?;
834+
let device: PyObject = inner.device.clone().into_py(py);
835+
let torch_device = module
836+
.getattr(intern!(py, "cuda"))?
837+
.getattr(intern!(py, "device"))?;
838+
let none = py.None();
839+
let lock = torch_device.call1((device,))?;
840+
lock.call_method1(intern!(py, "__exit__"), (&none, &none, &none))?;
841+
Ok(())
842+
})
843+
.ok();
844+
}
758845
}
846+
self.inner = None;
759847
}
760848
}
761849

@@ -874,7 +962,9 @@ impl PySafeSlice {
874962
let start = (self.info.data_offsets.0 + self.offset) as isize;
875963
let stop = (self.info.data_offsets.1 + self.offset) as isize;
876964
let slice = pyslice_new(py, start, stop, 1);
877-
let storage: &PyObject = storage.get(py).unwrap();
965+
let storage: &PyObject = storage
966+
.get(py)
967+
.ok_or_else(|| SafetensorError::new_err("Could not find storage"))?;
878968
let storage: &PyAny = storage.as_ref(py);
879969

880970
let storage_slice = storage

bindings/python/tests/test_simple.py

+3
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def test_get_correctly_dropped(self):
6969
with safe_open("./out.safetensors", framework="pt") as f:
7070
pass
7171

72+
with self.assertRaises(SafetensorError):
73+
print(f.keys())
74+
7275
with open("./out.safetensors", "w") as g:
7376
g.write("something")
7477

0 commit comments

Comments
 (0)