Skip to content

Commit

Permalink
Changing test parameters (#1)
Browse files Browse the repository at this point in the history
Adjusting the testing parameters for the dropout version of RBM
training to be closer to the experiments shown in (Srivastava *et al*,
2014).
  • Loading branch information
eric-tramel committed Sep 21, 2015
1 parent 916ca6b commit 7ac812a
Showing 1 changed file with 8 additions and 27 deletions.
35 changes: 8 additions & 27 deletions test/testdropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,20 @@ using ImageView
using Gadfly
using DataFrames

function plot_weights(W, imsize, padding=10)
h, w = imsize
n = size(W, 1)
rows = int(floor(sqrt(n)))
cols = int(ceil(n / rows))
halfpad = div(padding, 2)
dat = zeros(rows * (h + padding), cols * (w + padding))
for i=1:n
wt = W[i, :]
wim = reshape(wt, imsize)
wim = wim ./ (maximum(wim) - minimum(wim))
r = div(i - 1, cols) + 1
c = rem(i - 1, cols) + 1
dat[(r-1)*(h+padding)+halfpad+1 : r*(h+padding)-halfpad,
(c-1)*(w+padding)+halfpad+1 : c*(w+padding)-halfpad] = wim
end
view(dat)
return dat
end


function run_mnist()
# Configure Test
X, y = testdata()
HiddenUnits = 100
Epochs = 5
HiddenUnits = 256
Epochs = 15
X = X ./ (maximum(X) - minimum(X))
m_do = BernoulliRBM(28*28, HiddenUnits)
m = BernoulliRBM(28*28, HiddenUnits)
m_do = BernoulliRBM(28*28, HiddenUnits; momentum=0.95)
m = BernoulliRBM(28*28, HiddenUnits; momentum = 0.5)

# Fit Models
m_do, historical_pl_do = fit(m_do, X; persistent=true, lr=0.1, n_iter=Epochs, batch_size=100, n_gibbs=1, dorate=0.5)
m, historical_pl = fit(m, X; persistent=true, lr=0.1, n_iter=Epochs, batch_size=100, n_gibbs=1, dorate=0.0)
m_do, historical_pl_do = fit(m_do, X; persistent=false, lr=0.1, n_iter=Epochs, batch_size=100,
n_gibbs=1, dorate=0.5, weight_decay="l1",decay_magnitude=0.1)
m, historical_pl = fit(m, X; persistent=true, lr=0.1, n_iter=Epochs, batch_size=100,
n_gibbs=1, dorate=0.0, weight_decay="l1",decay_magnitude=0.1)

# Put results in dataframe
NoDropoutActivations = Boltzmann.transform(m,X)
Expand Down

0 comments on commit 7ac812a

Please sign in to comment.