-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
72 lines (59 loc) · 2.08 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.utils
import torchvision
import torch.distributed as dist
from nice import Nice
from trainer import Trainer
from torchvision.transforms import v2 as T
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.data.distributed import DistributedSampler
DATALOADER_ARGS = {
"batch_size": 1024,
"num_workers": 4,
"pin_memory": True,
}
def build_dataloader(dataset: Dataset):
return DataLoader(dataset,
sampler=DistributedSampler(dataset),
**DATALOADER_ARGS)
def main():
dist.init_process_group("nccl")
rank = dist.get_rank()
print(f"Start running basic DDP example on rank {rank}.")
transform = torchvision.transforms.Compose([
T.ToImage(),
T.ToDtype(torch.float32, scale = True),
T.Lambda(lambda x : x.reshape(-1)),
])
test_dataset = torchvision.datasets.MNIST(
root="./data",
download=True,
transform=transform
)
train_dataset = torchvision.datasets.MNIST(
root="./data",
train=True,
download=True,
transform=transform
)
dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])
train_dataset, val_dataset = random_split(dataset, [0.85, 0.15])
train_dataloader = build_dataloader(train_dataset)
val_dataloader = build_dataloader(val_dataset)
model = Nice(784, device=rank).to(rank)
model = DDP(model)
optimizer = torch.optim.RMSprop(model.parameters(),
lr=1e-3,
momentum=1e-8,
weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.995)
trainer = Trainer(model,
[train_dataloader, val_dataloader],
optimizer,
scheduler,
rank)
trainer.run()
dist.destroy_process_group()
if __name__ == "__main__":
main()