-
Notifications
You must be signed in to change notification settings - Fork 1
/
my_dbn.py
111 lines (90 loc) · 3.78 KB
/
my_dbn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from model.dbn import DBN
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
import numpy as np
import os
import matplotlib.pyplot as plt
import pickle
from scipy.io.arff import loadarff
def unison_shuffle(a, b):
assert len(a) == len(b)
p = np.random.permutation(len(a))
return a[p], b[p]
# Session settings
SAVE_DIR = 'saves/t1'
OUTPUT_DIR = 'results/final/mnist_small_nopretrain2'
if os.path.exists(OUTPUT_DIR):
choice = input(OUTPUT_DIR + ' already exists. Do you want to overwrite these results? y/n')
if choice != 'y':
print('Exiting')
exit()
PRETRAIN_ITERATIONS = 1
LEARNING_RATE = 0.01
DECAY_LR = False
FREEZE_RBMS = False
SAMPLE=True
RBM_ACTIVATION = 'sigmoid'
RBM_LAYERS = [600,625,650,600]
KEEP_CHANCE = 0.9
NUM_TRAIN = 5000
IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
mnist = read_data_sets("data", one_hot=True, reshape=False, validation_size=6000)
deep = DBN(hidden_layers=RBM_LAYERS,
rbm_activation=RBM_ACTIVATION,
freeze_rbms=FREEZE_RBMS,
keep_chance=KEEP_CHANCE,
fully_connected_layers=[150,200],
connected_activations=['relu','relu'])
flattened_train = np.reshape(mnist.train.images, [mnist.train.images.shape[0], -1])
flattened_validation = np.reshape(mnist.validation.images,
[mnist.validation.images.shape[0], -1])
flattened_test = np.reshape(mnist.test.images, [mnist.test.images.shape[0], -1])
deep.pretrain(train_set=flattened_train,
pretrain_iterations=PRETRAIN_ITERATIONS,
learning_rate=LEARNING_RATE)
# Remove "unlabeled" data for experiment
flattened_train, train_labels = unison_shuffle(flattened_train, mnist.train.labels)
flattened_train = flattened_train[0:NUM_TRAIN,:]
train_labels = train_labels[0:NUM_TRAIN,:]
loss_hist, acc_hist = deep.train(train_set=flattened_train,
train_labels=train_labels,
validation_set=flattened_validation,
validation_labels=mnist.validation.labels,
save_dir=SAVE_DIR,
learning_rate=LEARNING_RATE,
decay_lr=DECAY_LR,
is_sampled=SAMPLE)
accuracy, loss = deep.measure_test_accuracy(test_set=flattened_test,
test_labels=mnist.test.labels,
save_dir=SAVE_DIR)
print('Final Accuracy:', accuracy)
print('Final Loss:', loss)
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
with open(os.path.join(OUTPUT_DIR, 'loss.pickle'), 'wb') as f:
pickle.dump(loss_hist, f)
with open(os.path.join(OUTPUT_DIR, 'accuracy.pickle'), 'wb') as f:
pickle.dump(acc_hist, f)
with open(os.path.join(OUTPUT_DIR, 'summary.txt'), 'w') as f:
f.write('RESULTS:\n')
f.write('Trained in ' + str(len(loss_hist)) + ' epochs.\n')
f.write('LR = ' + str(LEARNING_RATE) + '\n')
f.write('Training dropout keep chance = ' + str(KEEP_CHANCE) + '\n')
f.write('DECAY_LR = ' + str(DECAY_LR) + '\n')
f.write('PRETRAIN_ITERATIONS = ' + str(PRETRAIN_ITERATIONS) + '\n')
f.write('RBM_LAYERS = ' + str(RBM_LAYERS) + '\n')
f.write('Freeze RBMS = ' + str(FREEZE_RBMS) + '\n')
f.write('SAMPLE = ' + str(SAMPLE) + '\n')
f.write('RBM activation = ' + str(RBM_ACTIVATION) + '\n')
f.write('Final Accuracy = ' + str(accuracy) + '\n')
f.write('Final Cross Entropy Loss = ' + str(loss) + '\n')
plt.plot(loss_hist)
plt.ylabel('Mean Cross Entropy Loss')
plt.xlabel('Epochs')
plt.savefig(os.path.join(OUTPUT_DIR, 'loss.png'))
plt.close()
plt.plot(acc_hist)
plt.ylabel('Validation Accuracy')
plt.xlabel('Epochs')
plt.savefig(os.path.join(OUTPUT_DIR, 'acc.png'))
plt.close()