-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinit_config.py
81 lines (68 loc) · 2.14 KB
/
init_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from easydict import EasyDict as edict
import yaml
from utils.flatwhite import *
import torch
import numpy as np
import random
import platform
def easy_dic(dic):
dic = edict(dic)
for key, value in dic.items():
if isinstance(value, dict):
dic[key] = edict(value)
return dic
def show_config(config, sub=False):
msg = ''
for key, value in config.items():
if isinstance(value, dict):
continue
else:
msg += '{:>25} : {:<15}\n'.format(key, value)
return msg
def type_align(source, target):
if isinstance(source, int):
return int(target)
elif isinstance(source, float):
return float(target)
elif isinstance(source, str):
return target
elif isinstance(source, bool):
return bool(source)
else:
print("Unsupported type: {}".format(type(source)))
def config_parser(config, args):
print(args)
for arg in args:
if '=' not in arg:
continue
else:
key, value = arg.split('=')
value = type_align(config[key], value)
config[key] = value
return config
def init_config(config_path, argvs):
with open(config_path, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
f.close()
config = easy_dic(config)
config = config_parser(config, argvs)
config.snapshot = osp.join(config.snapshot, config.note)
mkdir(config.snapshot)
print('Snapshot stored in: {}'.format(config.snapshot))
if config.fix_seed:
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)
message = show_config(config)
print(message)
return config
def set_exp_name(config):
exp_name = '{}'.format(config.note)
exp_name += '_D-{}'.format(config.source)
exp_name += '_tar-{}'.format(config.target)
exp_name += '_B-{}'.format(config.batch_size)
if config.step_0:
exp_name += 'Avg_ece-start-{}'.format(config.ece_train)
exp_name += '_ece-w{}'.format(config.alpha)
return exp_name