-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
77 lines (63 loc) · 1.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import random
import shutil
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from config import *
from train import train
from utils import generate_dataset, generate_model, show_config
def main():
# print configuration
show_config({
'BASIC CONFIG': BASIC_CONFIG,
'DATA CONFIG': DATA_CONFIG,
'TRAIN CONFIG': TRAIN_CONFIG
})
# reproducibility
seed = BASIC_CONFIG['random_seed']
set_random_seed(seed)
# create folder
save_path = BASIC_CONFIG['save_path']
if not os.path.exists(save_path):
os.makedirs(save_path)
# build model
network = BASIC_CONFIG['network']
device = BASIC_CONFIG['device']
model = generate_model(
network,
NET_CONFIG,
device,
BASIC_CONFIG['pretrained'],
BASIC_CONFIG['checkpoint']
)
# create dataset
train_dataset, val_dataset = generate_dataset(
DATA_CONFIG,
BASIC_CONFIG['data_path'],
BASIC_CONFIG['data_index']
)
# create logger
record_path = BASIC_CONFIG['record_path']
if os.path.exists(record_path):
shutil.rmtree(record_path)
logger = SummaryWriter(BASIC_CONFIG['record_path'])
# create estimator and then train
train(
model=model,
train_config=TRAIN_CONFIG,
data_config=DATA_CONFIG,
train_dataset=train_dataset,
val_dataset=val_dataset,
save_path=save_path,
device=device,
logger=logger
)
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
main()