-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
60 lines (41 loc) · 1.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os, sys
import torch
import utils
import json
def load_domain(args):
if args.domain_name == 'layout':
from domains.layout import LAYOUT_DOMAIN
return LAYOUT_DOMAIN()
elif args.domain_name == 'omni':
from domains.omni import OMNI_DOMAIN
return OMNI_DOMAIN()
elif args.domain_name == 'shape':
from domains.shape import SHAPE_DOMAIN
return SHAPE_DOMAIN()
else:
assert False, f'bad domain name {args.domain_name}'
def main():
main_args = utils.getArgs([
('-mm', '--main_mode', None, str), # Set the main mode ['finetune', 'pretrain']
('-dn', '--domain_name', None, str), # Set the domain ['layout', 'omni', 'shape']
])
domain = load_domain(main_args)
if main_args.main_mode == 'finetune':
import finetune as ft
return ft.fine_tune(domain)
elif main_args.main_mode == 'pretrain':
import pretrain as pre
return pre.train(domain)
elif main_args.main_mode == 'fsg_eval':
import fsg_eval as fsg_eval
return fsg_eval.fsg_eval(domain)
elif main_args.main_mode == 'coseg':
import coseg_task as coseg
return coseg.eval_dom(domain)
elif main_args.main_mode == 'train_magg':
import train_magg_net as tmagg
tmagg.train_magg_net(domain)
else:
assert False, f'bad main main {main_args.main_mode}'
if __name__ == '__main__':
main()