-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
150 lines (107 loc) · 5.14 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import torch
import torchvision
import torch.utils.data
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from modules import utils
from modules.engine import train_one_epoch, evaluate
import config
from utils import create_folder
from early_stopping import EarlyStopping
from dataloader import CustomDataset, get_transform
def create_model():
backbone = torchvision.models.vgg16(weights="VGG16_Weights.IMAGENET1K_V1").features
backbone.out_channels = 512
anchor_size = config.anchor_size
anchor_ratio = config.anchor_ratio
anchor_generator = AnchorGenerator(sizes=(anchor_size,),
aspect_ratios=(anchor_ratio,))
class_map = config.class_map
num_classes = len(class_map)
min_size = config.min_size
max_size = config.max_size
box_detections_per_img = config.detections_per_img
model = FasterRCNN(backbone=backbone,
num_classes=num_classes,
min_size=min_size,
max_size=max_size,
rpn_anchor_generator=anchor_generator,
box_detections_per_img=box_detections_per_img
)
return model
def create_dataloader():
train_batch = config.train_batch
train_dataset = CustomDataset('/hdd/thaihq/qnet_search/ori_data/train', get_transform(train=True))
val_dataset = CustomDataset('/hdd/thaihq/qnet_search/ori_data/val', get_transform(train=False))
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch,
shuffle=True, num_workers=4,
collate_fn=utils.collate_fn)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=train_batch//2,
shuffle=False, num_workers=4,
collate_fn=utils.collate_fn)
return train_data_loader, val_data_loader
def val(model, data_loader, device):
metric_logger = utils.MetricLogger(delimiter=" ")
for images, targets in metric_logger.log_every(data_loader, print_freq=100):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) if isinstance(
v, torch.Tensor) else v for k, v in t.items()} for t in targets]
with torch.cuda.amp.autocast(enabled=False):
loss_dict = model(images, targets)
# losses = sum(loss for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
return metric_logger
def train():
model = create_model()
train_data_loader, val_data_loader = create_dataloader()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
# construct an optimizer - SGD follow Faster R-CNN paper
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=config.learning_rate, momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
# Early stopping
stopper = EarlyStopping(config.patience)
# Config paths
log_folder, weight_folder = create_folder('training', True)
print('Training logs and weights at', log_folder, '\n')
train_log_path = os.path.join(log_folder, 'train_log.txt')
val_log_path = os.path.join(log_folder, 'val_log.txt')
last_weight_path = os.path.join(weight_folder, 'last.pt')
best_weight_path = os.path.join(weight_folder, 'best.pt')
print('*** Start training ***')
num_epochs = config.num_epochs
for epoch in range(num_epochs):
# train for one epoch, printing every 100 iterations
train_loss = train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=100) # val_data_loader for debugging
# Log train loss
stopper.log(train_log_path, str(train_loss))
# Save last weight
print('--- Save the last weight ---')
torch.save(model, last_weight_path)
# update the learning rate
lr_scheduler.step()
# evaluate on the val dataset
print('*** Start validating ***')
val_loss = val(model, val_data_loader, device)
# Log val loss
stopper.log(val_log_path, str(val_loss))
# Extract val loss
# loss = float(str(val_loss).split(' ')[0].split(' ')[1]) # for debugging
loss = float(str(val_loss).split(' ')[0].split(' ')[2][1:-1]) # Average
# Save best weight
if loss <= stopper.best_fitness:
print('+++ Save the best weight +++')
torch.save(model, best_weight_path)
# Stop early
stop = stopper(epoch, loss)
if stop:
break
if __name__ == '__main__':
train()