Skip to content

Commit 5df18c7

Browse files
runamefacebook-github-bot
authored andcommitted
Add support for mypy type checking (#37)
Summary: Pull Request resolved: facebookresearch/optimizers#37 Reviewed By: anana10c Differential Revision: D65546537 Pulled By: tsunghsienlee fbshipit-source-id: 0eaf18d0c732101d5634c693e45fd5511e349fa5
1 parent 397ad17 commit 5df18c7

25 files changed

+152
-119
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
name: type-check-mypy
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
mypy:
7+
runs-on: ubuntu-latest
8+
steps:
9+
- uses: actions/checkout@v4
10+
- name: Set up and update uv.
11+
run: |
12+
curl -LsSf https://astral.sh/uv/install.sh | sh
13+
uv self update
14+
- name: Install Python.
15+
run: uv python install 3.10
16+
- name: Create venv and install the package.
17+
run: |
18+
uv venv && source .venv/bin/activate
19+
uv pip install -e ".[dev]"
20+
- name: Run type checking with mypy.
21+
run: |
22+
source .venv/bin/activate
23+
make type-check

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ We actively welcome your pull requests for existing optimizers.
1212
2. If you've added code that should be tested, add tests.
1313
3. If you've changed APIs, update the documentation.
1414
4. Ensure the test suite passes. To run the subset of the tests that can be run on CPU use `make test`; to run the tests for a single GPU use `make test-gpu` and to run the subset of tests that require 2-4 GPUs use `make test-multi-gpu`.
15-
5. Make sure your code lints. You can use `make lint` and `make format` to automatically lint and format the code where possible.
15+
5. Make sure your code lints. You can use `make lint` and `make format` to automatically lint and format the code where possible. Use `make type-check` for type checking.
1616
6. If you haven't already, complete the Contributor License Agreement ("CLA").
1717

1818
## Contributor License Agreement ("CLA")

distributed_shampoo/distributed_shampoo.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -502,18 +502,18 @@ def _instantiate_distributor(
502502
if distributed_config is None:
503503
distributor = Distributor
504504
elif type(distributed_config) is DDPShampooConfig:
505-
distributor = partial(DDPDistributor, distributed_config=distributed_config)
505+
distributor = partial(DDPDistributor, distributed_config=distributed_config) # type: ignore[assignment]
506506
elif type(distributed_config) is FSDPShampooConfig:
507507
distributor = partial(
508508
FSDPDistributor, distributed_config=distributed_config
509-
)
509+
) # type: ignore[assignment]
510510
elif type(distributed_config) is FullyShardShampooConfig:
511511
distributor = FullyShardDistributor
512512
elif type(distributed_config) is HSDPShampooConfig:
513513
distributor = partial(
514514
HSDPDistributor,
515515
distributed_config=distributed_config,
516-
)
516+
) # type: ignore[assignment]
517517
else:
518518
raise NotImplementedError(f"{distributed_config=} not supported!")
519519

