-
Notifications
You must be signed in to change notification settings - Fork 0
/
measureStability.m
69 lines (58 loc) · 1.43 KB
/
measureStability.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
function [stab_mu, stab_y] = measureStability(w, X, k, F, S, type, kappa)
% w : weight vector.
% X : d x n matrix, representing n nodes taking d values.
% k : number of labels in the hidden states.
% F : features
% S : constraints structure.
% type : type of inference to use: 0 for dual, 1 for CRF
% kappa : modulus of convexity
%
% stab_mu : 1-norm stability of marginals
% stab_y : Hamming stability of decoding
% dimensions
[d,n] = size(X);
% initialize stability to 0
stab_mu = 0;
stab_y = 0;
% run initial inference
if type == 1
y_0 = crfInference(w, F, n*k, S);
else
y_0 = dualInference(w, F, kappa, S);
end
pred_0 = predictMax(y_0(1:k*n), n, k);
% run perturbed inference
for i=1:n
x_i = find(X(:,i));
for j=1:d
if j ~= x_i
% perturb x_i in X
X(:,i) = zeros(d,1);
X(j,i) = 1;
% recompute local features
localF = localFeatures(X,k);
[localm,localn] = size(localF);
F(1:localm,1:localn) = localF;
% run inference
if type == 1
y_1 = crfInference(w, F, n*k, S);
else
y_1 = dualInference(w, F, kappa, S);
end
% measure 1-norm and store max
delta = norm(y_0(1:k*n)-y_1(1:k*n), 1);
if stab_mu < delta
stab_mu = delta;
end
% measure Hamming distance of decoding and store max
pred_1 = predictMax(y_1(1:k*n), n, k);
delta = nnz(pred_0 ~= pred_1);
if stab_y < delta
stab_y = delta;
end
% replace perturbed value
X(j,i) = 0;
X(x_i,i) = 1;
end
end
end