Skip to content

Commit

Permalink
Changing Dropout Order (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-tramel committed Sep 10, 2015
1 parent f0899b3 commit 76bbde1
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/rbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ function hid_means(rbm::RBM, vis::Mat{Float64})
end


function vis_means(rbm::RBM, hid::Mat{Float64}, suppressedUnits)
hid[suppressedUnits] = 0.0 # Suppress dropped hidden units
function vis_means(rbm::RBM, hid::Mat{Float64})
p = rbm.W' * hid .+ rbm.vbias
return logistic(p)
end
Expand All @@ -85,12 +84,8 @@ function sample_hiddens{V,H}(rbm::RBM{V, H}, vis::Mat{Float64})
end


function sample_visibles{V,H}(rbm::RBM{V,H}, hid::Mat{Float64}, suppressedUnits)
# At this point, `suppressedUnits` should no longee be an optional term.
# Only gibbs() calls this function, and we are now, in dropout mode, always
# generating the dropout pattern. It can, however, be a pattern of 0's, meaning
# that no units are dropped.
means = vis_means(rbm, hid, suppressedUnits)
function sample_visibles{V,H}(rbm::RBM{V,H}, hid::Mat{Float64})
means = vis_means(rbm, hid)
return sample(V, means)
end

Expand All @@ -113,11 +108,22 @@ function gibbs(rbm::RBM, vis::Mat{Float64}; n_times=1,dorate=0.0)
# units.
v_pos = vis
h_pos = sample_hiddens(rbm, v_pos)

# Applying Dropout
h_pos[suppressedUnits] = 0.0

v_neg = sample_visibles(rbm, h_pos, suppressedUnits)
h_neg = sample_hiddens(rbm, v_neg)

# Applying Dropout
h_neg[suppressedUnits] = 0.0

for i=1:n_times-1
v_neg = sample_visibles(rbm, h_neg, suppressedUnits)
h_neg = sample_hiddens(rbm, v_neg)

# Applying Dropout
h_neg[suppressedUnits] = 0.0
end
return v_pos, h_pos, v_neg, h_neg
end
Expand Down

0 comments on commit 76bbde1

Please sign in to comment.