2
2
//! Dummy doc
3
3
use libloading:: { Library , Symbol } ;
4
4
use memmap2:: { Mmap , MmapOptions } ;
5
- use pyo3:: exceptions;
5
+ use pyo3:: exceptions:: PyException ;
6
6
use pyo3:: once_cell:: GILOnceCell ;
7
7
use pyo3:: prelude:: * ;
8
8
use pyo3:: types:: IntoPyDict ;
@@ -65,7 +65,7 @@ fn prepare(tensor_dict: HashMap<String, &PyDict>) -> PyResult<BTreeMap<String, T
65
65
"float64" => Dtype :: F64 ,
66
66
"bfloat16" => Dtype :: BF16 ,
67
67
dtype_str => {
68
- return Err ( exceptions :: PyException :: new_err ( format ! (
68
+ return Err ( SafetensorError :: new_err ( format ! (
69
69
"dtype {dtype_str} is not covered" ,
70
70
) ) ) ;
71
71
}
@@ -102,9 +102,8 @@ fn serialize<'a, 'b>(
102
102
) -> PyResult < & ' b PyBytes > {
103
103
let tensors = prepare ( tensor_dict) ?;
104
104
let metadata_btreemap = metadata. map ( |data| BTreeMap :: from_iter ( data. into_iter ( ) ) ) ;
105
- let out = safetensors:: tensor:: serialize ( & tensors, & metadata_btreemap) . map_err ( |e| {
106
- exceptions:: PyException :: new_err ( format ! ( "Error while serializing: {:?}" , e) )
107
- } ) ?;
105
+ let out = safetensors:: tensor:: serialize ( & tensors, & metadata_btreemap)
106
+ . map_err ( |e| SafetensorError :: new_err ( format ! ( "Error while serializing: {:?}" , e) ) ) ?;
108
107
let pybytes = PyBytes :: new ( py, & out) ;
109
108
Ok ( pybytes)
110
109
}
@@ -133,9 +132,7 @@ fn serialize_file(
133
132
let tensors = prepare ( tensor_dict) ?;
134
133
let metadata_btreemap = metadata. map ( |data| BTreeMap :: from_iter ( data. into_iter ( ) ) ) ;
135
134
safetensors:: tensor:: serialize_to_file ( & tensors, & metadata_btreemap, filename. as_path ( ) )
136
- . map_err ( |e| {
137
- exceptions:: PyException :: new_err ( format ! ( "Error while serializing: {:?}" , e) )
138
- } ) ?;
135
+ . map_err ( |e| SafetensorError :: new_err ( format ! ( "Error while serializing: {:?}" , e) ) ) ?;
139
136
Ok ( ( ) )
140
137
}
141
138
@@ -152,9 +149,8 @@ fn serialize_file(
152
149
#[ pyfunction]
153
150
#[ pyo3( text_signature = "(bytes)" ) ]
154
151
fn deserialize ( py : Python , bytes : & [ u8 ] ) -> PyResult < Vec < ( String , HashMap < String , PyObject > ) > > {
155
- let safetensor = SafeTensors :: deserialize ( bytes) . map_err ( |e| {
156
- exceptions:: PyException :: new_err ( format ! ( "Error while deserializing: {:?}" , e) )
157
- } ) ?;
152
+ let safetensor = SafeTensors :: deserialize ( bytes)
153
+ . map_err ( |e| SafetensorError :: new_err ( format ! ( "Error while deserializing: {:?}" , e) ) ) ?;
158
154
let mut items = vec ! [ ] ;
159
155
160
156
for ( tensor_name, tensor) in safetensor. tensors ( ) {
@@ -217,7 +213,7 @@ impl<'source> FromPyObject<'source> for Framework {
217
213
218
214
"jax" => Ok ( Framework :: Flax ) ,
219
215
"flax" => Ok ( Framework :: Flax ) ,
220
- name => Err ( exceptions :: PyException :: new_err ( format ! (
216
+ name => Err ( SafetensorError :: new_err ( format ! (
221
217
"framework {name} is invalid"
222
218
) ) ) ,
223
219
}
@@ -244,21 +240,19 @@ impl<'source> FromPyObject<'source> for Device {
244
240
let device: usize = tokens[ 1 ] . parse ( ) ?;
245
241
Ok ( Device :: Cuda ( device) )
246
242
} else {
247
- Err ( exceptions :: PyException :: new_err ( format ! (
243
+ Err ( SafetensorError :: new_err ( format ! (
248
244
"device {name} is invalid"
249
245
) ) )
250
246
}
251
247
}
252
- name => Err ( exceptions :: PyException :: new_err ( format ! (
248
+ name => Err ( SafetensorError :: new_err ( format ! (
253
249
"device {name} is invalid"
254
250
) ) ) ,
255
251
}
256
252
} else if let Ok ( number) = ob. extract :: < usize > ( ) {
257
253
Ok ( Device :: Cuda ( number) )
258
254
} else {
259
- Err ( exceptions:: PyException :: new_err ( format ! (
260
- "device {ob} is invalid"
261
- ) ) )
255
+ Err ( SafetensorError :: new_err ( format ! ( "device {ob} is invalid" ) ) )
262
256
}
263
257
}
264
258
}
@@ -503,7 +497,7 @@ impl safe_open {
503
497
let device = device. unwrap_or ( Device :: Cpu ) ;
504
498
505
499
if device != Device :: Cpu && framework != Framework :: Pytorch {
506
- return Err ( exceptions :: PyException :: new_err ( format ! (
500
+ return Err ( SafetensorError :: new_err ( format ! (
507
501
"Device {device:?} is not support for framework {framework:?}" ,
508
502
) ) ) ;
509
503
}
@@ -513,7 +507,7 @@ impl safe_open {
513
507
let buffer = unsafe { MmapOptions :: new ( ) . map ( & file) ? } ;
514
508
515
509
let ( n, metadata) = SafeTensors :: read_metadata ( & buffer) . map_err ( |e| {
516
- exceptions :: PyException :: new_err ( format ! ( "Error while deserializing header: {:?}" , e) )
510
+ SafetensorError :: new_err ( format ! ( "Error while deserializing header: {:?}" , e) )
517
511
} ) ?;
518
512
519
513
let offset = n + 8 ;
@@ -542,8 +536,7 @@ impl safe_open {
542
536
let module = get_module ( py, & TORCH_MODULE ) ?;
543
537
544
538
let version: String = module. getattr ( intern ! ( py, "__version__" ) ) ?. extract ( ) ?;
545
- let version =
546
- Version :: from_string ( & version) . map_err ( exceptions:: PyException :: new_err) ?;
539
+ let version = Version :: from_string ( & version) . map_err ( SafetensorError :: new_err) ?;
547
540
548
541
// Untyped storage only exists for versions over 1.11.0
549
542
// Same for torch.asarray which is necessary for zero-copy tensor
@@ -626,7 +619,7 @@ impl safe_open {
626
619
/// ```
627
620
pub fn get_tensor ( & self , name : & str ) -> PyResult < PyObject > {
628
621
let info = self . metadata . tensors ( ) . get ( name) . ok_or_else ( || {
629
- exceptions :: PyException :: new_err ( format ! ( "File does not contain tensor {name}" , ) )
622
+ SafetensorError :: new_err ( format ! ( "File does not contain tensor {name}" , ) )
630
623
} ) ?;
631
624
632
625
match & self . storage . as_ref ( ) {
@@ -716,7 +709,7 @@ impl safe_open {
716
709
storage : self . storage . clone ( ) ,
717
710
} )
718
711
} else {
719
- Err ( exceptions :: PyException :: new_err ( format ! (
712
+ Err ( SafetensorError :: new_err ( format ! (
720
713
"File does not contain tensor {name}" ,
721
714
) ) )
722
715
}
@@ -824,7 +817,7 @@ impl PySafeSlice {
824
817
. collect :: < Result < _ , _ > > ( ) ?;
825
818
826
819
let iterator = tensor. sliced_data ( slices. clone ( ) ) . map_err ( |e| {
827
- exceptions :: PyException :: new_err ( format ! (
820
+ SafetensorError :: new_err ( format ! (
828
821
"Error during slicing {slices:?} vs {:?}: {:?}" ,
829
822
self . info. shape, e
830
823
) )
@@ -923,7 +916,7 @@ fn get_module<'a>(
923
916
) -> PyResult < & ' a PyModule > {
924
917
let module: & PyModule = cell
925
918
. get ( py)
926
- . ok_or_else ( || exceptions :: PyException :: new_err ( "Could not find module" ) ) ?
919
+ . ok_or_else ( || SafetensorError :: new_err ( "Could not find module" ) ) ?
927
920
. as_ref ( py) ;
928
921
Ok ( module)
929
922
}
@@ -940,9 +933,7 @@ fn create_tensor(
940
933
Framework :: Pytorch => TORCH_MODULE . get ( py) ,
941
934
_ => NUMPY_MODULE . get ( py) ,
942
935
}
943
- . ok_or_else ( || {
944
- exceptions:: PyException :: new_err ( format ! ( "Could not find module {framework:?}" , ) )
945
- } ) ?
936
+ . ok_or_else ( || SafetensorError :: new_err ( format ! ( "Could not find module {framework:?}" , ) ) ) ?
946
937
. as_ref ( py) ;
947
938
let frombuffer = module. getattr ( intern ! ( py, "frombuffer" ) ) ?;
948
939
let dtype: PyObject = get_pydtype ( module, dtype) ?;
@@ -1011,7 +1002,7 @@ fn get_pydtype(module: &PyModule, dtype: Dtype) -> PyResult<PyObject> {
1011
1002
Dtype :: I8 => module. getattr ( intern ! ( py, "int8" ) ) ?. into ( ) ,
1012
1003
Dtype :: BOOL => module. getattr ( intern ! ( py, "bool" ) ) ?. into ( ) ,
1013
1004
dtype => {
1014
- return Err ( exceptions :: PyException :: new_err ( format ! (
1005
+ return Err ( SafetensorError :: new_err ( format ! (
1015
1006
"Dtype not understood: {:?}" ,
1016
1007
dtype
1017
1008
) ) )
@@ -1020,13 +1011,22 @@ fn get_pydtype(module: &PyModule, dtype: Dtype) -> PyResult<PyObject> {
1020
1011
Ok ( dtype)
1021
1012
} )
1022
1013
}
1014
+
1015
+ pyo3:: create_exception!(
1016
+ safetensors_rust,
1017
+ SafetensorError ,
1018
+ PyException ,
1019
+ "Custom Python Exception for Safetensor errors."
1020
+ ) ;
1021
+
1023
1022
/// A Python module implemented in Rust.
1024
1023
#[ pymodule]
1025
- fn safetensors_rust ( _py : Python , m : & PyModule ) -> PyResult < ( ) > {
1024
+ fn safetensors_rust ( py : Python , m : & PyModule ) -> PyResult < ( ) > {
1026
1025
m. add_function ( wrap_pyfunction ! ( serialize, m) ?) ?;
1027
1026
m. add_function ( wrap_pyfunction ! ( serialize_file, m) ?) ?;
1028
1027
m. add_function ( wrap_pyfunction ! ( deserialize, m) ?) ?;
1029
1028
m. add_class :: < safe_open > ( ) ?;
1029
+ m. add ( "SafetensorError" , py. get_type :: < SafetensorError > ( ) ) ?;
1030
1030
Ok ( ( ) )
1031
1031
}
1032
1032
0 commit comments