-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathKL_rdms.m
64 lines (45 loc) · 1.07 KB
/
KL_rdms.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
clear all;
rng default;
sem = @(x) std(x) / sqrt(length(x));
h = init_hyperparams();
%D = init_D_from_txt('solway4.txt');
D = init_D_from_txt('solway4.txt');
D.G.E(:) = 0; % erase edges; start empty
M = 100; % # particles
nsamples = 10; % # rejuvination steps
filename = sprintf('KL_rdms_M=%d_nsamples=%d_solway4.mat', M, nsamples);
for i = 1:M
H(i) = init_H(D, h);
P(i) = logpost(H(i), D, h);
end
P = P - logsumexp(P);
tic
for k = 1:size(D.G.edges,1)
k
% read new edge
u = D.G.edges(k,1);
v = D.G.edges(k,2);
D.G.E(u,v) = 1;
D.G.E(v,u) = 1;
% update posteriors
Q = P;
for i = 1:M
P(i) = logpost(H(i), D, h);
end
P = P - logsumexp(P);
% KL
KL(k) = KL_divergence(exp(P), exp(Q));
fprintf('KL(%d) = %f\n', k, KL(k));
% MCMC rejuvination
for i = 1:M
disp(i);
[samples, post] = sample(D, h, nsamples, 1, 1, H(i));
H(i) = samples(end);
P(i) = post(end);
end
P = P - logsumexp(P);
P_all(k,:) = P;
end
toc
RDM = squareRDMs(pdist(P_all, 'cosine'));
save(filename);