Skip to content

Commit 7a6f84e

Browse files
committed
add warmup step
1 parent 6e4f995 commit 7a6f84e

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

distributed_shampoo/utils/shampoo_preconditioner_list.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def _get_inverse_roots_from_override_with_high_order_default(
532532
)
533533

534534
@abstractmethod
535-
def _amortized_computation(self) -> None:
535+
def _amortized_computation(self, step: int) -> None:
536536
"""
537537
Computes the amortized computation needed for each Shampoo preconditioner implementation.
538538
This amortized computation is computation heavy work that cannot be done for each step.
@@ -631,7 +631,7 @@ def update_preconditioners(
631631
# In Shampoo, this is equivalent to computing the inverse factor matrix.
632632
# In Eigenvalue-Corrected Shampoo, this is equivalent to computing the eigenvector of the factor matrix.
633633
if perform_amortized_computation:
634-
self._amortized_computation()
634+
self._amortized_computation(step=step)
635635

636636
def _initialize_state_lists(
637637
self,
@@ -797,7 +797,7 @@ def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, ..
797797
)
798798

799799
@torch.compiler.disable
800-
def _amortized_computation(self) -> None:
800+
def _amortized_computation(self, step: int) -> None:
801801
# NOTE: This function currently only computes the matrix root inverse based on
802802
# the masked lists which combines both selection based on the distributor and where
803803
# grad is not None. Implicitly, this assumes that there are no changes between the
@@ -1032,7 +1032,7 @@ def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, ..
10321032
return tuple(preconditioned_grad_list)
10331033

10341034
@torch.compiler.disable
1035-
def _amortized_computation(self) -> None:
1035+
def _amortized_computation(self, step: int) -> None:
10361036
# NOTE: This function currently only computes the preconditioner eigenvectors based on
10371037
# the masked lists which combines both selection based on the distributor and where
10381038
# grad is not None. Implicitly, this assumes that there are no changes between the
@@ -1071,6 +1071,7 @@ def _amortized_computation(self) -> None:
10711071
eigenvectors_estimate=factor_matrix_eigenvectors,
10721072
eigenvector_computation_config=eigenvector_computation_config,
10731073
is_diagonal=bool(is_factor_matrix_diagonal),
1074+
step=step,
10741075
)
10751076
# Add success to success tracker.
10761077
success_tracker.append(True)

matrix_functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,7 @@ def matrix_eigenvectors(
624624
eigenvectors_estimate: Tensor | None = None,
625625
eigenvector_computation_config: EigenvectorConfig = DefaultEighEigenvectorConfig,
626626
is_diagonal: bool = False,
627+
step: int | None = None,
627628
) -> Tensor:
628629
"""Compute eigenvectors of matrix using eigendecomposition of symmetric positive (semi-)definite matrix.
629630
A = Q L Q^T => Q
@@ -668,9 +669,13 @@ def matrix_eigenvectors(
668669
retry_double_precision=eigenvector_computation_config.retry_double_precision,
669670
)
670671

672+
if step is None:
673+
raise ValueError("step param is required when using EighEigenvectorConfig.")
674+
671675
if (
672676
isinstance(eigenvector_computation_config, TopKCompressionEigenvectorConfig)
673677
and eigenvalues.shape[0] > eigenvector_computation_config.min_dim
678+
and step > eigenvector_computation_config.warmup_steps
674679
):
675680
effective_rank = compute_effective_rank(eigenvalues, eigenvector_computation_config.compression_t)
676681

matrix_functions_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ class TopKCompressionEigenvectorConfig(EighEigenvectorConfig):
151151

152152
compression_t: float = 0.95
153153

154+
warmup_steps: int = 0
155+
154156
def __post_init__(self):
155157
if isinstance(self.topk_compression, float):
156158
if not 0 < self.topk_compression <= 1:

0 commit comments

Comments
 (0)