@@ -532,7 +532,7 @@ def _get_inverse_roots_from_override_with_high_order_default(
532
532
)
533
533
534
534
@abstractmethod
535
- def _amortized_computation (self ) -> None :
535
+ def _amortized_computation (self , step : int ) -> None :
536
536
"""
537
537
Computes the amortized computation needed for each Shampoo preconditioner implementation.
538
538
This amortized computation is computation heavy work that cannot be done for each step.
@@ -631,7 +631,7 @@ def update_preconditioners(
631
631
# In Shampoo, this is equivalent to computing the inverse factor matrix.
632
632
# In Eigenvalue-Corrected Shampoo, this is equivalent to computing the eigenvector of the factor matrix.
633
633
if perform_amortized_computation :
634
- self ._amortized_computation ()
634
+ self ._amortized_computation (step = step )
635
635
636
636
def _initialize_state_lists (
637
637
self ,
@@ -797,7 +797,7 @@ def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, ..
797
797
)
798
798
799
799
@torch .compiler .disable
800
- def _amortized_computation (self ) -> None :
800
+ def _amortized_computation (self , step : int ) -> None :
801
801
# NOTE: This function currently only computes the matrix root inverse based on
802
802
# the masked lists which combines both selection based on the distributor and where
803
803
# grad is not None. Implicitly, this assumes that there are no changes between the
@@ -1032,7 +1032,7 @@ def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, ..
1032
1032
return tuple (preconditioned_grad_list )
1033
1033
1034
1034
@torch .compiler .disable
1035
- def _amortized_computation (self ) -> None :
1035
+ def _amortized_computation (self , step : int ) -> None :
1036
1036
# NOTE: This function currently only computes the preconditioner eigenvectors based on
1037
1037
# the masked lists which combines both selection based on the distributor and where
1038
1038
# grad is not None. Implicitly, this assumes that there are no changes between the
@@ -1071,6 +1071,7 @@ def _amortized_computation(self) -> None:
1071
1071
eigenvectors_estimate = factor_matrix_eigenvectors ,
1072
1072
eigenvector_computation_config = eigenvector_computation_config ,
1073
1073
is_diagonal = bool (is_factor_matrix_diagonal ),
1074
+ step = step ,
1074
1075
)
1075
1076
# Add success to success tracker.
1076
1077
success_tracker .append (True )
0 commit comments