Skip to content

Commit 48ea985

Browse files
committed
add egein stats
1 parent 7a6f84e commit 48ea985

File tree

2 files changed

+36
-26
lines changed

2 files changed

+36
-26
lines changed

distributed_shampoo/utils/shampoo_preconditioner_list.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from distributed_shampoo.utils.shampoo_block_info import BlockInfo
2626
from distributed_shampoo.utils.shampoo_utils import compress_list, get_dtype_size
27-
from matrix_functions import check_diagonal, matrix_eigenvectors, matrix_inverse_root
27+
from matrix_functions import EigenStats, check_diagonal, matrix_eigenvectors, matrix_inverse_root
2828

2929
from matrix_functions_types import EigenvectorConfig, RootInvConfig
3030
from optimizer_modules import OptimizerModule
@@ -303,6 +303,7 @@ class EigenvalueCorrectedShampooKroneckerFactorsList(BaseShampooKroneckerFactors
303303

304304
factor_matrices_eigenvectors: tuple[Tensor, ...]
305305
corrected_eigenvalues: Tensor
306+
eigen_stats: EigenStats | None = None
306307

307308
def __post_init__(self) -> None:
308309
super().__post_init__()
@@ -1065,8 +1066,7 @@ def _amortized_computation(self, step: int) -> None:
10651066
self._preconditioner_config.amortized_computation_config,
10661067
)
10671068
try:
1068-
logger.info(f"TYPEEEE: {type(eigenvector_computation_config)}...")
1069-
computed_eigenvectors = matrix_eigenvectors(
1069+
computed_eigenvectors, eigen_stats = matrix_eigenvectors(
10701070
A=factor_matrix,
10711071
eigenvectors_estimate=factor_matrix_eigenvectors,
10721072
eigenvector_computation_config=eigenvector_computation_config,
@@ -1093,6 +1093,8 @@ def _amortized_computation(self, step: int) -> None:
10931093
f"To mitigate, check factor matrix before the matrix computation: {factor_matrix=}"
10941094
)
10951095
factor_matrix_eigenvectors.copy_(computed_eigenvectors)
1096+
print(type(self._masked_kronecker_factors_list[idx]))
1097+
# self._masked_kronecker_factors_list[idx].eigen_stats = eigen_stats
10961098

10971099
# Only reuse previous eigenvectors if tolerance is not exceeded.
10981100
self._raise_exception_if_failure_tolerance_exceeded(

matrix_functions.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,13 @@
3030
)
3131

3232
from torch import Tensor
33+
from typing import NamedTuple
34+
3335

3436
logger: logging.Logger = logging.getLogger(__name__)
3537

3638

39+
3740
class NewtonConvergenceFlag(enum.Enum):
3841
"""
3942
Enum class for the state of the Newton / higher-order iteration method.
@@ -619,13 +622,24 @@ def compute_effective_rank(eigenvalues: torch.Tensor, threshold: float = 0.95) -
619622
return effective_rank
620623

621624

625+
class EigenStats(NamedTuple):
626+
effective_rank: int
627+
og_rank: int
628+
629+
@property
630+
def compression_ratio(self):
631+
return 1 - self.effective_rank / self.og_rank
632+
633+
def __repr__(self):
634+
return f"Effective rank: {self.effective_rank}, og_rank: {self.og_rank}, compression_ratio: {self.compression_ratio}"
635+
622636
def matrix_eigenvectors(
623637
A: Tensor,
624638
eigenvectors_estimate: Tensor | None = None,
625639
eigenvector_computation_config: EigenvectorConfig = DefaultEighEigenvectorConfig,
626640
is_diagonal: bool = False,
627641
step: int | None = None,
628-
) -> Tensor:
642+
) -> Tensor | EigenStats:
629643
"""Compute eigenvectors of matrix using eigendecomposition of symmetric positive (semi-)definite matrix.
630644
A = Q L Q^T => Q
631645
@@ -663,12 +677,20 @@ def matrix_eigenvectors(
663677
device=A.device,
664678
)
665679

680+
666681
if isinstance(eigenvector_computation_config, EighEigenvectorConfig):
667682
eigenvalues, eigenvectors = matrix_eigenvalue_decomposition(
668683
A,
669684
retry_double_precision=eigenvector_computation_config.retry_double_precision,
670685
)
671686

687+
compression_t = eigenvector_computation_config.compression_t if isinstance(eigenvector_computation_config, TopKCompressionEigenvectorConfig) else TopKCompressionEigenvectorConfig().compression_t
688+
689+
eigen_stats = EigenStats(
690+
effective_rank=compute_effective_rank(eigenvalues, compression_t),
691+
og_rank=eigenvalues.shape[0],
692+
)
693+
672694
if step is None:
673695
raise ValueError("step param is required when using EighEigenvectorConfig.")
674696

@@ -677,36 +699,19 @@ def matrix_eigenvectors(
677699
and eigenvalues.shape[0] > eigenvector_computation_config.min_dim
678700
and step > eigenvector_computation_config.warmup_steps
679701
):
680-
effective_rank = compute_effective_rank(eigenvalues, eigenvector_computation_config.compression_t)
681702

682-
# rank = int(os.environ.get("RANK", 0))
683-
potential_compression_ratio = 1 - effective_rank / eigenvalues.shape[0]
684703

685-
# if rank == 0:
686-
# import wandb
687-
688-
# wandb.log(
689-
# {
690-
# "effective_rank": effective_rank,
691-
# "og_rank": eigenvalues.shape[0],
692-
# "potential_compression_ratio": 1 - effective_rank / eigenvalues.shape[0],
693-
# }
694-
# )
695704

696705
if eigenvector_computation_config.auto:
697-
topk = effective_rank
698-
print(
699-
f"Effective rank: {effective_rank}, og_rank: {eigenvalues.shape[0]}, compression_ratio: {potential_compression_ratio}"
700-
)
706+
topk = eigen_stats.effective_rank
707+
print(eigen_stats)
701708
elif isinstance(eigenvector_computation_config.topk_compression, int):
702709
topk = eigenvector_computation_config.topk_compression
703710
else:
704711
topk = int(eigenvector_computation_config.topk_compression * eigenvalues.shape[0])
705712

706-
if potential_compression_ratio < eigenvector_computation_config.min_compression_ratio:
707-
print(
708-
f"Skipping eigenvector computation due to low compression ratio: {potential_compression_ratio}, effective_rank = {effective_rank}, og_rank = {eigenvalues.shape[0]}"
709-
)
713+
if eigen_stats.compression_ratio < eigenvector_computation_config.min_compression_ratio:
714+
print(f"Skipping eigenvector computation due to low compression ratio: {eigen_stats}")
710715
return eigenvectors
711716
# Sort eigenvalues and eigenvectors in descending order
712717
eigenvalues, indices = torch.sort(
@@ -720,9 +725,12 @@ def matrix_eigenvectors(
720725
mask[:, :topk] = 1.0
721726
eigenvectors = eigenvectors * mask
722727

723-
return eigenvectors
728+
return eigenvectors, eigen_stats
724729

725730
elif isinstance(eigenvector_computation_config, QRConfig):
731+
732+
raise NotImplementedError("QRConfig is not implemented yet.")
733+
726734
assert eigenvectors_estimate is not None, "Estimate of eigenvectors is required when using QRConfig."
727735

728736
eigenvectors = _compute_orthogonal_iterations(

0 commit comments

Comments
 (0)