Skip to content

Commit

Permalink
Open-sourced update on 09/26/2024
Browse files Browse the repository at this point in the history
Summary:
1. Add `RootInvConfig` and `matrix_functions_types.py`  for controlling root inverse computation.
2. Refactor `matrix_functions.py`.

Reviewed By: hjmshi

Differential Revision: D63465130

fbshipit-source-id: a3aa124d0a5a7844ae6ff14ee2a1fce0d2e9f537
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Sep 27, 2024
1 parent 7e2ac7b commit b854998
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 259 deletions.
20 changes: 2 additions & 18 deletions distributed_shampoo/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# PyTorch Distributed Shampoo

Distributed Shampoo is a preconditioned stochastic gradient optimizer in the adaptive gradient (Adagrad) family of methods [1, 2]. It converges faster by leveraging neural network-specific structures to achieve comparable model quality/accuracy in fewer iterations or epochs at the cost of additional FLOPs and memory, or achieve higher model quality in the same number of iterations or epochs. Our implementation offers specialized support for serial, [Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html), and [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) training.
Distributed Shampoo is a preconditioned stochastic gradient optimizer in the adaptive gradient (Adagrad) family of methods [1, 2]. It converges faster by leveraging neural network-specific structures to achieve comparable model quality/accuracy in fewer iterations or epochs at the cost of additional FLOPs and memory, or achieve higher model quality in the same number of iterations or epochs. Our implementation offers specialized support for serial, [Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html), [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Hybrid Sharding Data Parallel](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html#how-to-use-devicemesh-with-hsdp) training.

Distributed Shampoo currently only supports dense parameters.

Expand All @@ -16,23 +16,7 @@ Developers:

with contributions and support from:

Ganesh Ajjanagadde (Meta), Rohan Anil (Google), Adnan Aziz (Meta), Pavan Balaji (Meta), Shuo Chang (Meta), Weiwei Chu (Meta), Assaf Eisenman (Meta), Will Feng (Meta), Zhuobo Feng (Meta), Jose Gallego-Posada (Mila / Meta Platforms, Inc.), Avirup Ghosh (Meta), Yizi Gu (Meta), Vineet Gupta (Google), Yuchen Hao (Meta), Brian Hirsh (Meta), Yusuo Hu (Meta), Yuxi Hu (Meta), Minhui Huang (Meta), Guna Lakshminarayanan (Meta), Michael Lazos (Meta), Zhijing Li (Meta), Ming Liang (Meta), Wanchao Liang (Meta), Ying Liu (Meta), Wenguang Mao (Meta), Dheevatsa Mudigere (NVIDIA), Maxim Naumov (Meta), Jongsoo Park (Meta), Mike Rabbat (Meta), Kaushik Rangadurai (Meta), Dennis van der Staay (Meta), Fei Tian (Meta), Rohan Varma (Meta), Sanjay Vishwakarma (Meta), Xunnan (Shawn) Xu (Meta), Jiyan Yang (Meta), Chunxing Yin (Meta), Iris Zhang (Meta), and Will Zou (Meta).

## Updates
- (7/18/24) This update contains
- PyTorch 2 compile bug fixes.
- HSDP Shampoo via `HSDPDistributor`.
- Mixed-precision optimizer states.
- Higher-order coupled iterations, with relative epsilon based on estimate of largest eigenvalue.
- Further modularization of Shampoo step function.
- Other simplifications.
- (2/14/24) We have released our Distributed Shampoo v2.0.0 implementation, a ground-up re-write of our PyTorch Shampoo implementation. Our v2.0.0 implementation includes:
- Incorporates new performance optimizations, such as the usage of `torch._foreach_*` operators and PyTorch 2 compile.
- Shared support and enablement of DDP and FSDP Shampoo, via the specification of the `distributed_config` field.
- Cleaner API for configuring grafting methods through specifying the `grafting_config` field.
- Deprecation of handling large tensors by diagonalizing the Shampoo preconditioners and using standard diagonal Adagrad.
- While we do not currently support LAMB/LARS grafting, we intend to add support for this in the future.
- We will update our [ArXiv paper](https://arxiv.org/pdf/2309.06497.pdf) to reflect our implementation changes.
Ganesh Ajjanagadde (Meta), Rohan Anil (Google), Adnan Aziz (Meta), Pavan Balaji (Meta), Shuo Chang (Meta), Weiwei Chu (Meta), Assaf Eisenman (Meta), Will Feng (Meta), Zhuobo Feng (Meta), Jose Gallego-Posada (Mila / Meta Platforms, Inc.), Avirup Ghosh (Meta), Yizi Gu (Meta), Vineet Gupta (Google), Yuchen Hao (Meta), Brian Hirsh (Meta), Yusuo Hu (Meta), Yuxi Hu (Meta), Minhui Huang (Meta), Guna Lakshminarayanan (Meta), Michael Lazos (Meta), Zhijing Li (Meta), Ming Liang (Meta), Wanchao Liang (Meta), Ying Liu (Meta), Wenguang Mao (Meta), Dheevatsa Mudigere (NVIDIA), Maxim Naumov (Meta), Jongsoo Park (Meta), Mike Rabbat (Meta), Kaushik Rangadurai (Meta), Dennis van der Staay (Meta), Fei Tian (Meta), Rohan Varma (Meta), Sanjay Vishwakarma (Meta), Xunnan (Shawn) Xu (Meta), Jiyan Yang (Meta), Chunxing Yin (Meta), Iris Zhang (Meta), Chuanhao Zhuge (Meta), and Will Zou (Meta).

## Features

Expand Down
59 changes: 20 additions & 39 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
PRECONDITIONER_DTYPE,
PREVIOUS_GRAD_SELECTOR,
RMSpropGraftingConfig,
ROOT_INV_CONFIG,
SGDGraftingConfig,
SHAMPOO_PRECONDITIONER_LIST,
ShampooPT2CompileConfig,
Expand Down Expand Up @@ -99,6 +100,8 @@
QuantizedTensorList,
)
from distributed_shampoo.utils.shampoo_utils import compress_list

from matrix_functions_types import DefaultEigenConfig, RootInvConfig
from torch.optim.optimizer import ParamsT

logger: logging.Logger = logging.getLogger(__name__)
Expand All @@ -107,45 +110,6 @@
class DistributedShampoo(torch.optim.Optimizer):
"""Implements distributed Shampoo algorithm.
Developers:
Hao-Jun Michael Shi (Meta Platforms, Inc.)
Tsung-Hsien Lee
Anna Cai (Meta Platforms, Inc.)
Shintaro Iwasaki (Meta Platforms, Inc.)
Ke Sang (Meta Platforms, Inc.)
Wang Zhou (Meta Platforms, Inc.)
with contributions and support from:
Ganesh Ajjanagadde (Meta), Rohan Anil (Google), Adnan Aziz (Meta), Pavan Balaji (Meta), Shuo Chang (Meta), Weiwei Chu (Meta),
Assaf Eisenman (Meta), Will Feng (Meta), Zhuobo Feng (Meta), Jose Gallego-Posada (Mila / Meta Platforms, Inc.), Avirup Ghosh (Meta),
Yizi Gu (Meta), Vineet Gupta (Google), Yuchen Hao (Meta), Brian Hirsh (Meta), Yusuo Hu (Meta), Yuxi Hu (Meta), Minhui Huang (Meta),
Guna Lakshminarayanan (Meta), Michael Lazos (Meta), Zhijing Li (Meta), Ming Liang (Meta), Wanchao Liang (Meta), Ying Liu
(Meta), Wenguang Mao (Meta), Dheevatsa Mudigere (NVIDIA), Maxim Naumov (Meta), Jongsoo Park (Meta), Mike Rabbat (Meta),
Kaushik Rangadurai (Meta), Dennis van der Staay (Meta), Fei Tian (Meta), Sanjay Vishwakarma (Meta), Xunnan (Shawn) Xu (Meta),
Jiyan Yang (Meta), Chunxing Yin (Meta), and Iris Zhang (Meta).
Details in: https://arxiv.org/pdf/2309.06497.pdf.
Partly based on the work in:
- https://arxiv.org/pdf/1802.09568.pdf
- https://arxiv.org/pdf/2002.09018.pdf
------------
Requirements
------------
1. PyTorch >= 2.0
2. Python >= 3.10
3. CUDA 11.3, 11.4, 12.2+
In order to support checkpointing, one must use torch.distributed.checkpoint and pass the named parameters into state_dict.
Note that the standard checkpointing solution by PyTorch is not supported!
Note: We have observed known instabilities with the torch.linalg.eigh operator on CUDA 11.6-12.1, specifically for low-rank
matrices, which may appear with using a small start_preconditioning_step. Please avoid these versions of CUDA if possible.
See: https://github.com/pytorch/pytorch/issues/94772.
--------
Features
--------
Expand Down Expand Up @@ -296,6 +260,7 @@ class DistributedShampoo(torch.optim.Optimizer):
3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail.
track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes.
(Default: False)
root_inv_config (RootInvConfig): Configuration for root inverse computation. (Default: DefaultEigenConfig)
"""

Expand Down Expand Up @@ -326,6 +291,7 @@ def __init__(
precision_config: Optional[PrecisionConfig] = None,
use_protected_eigh: bool = True,
track_root_inv_residuals: bool = False,
root_inv_config: RootInvConfig = DefaultEigenConfig,
) -> None:
# Hyperparameter checks.
if not lr >= 0.0:
Expand Down Expand Up @@ -464,6 +430,7 @@ def __init__(
USE_MERGE_DIMS: use_merge_dims,
PRECONDITIONER_DTYPE: preconditioner_dtype,
PRECISION_CONFIG: precision_config,
ROOT_INV_CONFIG: root_inv_config,
},
)

