-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
144 lines (122 loc) · 4.94 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import sys
import random
import os
import torch
import torchvision
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from torchvision import transforms
sys.path.append(".")
from utils import get_gen_loss, get_disc_loss
def train(
dataset,
device,
gen,
gen_opt,
disc,
disc_opt,
adv_criterion,
lambda_recon,
recon_criterion,
n_epochs,
display_step,
batch_size,
model_name,
save_model=True,
cur_step=0
):
writer_real = SummaryWriter(f'logs/logs_{model_name}/real')
writer_fake = SummaryWriter(f'logs/logs_{model_name}/fake')
writer_condition = SummaryWriter(f'logs/logs_{model_name}/condition')
try:
os.mkdir('saved_model_paths')
except FileExistsError:
pass
train_dataset, val_dataset = random_split(
dataset,
[int(len(dataset) * 0.95), len(dataset) - int(len(dataset) * 0.95)]
)
dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
# applies a horizontal flip on an image
flip = torch.jit.script(torch.nn.Sequential(transforms.RandomHorizontalFlip(p=1)))
mean_generator_loss = 0
mean_discriminator_loss = 0
for epoch in range(n_epochs):
for image, _ in tqdm(dataloader, file=sys.stdout):
# Input Setup
condition = image[:, :, :, 256:].to(device)
real = image[:, :, :, :256].to(device)
# 50% chance of flipping the the images horizontally. Either both must be flipped or both must be normal.
if random.random() > 0.5:
real = flip(real)
condition = flip(condition)
# Update discriminator
disc_opt.zero_grad()
disc_loss = get_disc_loss(gen, disc, real, condition, adv_criterion)
disc_loss.backward(retain_graph=True)
disc_opt.step()
# Update generator
gen_opt.zero_grad()
gen_loss = get_gen_loss(
gen,
disc,
real,
condition,
adv_criterion,
recon_criterion,
lambda_recon,
)
gen_loss.backward()
gen_opt.step()
# Keep track of the average loss
mean_discriminator_loss += disc_loss.item() / display_step
mean_generator_loss += gen_loss.item() / display_step
# Visualization code
if cur_step % display_step == 0:
print()
mean_val_loss = 0
val_condition = None
val_real = None
for val_image, _ in tqdm(val_dataloader, file=sys.stdout, position=0, leave=True):
val_condition = val_image[:, :, :, 256:].to(device)
val_real = val_image[:, :, :, :256].to(device)
with torch.no_grad():
gen_loss = get_gen_loss(
gen,
disc,
val_real,
val_condition,
adv_criterion,
recon_criterion,
lambda_recon,
)
mean_val_loss += gen_loss.item() / len(val_dataloader)
print(
f"Epoch {epoch}: Step {cur_step}: "
f"Generator loss: {mean_generator_loss}, "
f"Generator Val Loss: {mean_val_loss}, "
f"Discriminator loss: {mean_discriminator_loss}, "
)
# Log with tensorboard
with torch.no_grad():
fake = gen(val_condition)
img_grid_real = torchvision.utils.make_grid(val_real, normalize=True)
img_grid_condition = torchvision.utils.make_grid(val_condition, normalize=True)
img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
writer_real.add_image("Real", img_grid_real, global_step=cur_step)
writer_fake.add_image("Fake", img_grid_fake, global_step=cur_step)
writer_condition.add_image("Condition", img_grid_condition, global_step=cur_step)
mean_generator_loss = 0
mean_discriminator_loss = 0
if save_model and cur_step % 2000 == 0:
torch.save(
{'gen': gen.state_dict(),
'gen_opt': gen_opt.state_dict(),
'disc': disc.state_dict(),
'disc_opt': disc_opt.state_dict()
},
f"saved_model_paths/{model_name}_{cur_step}.pth"
)
cur_step += 1