-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsp.py
64 lines (50 loc) · 1.79 KB
/
sp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Optional
import os.path as osp
import json
import yaml
import nni
class SimpleParam:
def __init__(self, local_dir: str = 'param', default: Optional[dict] = None):
if default is None:
default = dict()
self.local_dir = local_dir
self.default = default
def __call__(self, source: str, preprocess: str = 'none'):
if source == 'nni':
return {**self.default, **nni.get_next_parameter()}
if source.startswith('local'):
ts = source.split(':')
assert len(ts) == 2, 'local parameter file should be specified in a form of `local:FILE_NAME`'
path = ts[-1]
path = osp.join(self.local_dir, path)
if path.endswith('.json'):
loaded = parse_json(path)
elif path.endswith('.yaml') or path.endswith('.yml'):
path = "../"+path
loaded = parse_yaml(path)
else:
raise Exception('Invalid file name. Should end with .yaml or .json.')
if preprocess == 'nni':
loaded = preprocess_nni(loaded)
return {**self.default, **loaded}
if source == 'default':
return self.default
raise Exception('invalid source')
def preprocess_nni(params: dict):
def process_key(key: str):
xs = key.split('/')
if len(xs) == 3:
return xs[1]
elif len(xs) == 1:
return key
else:
raise Exception('Unexpected param name ' + key)
return {
process_key(k): v for k, v in params.items()
}
def parse_yaml(path: str):
content = open(path).read()
return yaml.load(content, Loader=yaml.Loader)
def parse_json(path: str):
content = open(path).read()
return json.loads(content)