Expand Down Expand Up @@ -542,6 +509,7 @@ def _instantiate_shampoo_preconditioner_list(self) -> None:
state=self.state,
block_info_list=state_lists[DISTRIBUTOR].global_block_info_list,
distributor_selector=state_lists[DISTRIBUTOR].distributor_selector,
root_inv_config=group[ROOT_INV_CONFIG],
beta2=group[BETAS][1],
epsilon=group[EPSILON],
inv_root_override=group[INV_ROOT_OVERRIDE],
Expand Down Expand Up @@ -1171,6 +1139,19 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
if not state_lists[MASKED_BLOCKED_GRADS]:
continue

# Convert the gradient dtype to the computation dtype set in the precision_config if
# necessary.
#
# This conversion is needed because the blocked gradient list has float32 dtype, and we
# need to convert it to the desired precision before precondition computation.
if (
computation_dtype := group[PRECISION_CONFIG].computation_dtype
) != state_lists[MASKED_BLOCKED_GRADS][0].dtype:
state_lists[MASKED_BLOCKED_GRADS] = tuple(
tensor.to(dtype=computation_dtype)
for tensor in state_lists[MASKED_BLOCKED_GRADS]
)

# Iterate group step counter and define Python scalar step.
step = state_lists[STEP].add_(1)
# NOTE: Wrap scalar of group[LR] into a 0D tensor to avoid PT2 recompilation;
Expand Down
1 change: 1 addition & 0 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
PRECISION_CONFIG = "precision_config"
PRECONDITION_FREQUENCY = "precondition_frequency"
PRECONDITIONER_DTYPE = "preconditioner_dtype"
ROOT_INV_CONFIG = "root_inv_config"
START_PRECONDITIONING_STEP = "start_preconditioning_step"
USE_BIAS_CORRECTION = "use_bias_correction"
USE_DECOUPLED_WEIGHT_DECAY = "use_decoupled_weight_decay"
Expand Down
28 changes: 21 additions & 7 deletions distributed_shampoo/tests/distributed_shampoo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ShampooPreconditionerList,
)
from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList
from matrix_functions_types import DefaultEigenConfig
from torch import nn