@@ -808,10 +808,7 @@ def _compute_and_log_root_inverse_residuals(
808808
Uses infinity norm to evaluate residuals and errors.
809809
"""
810810

811-
# Accumulate relative errors/residuals
812-
relative_errors = []
813-
relative_residuals = []
814-
811+
# Compute relative errors/residuals for each group.
815812
for (group_index, group), state_lists in zip(
816813
enumerate(self.param_groups), self._per_group_state_lists, strict=True
817814
):
@@ -827,12 +824,12 @@ def _compute_and_log_root_inverse_residuals(
827824
)
828825
continue
829826

830-
relative_errors, relative_residuals = state_lists[
831-
SHAMPOO_PRECONDITIONER_LIST
832-
].compute_root_inverse_residuals()
833-
834-
relative_errors = torch.stack(relative_errors)
835-
relative_residuals = torch.stack(relative_residuals)
827+
relative_errors, relative_residuals = map(
828+
torch.stack,
829+
state_lists[
830+
SHAMPOO_PRECONDITIONER_LIST
831+
].compute_root_inverse_residuals(),
832+
)
836833

837834
quantiles = torch.as_tensor(
838835
[0, 0.25, 0.5, 0.75, 1],
@@ -1141,7 +1138,7 @@ def _per_group_step_impl(
11411138
)
11421139

11431140
@torch.no_grad()
1144-
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
1141+
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override]
11451142
"""Performs a single optimization step.
11461143
11471144
Args:

distributed_shampoo/examples/default_cifar10_example.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def train_default_model(
4040
loss_function: nn.Module,
4141
data_loader: torch.utils.data.DataLoader,
4242
optimizer: torch.optim.Optimizer,
43-
device: str,
43+
device: torch.device,
4444
epochs: int = 1,
4545
window_size: int = 100,
4646
) -> Tuple[float, float, int]:
@@ -62,7 +62,11 @@ def train_default_model(
6262
metrics.update(loss)
6363
metrics.log()
6464

65-
return metrics._lifetime_loss, metrics._window_loss, metrics._iteration
65+
return (
66+
metrics._lifetime_loss.item(),
67+
metrics._window_loss.item(),
68+
metrics._iteration,
69+
)
6670

6771

6872
if __name__ == "__main__":
@@ -97,7 +101,7 @@ def train_default_model(
97101
set_seed(args.seed)
98102

99103
# check cuda availability and set device
100-
device = "cuda" if torch.cuda.is_available() else "cpu"
104+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101105

102106
# instantiate model and loss function
103107
model, loss_function = get_model_and_loss_fn(device)

distributed_shampoo/examples/fully_shard_cifar10_example.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import logging
1111
import os
12-
from typing import Optional, Tuple, Union
12+
from typing import Optional, Tuple
1313

1414
import torch
1515
import torch.distributed as dist
@@ -53,7 +53,7 @@ def train_fully_shard_model(
5353
sampler: torch.utils.data.Sampler,
5454
data_loader: torch.utils.data.DataLoader,
5555
optimizer: torch.optim.Optimizer,
56-
device: Union[str, torch.device],
56+
device: torch.device,
5757
epochs: int = 1,
5858
window_size: int = 100,
5959
use_distributed_checkpoint: bool = False,
@@ -71,7 +71,7 @@ def train_fully_shard_model(
7171
# main training loop
7272
for epoch in range(epochs):
7373
metrics._epoch = epoch
74-
sampler.set_epoch(epoch)
74+
sampler.set_epoch(epoch) # type: ignore[attr-defined]
7575

7676
for inputs, labels in data_loader:
7777
inputs, labels = inputs.to(device), labels.to(device)
@@ -89,6 +89,7 @@ def train_fully_shard_model(
8989

9090
# checkpoint optimizer and model using distributed checkpointing solution
9191
if use_distributed_checkpoint and isinstance(optimizer, DistributedShampoo):
92+
assert checkpoint_dir is not None
9293
state_dict = {
9394
"model": model.state_dict(),
9495
"optim": optimizer.distributed_state_dict(
@@ -100,7 +101,11 @@ def train_fully_shard_model(
100101
storage_writer=dist_checkpoint.FileSystemWriter(checkpoint_dir),
101102
)
102103

103-
return metrics._lifetime_loss, metrics._window_loss, metrics._iteration
104+
return (
105+
metrics._lifetime_loss.item(),
106+
metrics._window_loss.item(),
107+
metrics._iteration,
108+
)
104109

105110

106111
def create_model_and_optimizer_and_loss_fn(args, device):

distributed_shampoo/examples/trainer_utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import logging
1313
import random
1414
from abc import ABC, abstractmethod
15-
from typing import Optional, Tuple, Union
15+
from typing import Optional, Tuple
1616

1717
import numpy as np
1818

@@ -39,7 +39,7 @@
3939
PreconditionerComputationConfig,
4040
)
4141
from torch import nn
42-
from torchvision import datasets, transforms
42+
from torchvision import datasets, transforms # type: ignore[import-untyped]
4343

4444
logger = logging.getLogger(__name__)
4545

@@ -79,10 +79,10 @@ class PreconditionerComputationType(enum.Enum):
7979
###### ARGPARSER ######
8080
def enum_type_parse(s: str, enum_type: enum.Enum):
8181
try:
82-
return enum_type[s]
82+
return enum_type[s] # type: ignore[index]
8383
except KeyError:
8484
raise argparse.ArgumentTypeError(
85-
"Use one of {}".format(", ".join([t.name for t in enum_type]))
85+
"Use one of {}".format(", ".join([t.name for t in enum_type])) # type: ignore[attr-defined]
8686
)
8787

8888

@@ -349,7 +349,7 @@ def log(self): ...
349349
def reset(self): ...
350350

351351
@abstractmethod
352-
def update(self): ...
352+
def update(self, loss: torch.Tensor): ...
353353

354354

355355
class LossMetrics(Metrics):
@@ -365,7 +365,7 @@ def __init__(
365365
self._device = device
366366
self._epoch = 0
367367
self._iteration = 0
368-
self._window_losses = []
368+
self._window_losses: list[torch.Tensor] = []
369369
self._window_loss = torch.tensor(0.0, device=device)
370370
self._accumulated_loss = torch.tensor(0.0, device=device)
371371
self._lifetime_loss = torch.tensor(0.0, device=device)
@@ -461,15 +461,15 @@ def instantiate_optimizer(
461461
betas=betas,
462462
eps=epsilon,
463463
weight_decay=weight_decay,
464-
)
464+
) # type: ignore[assignment]
465465
else:
466466
optimizer = torch.optim.Adam(
467467
model.parameters(),
468468
lr=lr,
469469
betas=betas,
470470
eps=epsilon,
471471
weight_decay=weight_decay,
472-
)
472+
) # type: ignore[assignment]
473473
elif optimizer_type == OptimizerType.DISTRIBUTED_SHAMPOO:
474474
optimizer = DistributedShampoo(
475475
model.parameters(),
@@ -500,7 +500,7 @@ def instantiate_optimizer(
500500
preconditioner_computation_config=instantiate_preconditioner_computation_config(
501501
preconditioner_computation_type
502502
),
503-
)
503+
) # type: ignore[assignment]
504504
else:
505505
raise ValueError(f"Invalid OptimizerType {optimizer_type}!")
506506

@@ -576,8 +576,10 @@ def get_data_loader_and_sampler(
576576
dataset = datasets.CIFAR10(
577577
data_path, train=True, download=True, transform=transform
578578
)
579-
sampler = torch.utils.data.distributed.DistributedSampler(
580-
dataset, num_replicas=world_size, rank=rank, shuffle=True
579+
sampler: torch.utils.data.distributed.DistributedSampler = (
580+
torch.utils.data.distributed.DistributedSampler(
581+
dataset, num_replicas=world_size, rank=rank, shuffle=True
582+
)
581583
)
582584
return (
583585
torch.utils.data.DataLoader(
@@ -636,7 +638,7 @@ def train_model(
636638
sampler: torch.utils.data.Sampler,
637639
data_loader: torch.utils.data.DataLoader,
638640
optimizer: torch.optim.Optimizer,
639-
device: Union[str, torch.device],
641+
device: torch.device,
640642
epochs: int = 1,
641643
window_size: int = 100,
642644
local_rank: int = 0,
@@ -647,7 +649,7 @@ def train_model(
647649
# main training loop
648650
for epoch in range(epochs):
649651
metrics._epoch = epoch
650-
sampler.set_epoch(epoch)
652+
sampler.set_epoch(epoch) # type: ignore[attr-defined]
651653

652654
for inputs, labels in data_loader:
653655
inputs, labels = inputs.to(device), labels.to(device)

distributed_shampoo/gpu_tests/shampoo_pt2_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ def _test_shampoo_baseline_and_pt2(
8686

8787
@staticmethod
8888
def _shampoo_optim_factory(
89-
shampoo_pt2_compile_config: ShampooPT2CompileConfig,
89+
shampoo_pt2_compile_config: ShampooPT2CompileConfig | None,
9090
precondition_frequency: int,
9191
start_preconditioning_step: int,
9292
weight_decay: float,
9393
betas: tuple[float, float],
94-
grafting_config: GraftingConfig,
94+
grafting_config: GraftingConfig | None,
9595
) -> Callable[[ParamsT], torch.optim.Optimizer]:
9696
return lambda parameters: DistributedShampoo(
9797
parameters,

distributed_shampoo/shampoo_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class PrecisionConfig:
130130

131131
@dataclass
132132
class AbstractDataclass:
133-
def __new__(cls, *args: Any, **kwargs: Any) -> Optional["AbstractDataclass"]:
133+
def __new__(cls, *args: Any, **kwargs: Any) -> "AbstractDataclass":
134134
if cls == AbstractDataclass or cls.__bases__[0] == AbstractDataclass:
135135
raise TypeError(f"Cannot instantiate abstract class: {cls.__name__}.")
136136
return super().__new__(cls)

distributed_shampoo/utils/gpu_tests/shampoo_dist_utils_test.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def _verify_deivce_mesh(self, device_mesh: DeviceMesh) -> None:
4242
(shard_mesh.get_group(), replicate_mesh.get_group()),
4343
)
4444

45-
# type: ignore
46-
@with_comms
45+
@with_comms # type: ignore
4746
def test_get_device_mesh(self) -> None:
4847
mesh = tuple(
4948
map(
@@ -57,8 +56,7 @@ def test_get_device_mesh(self) -> None:
5756

5857
self._verify_deivce_mesh(
5958
device_mesh=get_device_mesh(
60-
# type: ignore
61-
device_type=self.device_type,
59+
device_type=self.device_type, # type: ignore
6260
mesh=mesh,
6361
mesh_dim_names=("replicate", "shard"),
6462
)
@@ -72,7 +70,7 @@ def test_get_device_mesh(self) -> None:
7270
"__init__",
7371
) as mock_device_mesh_init:
7472
device_mesh = get_device_mesh(
75-
device_type=self.device_type,
73+
device_type=self.device_type, # type: ignore[attr-defined]
7674
mesh=mesh,
7775
mesh_dim_names=("replicate", "shard"),
7876
)

distributed_shampoo/utils/gpu_tests/shampoo_fsdp_utils_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_compile_fsdp_parameter_metadata_with_no_flat_param(self) -> None:
102102
fsdp_model = FSDP(model, use_orig_params=True, ignored_states=params)
103103
actual_fsdp_parameter_metadata = compile_fsdp_parameter_metadata(fsdp_model)
104104

105-
expected_fsdp_parameter_metadata = {}
105+
expected_fsdp_parameter_metadata = {} # type: ignore[var-annotated]
106106

107107
self.assertEqual(
108108
actual_fsdp_parameter_metadata, expected_fsdp_parameter_metadata
@@ -117,7 +117,7 @@ def world_size(self) -> int:
117117

118118
@skip_if_lt_x_gpu(4)
119119
def test_parse_fsdp_params(self) -> None:
120-
HYBRID_SHARDING_STRATEGIES_TO_EXPECTED_KEYS = {
120+
HYBRID_SHARDING_STRATEGIES_TO_EXPECTED_KEYS = { # type: ignore[var-annotated]
121121
ShardingStrategy.HYBRID_SHARD: (
122122
[],
123123
[
@@ -135,7 +135,7 @@ def test_parse_fsdp_params(self) -> None:
135135
["1.weight"],
136136
),
137137
}
138-
SHARDING_STRATEGIES_TO_EXPECTED_KEYS = {
138+
SHARDING_STRATEGIES_TO_EXPECTED_KEYS = { # type: ignore[var-annotated]
139139
ShardingStrategy.NO_SHARD: (
140140
[],
141141
[],

distributed_shampoo/utils/gpu_tests/shampoo_fully_shard_distributor_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,16 @@ def _train_model(
108108
if uses_fully_shard:
109109
# When FullyShard is used, model parameters are DTensors. We obtain the full value of
110110
# parameters from DTensors.
111-
params = []
111+
params_list = []
112112
for param in model.parameters():
113113
# Need this assertion to get pass type-checking test.
114114
assert isinstance(param, DTensor)
115-
params.append(param.full_tensor().view(-1).detach().cpu())
115+
params_list.append(param.full_tensor().view(-1).detach().cpu())
116116
else:
117-
params = [param.view(-1).detach().cpu() for param in model.parameters()]
118-
return params, objective.detach().cpu()
117+
params_list = [
118+
param.view(-1).detach().cpu() for param in model.parameters()
119+
]
120+
return params_list, objective.detach().cpu()
119121

120122
@staticmethod
121123
def _test_two_configs(

0 commit comments

Comments
 (0)