-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmain.py
117 lines (94 loc) · 3.59 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Main training script."""
from contextlib import contextmanager
import argparse
import os
import pickle
import logging
import torch
import common
from common import PHASES, HASHER, G_ML, D_ML, ADV
from common import RUN_DIR, STATE_FILE, OPTS_FILE, LOG_FILE
import environ
def main():
"""Trains the model."""
parser = argparse.ArgumentParser()
parser.add_argument('--env', choices=environ.ENVS, default=environ.SYNTH)
parser.add_argument('--resume', action='store_true')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--prefix')
parser.add_argument('--rerun', nargs='+', default=[], choices=PHASES)
init_opts, remaining_opts = parser.parse_known_args()
opts_file = os.path.join(RUN_DIR, OPTS_FILE)
if init_opts.resume:
new_opts = environ.parse_env_opts(
init_opts, remaining_opts, no_defaults=True)
opts = argparse.Namespace(**common.unpickle(opts_file))
for k, v in vars(new_opts).items():
if k not in opts or v is not None:
setattr(opts, k, v)
else:
opts = environ.parse_env_opts(init_opts, remaining_opts)
os.mkdir(RUN_DIR)
with open(opts_file, 'wb') as f_opts:
pickle.dump(vars(opts), f_opts)
logging.basicConfig(format='%(message)s', level=logging.DEBUG, filemode='w')
torch.manual_seed(opts.seed)
torch.cuda.manual_seed_all(opts.seed)
env = environ.create(opts.env, opts)
for phase in PHASES:
if phase == HASHER and not opts.exploration_bonus:
continue
torch.manual_seed(opts.seed)
torch.cuda.manual_seed_all(opts.seed)
with _phase(env, phase, opts) as phase_runner:
if phase_runner:
logging.debug(f'# running phase: {phase}')
phase_runner() # pylint: disable=not-callable
@contextmanager
def _phase(env, phase, opts):
phase_dir = os.path.join(RUN_DIR, phase)
if not os.path.isdir(phase_dir):
os.mkdir(phase_dir)
prefixes = [opts.prefix]*bool(opts.prefix)
def _prefix(suffixes):
suffixes = suffixes if isinstance(suffixes, list) else [suffixes]
return '_'.join(prefixes + suffixes)
snap_file = os.path.join(phase_dir, STATE_FILE)
prefix_snap_file = os.path.join(phase_dir, _prefix(STATE_FILE))
if os.path.isfile(prefix_snap_file):
snap_file = prefix_snap_file
if os.path.isfile(snap_file) and phase not in opts.rerun:
env.state = torch.load(snap_file)
yield None
return
if phase == HASHER:
# import functools
# def _saver(env, epoch):
# torch.save(env.hasher.state_dict(),
# os.path.join(phase_dir, f'{epoch}.pth'))
# runner = functools.partial(env.train_hasher, hook=_saver)
runner = env.train_hasher
elif phase == G_ML:
runner = env.pretrain_g
elif phase == D_ML:
runner = env.pretrain_d
elif phase == ADV:
runner = env.train_adv
logger = logging.getLogger()
def _add_file_handler(lvl, log_prefix=None):
suffixes = [log_prefix]*bool(log_prefix) + [LOG_FILE]
log_path = os.path.join(phase_dir, _prefix(suffixes))
handler = logging.FileHandler(log_path, mode='w')
handler.setLevel(lvl)
logger.addHandler(handler)
return handler
file_handlers = [
_add_file_handler(logging.INFO),
_add_file_handler(logging.DEBUG, 'debug'),
]
yield runner
torch.save(env.state, prefix_snap_file)
for handler in file_handlers:
logger.removeHandler(handler)
if __name__ == '__main__':
main()