30
30
)
31
31
32
32
from torch import Tensor
33
+ from typing import NamedTuple
34
+
33
35
34
36
logger : logging .Logger = logging .getLogger (__name__ )
35
37
36
38
39
+
37
40
class NewtonConvergenceFlag (enum .Enum ):
38
41
"""
39
42
Enum class for the state of the Newton / higher-order iteration method.
@@ -619,13 +622,24 @@ def compute_effective_rank(eigenvalues: torch.Tensor, threshold: float = 0.95) -
619
622
return effective_rank
620
623
621
624
625
+ class EigenStats (NamedTuple ):
626
+ effective_rank : int
627
+ og_rank : int
628
+
629
+ @property
630
+ def compression_ratio (self ):
631
+ return 1 - self .effective_rank / self .og_rank
632
+
633
+ def __repr__ (self ):
634
+ return f"Effective rank: { self .effective_rank } , og_rank: { self .og_rank } , compression_ratio: { self .compression_ratio } "
635
+
622
636
def matrix_eigenvectors (
623
637
A : Tensor ,
624
638
eigenvectors_estimate : Tensor | None = None ,
625
639
eigenvector_computation_config : EigenvectorConfig = DefaultEighEigenvectorConfig ,
626
640
is_diagonal : bool = False ,
627
641
step : int | None = None ,
628
- ) -> Tensor :
642
+ ) -> Tensor | EigenStats :
629
643
"""Compute eigenvectors of matrix using eigendecomposition of symmetric positive (semi-)definite matrix.
630
644
A = Q L Q^T => Q
631
645
@@ -663,12 +677,20 @@ def matrix_eigenvectors(
663
677
device = A .device ,
664
678
)
665
679
680
+
666
681
if isinstance (eigenvector_computation_config , EighEigenvectorConfig ):
667
682
eigenvalues , eigenvectors = matrix_eigenvalue_decomposition (
668
683
A ,
669
684
retry_double_precision = eigenvector_computation_config .retry_double_precision ,
670
685
)
671
686
687
+ compression_t = eigenvector_computation_config .compression_t if isinstance (eigenvector_computation_config , TopKCompressionEigenvectorConfig ) else TopKCompressionEigenvectorConfig ().compression_t
688
+
689
+ eigen_stats = EigenStats (
690
+ effective_rank = compute_effective_rank (eigenvalues , compression_t ),
691
+ og_rank = eigenvalues .shape [0 ],
692
+ )
693
+
672
694
if step is None :
673
695
raise ValueError ("step param is required when using EighEigenvectorConfig." )
674
696
@@ -677,36 +699,19 @@ def matrix_eigenvectors(
677
699
and eigenvalues .shape [0 ] > eigenvector_computation_config .min_dim
678
700
and step > eigenvector_computation_config .warmup_steps
679
701
):
680
- effective_rank = compute_effective_rank (eigenvalues , eigenvector_computation_config .compression_t )
681
702
682
- # rank = int(os.environ.get("RANK", 0))
683
- potential_compression_ratio = 1 - effective_rank / eigenvalues .shape [0 ]
684
703
685
- # if rank == 0:
686
- # import wandb
687
-
688
- # wandb.log(
689
- # {
690
- # "effective_rank": effective_rank,
691
- # "og_rank": eigenvalues.shape[0],
692
- # "potential_compression_ratio": 1 - effective_rank / eigenvalues.shape[0],
693
- # }
694
- # )
695
704
696
705
if eigenvector_computation_config .auto :
697
- topk = effective_rank
698
- print (
699
- f"Effective rank: { effective_rank } , og_rank: { eigenvalues .shape [0 ]} , compression_ratio: { potential_compression_ratio } "
700
- )
706
+ topk = eigen_stats .effective_rank
707
+ print (eigen_stats )
701
708
elif isinstance (eigenvector_computation_config .topk_compression , int ):
702
709
topk = eigenvector_computation_config .topk_compression
703
710
else :
704
711
topk = int (eigenvector_computation_config .topk_compression * eigenvalues .shape [0 ])
705
712
706
- if potential_compression_ratio < eigenvector_computation_config .min_compression_ratio :
707
- print (
708
- f"Skipping eigenvector computation due to low compression ratio: { potential_compression_ratio } , effective_rank = { effective_rank } , og_rank = { eigenvalues .shape [0 ]} "
709
- )
713
+ if eigen_stats .compression_ratio < eigenvector_computation_config .min_compression_ratio :
714
+ print (f"Skipping eigenvector computation due to low compression ratio: { eigen_stats } " )
710
715
return eigenvectors
711
716
# Sort eigenvalues and eigenvectors in descending order
712
717
eigenvalues , indices = torch .sort (
@@ -720,9 +725,12 @@ def matrix_eigenvectors(
720
725
mask [:, :topk ] = 1.0
721
726
eigenvectors = eigenvectors * mask
722
727
723
- return eigenvectors
728
+ return eigenvectors , eigen_stats
724
729
725
730
elif isinstance (eigenvector_computation_config , QRConfig ):
731
+
732
+ raise NotImplementedError ("QRConfig is not implemented yet." )
733
+
726
734
assert eigenvectors_estimate is not None , "Estimate of eigenvectors is required when using QRConfig."
727
735
728
736
eigenvectors = _compute_orthogonal_iterations (
0 commit comments