-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathGaussianBernoulliRBM.m
110 lines (82 loc) · 3.55 KB
/
GaussianBernoulliRBM.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
classdef GaussianBernoulliRBM < matlab.mixin.Heterogeneous & handle
%BERNOULLIRBM Summary of this class goes here
% Detailed explanation goes here
properties
options
visibleunits
hiddenunits
nextrbm = 0
prevrbm = 0
W
vW
b
vb
c
vc
sig
end
methods
function rbm = GaussianBernoulliRBM(visibleunits, hiddenunits)
rbm.visibleunits = visibleunits;
rbm.hiddenunits = hiddenunits;
end
function init(rbm)
rbm.W = zeros(rbm.hiddenunits, rbm.visibleunits);
rbm.vW = zeros(rbm.hiddenunits, rbm.visibleunits);
rbm.b = zeros(rbm.visibleunits, 1);
rbm.vb = zeros(rbm.visibleunits, 1);
rbm.c = zeros(rbm.hiddenunits, 1);
rbm.vc = zeros(rbm.hiddenunits, 1);
rbm.sig = ones(rbm.visibleunits, 1);
end
function fit(rbm, X, options)
idx = 1;
m = size(X, 2);
for i = 1:options.epochs
randpos = randperm(m);
fprintf('[BernoulliRBM] Epoch: %d / %d\n', i, options.epochs);
while idx <= m
endidx = min(idx + options.batchsize - 1, m);
currentbatchsize = endidx - idx + 1;
v1 = zeros(size(X, 1), currentbatchsize);
for k = idx:endidx
v1(:, k - idx + 1) = X(:, randpos(k));
end
h1 = gibbssample(repmat(rbm.c, 1, currentbatchsize) + rbm.W * v1);
% Algorithm for RBM starts here
% Initialze the chain that goes back and forth
if i == 1 && idx == 1
h2 = h1;
end
% k steps of gibbs sampling for the negative phase
for j = 1:options.k
v2 = gibbssample(repmat(rbm.b, 1, currentbatchsize) + rbm.W' * h2);
h2 = gibbssample(repmat(rbm.c, 1, currentbatchsize) + rbm.W * v2);
end
c1 = h1 * v1';
c2 = h2 * v2';
rbm.W = bsxfun(@rdivide, rbm.W, rbm.sig');
rbm.b = bsxfun(@rdivide, rbm.b, rbm.sig .* rbm.sig);
rbm.vW = options.momentum * rbm.vW + ...
options.alpha * (c1 - c2 - options.decay * rbm.W) / currentbatchsize;
rbm.vb = options.momentum * rbm.vb + ...
options.alpha * (sum(v1' - v2')' - options.decay * rbm.b) / currentbatchsize;
rbm.vc = options.momentum * rbm.vc + ...
options.alpha * (sum(h1' - h2')' - options.decay * rbm.c) / currentbatchsize;
rbm.W = rbm.W + rbm.vW;
rbm.b = rbm.b + rbm.vb;
rbm.c = rbm.c + rbm.vc;
% Algorithm for RBM ends here
idx = idx + currentbatchsize;
end
idx = 1;
end
end
function X = rbmup(rbm, X)
X = sigmoid(repmat(rbm.c, 1, size(X, 2)) + rbm.W * X);
end
function X = rbmdown(rbm, X)
X = sigmoid(repmat(rbm.b, 1, size(X, 2)) + rbm.W' * X);
end
end
end