forked from princeton-vl/RAFT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
289 lines (217 loc) · 10.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
from __future__ import print_function, division
import sys
sys.path.append('core')
import argparse
import os
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from raft import RAFT
import evaluate
import datasets
from torch.utils.tensorboard import SummaryWriter
try:
from torch.cuda.amp import GradScaler
except:
# dummy GradScaler for PyTorch < 1.6
class GradScaler:
def __init__(self):
pass
def scale(self, loss):
return loss
def unscale_(self, optimizer):
pass
def step(self, optimizer):
optimizer.step()
def update(self):
pass
# exclude extremly large displacements
MAX_FLOW = 400
SUM_FREQ = 100
VAL_FREQ = 5000
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
""" Loss function defined over sequence of flow predictions """
n_predictions = len(flow_preds)
flow_loss = 0.0
# exlude invalid pixels and extremely large diplacements
mag = torch.sum(flow_gt**2, dim=1).sqrt()
valid = (valid >= 0.5) & (mag < max_flow)
# calculate loss for each optimization iteration
for i in range(n_predictions):
# first iteration has lowest weight gamma**(n_predictions-1),
# last one has gamma**0 = 1
i_weight = gamma**(n_predictions - i - 1)
# the loss at each stage is the mean over all pixels of the l1-norm of the estimated-gt-flow-difference
# non-valid pixels absolute difference is not counted
# i_loss shape: (batch, 2, ht, wd) ; valid[:,None] shape: (batch, 1, ht, wd)
i_loss = (flow_preds[i] - flow_gt).abs()
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
# calculate metrics
# calculate end point error for each pixel (euclidean distance true and estimated flow vector/position)
epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
epe = epe.view(-1)[valid.view(-1)]
metrics = {
'epe': epe.mean().item(),
'1px': (epe < 1).float().mean().item(),
'3px': (epe < 3).float().mean().item(),
'5px': (epe < 5).float().mean().item(),
}
return flow_loss, metrics
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
return optimizer, scheduler
class Logger:
def __init__(self, model, scheduler):
self.model = model
self.scheduler = scheduler
self.total_steps = 0
self.running_loss = {}
self.writer = None
def _print_training_status(self):
metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
# print the training status
print(training_str + metrics_str)
if self.writer is None:
self.writer = SummaryWriter()
for k in self.running_loss:
self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
self.running_loss[k] = 0.0
def push(self, metrics):
self.total_steps += 1
for key in metrics:
if key not in self.running_loss:
self.running_loss[key] = 0.0
self.running_loss[key] += metrics[key]
if self.total_steps % SUM_FREQ == SUM_FREQ-1:
self._print_training_status()
self.running_loss = {}
def write_dict(self, results):
if self.writer is None:
self.writer = SummaryWriter()
for key in results:
self.writer.add_scalar(key, results[key], self.total_steps)
def close(self):
self.writer.close()
def train(args):
# DataParallel splits input to the module into chunks in the batch dimension
# the forward and backward pass can be done independently on each of the devices
model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
print("Parameter Count: %d" % count_parameters(model))
# load model parameters from file
if args.restore_ckpt is not None:
model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
model.cuda()
model.train()
# collect mean/variance statistics on chairs training
# use these statistics during the rest of the training stages
if args.stage != 'chairs':
model.module.freeze_bn()
# instance of torch.utils.data.DataLoader for some dataset of type torch.utils.data.Dataset
# get the data loader that corresponds to the current training stage (chairs, things, sintel, kitti, hd1k)
train_loader = datasets.fetch_dataloader(args)
# get the optimizer and (learning rate) scheduler
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0
# gradient scaler multiplies loss by some factor to prevent underflows that can happen with float16 precision
# after gradient is accumulated, gradients are divided by this factor such that they do not interfere with the
# learning rate
scaler = GradScaler(enabled=args.mixed_precision)
logger = Logger(model, scheduler)
VAL_FREQ = 5000
add_noise = True
should_keep_training = True
while should_keep_training:
# iterate over batches
for i_batch, data_blob in enumerate(train_loader):
# set weight tensor gradients to zero
optimizer.zero_grad()
# send image1/image2/flow/valid_flow batches to the gpu
image1, image2, flow, valid = [x.cuda() for x in data_blob]
# add noise to images
if args.add_noise:
stdv = np.random.uniform(0.0, 5.0)
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
# forward pass
flow_predictions = model(image1, image2, iters=args.iters)
# calculate loss
loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
# backward pass with scaled loss/gradients
scaler.scale(loss).backward()
# unscale gradients on the weights
scaler.unscale_(optimizer)
# if the magnitude of the gradient vector exceeds a certain threshold according to some norm
# then multiply the unit gradient by that threshold and set it as the new gradient
# especially useful for recurrent neural networks suffering from exploding gradient problem
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
# step optimizer with scaler, with unscaling if not already done and checks for over/underflow
scaler.step(optimizer)
# step the (learning rate) schedule
scheduler.step()
# updates the scale factor (in reaction to occurrence of nan/inf params and thus skipped optimization steps)
scaler.update()
logger.push(metrics)
# validation during training for monitoring
if total_steps % VAL_FREQ == VAL_FREQ - 1:
PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
torch.save(model.state_dict(), PATH)
results = {}
for val_dataset in args.validation:
if val_dataset == 'chairs':
results.update(evaluate.validate_chairs(model.module))
elif val_dataset == 'sintel':
results.update(evaluate.validate_sintel(model.module))
elif val_dataset == 'kitti':
results.update(evaluate.validate_kitti(model.module))
logger.write_dict(results)
model.train()
if args.stage != 'chairs':
model.module.freeze_bn()
total_steps += 1
# stop training after specified number of steps (batches)
if total_steps > args.num_steps:
should_keep_training = False
break
logger.close()
PATH = 'checkpoints/%s.pth' % args.name
torch.save(model.state_dict(), PATH)
return PATH
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='raft', help="name your experiment")
parser.add_argument('--stage', help="determines which dataset to use for training")
parser.add_argument('--restore_ckpt', help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--validation', type=str, nargs='+')
parser.add_argument('--lr', type=float, default=0.00002)
parser.add_argument('--num_steps', type=int, default=100000)
parser.add_argument('--batch_size', type=int, default=6)
parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--iters', type=int, default=12)
parser.add_argument('--wdecay', type=float, default=.00005)
parser.add_argument('--epsilon', type=float, default=1e-8)
parser.add_argument('--clip', type=float, default=1.0)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
parser.add_argument('--add_noise', action='store_true')
args = parser.parse_args()
torch.manual_seed(1234)
np.random.seed(1234)
if not os.path.isdir('checkpoints'):
os.mkdir('checkpoints')
train(args)