-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_vqvae.py
93 lines (79 loc) · 3.01 KB
/
train_vqvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from data.load_data import get_data_loader
from data.load_cityscapes import load_cityscapes
from data.load_yosemite import get_yosemite_loader
from tqdm import tqdm
from statistics import mean
# from modules.vq_vae import VQVAE
from modules.vqvae import VQVAE
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
#####
BATCH_SIZE=32
IMG_SIZE=256
EPOCHS=8
BETA = 0.25
lr = 1e-3
device = "cuda"
training=True
dataset="yosemite"
#####
if dataset == "cityscapes":
print("RUNNING CITYSCAPES")
data_loader = load_cityscapes(
BATCH_SIZE, 128, 256,
real_path="/home/ulrik/datasets/cityscapes/full_dataset_kaggle/train/image/",
seg_path="/home/ulrik/datasets/cityscapes/full_dataset_kaggle/train/label/"
)
elif dataset == "yosemite":
print("RUNNING YOSEMITE")
data_loader = get_yosemite_loader(
32, 256, 256,
path_A="/home/ulrik/datasets/yosemite_translation/trainA",
path_B="/home/ulrik/datasets/yosemite_translation/trainB",
split_domains=False
)
else:
data_loader = get_data_loader("", BATCH_SIZE, IMG_SIZE)
def perceptual_loss_function(input_image, generated_image):
"""
Perceptual loss - computes the distance between the input image and the generated image using a pretrained VGG16 model
"""
input_image_adjusted = (input_image + 1) / 2
generated_image_adjusted = (generated_image + 1) / 2
perceptual_loss = learned_perceptual_image_patch_similarity(generated_image_adjusted, input_image_adjusted, normalize=True)
return perceptual_loss
#####
model = VQVAE()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
cross_entropy_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
if training:
for epoch in range(EPOCHS):
losses = []
for batch in tqdm(data_loader):
### forward real modality ###
optimizer.zero_grad()
real_x = batch.to(device)
x_hat, z_e, z_q = model(real_x, enc_modality="real", dec_modality="real")
# real_reconstruction_loss = torch.mean((real_x - x_hat) ** 2)
real_reconstruction_loss = perceptual_loss_function(real_x, x_hat)
real_vq_loss = torch.mean((z_e.detach() - z_q) ** 2)
real_commitment_loss = torch.mean((z_e - z_q.detach()) ** 2)
loss = real_reconstruction_loss + real_vq_loss + BETA*real_commitment_loss
losses.append(loss.item())
loss.backward()
optimizer.step()
print(f"Epoch: {epoch}, loss: {mean(losses)}")
#save weights
torch.save({
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}, "weights/vqvae_yosemite.pth")
print("Saved weights!")
else:
#load weights
checkpoint = torch.load("weights/vqvae_yosemite.pth")
model.load_state_dict(checkpoint["model_state_dict"])