Expand Down Expand Up @@ -267,15 +268,15 @@ def closure() -> float:

self.assertEqual(self._optimizer.step(closure=closure), 1.0)

def test_step_with_empty_grad_list(self) -> None:
@mock.patch.object(ShampooPreconditionerList, "update_preconditioners")
def test_step_with_empty_grad_list(
self, mock_upgrade_preconditioners: mock.Mock
) -> None:
# Test the case that the grad_list is empty.
self._optimizer.zero_grad()
with mock.patch.object(
ShampooPreconditionerList, "update_preconditioners"
) as mock_upgrade_preconditioners:
self._optimizer.step()
# Because the gradient list is empty, the preconditioners should not be updated.
mock_upgrade_preconditioners.assert_not_called()
self._optimizer.step()
# Because the gradient list is empty, the preconditioners should not be updated.
mock_upgrade_preconditioners.assert_not_called()


class DistributedShampooStateDictTest(unittest.TestCase):
Expand Down Expand Up @@ -447,6 +448,7 @@ def setUp(self) -> None:
"use_merge_dims": True,
"preconditioner_dtype": None,
"precision_config": PrecisionConfig(),
"root_inv_config": DefaultEigenConfig,
}
},
}
Expand Down Expand Up @@ -674,6 +676,8 @@ def setUp(self) -> None:
self._model = nn.Sequential(
nn.Linear(5, 10, bias=False),
)
self._x = torch.randn(5)
self._y = torch.randn(10)

def _instantiate_optimizer(
self, precision_config: PrecisionConfig
Expand Down Expand Up @@ -756,6 +760,12 @@ def test_precision_configs(self) -> None:
filtered_grad_dtype=torch.float16,
momentum_dtype=torch.float16,
),
PrecisionConfig(
factor_matrix_dtype=torch.float64,
inv_factor_matrix_dtype=torch.float64,
filtered_grad_dtype=torch.float64,
computation_dtype=torch.float64,
),
]

