generated from dongliangcao/pytorch-framework
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathoptions.py
139 lines (109 loc) · 4.09 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import random
import yaml
from collections import OrderedDict
from os import path as osp
import torch
from .misc import make_exp_dirs, set_random_seed
from .dist_util import get_dist_info, init_dist
def ordered_yaml():
"""Support OrderedDict for yaml.
Returns:
yaml Loader and Dumper.
"""
try:
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError:
from yaml import Dumper, Loader
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
def dict_constructor(loader, node):
return OrderedDict(loader.construct_pairs(node))
Dumper.add_representer(OrderedDict, dict_representer)
Loader.add_constructor(_mapping_tag, dict_constructor)
return Loader, Dumper
def parse(opt_path, root_path, is_train=True):
"""Parse option file.
Args:
opt_path (str): Option file path.
root_path (str): Root path.
is_train (str): Indicate whether in training or not. Default True.
Returns:
(dict): Options.
"""
# read config yaml file
with open(opt_path, mode='r') as f:
Loader, _ = ordered_yaml()
opt = yaml.load(f, Loader=Loader)
opt['is_train'] = is_train
# set number of gpus
if opt['num_gpu'] == 'auto':
opt['num_gpu'] = torch.cuda.device_count()
# paths
for key, val in opt['path'].items():
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
opt['path'][key] = osp.expanduser(val)
if is_train: # specify training log paths
experiments_root = osp.join(root_path, 'experiments', opt['name'])
opt['path']['experiments_root'] = experiments_root
opt['path']['models'] = osp.join(experiments_root, 'models')
opt['path']['log'] = osp.join(experiments_root, 'log')
else: # specify test log paths
results_root = osp.join(root_path, 'results', opt['name'])
opt['path']['results_root'] = results_root
opt['path']['log'] = osp.join(results_root, 'log')
opt['path']['visualization'] = osp.join(results_root, 'visualization')
return opt
def dict2str(opt, indent_level=1):
"""dict to string for printing options.
Args:
opt (dict): Option dict.
indent_level (int): Indent level. Default 1.
Return:
(str): Option string for printing.
"""
msg = '\n'
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_level * 2) + str(k) + ':['
msg += dict2str(v, indent_level + 1)
msg += ' ' * (indent_level * 2) + ']\n'
elif isinstance(v, list):
msg = ''
for iv in v:
if isinstance(iv, dict):
msg += dict2str(iv, indent_level)
else:
msg += '\n' + ' ' * (indent_level * 2) + str(iv)
else:
msg += ' ' * (indent_level * 2) + str(k) + ': ' + str(v) + '\n'
return msg
def parse_options(root_path, is_train=True):
parser = argparse.ArgumentParser()
parser.add_argument('--opt', type=str, required=True, help='Path to option YAML file.')
args = parser.parse_args()
opt = parse(args.opt, root_path, is_train=is_train)
# distributed settings
if opt['backend'] == 'dp':
opt['dist'] = False
print('Backend DataParallel.', flush=True)
elif opt['backend'] == 'ddp':
opt['dist'] = True
port = opt.get('port', 29500)
init_dist(port=port)
print('Backend DistributedDataParallel.', flush=True)
else:
raise ValueError(f'Invalid backend option: {opt["backend"]}, only supports "dp" and "ddp"')
# set rank and world_size
opt['rank'], opt['world_size'] = get_dist_info()
# make experiment directories
make_exp_dirs(opt)
# set random seed
seed = opt.get('manual_seed')
if seed is None:
seed = random.randint(1, 10000)
opt['manual_seed'] = seed
set_random_seed(seed + opt['rank'])
return opt