Skip to content

Commit 613a34f

Browse files
authored
Adding custom error type for users to properly catch. (#165)
1 parent 2a5cac8 commit 613a34f

File tree

2 files changed

+37
-30
lines changed

2 files changed

+37
-30
lines changed

bindings/python/src/lib.rs

+30-30
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//! Dummy doc
33
use libloading::{Library, Symbol};
44
use memmap2::{Mmap, MmapOptions};
5-
use pyo3::exceptions;
5+
use pyo3::exceptions::PyException;
66
use pyo3::once_cell::GILOnceCell;
77
use pyo3::prelude::*;
88
use pyo3::types::IntoPyDict;
@@ -65,7 +65,7 @@ fn prepare(tensor_dict: HashMap<String, &PyDict>) -> PyResult<BTreeMap<String, T
6565
"float64" => Dtype::F64,
6666
"bfloat16" => Dtype::BF16,
6767
dtype_str => {
68-
return Err(exceptions::PyException::new_err(format!(
68+
return Err(SafetensorError::new_err(format!(
6969
"dtype {dtype_str} is not covered",
7070
)));
7171
}
@@ -102,9 +102,8 @@ fn serialize<'a, 'b>(
102102
) -> PyResult<&'b PyBytes> {
103103
let tensors = prepare(tensor_dict)?;
104104
let metadata_btreemap = metadata.map(|data| BTreeMap::from_iter(data.into_iter()));
105-
let out = safetensors::tensor::serialize(&tensors, &metadata_btreemap).map_err(|e| {
106-
exceptions::PyException::new_err(format!("Error while serializing: {:?}", e))
107-
})?;
105+
let out = safetensors::tensor::serialize(&tensors, &metadata_btreemap)
106+
.map_err(|e| SafetensorError::new_err(format!("Error while serializing: {:?}", e)))?;
108107
let pybytes = PyBytes::new(py, &out);
109108
Ok(pybytes)
110109
}
@@ -133,9 +132,7 @@ fn serialize_file(
133132
let tensors = prepare(tensor_dict)?;
134133
let metadata_btreemap = metadata.map(|data| BTreeMap::from_iter(data.into_iter()));
135134
safetensors::tensor::serialize_to_file(&tensors, &metadata_btreemap, filename.as_path())
136-
.map_err(|e| {
137-
exceptions::PyException::new_err(format!("Error while serializing: {:?}", e))
138-
})?;
135+
.map_err(|e| SafetensorError::new_err(format!("Error while serializing: {:?}", e)))?;
139136
Ok(())
140137
}
141138

@@ -152,9 +149,8 @@ fn serialize_file(
152149
#[pyfunction]
153150
#[pyo3(text_signature = "(bytes)")]
154151
fn deserialize(py: Python, bytes: &[u8]) -> PyResult<Vec<(String, HashMap<String, PyObject>)>> {
155-
let safetensor = SafeTensors::deserialize(bytes).map_err(|e| {
156-
exceptions::PyException::new_err(format!("Error while deserializing: {:?}", e))
157-
})?;
152+
let safetensor = SafeTensors::deserialize(bytes)
153+
.map_err(|e| SafetensorError::new_err(format!("Error while deserializing: {:?}", e)))?;
158154
let mut items = vec![];
159155

160156
for (tensor_name, tensor) in safetensor.tensors() {
@@ -217,7 +213,7 @@ impl<'source> FromPyObject<'source> for Framework {
217213

218214
"jax" => Ok(Framework::Flax),
219215
"flax" => Ok(Framework::Flax),
220-
name => Err(exceptions::PyException::new_err(format!(
216+
name => Err(SafetensorError::new_err(format!(
221217
"framework {name} is invalid"
222218
))),
223219
}
@@ -244,21 +240,19 @@ impl<'source> FromPyObject<'source> for Device {
244240
let device: usize = tokens[1].parse()?;
245241
Ok(Device::Cuda(device))
246242
} else {
247-
Err(exceptions::PyException::new_err(format!(
243+
Err(SafetensorError::new_err(format!(
248244
"device {name} is invalid"
249245
)))
250246
}
251247
}
252-
name => Err(exceptions::PyException::new_err(format!(
248+
name => Err(SafetensorError::new_err(format!(
253249
"device {name} is invalid"
254250
))),
255251
}
256252
} else if let Ok(number) = ob.extract::<usize>() {
257253
Ok(Device::Cuda(number))
258254
} else {
259-
Err(exceptions::PyException::new_err(format!(
260-
"device {ob} is invalid"
261-
)))
255+
Err(SafetensorError::new_err(format!("device {ob} is invalid")))
262256
}
263257
}
264258
}
@@ -503,7 +497,7 @@ impl safe_open {
503497
let device = device.unwrap_or(Device::Cpu);
504498

505499
if device != Device::Cpu && framework != Framework::Pytorch {
506-
return Err(exceptions::PyException::new_err(format!(
500+
return Err(SafetensorError::new_err(format!(
507501
"Device {device:?} is not support for framework {framework:?}",
508502
)));
509503
}
@@ -513,7 +507,7 @@ impl safe_open {
513507
let buffer = unsafe { MmapOptions::new().map(&file)? };
514508

515509
let (n, metadata) = SafeTensors::read_metadata(&buffer).map_err(|e| {
516-
exceptions::PyException::new_err(format!("Error while deserializing header: {:?}", e))
510+
SafetensorError::new_err(format!("Error while deserializing header: {:?}", e))
517511
})?;
518512

519513
let offset = n + 8;
@@ -542,8 +536,7 @@ impl safe_open {
542536
let module = get_module(py, &TORCH_MODULE)?;
543537

544538
let version: String = module.getattr(intern!(py, "__version__"))?.extract()?;
545-
let version =
546-
Version::from_string(&version).map_err(exceptions::PyException::new_err)?;
539+
let version = Version::from_string(&version).map_err(SafetensorError::new_err)?;
547540

548541
// Untyped storage only exists for versions over 1.11.0
549542
// Same for torch.asarray which is necessary for zero-copy tensor
@@ -626,7 +619,7 @@ impl safe_open {
626619
/// ```
627620
pub fn get_tensor(&self, name: &str) -> PyResult<PyObject> {
628621
let info = self.metadata.tensors().get(name).ok_or_else(|| {
629-
exceptions::PyException::new_err(format!("File does not contain tensor {name}",))
622+
SafetensorError::new_err(format!("File does not contain tensor {name}",))
630623
})?;
631624

632625
match &self.storage.as_ref() {
@@ -716,7 +709,7 @@ impl safe_open {
716709
storage: self.storage.clone(),
717710
})
718711
} else {
719-
Err(exceptions::PyException::new_err(format!(
712+
Err(SafetensorError::new_err(format!(
720713
"File does not contain tensor {name}",
721714
)))
722715
}
@@ -824,7 +817,7 @@ impl PySafeSlice {
824817
.collect::<Result<_, _>>()?;
825818

826819
let iterator = tensor.sliced_data(slices.clone()).map_err(|e| {
827-
exceptions::PyException::new_err(format!(
820+
SafetensorError::new_err(format!(
828821
"Error during slicing {slices:?} vs {:?}: {:?}",
829822
self.info.shape, e
830823
))
@@ -923,7 +916,7 @@ fn get_module<'a>(
923916
) -> PyResult<&'a PyModule> {
924917
let module: &PyModule = cell
925918
.get(py)
926-
.ok_or_else(|| exceptions::PyException::new_err("Could not find module"))?
919+
.ok_or_else(|| SafetensorError::new_err("Could not find module"))?
927920
.as_ref(py);
928921
Ok(module)
929922
}
@@ -940,9 +933,7 @@ fn create_tensor(
940933
Framework::Pytorch => TORCH_MODULE.get(py),
941934
_ => NUMPY_MODULE.get(py),
942935
}
943-
.ok_or_else(|| {
944-
exceptions::PyException::new_err(format!("Could not find module {framework:?}",))
945-
})?
936+
.ok_or_else(|| SafetensorError::new_err(format!("Could not find module {framework:?}",)))?
946937
.as_ref(py);
947938
let frombuffer = module.getattr(intern!(py, "frombuffer"))?;
948939
let dtype: PyObject = get_pydtype(module, dtype)?;
@@ -1011,7 +1002,7 @@ fn get_pydtype(module: &PyModule, dtype: Dtype) -> PyResult<PyObject> {
10111002
Dtype::I8 => module.getattr(intern!(py, "int8"))?.into(),
10121003
Dtype::BOOL => module.getattr(intern!(py, "bool"))?.into(),
10131004
dtype => {
1014-
return Err(exceptions::PyException::new_err(format!(
1005+
return Err(SafetensorError::new_err(format!(
10151006
"Dtype not understood: {:?}",
10161007
dtype
10171008
)))
@@ -1020,13 +1011,22 @@ fn get_pydtype(module: &PyModule, dtype: Dtype) -> PyResult<PyObject> {
10201011
Ok(dtype)
10211012
})
10221013
}
1014+
1015+
pyo3::create_exception!(
1016+
safetensors_rust,
1017+
SafetensorError,
1018+
PyException,
1019+
"Custom Python Exception for Safetensor errors."
1020+
);
1021+
10231022
/// A Python module implemented in Rust.
10241023
#[pymodule]
1025-
fn safetensors_rust(_py: Python, m: &PyModule) -> PyResult<()> {
1024+
fn safetensors_rust(py: Python, m: &PyModule) -> PyResult<()> {
10261025
m.add_function(wrap_pyfunction!(serialize, m)?)?;
10271026
m.add_function(wrap_pyfunction!(serialize_file, m)?)?;
10281027
m.add_function(wrap_pyfunction!(deserialize, m)?)?;
10291028
m.add_class::<safe_open>()?;
1029+
m.add("SafetensorError", py.get_type::<SafetensorError>())?;
10301030
Ok(())
10311031
}
10321032

bindings/python/tests/test_simple.py

+7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77

88
from safetensors.numpy import load, load_file, save, save_file
9+
from safetensors.safetensors_rust import SafetensorError, serialize
910
from safetensors.torch import load_file as load_file_pt
1011
from safetensors.torch import save_file as save_file_pt
1112

@@ -94,3 +95,9 @@ def test_torch_example(self):
9495
# Now loading
9596
loaded = load_file_pt("./out.safetensors")
9697
self.assertTensorEqual(tensors2, loaded, torch.allclose)
98+
99+
def test_exception(self):
100+
flattened = {"test": {"dtype": "float32", "shape": [1]}}
101+
102+
with self.assertRaises(SafetensorError):
103+
serialize(flattened)

0 commit comments

Comments
 (0)