for precision_config in precision_configs:
Expand All @@ -767,6 +777,10 @@ def test_precision_configs(self) -> None:
self._assert_state_list_dtype(state_list, precision_config)

for _ in range(2):
optimizer.zero_grad()
y_hat = self._model(self._x)
loss = torch.nn.CrossEntropyLoss()(y_hat, self._y)
loss.backward()
optimizer.step()
for state_list in optimizer._per_group_state_lists:
self._assert_state_list_dtype(state_list, precision_config)
Expand Down
13 changes: 7 additions & 6 deletions distributed_shampoo/tests/shampoo_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ def __init__(
) -> None:
super().__init__()
self.linear_layers = nn.ModuleList(
[
nn.Linear(a, b, bias=bias)
for a, b in itertools.pairwise(model_linear_layers_dims)
]
nn.Linear(a, b, bias=bias)
for a, b in itertools.pairwise(model_linear_layers_dims)
)
if model_dead_layer_dims is not None:
self.useless: nn.Module = nn.Linear(*model_dead_layer_dims, bias=False)
self.dead_layers: nn.ModuleList = nn.ModuleList(
nn.Linear(a, b, bias=False)
for a, b in itertools.pairwise(model_dead_layer_dims)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
for linear_layer in self.linear_layers:
Expand All @@ -49,7 +50,7 @@ def construct_training_problem(
Args:
model_linear_layers_dims (tuple[int, ...]): The dimensions of the model linear layers.
model_dead_layer_dims (Optional[tuple[int, ...]]): The dimensions of the model dead layer. (Default: (10, 10))
model_dead_layer_dims (Optional[tuple[int, ...]]): The dimensions of the model dead linear layers. (Default: (10, 10))
device (Optional[torch.device]): The device to use. (Default: None)
bias (bool): Whether to use bias in the linear (non-dead) layers. (Default: False)
fill (float | tuple[float, ...]): The value(s) to fill the model parameters. If a tuple, each element should correspond to one layer. (Default: 0.0)
Expand Down
51 changes: 24 additions & 27 deletions distributed_shampoo/utils/gpu_tests/shampoo_ddp_distributor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,36 +165,33 @@ def test_losses(self) -> None:
device=torch.device("cuda"),
)

def test_empty_local_blocked_params(self) -> None:
# This mock is used to catch the number of calls to Shampoo's step(), which happened after __init__().
# If there is no blocked params, __init__() will raise and step() should not be called.
# Otherwise, step() will be called.
@mock.patch.object(DistributedShampoo, "step")
def test_empty_local_blocked_params(self, mock_step: mock.Mock) -> None:
self._init_distributed()

# The test setting is only rank 0 has params, so all other ranks have no parameters to work on.
has_blocked_params = dist.get_rank() == 0
with (
# This mock is used to catch the number of calls to Shampoo's step(), which happened after __init__().
# If there is no blocked params, __init__() will raise and step() should not be called.
# Otherwise, step() will be called.
mock.patch.object(DistributedShampoo, "step")
) as mock_step:
with (
contextlib.nullcontext()
if has_blocked_params
else self.assertRaisesRegex(
AssertionError,
re.escape("Some workers have no parameters to work on."),
)
):
ShampooDDPDistributorTest._train_model(
self._shampoo_optim_factory(distributed_config=DDPShampooConfig()),
device=torch.device("cuda"),
# Setting model_linear_layers_dims to (20, 1) creates an model with one linear layer with 20x1 weight.
# Because Shampoo's max_preconditioner_dim = 20, there will be only one block.
# In the case of two trainers per group, there will be one trainer has no params to work on.
model_linear_layers_dims=(20, 1),
model_dead_layer_dims=None,
)
contextlib.nullcontext()
if has_blocked_params
else self.assertRaisesRegex(
AssertionError, re.escape("Some workers have no parameters to work on.")
)
):
ShampooDDPDistributorTest._train_model(
self._shampoo_optim_factory(distributed_config=DDPShampooConfig()),
device=torch.device("cuda"),
# Setting model_linear_layers_dims to (20, 1) creates an model with one linear layer with 20x1 weight.
# Because Shampoo's max_preconditioner_dim = 20, there will be only one block.
# In the case of two trainers per group, there will be one trainer has no params to work on.
model_linear_layers_dims=(20, 1),
model_dead_layer_dims=None,
)

if has_blocked_params:
mock_step.assert_called()
else:
mock_step.assert_not_called()
if has_blocked_params:
mock_step.assert_called()
else:
mock_step.assert_not_called()
Loading

0 comments on commit b854998

Please sign in to comment.