-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.py
73 lines (55 loc) · 2.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import argparse
import pickle
from wrapper import Wrapper
from visualize import *
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='mnist2500')
parser.add_argument('--num_iters', type=int, default=1000)
parser.add_argument('--optimizer', type=str, default='gd-momentum')
parser.add_argument('--out_dim', type=int, default=2)
parser.add_argument('--svd_dim', type=int, default=50)
parser.add_argument('--lr', type=float, default=1000.0)
parser.add_argument('--anneal_scheme', type=int, default=0,
help='annealing scheme: 0 is no annealing, '+
'1 is linear annealing from tmin to tmax after half of iterations')
parser.add_argument('--t', type=float, default=2.0)
parser.add_argument('--t_max', type=float, default=3.0)
parser.add_argument('--save_fig', action='store_true')
parser.add_argument('--animate', action='store_true')
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--print_every', type=int, default=100)
def main(config):
# initialize trimap
with open('data/%s.pkl' % config.dataset, 'rb') as f:
X, labels = pickle.load(f)
trimap = Wrapper(config)
triplets_path = 'triplets/%s.pkl' % config.dataset
if os.path.isfile(triplets_path):
trimap.load_triplets(triplets_path)
else:
if not os.path.exists('triplets'):
os.makedirs('triplets')
trimap.generate_triplets(X, triplets_path)
if config.save_fig:
# create and save an embedding
fig_name = '%s-%s' % (config.dataset, config.optimizer)
fig_temp = 'figures/%s.' + ('gif' if config.animate else 'png')
if not os.path.exists('figures'):
os.makedirs('figures')
i = 0
while os.path.exists(fig_temp % (fig_name+str(i))):
i += 1
fig_path = fig_temp % (fig_name+str(i))
if config.animate:
Y_seq = trimap.embed(return_seq=True)
savegif(Y_seq, labels, fig_name, fig_path)
else:
Y = trimap.embed()
savepng(Y, labels, fig_name, fig_path)
else:
Y = trimap.embed()
scatter(Y, labels)
if __name__ == '__main__':
config, unparsed = parser.parse_known_args()
main(config)