-
Notifications
You must be signed in to change notification settings - Fork 49
/
Copy pathtrain.py
99 lines (85 loc) · 4.35 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Train directional marking point detector."""
import math
import random
import torch
from torch.utils.data import DataLoader
import config
import data
import util
from model import DirectionalPointDetector
def plot_prediction(logger, image, marking_points, prediction):
"""Plot the ground truth and prediction of a random sample in a batch."""
rand_sample = random.randint(0, image.size(0)-1)
sampled_image = util.tensor2im(image[rand_sample])
logger.plot_marking_points(sampled_image, marking_points[rand_sample],
win_name='gt_marking_points')
sampled_image = util.tensor2im(image[rand_sample])
pred_points = data.get_predicted_points(prediction[rand_sample], 0.01)
if pred_points:
logger.plot_marking_points(sampled_image,
list(list(zip(*pred_points))[1]),
win_name='pred_marking_points')
def generate_objective(marking_points_batch, device):
"""Get regression objective and gradient for directional point detector."""
batch_size = len(marking_points_batch)
objective = torch.zeros(batch_size, config.NUM_FEATURE_MAP_CHANNEL,
config.FEATURE_MAP_SIZE, config.FEATURE_MAP_SIZE,
device=device)
gradient = torch.zeros_like(objective)
gradient[:, 0].fill_(1.)
for batch_idx, marking_points in enumerate(marking_points_batch):
for marking_point in marking_points:
col = math.floor(marking_point.x * config.FEATURE_MAP_SIZE)
row = math.floor(marking_point.y * config.FEATURE_MAP_SIZE)
# Confidence Regression
objective[batch_idx, 0, row, col] = 1.
# Makring Point Shape Regression
objective[batch_idx, 1, row, col] = marking_point.shape
# Offset Regression
objective[batch_idx, 2, row, col] = marking_point.x*16 - col
objective[batch_idx, 3, row, col] = marking_point.y*16 - row
# Direction Regression
direction = marking_point.direction
objective[batch_idx, 4, row, col] = math.cos(direction)
objective[batch_idx, 5, row, col] = math.sin(direction)
# Assign Gradient
gradient[batch_idx, 1:6, row, col].fill_(1.)
return objective, gradient
def train_detector(args):
"""Train directional point detector."""
args.cuda = not args.disable_cuda and torch.cuda.is_available()
device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
torch.set_grad_enabled(True)
dp_detector = DirectionalPointDetector(
3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
if args.detector_weights:
print("Loading weights: %s" % args.detector_weights)
dp_detector.load_state_dict(torch.load(args.detector_weights))
dp_detector.train()
optimizer = torch.optim.Adam(dp_detector.parameters(), lr=args.lr)
if args.optimizer_weights:
print("Loading weights: %s" % args.optimizer_weights)
optimizer.load_state_dict(torch.load(args.optimizer_weights))
logger = util.Logger(args.enable_visdom, ['train_loss'])
data_loader = DataLoader(data.ParkingSlotDataset(args.dataset_directory),
batch_size=args.batch_size, shuffle=True,
num_workers=args.data_loading_workers,
collate_fn=lambda x: list(zip(*x)))
for epoch_idx in range(args.num_epochs):
for iter_idx, (images, marking_points) in enumerate(data_loader):
images = torch.stack(images).to(device)
optimizer.zero_grad()
prediction = dp_detector(images)
objective, gradient = generate_objective(marking_points, device)
loss = (prediction - objective) ** 2
loss.backward(gradient)
optimizer.step()
train_loss = torch.sum(loss*gradient).item() / loss.size(0)
logger.log(epoch=epoch_idx, iter=iter_idx, train_loss=train_loss)
if args.enable_visdom:
plot_prediction(logger, images, marking_points, prediction)
torch.save(dp_detector.state_dict(),
'weights/dp_detector_%d.pth' % epoch_idx)
torch.save(optimizer.state_dict(), 'weights/optimizer.pth')
if __name__ == '__main__':
train_detector(config.get_parser_for_training().parse_args())