-
Notifications
You must be signed in to change notification settings - Fork 14
/
main.py
41 lines (34 loc) · 1.18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import tensorflow as tf
from configs import parse_args
from utils import show_all_variables
from gan.spatchgan import SPatchGAN
def main():
"""General entry point for running GANs."""
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)) as sess:
if args.network == 'spatchgan':
gan = SPatchGAN('SPatchGAN', sess, args)
else:
raise RuntimeError('Invalid network!')
if args.phase == 'train':
gan.build_model_train()
show_all_variables()
gan.train()
print(" [*] Training finished!")
elif args.phase == 'test':
gan.build_model_test()
show_all_variables()
gan.test()
print(" [*] Test finished!")
elif args.phase == 'freeze_graph':
gan.build_model_test()
show_all_variables()
gan.freeze_graph()
print(" [*] Graph frozen!")
else:
raise RuntimeError('Invalid phase!')
if __name__ == '__main__':
main()