-
Notifications
You must be signed in to change notification settings - Fork 3
/
config_util.py
216 lines (198 loc) · 8.08 KB
/
config_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import argparse
import configparser
import os
from typing import List, Dict, Any
FILEPATH = os.path.abspath(__file__)
REPODIR = os.path.dirname(os.path.dirname(FILEPATH))
def parse_str_list(lines) -> List[str]:
return list(line for line in lines.splitlines() if len(line) > 0)
# Tempate that the config should match, providing types for each option
TEMPLATE = {
'Training': {
'batch_size': int,
'epochs': int,
'lr': float,
'seed': int,
'l1_loss': float,
'gpu': int,
},
'Model': {
'model': str,
'fixed_filts': str,
'saved_model': str,
},
'Model Kwargs': dict,
'Data Files': {
'base_dir': str,
'datasets': parse_str_list,
'loader_func': str,
},
'Loader Kwargs': dict,
'Preprocessing': {
'augment': bool,
'oversample': bool,
},
'Logging': {
'save_root': str,
'log_root': str,
'tag': str,
}
}
DEFAULTS = {
'Training': {
'batch_size': 256,
'epochs': 400,
'lr': 0.005,
'seed': 1111,
'l1_loss': 0.005,
'gpu': 0,
},
'Model': {
'model': 'CCNN',
'fixed_filts': None,
'saved_model': None,
},
'Model Kwargs': {
'in_channels': 1,
'num_filts': 2,
'filter_size': 3,
'order': 4,
'num_classes': 2,
'absbeta': False,
},
'Data Files': {
'base_dir': os.path.join(os.path.dirname(REPODIR), 'QGasData'),
'datasets': ['FullInfoAS', 'FullInfoPi'],
'loader_func': 'load_qgm_data',
},
'Loader Kwargs': {
'doping_level': 9.0,
'crop': None,
'circle_crop': False,
},
'Preprocessing': {
'augment': False,
'oversample': False,
},
'Logging': {
'save_root': os.path.join(REPODIR, 'model'),
'log_root': os.path.join(REPODIR, 'log'),
'tag': '',
}
}
def eval_parse_dict(items) -> Dict:
parsed_dict = dict()
for key, val in items:
parsed_dict[key] = eval(val)
return parsed_dict
# noinspection PyTypeChecker
def parse_config(filename: str) -> Dict:
config = configparser.ConfigParser()
with open(filename, 'r') as f:
config.read_file(f)
parsed_dict = {}
for section, options_dict in TEMPLATE.items():
config_sect = config[section]
# Special case, parse dict using Python evaluator
# Note: `options_dict is dict` is checking if `options_dict` is literally the
# *function* dict, not checking if it is a dictionary.
if options_dict is dict:
parsed_dict[section] = eval_parse_dict(config.items(section))
continue
parsed_dict[section] = dict()
parsed_section = parsed_dict[section]
default_section = DEFAULTS[section]
for option, converter in options_dict.items():
if converter is int:
parsed_section[option] = config_sect.getint(
option, fallback=default_section[option]
)
elif converter is float:
parsed_section[option] = config_sect.getfloat(
option, fallback=default_section[option]
)
elif converter is bool:
parsed_section[option] = config_sect.getboolean(
option, fallback=default_section[option]
)
elif converter is str:
parsed_section[option] = config_sect.get(
option, fallback=default_section[option]
)
else:
option_str = config_sect.get(option)
if option_str is None:
parsed_section[option] = default_section[option]
else:
parsed_section[option] = converter(option_str)
return parsed_dict
def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str,
help='Config file for training.'
' Settings of the config are overridden by command line arguments.')
parser.add_argument('--batch-size', type=int,
help='Training batch size (Default: 256)')
parser.add_argument('-e', '--epochs', type=int,
help='Number of epochs to train (Default: 400)')
parser.add_argument('--data-dir', type=str,
help='Directory where data is located (Default: <REPODIR>/../QGasData)')
parser.add_argument('--lr', type=float,
help='Starting learning rate (Default: 0.005)')
parser.add_argument('--save-root', type=str,
help='Directory to save model in')
parser.add_argument('--log-root', type=str,
help='Directory to save training logs to')
parser.add_argument('--model', type=str,
help='Sets the model used. See bottom of nn_models.py for options.'
' (Default: CCNN).')
parser.add_argument('--num-filts', type=int,
help='Number of filters to use in the trained model.'
' (Default: 2')
parser.add_argument('--filter-size', type=int,
help='The spatial size of the learned filters.'
' (Default: 3)')
parser.add_argument('--order', type=int,
help='Sets the order of the correlators used in the model. '
'Only has an effect with Correlator architectures.'
' (Default: 4)')
parser.add_argument('--crop', type=int,
help='If set, crops snapshots to a square of the given size.')
parser.add_argument('--circle-crop', action='store_true', default=None,
help='If set, crops snapshots to the circular area of the Fermi-Hubbard'
' experiment')
parser.add_argument('--fixed-filts', type=str,
help='Only use along with fixed filter models. '
'Name of numpy file containing fixed filters to use.')
parser.add_argument('--saved-model', type=str,
help='Load from saved model and continue training')
parser.add_argument('--group', type=int,
help='If set, collects snapshots into groups of the given size which'
' are classified together.')
parser.add_argument('--seed', type=int, default=None,
help='Sets the random seed controlling parameter initialization/batching.')
parser.add_argument('--augment', action='store_true',
help='If set, performs data augmentation on the training data')
parser.add_argument('--l1-loss', type=float,
help='Coefficient on L1 norm regularization loss for convolutional filters'
' (Default: 0.005)')
parser.add_argument('--absbeta', action='store_true',
help="If set, forces logistic beta coefficients to be positive")
parser.add_argument('--reach', type=int,
help="Keyword argument for kagome models, defining filter sizes.")
parser.add_argument('--tag', type=str,
help='A tag to append to the log/model filenames')
parser.add_argument('--gpu', type=int, help='Index of GPU to use')
parser.add_argument('--doping-level', type=float, default=None,
help='For Fermi-Hubbard data, set the doping level to load.')
parser.add_argument('--fold', type=int, default=1,
help='Fold for 10-fold cross validation (Rydberg data only).')
return parser
# Updates a config dict with options provided through the command line
def update_config(config: Dict[str, Dict], args: argparse.Namespace):
for _, sect_dict in config.items():
for key in sect_dict.keys():
if key in args:
arg_val = getattr(args, key)
if arg_val is not None:
sect_dict[key] = arg_val