-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgkernel.m
3960 lines (42 loc) · 4.9 KB
/
gkernel.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
function [grad] = gkernel(A,K,S,m,para,Obs)
grad=zeros(size(A,1),size(A,2));
eps=1E-10;
M=size(K,3);
N=size(K,2);
for l =1:1:M
temp(l).D=A(:,Obs(l).id,l)*K(Obs(l).id,Obs(l).id,l);
temp(l).hatk=temp(l).D*A(:,Obs(l).id,l)';
end
nObs=length(Obs(m).id);
if nObs~=N
temp(m).B=K(Obs(m).id,Obs(m).id,m)-temp(m).hatk(Obs(m).id,Obs(m).id);
temp(m).E=temp(m).B*temp(m).D(Obs(m).id,:);
else
temp(m).B=K(:,:,m)-temp(m).hatk;
temp(m).E=temp(m).B*temp(m).D;
end
for l=1:1:M
temp(l).C=zeros(N,N) ;
for l2=1:1:M
if l2==l
temp(l).C =temp(l).C+temp(l).hatk;
else
temp(l).C=temp(l).C-S(l,l2)*temp(l2).hatk;
end
end
end
% grad for A of m th view only
grad(Obs(m).id,Obs(m).id) = grad(Obs(m).id,Obs(m).id) - 4*para.c1*temp(m).E/(M*length(Obs(m).id)^2); % for loss function
% relevence part
part= zeros(size(A,1),size(A,2));
for l =1:1:M
if l==m
part=part+temp(m).C;
else
part=part-S(l,m)*temp(l).C;
end
end
grad(Obs(m).id,Obs(m).id)=grad(Obs(m).id,Obs(m).id)+4*para.c2*(part(Obs(m).id,Obs(m).id)*temp(m).D(Obs(m).id,:))/(M*N);
clear temp part;
return;
end