@@ -466,32 +466,15 @@ impl Version {
466
466
}
467
467
}
468
468
469
- /// Opens a safetensors lazily and returns tensors as asked
470
- ///
471
- /// Args:
472
- /// filename (`str`, or `os.PathLike`):
473
- /// The filename to open
474
- ///
475
- /// framework (`str`):
476
- /// The framework you want you tensors in. Supported values:
477
- /// `pt`, `tf`, `flax`, `numpy`.
478
- ///
479
- /// device (`str`, defaults to `"cpu"`):
480
- /// The device on which you want the tensors.
481
- #[ pyclass]
482
- #[ allow( non_camel_case_types) ]
483
- #[ pyo3( text_signature = "(self, filename, framework, device=\" cpu\" )" ) ]
484
- struct safe_open {
469
+ struct Open {
485
470
metadata : Metadata ,
486
471
offset : usize ,
487
472
framework : Framework ,
488
473
device : Device ,
489
474
storage : Arc < Storage > ,
490
475
}
491
476
492
- #[ pymethods]
493
- impl safe_open {
494
- #[ new]
477
+ impl Open {
495
478
fn new ( filename : PathBuf , framework : Framework , device : Option < Device > ) -> PyResult < Self > {
496
479
let file = File :: open ( & filename) ?;
497
480
let device = device. unwrap_or ( Device :: Cpu ) ;
@@ -714,6 +697,105 @@ impl safe_open {
714
697
) ) )
715
698
}
716
699
}
700
+ }
701
+
702
+ /// Opens a safetensors lazily and returns tensors as asked
703
+ ///
704
+ /// Args:
705
+ /// filename (`str`, or `os.PathLike`):
706
+ /// The filename to open
707
+ ///
708
+ /// framework (`str`):
709
+ /// The framework you want you tensors in. Supported values:
710
+ /// `pt`, `tf`, `flax`, `numpy`.
711
+ ///
712
+ /// device (`str`, defaults to `"cpu"`):
713
+ /// The device on which you want the tensors.
714
+ #[ pyclass]
715
+ #[ allow( non_camel_case_types) ]
716
+ #[ pyo3( text_signature = "(self, filename, framework, device=\" cpu\" )" ) ]
717
+ struct safe_open {
718
+ inner : Option < Open > ,
719
+ }
720
+
721
+ impl safe_open {
722
+ fn inner ( & self ) -> PyResult < & Open > {
723
+ let inner = self
724
+ . inner
725
+ . as_ref ( )
726
+ . ok_or_else ( || SafetensorError :: new_err ( format ! ( "File is closed" , ) ) ) ?;
727
+ Ok ( inner)
728
+ }
729
+ }
730
+
731
+ #[ pymethods]
732
+ impl safe_open {
733
+ #[ new]
734
+ fn new ( filename : PathBuf , framework : Framework , device : Option < Device > ) -> PyResult < Self > {
735
+ let inner = Some ( Open :: new ( filename, framework, device) ?) ;
736
+ Ok ( Self { inner } )
737
+ }
738
+
739
+ /// Return the special non tensor information in the header
740
+ ///
741
+ /// Returns:
742
+ /// (`Dict[str, str]`):
743
+ /// The freeform metadata.
744
+ pub fn metadata ( & self ) -> PyResult < Option < BTreeMap < String , String > > > {
745
+ Ok ( self . inner ( ) ?. metadata ( ) )
746
+ }
747
+
748
+ /// Returns the names of the tensors in the file.
749
+ ///
750
+ /// Returns:
751
+ /// (`List[str]`):
752
+ /// The name of the tensors contained in that file
753
+ pub fn keys ( & self ) -> PyResult < Vec < String > > {
754
+ self . inner ( ) ?. keys ( )
755
+ }
756
+
757
+ /// Returns a full tensor
758
+ ///
759
+ /// Args:
760
+ /// name (`str`):
761
+ /// The name of the tensor you want
762
+ ///
763
+ /// Returns:
764
+ /// (`Tensor`):
765
+ /// The tensor in the framework you opened the file for.
766
+ ///
767
+ /// Example:
768
+ /// ```python
769
+ /// from safetensors import safe_open
770
+ ///
771
+ /// with safe_open("model.safetensors", framework="pt", device=0) as f:
772
+ /// tensor = f.get_tensor("embedding")
773
+ ///
774
+ /// ```
775
+ pub fn get_tensor ( & self , name : & str ) -> PyResult < PyObject > {
776
+ self . inner ( ) ?. get_tensor ( name)
777
+ }
778
+
779
+ /// Returns a full slice view object
780
+ ///
781
+ /// Args:
782
+ /// name (`str`):
783
+ /// The name of the tensor you want
784
+ ///
785
+ /// Returns:
786
+ /// (`PySafeSlice`):
787
+ /// A dummy object you can slice into to get a real tensor
788
+ /// Example:
789
+ /// ```python
790
+ /// from safetensors import safe_open
791
+ ///
792
+ /// with safe_open("model.safetensors", framework="pt", device=0) as f:
793
+ /// tensor_part = f.get_slice("embedding")[:, ::8]
794
+ ///
795
+ /// ```
796
+ pub fn get_slice ( & self , name : & str ) -> PyResult < PySafeSlice > {
797
+ self . inner ( ) ?. get_slice ( name)
798
+ }
717
799
718
800
pub fn __enter__ ( slf : Py < Self > ) -> Py < Self > {
719
801
// SAFETY: This code is extremely important to the GPU fast load.
@@ -726,9 +808,10 @@ impl safe_open {
726
808
// of the context manager lifecycle.
727
809
Python :: with_gil ( |py| -> PyResult < ( ) > {
728
810
let _self: & safe_open = & slf. borrow ( py) ;
729
- if let ( Device :: Cuda ( _) , Framework :: Pytorch ) = ( & _self. device , & _self. framework ) {
811
+ let inner = _self. inner ( ) ?;
812
+ if let ( Device :: Cuda ( _) , Framework :: Pytorch ) = ( & inner. device , & inner. framework ) {
730
813
let module = get_module ( py, & TORCH_MODULE ) ?;
731
- let device: PyObject = _self . device . clone ( ) . into_py ( py) ;
814
+ let device: PyObject = inner . device . clone ( ) . into_py ( py) ;
732
815
let torch_device = module
733
816
. getattr ( intern ! ( py, "cuda" ) ) ?
734
817
. getattr ( intern ! ( py, "device" ) ) ?;
@@ -742,10 +825,11 @@ impl safe_open {
742
825
}
743
826
744
827
pub fn __exit__ ( & mut self , _exc_type : PyObject , _exc_value : PyObject , _traceback : PyObject ) {
745
- if let ( Device :: Cuda ( _) , Framework :: Pytorch ) = ( & self . device , & self . framework ) {
828
+ let inner = self . inner ( ) . unwrap ( ) ;
829
+ if let ( Device :: Cuda ( _) , Framework :: Pytorch ) = ( & inner. device , & inner. framework ) {
746
830
Python :: with_gil ( |py| -> PyResult < ( ) > {
747
831
let module = get_module ( py, & TORCH_MODULE ) ?;
748
- let device: PyObject = self . device . clone ( ) . into_py ( py) ;
832
+ let device: PyObject = inner . device . clone ( ) . into_py ( py) ;
749
833
let torch_device = module
750
834
. getattr ( intern ! ( py, "cuda" ) ) ?
751
835
. getattr ( intern ! ( py, "device" ) ) ?;
@@ -756,6 +840,7 @@ impl safe_open {
756
840
} )
757
841
. ok ( ) ;
758
842
}
843
+ self . inner = None ;
759
844
}
760
845
}
761
846
0 commit comments