-
Notifications
You must be signed in to change notification settings - Fork 58
/
run_sr.lua
172 lines (143 loc) · 4.78 KB
/
run_sr.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
require 'torch'
require 'nn'
require 'optim'
require 'image'
require 'nngraph'
require 'weight-init'
local G=require 'adversarial_G.lua'
local D=require 'adversarial_D.lua'
util = paths.dofile('util.lua')
opt = {
dataset = 'folder',
lr = 0.001,
beta1 = 0.9,
batchSize=32,
niter=250,
loadSize=96,
ntrain = math.huge,
name='super_resolution',
gpu=1,
nThreads = 4,
t_folder='',
model_folder='',
}
torch.manualSeed(1)
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
print(opt)
local DataLoader = paths.dofile('data/data.lua')
data = DataLoader.new(opt.nThreads, opt)
print("Dataset: " .. opt.dataset, " Size: ", data:size())
local real_label=1
local fake_label=0
local G=require 'adversarial_G.lua'
local modelG = require('weight-init')(G(), 'kaiming')
local D=require 'adversarial_D.lua'
local modelD = require('weight-init')(D(opt.loadSize),'kaiming')
local criterion = nn.BCECriterion()
local criterion_mse = nn.MSECriterion()
optimStateG = {
learningRate = opt.lr,
beta1 = opt.beta1,
weightDecay=0.0001,
}
optimStateD = {
learningRate = opt.lr*0.1,
beta1 = opt.beta1,
weightDecay=0.0001,
}
local input = torch.Tensor(opt.batchSize, 1, opt.loadSize/4, opt.loadSize/4)
local real_uncropped = torch.Tensor(opt.batchSize,1,opt.loadSize,opt.loadSize)
local errD, errG
local epoch_tm = torch.Timer()
local tm = torch.Timer()
local test = torch.Tensor(1, opt.loadSize, opt.loadSize)
local test2 = torch.Tensor(1, opt.loadSize/4, opt.loadSize/4)
local label = torch.Tensor(opt.batchSize)
if opt.gpu > 0 then
require 'cunn'
print('cunn used')
cutorch.setDevice(opt.gpu)
input = input:cuda();
modelG=modelG:cuda()
modelD=modelD:cuda()
criterion:cuda()
criterion_mse:cuda();
label=label:cuda()
end
local parametersG, gradientsG = modelG:getParameters()
local parametersD,gradientsD= modelD:getParameters()
local fDx=function(x)
if x ~= parametersD then
parametersD:copy(x)
end
modelD:zeroGradParameters()
real_uncropped,input= data:getBatch()
real_uncropped=real_uncropped:cuda()
label:fill(real_label)
local output=modelD:forward(real_uncropped)
local errD_real=criterion:forward(output,label)
local df_do = criterion:backward(output, label)
modelD:backward(real_uncropped,df_do)
input=input:cuda()
fake = modelG:forward(input)
label:fill(fake_label)
local output=modelD:forward(fake)
local errD_fake=criterion:forward(output,label)
local df_do = criterion:backward(output, label)
modelD:backward(fake, df_do)
errD = errD_real + errD_fake
return errD, gradientsD
end
local fGx=function(x)
modelG:zeroGradParameters()
label:fill(real_label)
local output=modelD.output
input=input:cuda()
errG = criterion:forward(output, label)
errG_mse=criterion_mse:forward(fake,real_uncropped)
local df_do = criterion:backward(output, label)
local df_do_mse=criterion_mse:backward(fake,real_uncropped)
local df_dg=modelD:updateGradInput(fake,df_do)
modelG:backward(input,0.001*df_dg+0.999*df_do_mse)
err_all=0.001*errG+0.999*errG_mse
return err_all,gradientsG
end
local counter = 0
for epoch = 1, opt.niter do
epoch_tm:reset()
for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do
tm:reset()
optim.adam(fDx, parametersD, optimStateD)
optim.adam(fGx, parametersG, optimStateG)
counter = counter + 1
print('count: '..counter)
if counter % 10 == 0 then
test:copy(real_uncropped[1])
local real_rgb=test
image.save(opt.name..counter..'_real.png',real_rgb)
test2:copy(input[1])
image.save(opt.name..counter..'_input.png',test2)
fake[fake:gt(1)]=1
fake[fake:lt(0)]=0
test:copy(fake[1])
image.save(opt.name..counter..'_fake.png',test)
end
if ((i-1) / opt.batchSize) % 1 == 0 then
print(('Epoch: [%d][%8d / %8d]\t Time: %.3f '
.. ' Err_G: %.4f Err_D: %.4f'):format(
epoch, ((i-1) / opt.batchSize),
math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize),
tm:time().real,
err_all and err_all or -1, errD and errD or -1))
end
end
--paths.mkdir('/media/DATA/MODELS/SUPER_RES/checkpoints')
parametersD, gradientsD= nil, nil
parametersG, gradientsG = nil, nil
util.save(opt.model_folder .. opt.name .. '_adversarial_G_' .. epoch, modelG, opt.gpu)
util.save(opt.model_folder .. opt.name .. '_adversarial_D_' .. epoch, modelD, opt.gpu)
parametersG, gradientsG = modelG:getParameters()
parametersD, gradientsD=modelD:getParameters()
print(('End of epoch %d / %d \t Time Taken: %.3f'):format(
epoch, opt.niter, epoch_tm:time().real))
end