@@ -466,32 +466,15 @@ impl Version {
466
466
}
467
467
}
468
468
469
- /// Opens a safetensors lazily and returns tensors as asked
470
- ///
471
- /// Args:
472
- /// filename (`str`, or `os.PathLike`):
473
- /// The filename to open
474
- ///
475
- /// framework (`str`):
476
- /// The framework you want you tensors in. Supported values:
477
- /// `pt`, `tf`, `flax`, `numpy`.
478
- ///
479
- /// device (`str`, defaults to `"cpu"`):
480
- /// The device on which you want the tensors.
481
- #[ pyclass]
482
- #[ allow( non_camel_case_types) ]
483
- #[ pyo3( text_signature = "(self, filename, framework, device=\" cpu\" )" ) ]
484
- struct safe_open {
469
+ struct Open {
485
470
metadata : Metadata ,
486
471
offset : usize ,
487
472
framework : Framework ,
488
473
device : Device ,
489
474
storage : Arc < Storage > ,
490
475
}
491
476
492
- #[ pymethods]
493
- impl safe_open {
494
- #[ new]
477
+ impl Open {
495
478
fn new ( filename : PathBuf , framework : Framework , device : Option < Device > ) -> PyResult < Self > {
496
479
let file = File :: open ( & filename) ?;
497
480
let device = device. unwrap_or ( Device :: Cpu ) ;
@@ -661,7 +644,9 @@ impl safe_open {
661
644
let start = ( info. data_offsets . 0 + self . offset ) as isize ;
662
645
let stop = ( info. data_offsets . 1 + self . offset ) as isize ;
663
646
let slice = pyslice_new ( py, start, stop, 1 ) ;
664
- let storage: & PyObject = storage. get ( py) . unwrap ( ) ;
647
+ let storage: & PyObject = storage
648
+ . get ( py)
649
+ . ok_or_else ( || SafetensorError :: new_err ( "Could not find storage" ) ) ?;
665
650
let storage: & PyAny = storage. as_ref ( py) ;
666
651
667
652
let storage_slice = storage
@@ -714,6 +699,105 @@ impl safe_open {
714
699
) ) )
715
700
}
716
701
}
702
+ }
703
+
704
+ /// Opens a safetensors lazily and returns tensors as asked
705
+ ///
706
+ /// Args:
707
+ /// filename (`str`, or `os.PathLike`):
708
+ /// The filename to open
709
+ ///
710
+ /// framework (`str`):
711
+ /// The framework you want you tensors in. Supported values:
712
+ /// `pt`, `tf`, `flax`, `numpy`.
713
+ ///
714
+ /// device (`str`, defaults to `"cpu"`):
715
+ /// The device on which you want the tensors.
716
+ #[ pyclass]
717
+ #[ allow( non_camel_case_types) ]
718
+ #[ pyo3( text_signature = "(self, filename, framework, device=\" cpu\" )" ) ]
719
+ struct safe_open {
720
+ inner : Option < Open > ,
721
+ }
722
+
723
+ impl safe_open {
724
+ fn inner ( & self ) -> PyResult < & Open > {
725
+ let inner = self
726
+ . inner
727
+ . as_ref ( )
728
+ . ok_or_else ( || SafetensorError :: new_err ( "File is closed" . to_string ( ) ) ) ?;
729
+ Ok ( inner)
730
+ }
731
+ }
732
+
733
+ #[ pymethods]
734
+ impl safe_open {
735
+ #[ new]
736
+ fn new ( filename : PathBuf , framework : Framework , device : Option < Device > ) -> PyResult < Self > {
737
+ let inner = Some ( Open :: new ( filename, framework, device) ?) ;
738
+ Ok ( Self { inner } )
739
+ }
740
+
741
+ /// Return the special non tensor information in the header
742
+ ///
743
+ /// Returns:
744
+ /// (`Dict[str, str]`):
745
+ /// The freeform metadata.
746
+ pub fn metadata ( & self ) -> PyResult < Option < BTreeMap < String , String > > > {
747
+ Ok ( self . inner ( ) ?. metadata ( ) )
748
+ }
749
+
750
+ /// Returns the names of the tensors in the file.
751
+ ///
752
+ /// Returns:
753
+ /// (`List[str]`):
754
+ /// The name of the tensors contained in that file
755
+ pub fn keys ( & self ) -> PyResult < Vec < String > > {
756
+ self . inner ( ) ?. keys ( )
757
+ }
758
+
759
+ /// Returns a full tensor
760
+ ///
761
+ /// Args:
762
+ /// name (`str`):
763
+ /// The name of the tensor you want
764
+ ///
765
+ /// Returns:
766
+ /// (`Tensor`):
767
+ /// The tensor in the framework you opened the file for.
768
+ ///
769
+ /// Example:
770
+ /// ```python
771
+ /// from safetensors import safe_open
772
+ ///
773
+ /// with safe_open("model.safetensors", framework="pt", device=0) as f:
774
+ /// tensor = f.get_tensor("embedding")
775
+ ///
776
+ /// ```
777
+ pub fn get_tensor ( & self , name : & str ) -> PyResult < PyObject > {
778
+ self . inner ( ) ?. get_tensor ( name)
779
+ }
780
+
781
+ /// Returns a full slice view object
782
+ ///
783
+ /// Args:
784
+ /// name (`str`):
785
+ /// The name of the tensor you want
786
+ ///
787
+ /// Returns:
788
+ /// (`PySafeSlice`):
789
+ /// A dummy object you can slice into to get a real tensor
790
+ /// Example:
791
+ /// ```python
792
+ /// from safetensors import safe_open
793
+ ///
794
+ /// with safe_open("model.safetensors", framework="pt", device=0) as f:
795
+ /// tensor_part = f.get_slice("embedding")[:, ::8]
796
+ ///
797
+ /// ```
798
+ pub fn get_slice ( & self , name : & str ) -> PyResult < PySafeSlice > {
799
+ self . inner ( ) ?. get_slice ( name)
800
+ }
717
801
718
802
pub fn __enter__ ( slf : Py < Self > ) -> Py < Self > {
719
803
// SAFETY: This code is extremely important to the GPU fast load.
@@ -726,9 +810,10 @@ impl safe_open {
726
810
// of the context manager lifecycle.
727
811
Python :: with_gil ( |py| -> PyResult < ( ) > {
728
812
let _self: & safe_open = & slf. borrow ( py) ;
729
- if let ( Device :: Cuda ( _) , Framework :: Pytorch ) = ( & _self. device , & _self. framework ) {
813
+ let inner = _self. inner ( ) ?;
814
+ if let ( Device :: Cuda ( _) , Framework :: Pytorch ) = ( & inner. device , & inner. framework ) {
730
815
let module = get_module ( py, & TORCH_MODULE ) ?;
731
- let device: PyObject = _self . device . clone ( ) . into_py ( py) ;
816
+ let device: PyObject = inner . device . clone ( ) . into_py ( py) ;
732
817
let torch_device = module
733
818
. getattr ( intern ! ( py, "cuda" ) ) ?
734
819
. getattr ( intern ! ( py, "device" ) ) ?;
@@ -742,20 +827,23 @@ impl safe_open {
742
827
}
743
828
744
829
pub fn __exit__ ( & mut self , _exc_type : PyObject , _exc_value : PyObject , _traceback : PyObject ) {
745
- if let ( Device :: Cuda ( _) , Framework :: Pytorch ) = ( & self . device , & self . framework ) {
746
- Python :: with_gil ( |py| -> PyResult < ( ) > {
747
- let module = get_module ( py, & TORCH_MODULE ) ?;
748
- let device: PyObject = self . device . clone ( ) . into_py ( py) ;
749
- let torch_device = module
750
- . getattr ( intern ! ( py, "cuda" ) ) ?
751
- . getattr ( intern ! ( py, "device" ) ) ?;
752
- let none = py. None ( ) ;
753
- let lock = torch_device. call1 ( ( device, ) ) ?;
754
- lock. call_method1 ( intern ! ( py, "__exit__" ) , ( & none, & none, & none) ) ?;
755
- Ok ( ( ) )
756
- } )
757
- . ok ( ) ;
830
+ if let Some ( inner) = & self . inner {
831
+ if let ( Device :: Cuda ( _) , Framework :: Pytorch ) = ( & inner. device , & inner. framework ) {
832
+ Python :: with_gil ( |py| -> PyResult < ( ) > {
833
+ let module = get_module ( py, & TORCH_MODULE ) ?;
834
+ let device: PyObject = inner. device . clone ( ) . into_py ( py) ;
835
+ let torch_device = module
836
+ . getattr ( intern ! ( py, "cuda" ) ) ?
837
+ . getattr ( intern ! ( py, "device" ) ) ?;
838
+ let none = py. None ( ) ;
839
+ let lock = torch_device. call1 ( ( device, ) ) ?;
840
+ lock. call_method1 ( intern ! ( py, "__exit__" ) , ( & none, & none, & none) ) ?;
841
+ Ok ( ( ) )
842
+ } )
843
+ . ok ( ) ;
844
+ }
758
845
}
846
+ self . inner = None ;
759
847
}
760
848
}
761
849
@@ -874,7 +962,9 @@ impl PySafeSlice {
874
962
let start = ( self . info . data_offsets . 0 + self . offset ) as isize ;
875
963
let stop = ( self . info . data_offsets . 1 + self . offset ) as isize ;
876
964
let slice = pyslice_new ( py, start, stop, 1 ) ;
877
- let storage: & PyObject = storage. get ( py) . unwrap ( ) ;
965
+ let storage: & PyObject = storage
966
+ . get ( py)
967
+ . ok_or_else ( || SafetensorError :: new_err ( "Could not find storage" ) ) ?;
878
968
let storage: & PyAny = storage. as_ref ( py) ;
879
969
880
970
let storage_slice = storage
0 commit comments