diff --git a/distributed_shampoo/README.md b/distributed_shampoo/README.md index aaf0a053..27dc1fb2 100644 --- a/distributed_shampoo/README.md +++ b/distributed_shampoo/README.md @@ -1,6 +1,6 @@ # PyTorch Distributed Shampoo -Distributed Shampoo is a preconditioned stochastic gradient optimizer in the adaptive gradient (Adagrad) family of methods [1, 2]. It converges faster by leveraging neural network-specific structures to achieve comparable model quality/accuracy in fewer iterations or epochs at the cost of additional FLOPs and memory, or achieve higher model quality in the same number of iterations or epochs. Our implementation offers specialized support for serial, [Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html), and [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) training. +Distributed Shampoo is a preconditioned stochastic gradient optimizer in the adaptive gradient (Adagrad) family of methods [1, 2]. It converges faster by leveraging neural network-specific structures to achieve comparable model quality/accuracy in fewer iterations or epochs at the cost of additional FLOPs and memory, or achieve higher model quality in the same number of iterations or epochs. Our implementation offers specialized support for serial, [Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html), [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Hybrid Sharding Data Parallel](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html#how-to-use-devicemesh-with-hsdp) training. Distributed Shampoo currently only supports dense parameters. @@ -16,23 +16,7 @@ Developers: with contributions and support from: -Ganesh Ajjanagadde (Meta), Rohan Anil (Google), Adnan Aziz (Meta), Pavan Balaji (Meta), Shuo Chang (Meta), Weiwei Chu (Meta), Assaf Eisenman (Meta), Will Feng (Meta), Zhuobo Feng (Meta), Jose Gallego-Posada (Mila / Meta Platforms, Inc.), Avirup Ghosh (Meta), Yizi Gu (Meta), Vineet Gupta (Google), Yuchen Hao (Meta), Brian Hirsh (Meta), Yusuo Hu (Meta), Yuxi Hu (Meta), Minhui Huang (Meta), Guna Lakshminarayanan (Meta), Michael Lazos (Meta), Zhijing Li (Meta), Ming Liang (Meta), Wanchao Liang (Meta), Ying Liu (Meta), Wenguang Mao (Meta), Dheevatsa Mudigere (NVIDIA), Maxim Naumov (Meta), Jongsoo Park (Meta), Mike Rabbat (Meta), Kaushik Rangadurai (Meta), Dennis van der Staay (Meta), Fei Tian (Meta), Rohan Varma (Meta), Sanjay Vishwakarma (Meta), Xunnan (Shawn) Xu (Meta), Jiyan Yang (Meta), Chunxing Yin (Meta), Iris Zhang (Meta), and Will Zou (Meta). - -## Updates -- (7/18/24) This update contains - - PyTorch 2 compile bug fixes. - - HSDP Shampoo via `HSDPDistributor`. - - Mixed-precision optimizer states. - - Higher-order coupled iterations, with relative epsilon based on estimate of largest eigenvalue. - - Further modularization of Shampoo step function. - - Other simplifications. -- (2/14/24) We have released our Distributed Shampoo v2.0.0 implementation, a ground-up re-write of our PyTorch Shampoo implementation. Our v2.0.0 implementation includes: - - Incorporates new performance optimizations, such as the usage of `torch._foreach_*` operators and PyTorch 2 compile. - - Shared support and enablement of DDP and FSDP Shampoo, via the specification of the `distributed_config` field. - - Cleaner API for configuring grafting methods through specifying the `grafting_config` field. - - Deprecation of handling large tensors by diagonalizing the Shampoo preconditioners and using standard diagonal Adagrad. - - While we do not currently support LAMB/LARS grafting, we intend to add support for this in the future. - - We will update our [ArXiv paper](https://arxiv.org/pdf/2309.06497.pdf) to reflect our implementation changes. +Ganesh Ajjanagadde (Meta), Rohan Anil (Google), Adnan Aziz (Meta), Pavan Balaji (Meta), Shuo Chang (Meta), Weiwei Chu (Meta), Assaf Eisenman (Meta), Will Feng (Meta), Zhuobo Feng (Meta), Jose Gallego-Posada (Mila / Meta Platforms, Inc.), Avirup Ghosh (Meta), Yizi Gu (Meta), Vineet Gupta (Google), Yuchen Hao (Meta), Brian Hirsh (Meta), Yusuo Hu (Meta), Yuxi Hu (Meta), Minhui Huang (Meta), Guna Lakshminarayanan (Meta), Michael Lazos (Meta), Zhijing Li (Meta), Ming Liang (Meta), Wanchao Liang (Meta), Ying Liu (Meta), Wenguang Mao (Meta), Dheevatsa Mudigere (NVIDIA), Maxim Naumov (Meta), Jongsoo Park (Meta), Mike Rabbat (Meta), Kaushik Rangadurai (Meta), Dennis van der Staay (Meta), Fei Tian (Meta), Rohan Varma (Meta), Sanjay Vishwakarma (Meta), Xunnan (Shawn) Xu (Meta), Jiyan Yang (Meta), Chunxing Yin (Meta), Iris Zhang (Meta), Chuanhao Zhuge (Meta), and Will Zou (Meta). ## Features diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 74b2727f..4a0a9a5b 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -61,6 +61,7 @@ PRECONDITIONER_DTYPE, PREVIOUS_GRAD_SELECTOR, RMSpropGraftingConfig, + ROOT_INV_CONFIG, SGDGraftingConfig, SHAMPOO_PRECONDITIONER_LIST, ShampooPT2CompileConfig, @@ -99,6 +100,8 @@ QuantizedTensorList, ) from distributed_shampoo.utils.shampoo_utils import compress_list + +from matrix_functions_types import DefaultEigenConfig, RootInvConfig from torch.optim.optimizer import ParamsT logger: logging.Logger = logging.getLogger(__name__) @@ -107,45 +110,6 @@ class DistributedShampoo(torch.optim.Optimizer): """Implements distributed Shampoo algorithm. - Developers: - Hao-Jun Michael Shi (Meta Platforms, Inc.) - Tsung-Hsien Lee - Anna Cai (Meta Platforms, Inc.) - Shintaro Iwasaki (Meta Platforms, Inc.) - Ke Sang (Meta Platforms, Inc.) - Wang Zhou (Meta Platforms, Inc.) - - with contributions and support from: - - Ganesh Ajjanagadde (Meta), Rohan Anil (Google), Adnan Aziz (Meta), Pavan Balaji (Meta), Shuo Chang (Meta), Weiwei Chu (Meta), - Assaf Eisenman (Meta), Will Feng (Meta), Zhuobo Feng (Meta), Jose Gallego-Posada (Mila / Meta Platforms, Inc.), Avirup Ghosh (Meta), - Yizi Gu (Meta), Vineet Gupta (Google), Yuchen Hao (Meta), Brian Hirsh (Meta), Yusuo Hu (Meta), Yuxi Hu (Meta), Minhui Huang (Meta), - Guna Lakshminarayanan (Meta), Michael Lazos (Meta), Zhijing Li (Meta), Ming Liang (Meta), Wanchao Liang (Meta), Ying Liu - (Meta), Wenguang Mao (Meta), Dheevatsa Mudigere (NVIDIA), Maxim Naumov (Meta), Jongsoo Park (Meta), Mike Rabbat (Meta), - Kaushik Rangadurai (Meta), Dennis van der Staay (Meta), Fei Tian (Meta), Sanjay Vishwakarma (Meta), Xunnan (Shawn) Xu (Meta), - Jiyan Yang (Meta), Chunxing Yin (Meta), and Iris Zhang (Meta). - - Details in: https://arxiv.org/pdf/2309.06497.pdf. - - Partly based on the work in: - - https://arxiv.org/pdf/1802.09568.pdf - - https://arxiv.org/pdf/2002.09018.pdf - - ------------ - Requirements - ------------ - - 1. PyTorch >= 2.0 - 2. Python >= 3.10 - 3. CUDA 11.3, 11.4, 12.2+ - - In order to support checkpointing, one must use torch.distributed.checkpoint and pass the named parameters into state_dict. - Note that the standard checkpointing solution by PyTorch is not supported! - - Note: We have observed known instabilities with the torch.linalg.eigh operator on CUDA 11.6-12.1, specifically for low-rank - matrices, which may appear with using a small start_preconditioning_step. Please avoid these versions of CUDA if possible. - See: https://github.com/pytorch/pytorch/issues/94772. - -------- Features -------- @@ -296,6 +260,7 @@ class DistributedShampoo(torch.optim.Optimizer): 3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail. track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes. (Default: False) + root_inv_config (RootInvConfig): Configuration for root inverse computation. (Default: DefaultEigenConfig) """ @@ -326,6 +291,7 @@ def __init__( precision_config: Optional[PrecisionConfig] = None, use_protected_eigh: bool = True, track_root_inv_residuals: bool = False, + root_inv_config: RootInvConfig = DefaultEigenConfig, ) -> None: # Hyperparameter checks. if not lr >= 0.0: @@ -464,6 +430,7 @@ def __init__( USE_MERGE_DIMS: use_merge_dims, PRECONDITIONER_DTYPE: preconditioner_dtype, PRECISION_CONFIG: precision_config, + ROOT_INV_CONFIG: root_inv_config, }, ) @@ -542,6 +509,7 @@ def _instantiate_shampoo_preconditioner_list(self) -> None: state=self.state, block_info_list=state_lists[DISTRIBUTOR].global_block_info_list, distributor_selector=state_lists[DISTRIBUTOR].distributor_selector, + root_inv_config=group[ROOT_INV_CONFIG], beta2=group[BETAS][1], epsilon=group[EPSILON], inv_root_override=group[INV_ROOT_OVERRIDE], @@ -1171,6 +1139,19 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] if not state_lists[MASKED_BLOCKED_GRADS]: continue + # Convert the gradient dtype to the computation dtype set in the precision_config if + # necessary. + # + # This conversion is needed because the blocked gradient list has float32 dtype, and we + # need to convert it to the desired precision before precondition computation. + if ( + computation_dtype := group[PRECISION_CONFIG].computation_dtype + ) != state_lists[MASKED_BLOCKED_GRADS][0].dtype: + state_lists[MASKED_BLOCKED_GRADS] = tuple( + tensor.to(dtype=computation_dtype) + for tensor in state_lists[MASKED_BLOCKED_GRADS] + ) + # Iterate group step counter and define Python scalar step. step = state_lists[STEP].add_(1) # NOTE: Wrap scalar of group[LR] into a 0D tensor to avoid PT2 recompilation; diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index d2129d08..627e6ae1 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -35,6 +35,7 @@ PRECISION_CONFIG = "precision_config" PRECONDITION_FREQUENCY = "precondition_frequency" PRECONDITIONER_DTYPE = "preconditioner_dtype" +ROOT_INV_CONFIG = "root_inv_config" START_PRECONDITIONING_STEP = "start_preconditioning_step" USE_BIAS_CORRECTION = "use_bias_correction" USE_DECOUPLED_WEIGHT_DECAY = "use_decoupled_weight_decay" diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index 704be722..ebbfa803 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -36,6 +36,7 @@ ShampooPreconditionerList, ) from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList +from matrix_functions_types import DefaultEigenConfig from torch import nn @@ -267,15 +268,15 @@ def closure() -> float: self.assertEqual(self._optimizer.step(closure=closure), 1.0) - def test_step_with_empty_grad_list(self) -> None: + @mock.patch.object(ShampooPreconditionerList, "update_preconditioners") + def test_step_with_empty_grad_list( + self, mock_upgrade_preconditioners: mock.Mock + ) -> None: # Test the case that the grad_list is empty. self._optimizer.zero_grad() - with mock.patch.object( - ShampooPreconditionerList, "update_preconditioners" - ) as mock_upgrade_preconditioners: - self._optimizer.step() - # Because the gradient list is empty, the preconditioners should not be updated. - mock_upgrade_preconditioners.assert_not_called() + self._optimizer.step() + # Because the gradient list is empty, the preconditioners should not be updated. + mock_upgrade_preconditioners.assert_not_called() class DistributedShampooStateDictTest(unittest.TestCase): @@ -447,6 +448,7 @@ def setUp(self) -> None: "use_merge_dims": True, "preconditioner_dtype": None, "precision_config": PrecisionConfig(), + "root_inv_config": DefaultEigenConfig, } }, } @@ -674,6 +676,8 @@ def setUp(self) -> None: self._model = nn.Sequential( nn.Linear(5, 10, bias=False), ) + self._x = torch.randn(5) + self._y = torch.randn(10) def _instantiate_optimizer( self, precision_config: PrecisionConfig @@ -756,6 +760,12 @@ def test_precision_configs(self) -> None: filtered_grad_dtype=torch.float16, momentum_dtype=torch.float16, ), + PrecisionConfig( + factor_matrix_dtype=torch.float64, + inv_factor_matrix_dtype=torch.float64, + filtered_grad_dtype=torch.float64, + computation_dtype=torch.float64, + ), ] for precision_config in precision_configs: @@ -767,6 +777,10 @@ def test_precision_configs(self) -> None: self._assert_state_list_dtype(state_list, precision_config) for _ in range(2): + optimizer.zero_grad() + y_hat = self._model(self._x) + loss = torch.nn.CrossEntropyLoss()(y_hat, self._y) + loss.backward() optimizer.step() for state_list in optimizer._per_group_state_lists: self._assert_state_list_dtype(state_list, precision_config) diff --git a/distributed_shampoo/tests/shampoo_test_utils.py b/distributed_shampoo/tests/shampoo_test_utils.py index 98ac489b..6fe86198 100644 --- a/distributed_shampoo/tests/shampoo_test_utils.py +++ b/distributed_shampoo/tests/shampoo_test_utils.py @@ -23,13 +23,14 @@ def __init__( ) -> None: super().__init__() self.linear_layers = nn.ModuleList( - [ - nn.Linear(a, b, bias=bias) - for a, b in itertools.pairwise(model_linear_layers_dims) - ] + nn.Linear(a, b, bias=bias) + for a, b in itertools.pairwise(model_linear_layers_dims) ) if model_dead_layer_dims is not None: - self.useless: nn.Module = nn.Linear(*model_dead_layer_dims, bias=False) + self.dead_layers: nn.ModuleList = nn.ModuleList( + nn.Linear(a, b, bias=False) + for a, b in itertools.pairwise(model_dead_layer_dims) + ) def forward(self, x: torch.Tensor) -> torch.Tensor: for linear_layer in self.linear_layers: @@ -49,7 +50,7 @@ def construct_training_problem( Args: model_linear_layers_dims (tuple[int, ...]): The dimensions of the model linear layers. - model_dead_layer_dims (Optional[tuple[int, ...]]): The dimensions of the model dead layer. (Default: (10, 10)) + model_dead_layer_dims (Optional[tuple[int, ...]]): The dimensions of the model dead linear layers. (Default: (10, 10)) device (Optional[torch.device]): The device to use. (Default: None) bias (bool): Whether to use bias in the linear (non-dead) layers. (Default: False) fill (float | tuple[float, ...]): The value(s) to fill the model parameters. If a tuple, each element should correspond to one layer. (Default: 0.0) diff --git a/distributed_shampoo/utils/gpu_tests/shampoo_ddp_distributor_test.py b/distributed_shampoo/utils/gpu_tests/shampoo_ddp_distributor_test.py index ad6a3ac2..3af5495d 100644 --- a/distributed_shampoo/utils/gpu_tests/shampoo_ddp_distributor_test.py +++ b/distributed_shampoo/utils/gpu_tests/shampoo_ddp_distributor_test.py @@ -165,36 +165,33 @@ def test_losses(self) -> None: device=torch.device("cuda"), ) - def test_empty_local_blocked_params(self) -> None: + # This mock is used to catch the number of calls to Shampoo's step(), which happened after __init__(). + # If there is no blocked params, __init__() will raise and step() should not be called. + # Otherwise, step() will be called. + @mock.patch.object(DistributedShampoo, "step") + def test_empty_local_blocked_params(self, mock_step: mock.Mock) -> None: self._init_distributed() # The test setting is only rank 0 has params, so all other ranks have no parameters to work on. has_blocked_params = dist.get_rank() == 0 with ( - # This mock is used to catch the number of calls to Shampoo's step(), which happened after __init__(). - # If there is no blocked params, __init__() will raise and step() should not be called. - # Otherwise, step() will be called. - mock.patch.object(DistributedShampoo, "step") - ) as mock_step: - with ( - contextlib.nullcontext() - if has_blocked_params - else self.assertRaisesRegex( - AssertionError, - re.escape("Some workers have no parameters to work on."), - ) - ): - ShampooDDPDistributorTest._train_model( - self._shampoo_optim_factory(distributed_config=DDPShampooConfig()), - device=torch.device("cuda"), - # Setting model_linear_layers_dims to (20, 1) creates an model with one linear layer with 20x1 weight. - # Because Shampoo's max_preconditioner_dim = 20, there will be only one block. - # In the case of two trainers per group, there will be one trainer has no params to work on. - model_linear_layers_dims=(20, 1), - model_dead_layer_dims=None, - ) + contextlib.nullcontext() + if has_blocked_params + else self.assertRaisesRegex( + AssertionError, re.escape("Some workers have no parameters to work on.") + ) + ): + ShampooDDPDistributorTest._train_model( + self._shampoo_optim_factory(distributed_config=DDPShampooConfig()), + device=torch.device("cuda"), + # Setting model_linear_layers_dims to (20, 1) creates an model with one linear layer with 20x1 weight. + # Because Shampoo's max_preconditioner_dim = 20, there will be only one block. + # In the case of two trainers per group, there will be one trainer has no params to work on. + model_linear_layers_dims=(20, 1), + model_dead_layer_dims=None, + ) - if has_blocked_params: - mock_step.assert_called() - else: - mock_step.assert_not_called() + if has_blocked_params: + mock_step.assert_called() + else: + mock_step.assert_not_called() diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index b8f07fc5..0e7c26c8 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -32,6 +32,8 @@ compute_matrix_root_inverse_residuals, matrix_inverse_root, ) + +from matrix_functions_types import DefaultEigenConfig, RootInvConfig from optimizer_modules import OptimizerModule from torch import Tensor from torch.autograd import profiler @@ -362,6 +364,7 @@ class ShampooPreconditionerList(PreconditionerList): Note that this should have the same length as block_list. distributor_selector (Tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter is selected by the current Distributor. + root_inv_config (RootInvConfig): Configuration for root inverse computation. (Default: DefaultEigenConfig) beta2 (float): Exponential moving average factor for Shampoo factor matrices. If beta2 = 1., will use unweighted sum. (Default: 1.0) epsilon (float): Epsilon term for regularizing preconditioner to ensure positive definiteness. (Default: 1e-12) @@ -388,6 +391,7 @@ def __init__( state: DefaultDict[Tensor, Any], block_info_list: Tuple[BlockInfo, ...], distributor_selector: Tuple[bool, ...], + root_inv_config: RootInvConfig = DefaultEigenConfig, beta2: float = 1.0, epsilon: float = 1e-12, inv_root_override: Union[int, Tuple[int, ...]] = 0, @@ -401,6 +405,7 @@ def __init__( super().__init__(block_list) # Initialize parameters. + self._root_inv_config = root_inv_config self._beta2 = beta2 self._epsilon = epsilon self._inv_root_override = inv_root_override @@ -690,10 +695,10 @@ def compute_root_inverse(self) -> None: computed_inv_factor_matrix = matrix_inverse_root( A=bias_corrected_factor_matrix, root=root, + root_inv_config=self._root_inv_config, epsilon=self._epsilon, exponent_multiplier=self._exponent_multiplier, is_diagonal=is_factor_matrix_diagonal, - retry_double_precision=self._use_protected_eigh, ).to(dtype=inv_factor_matrix.dtype) # Check if we encounter NaN or inf values in computed inverse matrix. @@ -776,11 +781,12 @@ def compute_root_inverse_residuals( relative_error, relative_residual, ) = compute_matrix_root_inverse_residuals( - bias_corrected_factor_matrix, - inv_factor_matrix, - root, - self._epsilon, - self._exponent_multiplier, + A=bias_corrected_factor_matrix, + X_hat=inv_factor_matrix, + root=root, + epsilon=self._epsilon, + exponent_multiplier=self._exponent_multiplier, + root_inv_config=self._root_inv_config, ) relative_errors.append(relative_error) relative_residuals.append(relative_residual) diff --git a/matrix_functions.py b/matrix_functions.py index baededcb..93f0bfc5 100644 --- a/matrix_functions.py +++ b/matrix_functions.py @@ -11,11 +11,20 @@ import logging import math import time +from dataclasses import asdict from fractions import Fraction from math import isfinite from typing import Tuple, Union import torch +from matrix_functions_types import ( + CoupledHigherOrderConfig, + CoupledNewtonConfig, + DefaultEigenConfig, + EigenConfig, + RootInvConfig, +) + from torch import Tensor logger: logging.Logger = logging.getLogger(__name__) @@ -35,20 +44,6 @@ class NewtonConvergenceFlag(enum.Enum): EARLY_STOP = 2 -class RootInvMethod(enum.Enum): - """ - Enum class for supported root inverse methods, i.e., computing M -> M^{-1/root}. - - EIGEN: Uses eigendecomposition followed by diagonal powering. - NEWTON: Uses coupled inverse Newton iteration (Higham, Functions of Matrices). - HIGHER_ORDER: Uses higher-order variants of NEWTON (Lakic, 1998: "On the Computation of the Matrix k-th Root"). - """ - - EIGEN = 0 - NEWTON = 1 - HIGHER_ORDER = 2 - - def check_diagonal(A: Tensor) -> bool: """Checks if symmetric matrix is diagonal. Throw if the input is not a square matrix.""" @@ -66,30 +61,21 @@ def check_diagonal(A: Tensor) -> bool: def matrix_inverse_root( A: Tensor, root: Union[Fraction, int], + root_inv_config: RootInvConfig = DefaultEigenConfig, epsilon: float = 0.0, exponent_multiplier: float = 1.0, - root_inv_method: RootInvMethod = RootInvMethod.EIGEN, - max_iterations: int = 100, - tolerance: float = 1e-6, is_diagonal: Union[Tensor, bool] = False, - retry_double_precision: bool = True, - order: int = 3, ) -> Tensor: """Computes matrix root inverse of square symmetric positive definite matrix. Args: A (Tensor): Square matrix of interest. root (int): Root of interest. Any natural number. + root_inv_config (RootInvConfig): Configuration for root inverse computation. (Default: DefaultEigenConfig) epsilon (float): Adds epsilon * I to matrix before taking matrix root. (Default: 0.0) exponent_multiplier (float): exponent multiplier in the eigen method (Default: 1.0) - root_inv_method (RootInvMethod): Specifies method to use to compute root inverse. (Default: RootInvMethod.EIGEN) - max_iterations (int): Maximum number of iterations for coupled Newton iteration. (Default: 1000) - tolerance (float): Tolerance for computing root inverse using coupled Newton iteration. (Default: 1e-6) is_diagonal (Tensor, bool): Flag for whether or not matrix is diagonal. If so, will compute root inverse by computing root inverse of diagonal entries. (Default: False) - retry_double_precision (bool): Flag for re-trying eigendecomposition with higher precision if lower precision fails due - to CuSOLVER failure. (Default: True) - order (int): Order used in the higher-order method. (Default: 3) Returns: X (Tensor): Inverse root of matrix A. @@ -108,24 +94,22 @@ def matrix_inverse_root( raise ValueError("Matrix is not square!") if is_diagonal: - X = matrix_root_diagonal( + X = _matrix_root_diagonal( A=A, root=root, epsilon=epsilon, - inverse=True, exponent_multiplier=exponent_multiplier, return_full_matrix=True, ) - elif root_inv_method == RootInvMethod.EIGEN: + elif type(root_inv_config) is EigenConfig: X, _, _ = _matrix_root_eigen( A=A, root=root, epsilon=epsilon, - inverse=True, exponent_multiplier=exponent_multiplier, - retry_double_precision=retry_double_precision, + **asdict(root_inv_config), ) - elif root_inv_method == RootInvMethod.NEWTON: + elif type(root_inv_config) is CoupledNewtonConfig: if exponent_multiplier != 1.0: raise ValueError( f"Exponent multiplier {exponent_multiplier} must be equal to 1 to use coupled inverse Newton iteration!" @@ -140,14 +124,13 @@ def matrix_inverse_root( A=A, root=root, epsilon=epsilon, - max_iterations=max_iterations, - tolerance=tolerance, + **asdict(root_inv_config), ) if termination_flag == NewtonConvergenceFlag.REACHED_MAX_ITERS: logging.warning( "Newton did not converge and reached maximum number of iterations!" ) - elif root_inv_method == RootInvMethod.HIGHER_ORDER: + elif type(root_inv_config) is CoupledHigherOrderConfig: if exponent_multiplier != 1.0: raise ValueError( f"Exponent multiplier {exponent_multiplier} must be equal to 1 to use coupled higher order method!" @@ -156,11 +139,8 @@ def matrix_inverse_root( X, _, termination_flag, _, _ = _matrix_inverse_root_higher_order( A=A, root=Fraction(root), - rel_epsilon=epsilon, abs_epsilon=epsilon, - order=order, - max_iterations=max_iterations, - tolerance=tolerance, + **asdict(root_inv_config), ) if termination_flag == NewtonConvergenceFlag.REACHED_MAX_ITERS: logging.warning( @@ -168,17 +148,16 @@ def matrix_inverse_root( ) else: raise NotImplementedError( - f"Root inverse method is not implemented! Specified root inverse method is {str(root_inv_method)}." + f"Root inverse config is not implemented! Specified root inverse config is {root_inv_config=}." ) return X -def matrix_root_diagonal( +def _matrix_root_diagonal( A: Tensor, root: Union[Fraction, int], epsilon: float = 0.0, - inverse: bool = True, exponent_multiplier: float = 1.0, return_full_matrix: bool = False, ) -> Tensor: @@ -188,46 +167,38 @@ def matrix_root_diagonal( A (Tensor): One- or two-dimensional tensor containing either the diagonal entries of the matrix or a diagonal matrix. root (int): Root of interest. Any natural number. epsilon (float): Adds epsilon * I to matrix before taking matrix root. (Default: 0.0) - inverse (bool): Returns inverse root matrix. (Default: True) + exponent_multiplier (float): exponent multiplier in the eigen method (Default: 1.0) return_full_matrix (bool): Returns full matrix by taking torch.diag of diagonal entries. (bool: False) Returns: X (Tensor): Inverse root of diagonal entries. """ - - # check order of tensor - order = len(A.shape) - if order == 2: - A = torch.diag(A) - elif order > 2: - raise ValueError("Matrix is not 2-dimensional!") - # check if root is positive integer if root <= 0: raise ValueError(f"Root {root} should be positive!") # compute matrix power - alpha = exponent_multiplier / root - if inverse: - alpha = -alpha + alpha = -exponent_multiplier / root - X = (A + epsilon).pow(alpha) - return torch.diag(X) if return_full_matrix else X + return ( + torch.diag(X := (torch.diag(A) + epsilon).pow(alpha)) + if return_full_matrix + else X + ) def _matrix_root_eigen( A: Tensor, root: Union[Fraction, int], epsilon: float = 0.0, - inverse: bool = True, exponent_multiplier: float = 1.0, make_positive_semidefinite: bool = True, retry_double_precision: bool = True, ) -> Tuple[Tensor, Tensor, Tensor]: - """Compute matrix (inverse) root using eigendecomposition of symmetric positive (semi-)definite matrix. + """Compute matrix inverse root using eigendecomposition of symmetric positive (semi-)definite matrix. - A = Q L Q^T => A^{1/r} = Q L^{1/r} Q^T OR A^{-1/r} = Q L^{-1/r} Q^T + A^{-1/r} = Q L^{-1/r} Q^T Assumes matrix A is symmetric. @@ -235,7 +206,6 @@ def _matrix_root_eigen( A (Tensor): Square matrix of interest. root (int): Root of interest. Any natural number. epsilon (float): Adds epsilon * I to matrix before taking matrix root. (Default: 0.0) - inverse (bool): Returns inverse root matrix. (Default: True) exponent_multiplier (float): exponent multiplier in the eigen method (Default: 1.0) make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True) retry_double_precision (bool): Flag for re-trying eigendecomposition with higher precision if lower precision fails due @@ -253,9 +223,7 @@ def _matrix_root_eigen( raise ValueError(f"Root {root} should be positive!") # compute matrix power - alpha = exponent_multiplier / root - if inverse: - alpha = -alpha + alpha = -exponent_multiplier / root # compute eigendecomposition and compute minimum eigenvalue try: @@ -393,7 +361,7 @@ def _matrix_inverse_root_higher_order( Generally recommend setting according to A.dtype (1e-3 for tf32, 1e-5 for fp32, 1e-9 for fp64) (Default: 0.0) max_iterations (int): Maximum number of iterations. Typically we need < 20 iterations. (Default: 100) tolerance (float): Tolerance for determining exit criterion from iterations. (Default: 1e-20, which in practice guarantees they run to convergence) - order (int): Order of the method. Order must be >= 2. Higher order methods accelerate convergence (fewer iterations), but can take more matmuls per iteration. (Default: 2, ie Newton) + order (int): Order of the method. Order must be >= 2. Higher order methods accelerate convergence (fewer iterations), but can take more matmuls per iteration. (Default: 3) disable_tf32 (bool): Whether to disable tf32 matmuls or not internally. Highly recommend keeping True, since tf32 is challenging numerically here. (Default: True) Returns: @@ -482,7 +450,6 @@ def _matrix_inverse_root_higher_order( ) # main while loop - termination_flag = NewtonConvergenceFlag.CONVERGED while error > tolerance and iteration < max_iterations: t_iter_begin = time.time() iteration += 1 @@ -515,10 +482,13 @@ def _matrix_inverse_root_higher_order( logger.debug( f"Iteration dur (s): {t_iter_end - t_iter_begin}, Error (|M-I|) at iteration {iteration}: {error.item()}" ) - - # determine convergence flag - if termination_flag != NewtonConvergenceFlag.EARLY_STOP and error > tolerance: - termination_flag = NewtonConvergenceFlag.REACHED_MAX_ITERS + else: + # determine convergence flag based on error and tolerance because the main while loop exited with False condition. + termination_flag = ( + NewtonConvergenceFlag.REACHED_MAX_ITERS + if error > tolerance + else NewtonConvergenceFlag.CONVERGED + ) # compute a cheap error proxy true_error = torch.linalg.vector_norm( @@ -565,6 +535,7 @@ def compute_matrix_root_inverse_residuals( root: int, epsilon: float, exponent_multiplier: float, + root_inv_config: RootInvConfig = DefaultEigenConfig, ) -> Tuple[Tensor, Tensor]: """Compute residual of matrix root inverse for debugging purposes. @@ -577,6 +548,7 @@ def compute_matrix_root_inverse_residuals( root (int): Root of interest. epsilon (float): Adds epsilon * I to matrix. exponent_multiplier (float): Exponent multiplier to be multiplied to the numerator of the inverse root. + root_inv_config (RootInvConfig): Configuration for root inverse computation (only supports EigenConfig for now). (Default: DefaultEigenConfig) Returns: absolute_error (Tensor): absolute error of matrix root inverse @@ -584,6 +556,10 @@ def compute_matrix_root_inverse_residuals( residual (Tensor): residual of matrix root inverse """ + # only do root inverse residual computation for EigenConfig + assert ( + type(root_inv_config) is EigenConfig + ), f"Only EigenConfig is supported for compute_matrix_root_inverse_residuals; currently {root_inv_config=}." # check shape of matrix if len(A.shape) != 2: @@ -595,7 +571,11 @@ def compute_matrix_root_inverse_residuals( # compute error by comparing against double precision X = matrix_inverse_root( - A.double(), root, epsilon=epsilon, exponent_multiplier=exponent_multiplier + A.double(), + root, + root_inv_config=root_inv_config, + epsilon=epsilon, + exponent_multiplier=exponent_multiplier, ) relative_error = torch.dist(X, X_hat, p=torch.inf) / torch.norm(X, p=torch.inf) @@ -607,7 +587,6 @@ def compute_matrix_root_inverse_residuals( X_hat.double(), root=1, epsilon=0.0, - inverse=True, make_positive_semidefinite=True, exponent_multiplier=root / exponent_multiplier, ) diff --git a/matrix_functions_types.py b/matrix_functions_types.py new file mode 100644 index 00000000..8e4a130c --- /dev/null +++ b/matrix_functions_types.py @@ -0,0 +1,72 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. + +""" + +from dataclasses import dataclass + + +@dataclass(kw_only=True) +class RootInvConfig: + """Base dataclass for matrix root inverse method configurations in Shampoo.""" + + ... + + +@dataclass(kw_only=True) +class EigenConfig(RootInvConfig): + """Configuration for eigendecomposition method in Shampoo. + + Args: + make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True) + retry_double_precision (bool): Whether to re-trying eigendecomposition with higher(double) precision if lower precision fails due + to CuSOLVER failure. (Default: True) + + """ + + make_positive_semidefinite: bool = True + retry_double_precision: bool = True + + +DefaultEigenConfig = EigenConfig() + + +@dataclass(kw_only=True) +class CoupledNewtonConfig(RootInvConfig): + """Configuration for coupled Newton method in Shampoo. + + Args: + max_iterations (int): Maximum number of iterations for coupled Newton iteration. (Default: 100) + tolerance (float): Tolerance for computing root inverse using coupled Newton iteration. (Default: 1e-6) + + """ + + max_iterations: int = 100 + tolerance: float = 1e-6 + + +@dataclass(kw_only=True) +class CoupledHigherOrderConfig(RootInvConfig): + """Configuration for coupled higher-order method in Shampoo. + + Args: + rel_epsilon (float): Relative epsilon for coupled higher order method. Adds epsilon * lambda_max * I to matrix + before taking matrix root, where lambda_max is an upper bound on maximum eigenvalue. (Default: 0.0) + max_iterations (int): Maximum number of iterations for coupled higher order method. (Default: 100) + tolerance (float): Tolerance for computing root inverse using coupled higher order method. (Default: 1e-8) + order (int): Order of the method. Order must be >= 2. Higher order methods accelerate convergence (fewer iterations), + but can take more matmuls per iteration. order=2 represents Newton's method. (Default: 3) + disable_tf32 (bool): Whether to disable tf32 matmuls or not internally. Highly recommend keeping True, + since tf32 is challenging numerically here. (Default: True) + + """ + + rel_epsilon: float = 0.0 + max_iterations: int = 100 + tolerance: float = 1e-8 + order: int = 3 + disable_tf32: bool = True diff --git a/tests/matrix_functions_test.py b/tests/matrix_functions_test.py index 0d61f9c6..c0982b62 100644 --- a/tests/matrix_functions_test.py +++ b/tests/matrix_functions_test.py @@ -27,9 +27,13 @@ check_diagonal, compute_matrix_root_inverse_residuals, matrix_inverse_root, - matrix_root_diagonal, NewtonConvergenceFlag, - RootInvMethod, +) +from matrix_functions_types import ( + CoupledHigherOrderConfig, + CoupledNewtonConfig, + EigenConfig, + RootInvConfig, ) from torch import Tensor @@ -61,7 +65,9 @@ def test_matrix_inverse_root_scalar(self) -> None: self.assertEqual( A ** torch.tensor(-1.82 / 2), matrix_inverse_root( - A, root=root, exponent_multiplier=exponent_multiplier + A, + root=root, + exponent_multiplier=exponent_multiplier, ), ) with self.subTest("Test with matrix case."): @@ -142,7 +148,6 @@ def test_matrix_inverse_root(self) -> None: A_list[i], root=root, exponent_multiplier=exponent_multiplier, - root_inv_method=RootInvMethod.EIGEN, is_diagonal=False, ), atol=atol, @@ -155,8 +160,8 @@ def test_matrix_inverse_root(self) -> None: matrix_inverse_root( A_list[i], root=root, + root_inv_config=CoupledNewtonConfig(), exponent_multiplier=exponent_multiplier, - root_inv_method=RootInvMethod.NEWTON, is_diagonal=False, ), atol=atol, @@ -170,9 +175,8 @@ def test_matrix_inverse_root(self) -> None: matrix_inverse_root( A_list[i], root=Fraction(root), + root_inv_config=CoupledHigherOrderConfig(order=order), exponent_multiplier=exponent_multiplier, - root_inv_method=RootInvMethod.HIGHER_ORDER, - order=order, is_diagonal=False, ), atol=atol, @@ -185,9 +189,8 @@ def test_matrix_inverse_root(self) -> None: matrix_inverse_root( A_list[i], root=Fraction(root) / exp, + root_inv_config=CoupledHigherOrderConfig(order=order), exponent_multiplier=exponent_multiplier, - root_inv_method=RootInvMethod.HIGHER_ORDER, - order=order, is_diagonal=False, ), atol=atol, @@ -196,7 +199,7 @@ def test_matrix_inverse_root(self) -> None: def test_matrix_inverse_root_higher_order_blowup(self) -> None: A = torch.tensor([[1.0, 0.0], [0.0, 1e-4]]) - root_inv_method = RootInvMethod.HIGHER_ORDER + root_inv_config = CoupledHigherOrderConfig() self.assertRaisesRegex( ArithmeticError, re.escape( @@ -206,30 +209,30 @@ def test_matrix_inverse_root_higher_order_blowup(self) -> None: A=A, root=Fraction(1, 20), exponent_multiplier=1.0, - root_inv_method=root_inv_method, + root_inv_config=root_inv_config, ) def test_matrix_inverse_root_with_no_effect_exponent_multiplier(self) -> None: A = torch.tensor([[1.0, 0.0], [0.0, 4.0]]) - root_inv_method_and_msg: List[Tuple[RootInvMethod, str]] = [ - (RootInvMethod.NEWTON, "inverse Newton iteration"), - (RootInvMethod.HIGHER_ORDER, "higher order method"), + root_inv_config_and_msg: List[Tuple[RootInvConfig, str]] = [ + (CoupledNewtonConfig(), "inverse Newton iteration"), + (CoupledHigherOrderConfig(), "higher order method"), ] - for root_inv_method, root_inv_method_msg in root_inv_method_and_msg: + for root_inv_config, root_inv_config_msg in root_inv_config_and_msg: with self.subTest( - root_inv_method=root_inv_method, root_inv_method_msg=root_inv_method_msg + root_inv_config=root_inv_config, root_inv_config_msg=root_inv_config_msg ): self.assertRaisesRegex( ValueError, re.escape( - f"Exponent multiplier 2.0 must be equal to 1 to use coupled {root_inv_method_msg}!" + f"Exponent multiplier 2.0 must be equal to 1 to use coupled {root_inv_config_msg}!" ), matrix_inverse_root, A=A, root=2, exponent_multiplier=2.0, - root_inv_method=root_inv_method, + root_inv_config=root_inv_config, ) def test_matrix_inverse_root_newton_fraction(self) -> None: @@ -242,28 +245,28 @@ def test_matrix_inverse_root_newton_fraction(self) -> None: matrix_inverse_root, A=A, root=Fraction(numerator=1, denominator=2), - root_inv_method=RootInvMethod.NEWTON, + root_inv_config=CoupledNewtonConfig(), is_diagonal=False, ) def test_matrix_inverse_root_reach_max_iterations(self) -> None: A = torch.tensor([[1.0, 0.0], [0.0, 4.0]]) root = 4 - root_inv_method_and_implementation_and_msg: List[ - Tuple[RootInvMethod, str, str] + root_inv_config_and_implementation_and_msg: List[ + Tuple[RootInvConfig, str, str] ] = [ - (RootInvMethod.NEWTON, "_matrix_inverse_root_newton", "Newton"), + (CoupledNewtonConfig(), "_matrix_inverse_root_newton", "Newton"), ( - RootInvMethod.HIGHER_ORDER, + CoupledHigherOrderConfig(), "_matrix_inverse_root_higher_order", "Higher order method", ), ] for ( - root_inv_method, + root_inv_config, implementation, msg, - ) in root_inv_method_and_implementation_and_msg: + ) in root_inv_config_and_implementation_and_msg: with mock.patch.object( matrix_functions, implementation, @@ -275,7 +278,7 @@ def test_matrix_inverse_root_reach_max_iterations(self) -> None: None, ), ), self.subTest( - root_inv_method=root_inv_method, + root_inv_config=root_inv_config, implementation=implementation, msg=msg, ), self.assertLogs( @@ -284,7 +287,7 @@ def test_matrix_inverse_root_reach_max_iterations(self) -> None: matrix_inverse_root( A=A, root=root, - root_inv_method=root_inv_method, + root_inv_config=root_inv_config, ) self.assertIn( f"{msg} did not converge and reached maximum number of iterations!", @@ -303,56 +306,41 @@ def test_matrix_inverse_root_higher_order_tf32_preservation(self) -> None: A=A, root=Fraction(root), exponent_multiplier=exponent_multiplier, - root_inv_method=RootInvMethod.HIGHER_ORDER, + root_inv_config=CoupledHigherOrderConfig(), ) tf32_flag_after = torch.backends.cuda.matmul.allow_tf32 assert tf32_flag_before == tf32_flag_after - def test_matrix_inverse_root_with_invalid_root_inv_method(self) -> None: + def test_matrix_inverse_root_with_invalid_root_inv_config(self) -> None: A = torch.tensor([[1.0, 0.0], [0.0, 4.0]]) root = 4 - with mock.patch.object( - RootInvMethod, "__eq__", return_value=False - ), self.assertRaisesRegex( + with self.assertRaisesRegex( NotImplementedError, re.escape( - "Root inverse method is not implemented! Specified root inverse method is RootInvMethod.NEWTON." + "Root inverse config is not implemented! Specified root inverse config is root_inv_config=RootInvConfig()." ), ): matrix_inverse_root( A=A, root=root, - root_inv_method=RootInvMethod.NEWTON, + root_inv_config=RootInvConfig(), is_diagonal=False, ) class MatrixRootDiagonalTest(unittest.TestCase): - def test_matrix_root_diagonal_with_not_two_dim_matrix(self) -> None: - A = torch.zeros((1, 2, 3)) - root = 4 - exponent_multiplier = 1.82 - self.assertRaisesRegex( - ValueError, - re.escape("Matrix is not 2-dimensional!"), - matrix_root_diagonal, - A=A, - root=root, - exponent_multiplier=exponent_multiplier, - return_full_matrix=True, - ) - def test_matrix_root_diagonal_nonpositive_root(self) -> None: A = torch.tensor([[-1.0, 0.0], [0.0, 2.0]]) - root = -1 - self.assertRaisesRegex( - ValueError, - re.escape(f"Root {root} should be positive!"), - matrix_root_diagonal, - A=A, - root=root, - return_full_matrix=True, - ) + for root in (-1, 0): + with self.subTest(f"With {root=}"): + self.assertRaisesRegex( + ValueError, + re.escape(f"Root {root} should be positive!"), + matrix_inverse_root, + A=A, + root=root, + is_diagonal=True, + ) class EigenRootTest(unittest.TestCase): @@ -361,7 +349,6 @@ def _test_eigen_root( A: torch.Tensor, root: int, make_positive_semidefinite: bool, - inverse: bool, epsilon: float, tolerance: float, eig_sols: Tensor, @@ -371,11 +358,8 @@ def _test_eigen_root( root=root, epsilon=epsilon, make_positive_semidefinite=make_positive_semidefinite, - inverse=inverse, ) - if inverse: - root = -root - abs_error = torch.dist(torch.linalg.matrix_power(X, root), A, p=torch.inf) + abs_error = torch.dist(torch.linalg.matrix_power(X, -root), A, p=torch.inf) A_norm = torch.linalg.norm(A, ord=torch.inf) rel_error = abs_error / torch.maximum(torch.tensor(1.0), A_norm) torch.testing.assert_close(L, eig_sols) @@ -397,16 +381,6 @@ def _test_eigen_root_multi_dim( A(n), root, make_positive_semidefinite, - False, - epsilon, - tolerance, - eig_sols(n), - ) - self._test_eigen_root( - A(n), - root, - make_positive_semidefinite, - True, epsilon, tolerance, eig_sols(n), @@ -524,7 +498,7 @@ def test_matrix_root_eigen_nonpositive_root(self) -> None: self.assertRaisesRegex( ValueError, re.escape(f"Root {root} should be positive!"), - _matrix_root_eigen, + matrix_inverse_root, A=A, root=root, ) @@ -539,13 +513,11 @@ def test_no_retry_double_precision_raise_exception( ) -> None: A = torch.tensor([[-1.0, 0.0], [0.0, 2.0]]) with self.assertRaisesRegex(RuntimeError, re.escape("Mock Eigen Error")): - _matrix_root_eigen( + matrix_inverse_root( A=A, root=2, + root_inv_config=EigenConfig(retry_double_precision=False), epsilon=0.0, - make_positive_semidefinite=True, - inverse=False, - retry_double_precision=False, ) mock_eigh.assert_called_once() @@ -555,13 +527,10 @@ def test_no_retry_double_precision_raise_exception( def test_retry_double_precision_raise_exception(self, mock_eigh: mock.Mock) -> None: A = torch.tensor([[-1.0, 0.0], [0.0, 2.0]]) with self.assertRaisesRegex(RuntimeError, re.escape("Mock Eigen Error")): - _matrix_root_eigen( + matrix_inverse_root( A=A, root=2, epsilon=0.0, - make_positive_semidefinite=True, - inverse=False, - retry_double_precision=True, ) mock_eigh.assert_called() self.assertEqual(mock_eigh.call_count, 2) @@ -578,13 +547,10 @@ def test_retry_double_precision_double_precision( self, mock_eigh: mock.Mock ) -> None: A = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) - X, _, _ = _matrix_root_eigen( + X = matrix_inverse_root( A=A, root=2, epsilon=0.0, - make_positive_semidefinite=True, - inverse=False, - retry_double_precision=True, ) torch.testing.assert_close(X, torch.eye(2)) mock_eigh.assert_called() @@ -601,7 +567,7 @@ def _test_newton_root_inverse( A_tol: float, M_tol: float, ) -> None: - X, M, flag, iteration, M_error = _matrix_inverse_root_newton( + X, _, _, _, M_error = _matrix_inverse_root_newton( A, root, epsilon, max_iterations, M_tol ) abs_A_error = torch.dist(torch.linalg.matrix_power(X, -root), A, p=torch.inf)