-
Notifications
You must be signed in to change notification settings - Fork 0
/
cvx_solve.m
65 lines (59 loc) · 1.61 KB
/
cvx_solve.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
function [B,Bcost,info] = cvx_solve(A,version)
if nargin==1 || version == "fast"
[B,Bcost,info] = cvx_solve_fast(A);
elseif version=="slow"
[B,Bcost,info] = cvx_solve_slow(A);
end
end
function [B,Bcost,info] = cvx_solve_fast(A)
[n,~,k] = size(A);
Asqrt=multisqrtm(A);
diagAsqrt = zeros(n*k);
for j=1:k
ids = (j-1)*n + (1:n);
diagAsqrt(ids,ids) = Asqrt(:,:,j);
end
trA = sum(multitrace(A));
% The convex problem:
cvx_begin sdp quiet
variable B( n, n ) symmetric
cost = trace(B) + trA/k;
for j=1:k
cost = cost - 2/k * trace_sqrtm(Asqrt(:,:,j) * B * Asqrt(:,:,j));
end
minimize(cost)
B >= 0
cvx_end
Bcost=cvx_optval;
info.cost = cvx_optval;
info.status = cvx_status;
info.iter = cvx_slvitr;
info.tol = cvx_slvtol;
%info.optbnd = cvx_optbnd;
info.time = cvx_cputime;
% if info.status~="Solved"
% warning(info.status)
% end
end
function [B,Bcost,info] = cvx_solve_slow(A)
[n,~,k] = size(A);
Asqrt=multisqrtm(A);
diagAsqrt = zeros(n*k);
for j=1:k
ids = (j-1)*n + (1:n);
diagAsqrt(ids,ids) = Asqrt(:,:,j);
end
trA = sum(multitrace(A));
cvx_begin sdp quiet
variable B( n, n ) symmetric
minimize(trace(B) + trA/k - 2/k*trace_sqrtm(diagAsqrt * kron(eye(k),B) * diagAsqrt))
B >= 0
cvx_end
Bcost=cvx_optval;
info.cost = cvx_optval;
info.status = cvx_status;
info.iter = cvx_slvitr;
info.tol = cvx_slvtol;
info.optbnd = cvx_optbnd;
info.time = cvx_cputime;
end