-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
75 lines (66 loc) · 2.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import transforms
from tqdm import tqdm
from model import VICReg
from loss import variance, invariance, covariance
class Augmentation:
"""
Wrapper around a PyTorch transform, outputting two different augmentations
for a single input. Applying this when loading a dataset ensures that a
dataloader will provide two augmentations for each sample in a batch.
"""
augment = transforms.Compose([
transforms.ToTensor(),
transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
transforms.RandomGrayscale(0.2),
transforms.RandomApply([transforms.GaussianBlur(3)], p=0.5),
transforms.RandomSolarize(0.5, p=0.2),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
def __call__(self, x):
return self.augment(x), self.augment(x)
# define model and move to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder_dim, projector_dim = 512, 1024
model = VICReg(encoder_dim, projector_dim).to(device)
# prepare data, optimizer, and training hyperparams
num_epochs, batch_size = 500, 256
data = CIFAR10(root=".", train=True, download=True, transform=Augmentation())
dataloader = DataLoader(data, batch_size, shuffle=True, num_workers=2)
opt = Adam(model.parameters(), lr=2e-4, weight_decay=1e-6)
progress = tqdm(range(num_epochs))
# load from checkpoint if it exists
if os.path.exists("checkpoint.pt"):
print("Loading from checkpoint...")
cp = torch.load("checkpoint.pt")
model.load_state_dict(cp["model_state_dict"])
opt.load_state_dict(cp["optimizer_state_dict"])
progress = tqdm(range(cp["epoch"], num_epochs))
# train the model and regularly save to disk
for epoch in progress:
for images, _ in dataloader:
x1, x2 = [x.to(device) for x in images]
z1, z2 = model(x1, x2)
la, mu, nu = 25, 25, 1
var1, var2 = variance(z1), variance(z2)
inv = invariance(z1, z2)
cov1, cov2 = covariance(z1), covariance(z2)
loss = la*inv + mu*(var1 + var2) + nu*(cov1 + cov2)
opt.zero_grad()
loss.backward()
opt.step()
progress.set_description(f"Loss: {loss.item()}")
if epoch % 10 == 0 or epoch == num_epochs - 1:
torch.save({
"epoch": epoch + 1,
"encoder_dim": encoder_dim,
"projector_dim": projector_dim,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict()
},"checkpoint.pt")