-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmain_train.py
104 lines (70 loc) · 2.87 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import random
from easydict import EasyDict as edict
import json
import logging
import sys
import torch.backends.cudnn as cudnn
import torch.utils.data
from torch.utils.data import DataLoader
from datetime import date
ch = logging.StreamHandler(sys.stdout)
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(
format='%(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch])
logging.basicConfig(level=logging.INFO, format="")
# dataset
from lib.dataloader import PluckerData3D_precompute
# configuration
from config import get_config
# trainer
from trainer_plucker import PluckerTrainer
# main function
def main(configs):
train_loader = DataLoader(PluckerData3D_precompute( phase='train', config = configs), batch_size=configs.train_batch_size, shuffle=True, drop_last=True, num_workers=6)
val_loader = DataLoader(PluckerData3D_precompute(phase='valid', config = configs), batch_size=1, shuffle=False, drop_last=False, num_workers=1)
trainer = PluckerTrainer(configs, train_loader, val_loader)
# train
trainer.train()
if __name__ == '__main__':
configs = get_config()
# -------------------------------------------------------------
"""You can change the configurations here or in the file config.py"""
# select dataset
configs.dataset = "structured3D"
# configs.dataset = "semantic3D"
# dataset path
configs.data_dir = "./dataset"
# select which GPU to be used
configs.gpu_inds = 0
# This is a model number, set it to whatever you want
configs.model_nb = str(date.today())
# training batch size
configs.train_batch_size = 12
# learning rate
configs.train_lr = 1e-3
dconfig = vars(configs)
# if your training is terminated unexpectly, uncomment the following line and set the resume_dir to continue
# configs.resume_dir = "./output"
if configs.resume_dir:
resume_config = json.load(open(configs.resume_dir + "/" + configs.dataset + "/" + configs.model_nb + '/config.json', 'r'))
for k in dconfig:
if k in resume_config:
dconfig[k] = resume_config[k]
# most recent checkpoint
dconfig['resume'] = os.path.join(resume_config['out_dir'], resume_config['dataset'], configs.model_nb) + '/checkpoint.pth'
# the best checkpoint
# dconfig['resume'] = os.path.join(resume_config['out_dir'], resume_config['dataset'], configs.model_nb) + '/best_val_checkpoint.pth'
else:
dconfig['resume'] = None
logging.info('===> Configurations')
for k in dconfig:
logging.info(' {}: {}'.format(k, dconfig[k]))
# Convert to dict
configs = edict(dconfig)
if configs.train_seed is not None:
random.seed(configs.train_seed)
torch.manual_seed(configs.train_seed)
torch.cuda.manual_seed(configs.train_seed)
cudnn.deterministic = True
main(configs)