-
Notifications
You must be signed in to change notification settings - Fork 0
/
jointObjectiveEnt.m
55 lines (40 loc) · 1.26 KB
/
jointObjectiveEnt.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
function [f, g] = jointObjectiveEnt(x, F, labels, scope, S, C, F_labels, za, zb, varargin)
% outputs the objective value and gradient of the joint learning objective
% using the dual of loss-augmented inference to make the objective a
% minimization
% scope is an index vector (or logical vector) indicating which entries of
% the marginal vector should be counted in the loss
if ~exist('F_labels', 'var')
F_labels = F*labels;
end
% set numeric constants
if ~exist('za', 'var') || isempty(za)
za = 1e-5;
end
if ~exist('zb', 'var') || isempty(za)
zb = 1e-5;
end
[d,m] = size(F);
w = x(1:d);
kappa = max(0, x(d+1)); % don't let kappa be negative
lambda = x(d+2:end);
%isolate w
% kappa = 1;
A = S.Aeq;
b = S.beq;
ell = zeros(size(labels));
ell(scope) = 1-2*labels(scope);
% delta = sum(labels(scope));
logy = (F'*w + ell + A'*lambda)/(kappa + zb) - 1;
y = exp(logy);
loss = C * ((kappa + zb) * sum(y) - w'*F_labels - b'*lambda);
f = 0.5*(w'*w) / ((kappa + za)^2) + loss;
maxG = 1e16;
if nargout == 2
gradW = w / ((kappa + za)^2) + C*(F * y) - C*F_labels;
gradKappa = - w'*w / ((kappa + za)^3) - C*y(y>0)'*logy(y>0);
gradLambda = C * (A * y - b);
g = [gradW; gradKappa; gradLambda];
% g(g>maxG) = maxG;
% g(g<-maxG) = -maxG;
end