Skip to content

Commit 5da5a91

Browse files
committed
add egeinvector stats
1 parent 48ea985 commit 5da5a91

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

distributed_shampoo/distributed_shampoo.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,3 +1250,27 @@ def load_distributed_state_dict(
12501250
param_group_to_load = param_groups_to_load[param_group_key]
12511251
for key, value in param_group_to_load.items():
12521252
group[key] = deepcopy(value)
1253+
1254+
1255+
@torch.no_grad()
1256+
def eigenvector_stats(self, key_to_param: Iterator[tuple[str, torch.Tensor]], summary: bool = False):
1257+
# Create mapping from parameter to its name
1258+
param_to_key = {param: key for key, param in key_to_param}
1259+
1260+
stats = {}
1261+
for idx, (state_lists, group) in enumerate(zip(self._per_group_state_lists, self.param_groups)):
1262+
shampoo_preconditioner_list = state_lists[SHAMPOO_PRECONDITIONER_LIST]
1263+
if isinstance(shampoo_preconditioner_list, EigenvalueCorrectedShampooPreconditionerList):
1264+
# Get eigenvalue stats for this group
1265+
group_eigen_stats = shampoo_preconditioner_list.eigenvector_stats()
1266+
1267+
# Map each parameter to its stats
1268+
param_stats = {}
1269+
for param, eigen_stat in zip(group[PARAMS], group_eigen_stats):
1270+
if param in param_to_key:
1271+
param_key = param_to_key[param]
1272+
param_stats[param_key] = eigen_stat
1273+
1274+
stats[f"group_{idx}"] = param_stats
1275+
1276+
return stats

distributed_shampoo/utils/shampoo_preconditioner_list.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,8 +1093,7 @@ def _amortized_computation(self, step: int) -> None:
10931093
f"To mitigate, check factor matrix before the matrix computation: {factor_matrix=}"
10941094
)
10951095
factor_matrix_eigenvectors.copy_(computed_eigenvectors)
1096-
print(type(self._masked_kronecker_factors_list[idx]))
1097-
# self._masked_kronecker_factors_list[idx].eigen_stats = eigen_stats
1096+
self._masked_kronecker_factors_list[idx].eigen_stats = eigen_stats
10981097

10991098
# Only reuse previous eigenvectors if tolerance is not exceeded.
11001099
self._raise_exception_if_failure_tolerance_exceeded(
@@ -1104,3 +1103,6 @@ def _amortized_computation(self, step: int) -> None:
11041103
f"The number of failed eigenvector computations for factors {kronecker_factors.factor_matrix_indices} exceeded the allowed tolerance."
11051104
),
11061105
)
1106+
1107+
def eigenvector_stats(self) -> tuple[EigenStats | None, ...]:
1108+
return tuple(kronecker_factors.eigen_stats for kronecker_factors in self._masked_kronecker_factors_list)

matrix_functions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,13 @@ def compression_ratio(self):
633633
def __repr__(self):
634634
return f"Effective rank: {self.effective_rank}, og_rank: {self.og_rank}, compression_ratio: {self.compression_ratio}"
635635

636+
def log_stats(self) -> dict[str, int|float]:
637+
return {
638+
"effective_rank": self.effective_rank,
639+
"og_rank": self.og_rank,
640+
"compression_ratio": self.compression_ratio,
641+
}
642+
636643
def matrix_eigenvectors(
637644
A: Tensor,
638645
eigenvectors_estimate: Tensor | None = None,
@@ -704,7 +711,6 @@ def matrix_eigenvectors(
704711

705712
if eigenvector_computation_config.auto:
706713
topk = eigen_stats.effective_rank
707-
print(eigen_stats)
708714
elif isinstance(eigenvector_computation_config.topk_compression, int):
709715
topk = eigenvector_computation_config.topk_compression
710716
else:

0 commit comments

Comments
 (0)