-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
executable file
·73 lines (59 loc) · 1.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Main entry for running experiments.
The results will be saved to ./experiments/
(relative to this file)
"""
import argparse
import json
import os
import sys
from spanparser.configs import (
load_config,
merge_configs,
ConfigDict,
)
from spanparser.experiments import Experiment
from spanparser.outputter import Outputter
if 'SPANPARSER_BASEDIR' in os.environ:
BASEDIR = os.environ['SPANPARSER_BASEDIR']
else:
BASEDIR = os.path.abspath(os.path.join(
os.path.dirname(__file__), 'out',
))
if not os.path.isdir(BASEDIR):
os.makedirs(BASEDIR)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-l', '--load-prefix',
help='Load a model from this file')
parser.add_argument('-c', '--config-string',
help='Additional config (as JSON)')
parser.add_argument('-o', '--outdir',
help='Force the output directory')
parser.add_argument('-s', '--seed', type=int,
help='Set seed')
parser.add_argument('action', choices=['train', 'test'])
parser.add_argument('configs', nargs='+',
help='Config JSON or YAML files')
args = parser.parse_args()
config = {}
for path in args.configs:
new_config = load_config(path)
merge_configs(config, new_config)
if args.config_string:
new_config = json.loads(args.config_string)
merge_configs(config, new_config)
print(json.dumps(config, indent=2))
config = ConfigDict(config)
outputter = Outputter(config, basedir=BASEDIR, force_outdir=args.outdir)
experiment = Experiment(config, outputter, args.load_prefix, args.seed)
if args.action == 'train':
experiment.train()
elif args.action == 'test':
experiment.test()
else:
raise ValueError('Unknown action: {}'.format(args.action))
if __name__ == '__main__':
main()