-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
211 lines (195 loc) · 10.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import numpy as np
import os
import time
import skimage.measure
import torch
import torch.nn as nn
from torch import optim
import torch.utils.data as Data
from torch.utils.tensorboard import SummaryWriter
import torchvision
from model import FeatureExtractor
from data import generate_train_test_data_DCT, generate_train_test_data_NONE,
from config import args_config
from utils import get_para_log, get_model, ssim, get_time
from utils.mask_utils import get_mask
args = args_config()
os.environ["CUDA_VISIBLE_DEVICES"] = '{}'.format(args.use_cuda_num)
def main_train():
# =================================== BASIC CONFIGS =================================== #
print('[*] run basic configs ... ')
writer = SummaryWriter(os.path.join("runs/", args.model_log))
# write para into a text
para_data = os.path.join("runs/", args.model_log) + "/para_data.txt"
with open(para_data, "w") as file: # ”w"代表着每次运行都覆盖内容
file.write(get_para_log(args))
# ==================================== PREPARE DATA ==================================== #
print('[*] loading mask ... ')
mask = get_mask(mask_name=args.maskname,
mask_perc=args.maskperc, mask_path="data/mask")
print('[*] load data ... ')
if args.model == "Unet_conv" or args.model == "DQBCS":
[x_train, y_train, x_test, y_test] = generate_train_test_data_NONE(args.data_path, args.data_star_num,
args.data_end_num,
testselect=10, verbose=0)
print("[===] tips x_train == y_train")
else:
[x_train, y_train, x_test, y_test] = generate_train_test_data_DCT(args.data_path, args.data_star_num,
args.data_end_num, mask, testselect=10,
verbose=0)
x_train = torch.from_numpy(x_train[:]).float().unsqueeze(1)
y_train = torch.from_numpy(y_train[:]).float().unsqueeze(1)
x_test = torch.from_numpy(x_test[:]).float().unsqueeze(1)
y_test = torch.from_numpy(y_test[:]).float().unsqueeze(1)
if torch.cuda.is_available() and args.use_cuda:
x_train, y_train, x_test, y_test = x_train.cuda(
), y_train.cuda(), x_test.cuda(), y_test.cuda()
print('[*] ====> Running on GPU <==== [*]')
print("x_data shape is [{}],y_data shape is [{}]".format(
x_train.shape, y_train.shape))
train_loader = Data.DataLoader(dataset=Data.TensorDataset(x_train, y_train), batch_size=args.batch_size,
shuffle=True)
img_grid_y = torchvision.utils.make_grid(y_test, nrow=5)
img_grid_x = torchvision.utils.make_grid(x_test, nrow=5)
writer.add_image('img_test/ground', img_grid_y)
writer.add_image('img_test/input', img_grid_x)
# ==================================== DEFINE MODEL ==================================== #
print('[*] define model ... ')
device = torch.device('cuda' if torch.cuda.is_available() and args.use_cuda else 'cpu')
my_net_G = get_model(
model=args.model, n_channels=args.img_n_channels, n_classes=args.img_n_classes)
demo_input = torch.rand(1, 1, 256, 256)
writer.add_graph(my_net_G, input_to_model=demo_input)
my_net_G.to(device)
print('[*] Try resume from checkpoint')
if os.path.isdir('checkpoint'):
try:
checkpoint = torch.load('./checkpoint/' + args.model_checkpoint)
print('==> Load last checkpoint data')
my_net_G.load_state_dict(checkpoint['state']) # 从字典中依次读取
start_epoch = checkpoint['epoch']
best_loss = checkpoint['best_loss']
print("==> Loaded checkpoint '{}' (trained for {} epochs,the best loss is {:.6f})".format(
args.model_checkpoint,
checkpoint[
'epoch'],
best_loss))
except FileNotFoundError:
start_epoch = 0
best_loss = 10
print('==> Can\'t found ' + args.model_checkpoint)
else:
start_epoch = 0
best_loss = np.inf
print('==> Start from scratch')
if args.model_show:
# summary(my_net_G, input_size=(1, args.img_size_x, args.img_size_y))
pass
# ==================================== DEFINE TRAIN OPTS ==================================== #
print('[*] define training options ... ')
# optimize all net parameters
optimizer_G = optim.Adam(my_net_G.parameters(), lr=args.lr)
# ==================================== DEFINE LOSS ==================================== #
print('[*] define loss functions ... ')
loss_mse = nn.MSELoss()
vgg_Feature_model = FeatureExtractor().to(device)
# ==================================== TRAINING ==================================== #
print('[*] start training ... ')
start_time = time.time()
for epoch in range(start_epoch, args.epochs):
for step, (train_x, train_y) in enumerate(
train_loader): # gives batch data, normalize x when iterate train_loader
iter_num = (epoch) * len(train_loader) * \
args.batch_size + step * args.batch_size
optimizer_G.zero_grad() # clear gradients for this training step
# Generate a batch of images
if args.model == "DQBCS":
measure, g_img = my_net_G(train_x) # get output
else:
g_img = my_net_G(train_x)
# Loss measures generator's ability to fool the discriminator
loss_g_mse = loss_mse(g_img, train_y)
if args.loss_mse_only == True:
g_loss = loss_g_mse
else:
g_loss = args.alpha * loss_g_mse
if args.loss_ssim == True:
loss_g_ssim = 1 - ssim(g_img, train_y)
g_loss += args.gamma * loss_g_ssim
if args.loss_vgg == True:
loss_g_vgg = loss_mse(vgg_Feature_model(g_img), vgg_Feature_model(train_y))
g_loss += args.beta * loss_g_vgg
g_loss.backward() # backpropagation, compute gradients
optimizer_G.step() # apply gradients
if step % 2 == 0:
with torch.no_grad():
if args.model == "DQBCS":
measure, test_output = my_net_G(x_test)
else:
test_output = my_net_G(x_test)
psnr_num = skimage.measure.compare_psnr(
y_test.cpu().data.numpy(), test_output.cpu().data.numpy())
mse_num = skimage.measure.compare_mse(
y_test.cpu().data.numpy() * 255, test_output.cpu().data.numpy() * 255)
log = "[**] Epoch [{:02d}/{:02d}] Step [{:04d}/{:04d}]".format(epoch + 1, args.epochs,
(step + 1) *
args.batch_size,
len(train_loader) * args.batch_size)
# TensorboardX log and print in command line
# Train total loss
writer.add_scalar(
"train_loss", g_loss.cpu().data.numpy(), iter_num)
log += " || TRAIN [loss: {:.6f}]".format(
g_loss.cpu().data.numpy())
# Train detail loss
writer.add_scalar("train/MSE_loss",
loss_g_mse.cpu().data.numpy(), iter_num)
log += " [MSE: {:.6f}]".format(loss_g_mse.cpu().data.numpy())
if args.loss_mse_only == False and args.loss_ssim == True:
writer.add_scalar("train/SSIM_loss",
loss_g_ssim.cpu().data.numpy(), iter_num)
log += " [SSIM: {:.4f}]".format(
loss_g_ssim.cpu().data.numpy())
if args.loss_mse_only == False and args.loss_vgg == True:
writer.add_scalar("train/VGG_loss",
loss_g_vgg.cpu().data.numpy(), iter_num)
log += " [VGG: {:.4f}]".format(
loss_g_vgg.cpu().data.numpy())
# Test loss
writer.add_scalar("test/test_MSE", mse_num, iter_num)
log += " || TEST [MSE: {:.4f}]".format(mse_num)
writer.add_scalar("test/test_PSNR", psnr_num, iter_num)
log += " [PSNR: {:.4f}]".format(psnr_num)
# time caculate
use_time = time.time() - start_time
ave_time = use_time / (
(epoch - start_epoch) * len(train_loader) * args.batch_size + (step + 1) * args.batch_size)
resttime = ave_time * ((args.epochs - epoch) * len(train_loader) * args.batch_size + len(
train_loader) * args.batch_size - (step + 1) * args.batch_size)
log += " || Use time :{} Rest time :{}".format(
get_time(use_time), get_time(resttime))
print(log)
if g_loss.cpu().data.numpy() < best_loss and step % 20 == 0 and args.model_save:
# 保存模型示例代码
best_loss = g_loss.cpu().data.numpy()
state = {
'state': my_net_G.state_dict(),
'epoch': epoch, # 将epoch一并保存
'best_loss': g_loss.cpu().data.numpy()
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/' + args.model_checkpoint)
print(
"[*] Save checkpoints SUCCESS! || loss= {:.5f} epoch= {:03d}".format(best_loss, epoch + 1))
# show test every epoch
img_grid = torchvision.utils.make_grid(test_output, nrow=5)
writer.add_image('img_epoch', img_grid, global_step=epoch)
writer.close()
with open(para_data, "a") as file:
log = "[*] Time is " + time.asctime(time.localtime(time.time())) + "\n"
log += "=" * 40 + "\n"
file.write(log)
print("[ ] train Done !")
if __name__ == '__main__':
main_train()