Skip to content

Commit a410cda

Browse files
author
Dominik Kuzinowicz
committed
Adjusted IHB, added more FW variant options
1 parent 56d8ac0 commit a410cda

11 files changed

+109
-108
lines changed

src/oracle_avi.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ function fit_oavi(
8282
term_evaluated_squared = norm(term_evaluated, 2)^2
8383

8484
# built-in Frank-Wolfe oracle
85-
if oracle in ["CG", "BCG", "BPCG"]
85+
if oracle in ["CG", "Away", "AFW", "PCG", "Lazy", "LCG", "BCG", "BPCG"]
8686
coefficient_vector, loss = conditional_gradients( oracle,
8787
data, term_evaluated,
8888
lambda,

src/oracle_constructors.jl

+30-17
Original file line numberDiff line numberDiff line change
@@ -42,38 +42,47 @@ function conditional_gradients(
4242

4343
# determine oracle
4444
if oracle_type == "CG"
45-
oracle = frank_wolfe
45+
oracle = FrankWolfe.frank_wolfe
46+
elseif oracle_type == "Away" || oracle_type == "AFW"
47+
oracle = FrankWolfe.away_frank_wolfe
48+
elseif oracle_type == "PCG"
49+
oracle = FrankWolfe.pairwise_frank_wolfe
50+
elseif oracle_type == "Lazy" || oracle_type == "LCG"
51+
oracle = FrankWolfe.lazified_conditional_gradient
4652
elseif oracle_type == "BCG"
47-
oracle = blended_conditional_gradient
53+
oracle = FrankWolfe.blended_conditional_gradient
4854
elseif oracle_type == "BPCG"
4955
oracle = FrankWolfe.blended_pairwise_conditional_gradient
5056
end
5157

5258
# create L1 ball as feasible region
5359
region = FrankWolfe.LpNormLMO{1}(tau-1)
54-
55-
# compute starting point
60+
61+
62+
# call oracles
5663
if inverse_hessian_boost in ["weak", "full"]
64+
display("Inverse Hessian Boosting (IHB) is active. Vanilla Frank-Wolfe is used for the IHB run.")
65+
66+
# compute starting point for IHB
5767
x0 = l1_projection(solution; radius=tau-1)
5868
x0 = reshape(x0, length(x0))
59-
else
60-
x0 = compute_extreme_point(region, zeros(Float64, n))
61-
x0 = Vector(x0)
62-
end
63-
64-
# run oracle to find coefficient vector
65-
if inverse_hessian_boost == "weak"
66-
coefficient_vector, _ = oracle(f, grad!, region, x0; epsilon=epsilon, max_iteration=max_iters)
69+
70+
# IHB oracle call
71+
coefficient_vector, _ = FrankWolfe.frank_wolfe(f, grad!, region, x0; epsilon=epsilon, max_iteration=max_iters)
6772
if typeof(coefficient_vector) <: FrankWolfe.ScaledHotVector
6873
coefficient_vector = convert(Vector, coefficient_vector)
6974
end
7075
coefficient_vector = vcat(coefficient_vector, [1])
7176

7277
loss = 1/m * norm(data_with_labels * coefficient_vector, 2)^2
73-
74-
if loss <= psi
78+
79+
# attempt to find sparse solution if IHB solution found
80+
if inverse_hessian_boost == "weak" && loss <= psi
81+
display("IHB solution found. Attempting to find sparse solution.")
82+
7583
x0 = compute_extreme_point(region, zeros(Float64, n))
7684
x0 = Vector(x0)
85+
7786
tmp_coefficient_vector, _ = oracle(f, grad!, region, x0; epsilon=epsilon, max_iteration=max_iters)
7887
tmp_coefficient_vector = vcat(tmp_coefficient_vector, [1])
7988

@@ -83,14 +92,18 @@ function conditional_gradients(
8392
loss = loss2
8493
coefficient_vector = tmp_coefficient_vector
8594
end
86-
8795
end
88-
else
96+
else
97+
# compute starting vertex
98+
x0 = compute_extreme_point(region, zeros(Float64, n))
99+
x0 = Vector(x0)
100+
101+
# oracle call
89102
coefficient_vector, _ = oracle(f, grad!, region, x0; epsilon=epsilon, max_iteration=max_iters)
90103
coefficient_vector = vcat(coefficient_vector, [1])
91104

92105
loss = 1/m * norm(data_with_labels * coefficient_vector, 2)^2
93-
end
106+
end
94107
return coefficient_vector, loss
95108
end
96109

test/test_auxiliary_functions.jl

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using Test
2+
using LinearAlgebra
3+
using ApproximateVanishingIdeals
4+
const AVI = ApproximateVanishingIdeals
25

3-
include("../src/auxiliary_functions.jl")
46

57
matrix = Matrix([[1 2 3 4 3];
68
[0 1 2 3 2];
@@ -16,26 +18,26 @@ matrix_unique = Matrix([[1 3 2 4];
1618

1719

1820
@testset "Test suite for deg_lex_sort" begin
19-
matrix_sorted_1, matrix_sorted_2, _ = deg_lex_sort(matrix, 1. * matrix)
21+
matrix_sorted_1, matrix_sorted_2, _ = AVI.deg_lex_sort(matrix, 1. * matrix)
2022
@test matrix_sorted_1 == matrix_sorted
2123
@test matrix_sorted_2 == matrix_sorted
2224
end
2325

2426

2527
@testset "Test suite for get_unique_columns" begin
26-
mat1_unique, mat2_unique, unique_inds = get_unique_columns(matrix, matrix)
28+
mat1_unique, mat2_unique, unique_inds = AVI.get_unique_columns(matrix, matrix)
2729
@test mat1_unique == matrix_unique
2830
@test mat2_unique == matrix_unique
2931
@test unique_inds == [1, 2, 4, 5]
3032

31-
mat1_unique, mat2_unique, _ = get_unique_columns(matrix)
33+
mat1_unique, mat2_unique, _ = AVI.get_unique_columns(matrix)
3234
@test mat1_unique == matrix_unique
3335
@test mat2_unique == zeros(Float64, 0, 0)
3436
end
3537

3638

3739
@testset "Test suite for compute_degree" begin
38-
@test compute_degree(matrix) == [2 6 6 9 6]
40+
@test AVI.compute_degree(matrix) == [2 6 6 9 6]
3941
end
4042

4143

@@ -45,9 +47,9 @@ matrix_non_zero = Matrix([[1 0 3 0 0 0];
4547
[1 0 0 0 3 0]])
4648

4749
@testset "Test suite for finding non-zero entries" begin
48-
first_ids = find_first_non_zero_entries(matrix_non_zero)
50+
first_ids = AVI.find_first_non_zero_entries(matrix_non_zero)
4951
@test first_ids == [1, 2, 1, 3, 4, 1]
5052

51-
last_ids = find_last_non_zero_entries(matrix_non_zero)
53+
last_ids = AVI.find_last_non_zero_entries(matrix_non_zero)
5254
@test last_ids == [4, 3, 3, 3, 4, 4]
5355
end

test/test_auxiliary_functions_avi.jl

+14-18
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,24 @@
11
using Test
22
using FrankWolfe
3+
using ApproximateVanishingIdeals
4+
using LinearAlgebra
5+
const AVI = ApproximateVanishingIdeals
36

4-
include("../src/auxiliary_functions.jl")
5-
include("../src/terms_and_polynomials.jl")
6-
include("../src/oracle_constructors.jl")
7-
include("../src/border_construction.jl")
8-
include("../src/objective_functions.jl")
9-
include("../src/auxiliary_functions_avi.jl")
10-
include("../src/oracle_avi.jl")
117

128
@testset "Test suite for update_coefficient_vectors" begin
139
G_coefficient_vectors = reshape([1, 2, 0], 3, 1)
1410

1511
vec1 = reshape([1, 2], 2, 1)
1612

17-
G_coefficient_vectors = update_coefficient_vectors(G_coefficient_vectors, vec1)
13+
G_coefficient_vectors = AVI.update_coefficient_vectors(G_coefficient_vectors, vec1)
1814
G_coefficient_vectors = vcat(G_coefficient_vectors, zeros(1, size(G_coefficient_vectors, 2)))
19-
G_coefficient_vectors = update_coefficient_vectors(G_coefficient_vectors, vec1)
15+
G_coefficient_vectors = AVI.update_coefficient_vectors(G_coefficient_vectors, vec1)
2016
G_coefficient_vectors = vcat(G_coefficient_vectors, zeros(1, size(G_coefficient_vectors, 2)))
2117
G_coefficient_vectors = vcat(G_coefficient_vectors, zeros(1, size(G_coefficient_vectors, 2)))
2218

2319
vec1 = reshape([1, 2, 3], 3, 1)
2420

25-
G_coefficient_vectors = update_coefficient_vectors(G_coefficient_vectors, vec1)
21+
G_coefficient_vectors = AVI.update_coefficient_vectors(G_coefficient_vectors, vec1)
2622

2723
@test G_coefficient_vectors == Matrix([[1. 1. 1. 1.];
2824
[2. 0. 0. 0.];
@@ -37,14 +33,14 @@ end
3733
vec1 = rand(1:10, 20)
3834
radius_1 = 2.5
3935
radius_2 = 3.0
40-
@test norm(l1_projection(vec1), 1) 1
41-
@test norm(l1_projection(vec1; radius=radius_1), 1) radius_1
42-
@test norm(l1_projection(vec1; radius=radius_2), 1) radius_2
36+
@test norm(AVI.l1_projection(vec1), 1) 1
37+
@test norm(AVI.l1_projection(vec1; radius=radius_1), 1) radius_1
38+
@test norm(AVI.l1_projection(vec1; radius=radius_2), 1) radius_2
4339

44-
vec2 = l1_projection(vec1)
45-
@test vec2 l1_projection(vec2)
46-
@test norm(l1_projection(vec2), 1) 1
47-
@test norm(vec2) norm(l1_projection(vec2))
40+
vec2 = AVI.l1_projection(vec1)
41+
@test vec2 AVI.l1_projection(vec2)
42+
@test norm(AVI.l1_projection(vec2), 1) 1
43+
@test norm(vec2) norm(AVI.l1_projection(vec2))
4844
end;
4945

5046

@@ -60,7 +56,7 @@ end;
6056

6157
A_a = transpose(A) * a
6258

63-
B, B_2, B_2_1 = streaming_matrix_updates(A, A_sq, A_a, a, a_sq; A_squared_inv=A_sq_inv)
59+
B, B_2, B_2_1 = AVI.streaming_matrix_updates(A, A_sq, A_a, a, a_sq; A_squared_inv=A_sq_inv)
6460

6561
C = hcat(A, a)
6662

test/test_border_construction.jl

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
include("../src/border_construction.jl")
2-
include("../src/auxiliary_functions.jl")
3-
41
using LinearAlgebra
2+
using ApproximateVanishingIdeals
3+
const AVI = ApproximateVanishingIdeals
54
using Test
65

76

@@ -14,7 +13,7 @@ using Test
1413
[0 1];
1514
[0 1];])
1615

17-
matrix_2_purged, matrix_2_purged_2, _ = purge(matrix_2, 1. * matrix_2, matrix_1)
16+
matrix_2_purged, matrix_2_purged_2, _ = AVI.purge(matrix_2, 1. * matrix_2, matrix_1)
1817

1918
@test size(matrix_2_purged, 2) == 0
2019
@test size(matrix_2_purged_2, 2) == 0
@@ -27,7 +26,7 @@ using Test
2726
[0 2];
2827
[2 1];])
2928

30-
matrix_2_purged, matrix_2_purged_2, _ = purge(matrix_2, 1. * matrix_2, matrix_1)
29+
matrix_2_purged, matrix_2_purged_2, _ = AVI.purge(matrix_2, 1. * matrix_2, matrix_1)
3130

3231
@test matrix_2_purged == matrix_2_purged_2
3332
@test matrix_2_purged == Matrix([[1 2 3];
@@ -52,7 +51,7 @@ end
5251
# duplicate indices: 3, 6, 8, 12; purged indices: 7, 8, 11, 12, 13, 15
5352
unique_non_purging_indices = [1, 2, 4, 5, 9, 10, 14]
5453

55-
terms_raw, _, non_purging_indices, _ = construct_border(terms, 1. * terms, zeros(Float64, 0, 0), degree_1_terms, 1. * degree_1_terms, purging_terms)
54+
terms_raw, _, non_purging_indices, _ = AVI.construct_border(terms, 1. * terms, zeros(Float64, 0, 0), degree_1_terms, 1. * degree_1_terms, purging_terms)
5655

5756
@test terms_raw[:, non_purging_indices] == raw_border[:, unique_non_purging_indices]
5857
end

test/test_objective_functions.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,32 @@
11
using Test
22
using LinearAlgebra
3+
using ApproximateVanishingIdeals
4+
const AVI = ApproximateVanishingIdeals
35

4-
include("../src/objective_functions.jl")
56

6-
@testset "Test suite for 'evaluate_function' in L2Loss" begin
7+
@testset "Test suite for evaluate_function in L2Loss" begin
78
for m in 1:5
89
for n in 1:10
910
A = rand(m, n)
1011
b = rand(m)
1112
x = rand(n)
1213
lambda = rand() * n
13-
_, evaluate_function, _ = L2Loss(A, b, lambda, A' * A, A' * b, b' * b)
14+
_, evaluate_function, _ = AVI.L2Loss(A, b, lambda, A' * A, A' * b, b' * b)
1415

1516
@test 1/m * norm(A * x + b, 2)^2 + lambda * norm(x, 2)^2 / 2 evaluate_function(x)
1617
end
1718
end
1819
end;
1920

2021

21-
@testset "Test suite for 'evaluate_gradient!' in L2Loss" begin
22+
@testset "Test suite for evaluate_gradient! in L2Loss" begin
2223
for m in 1:5
2324
for n in 1:10
2425
A = rand(m, n)
2526
b = rand(m)
2627
x = rand(n)
2728
lambda = rand() * n
28-
_, _, evaluate_gradient! = L2Loss(A, b, lambda, A' * A, A' * b, b' * b)
29+
_, _, evaluate_gradient! = AVI.L2Loss(A, b, lambda, A' * A, A' * b, b' * b)
2930

3031
gradient = 2/m * (A' * A * x + A' * b + m/2 * lambda * x)
3132
approx_vec = (gradient .≈ evaluate_gradient!(zeros(n), x))

test/test_oracle_avi.jl

+9-14
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,16 @@ using Test
22
using Random
33
using LinearAlgebra
44
using FrankWolfe
5+
using ApproximateVanishingIdeals
6+
const AVI = ApproximateVanishingIdeals
57

6-
include("../src/oracle_avi.jl")
7-
include("../src/auxiliary_functions.jl")
8-
include("../src/terms_and_polynomials.jl")
9-
include("../src/objective_functions.jl")
10-
include("../src/auxiliary_functions_avi.jl")
11-
include("../src/oracle_constructors.jl")
12-
include("../src/border_construction.jl")
138

149
@testset "Test suite for fit_oavi" begin
15-
for oracle in ["CG", "BCG", "BPCG"]
10+
for oracle in ["CG", "Away", "PCG", "Lazy", "BCG", "BPCG"]
1611
m, n = rand(15:25), rand(4:10)
1712
X_train = rand(m, n)
1813
for ihb in ["false", "weak", "full"]
19-
X_train_transformed, sets = fit_oavi(X_train; oracle=oracle, inverse_hessian_boost=ihb)
14+
X_train_transformed, sets = AVI.fit_oavi(X_train; oracle=oracle, inverse_hessian_boost=ihb)
2015
loss_list = Vector{Float64}([])
2116
for col in 1:size(sets.G_evaluations, 2)
2217
cur_col = sets.G_evaluations[:, col]
@@ -33,7 +28,7 @@ end;
3328
for _ in 1:5
3429
m, n = rand(15:25), rand(4:10)
3530
X_train = rand(m, n)
36-
X_train_transformed, sets = fit_oavi(X_train; oracle="ABM", psi=0.05)
31+
X_train_transformed, sets = AVI.fit_oavi(X_train; oracle="ABM", psi=0.05)
3732
loss_list = Vector{Float64}([])
3833
for col in 1:size(sets.G_evaluations, 2)
3934
cur_col = sets.G_evaluations[:, col]
@@ -49,15 +44,15 @@ end;
4944
for oracle in ["CG", "BPCG", "ABM"]
5045
m, n = rand(15:25), rand(4:10)
5146
X_tr = rand(m, n)
52-
X_tr_transformed, sets_tr = fit_oavi(X_tr; oracle=oracle)
53-
X_te_transformed, sets_te = evaluate_oavi(sets_tr, X_tr)
47+
X_tr_transformed, sets_tr = AVI.fit_oavi(X_tr; oracle=oracle)
48+
X_te_transformed, sets_te = AVI.evaluate_oavi(sets_tr, X_tr)
5449

5550
@test all(X_tr_transformed .- X_te_transformed .<= 1.0e-10)
5651

5752
if oracle !== "ABM"
5853
X_train = rand(10, 3)
59-
X_train_transformed, sets_train = fit_oavi(X_train; psi=0.01, lambda=0.1, oracle=oracle)
60-
X_test_transformed, sets_test = evaluate_oavi(sets_train, X_train)
54+
X_train_transformed, sets_train = AVI.fit_oavi(X_train; psi=0.01, lambda=0.1, oracle=oracle)
55+
X_test_transformed, sets_test = AVI.evaluate_oavi(sets_train, X_train)
6156
@test all(X_train_transformed .- X_test_transformed .<= 1.0e-10)
6257
end
6358
end

test/test_print_polynomials.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
using Test
2+
using LinearAlgebra
3+
using ApproximateVanishingIdeals
4+
const AVI = ApproximateVanishingIdeals
25

3-
include("../src/print_polynomials.jl")
46

57
@testset "Test suite for print_polynomials" begin
68
A = rand(10, 3)
79

8-
sets = construct_SetsOandG(A)
10+
sets = AVI.construct_SetsOandG(A)
911

1012
append!(sets.G_coefficient_vectors, [reshape( [ -0.77,
1113
0.44,
@@ -39,7 +41,7 @@ include("../src/print_polynomials.jl")
3941
"x_{3}^{2} - x_{3} - 0.03x_{1} + 0.17"
4042
]
4143

42-
constructed_polys = print_polynomials(sets; ret=true)
44+
constructed_polys = AVI.print_polynomials(sets; ret=true)
4345

4446
@test all(polys .== constructed_polys)
4547
end;
@@ -55,6 +57,6 @@ end;
5557

5658
for i in 1:size(terms, 2)
5759
term = terms[:, i]
58-
@test convert_term_to_latex(term) == converted_terms[i]
60+
@test AVI.convert_term_to_latex(term) == converted_terms[i]
5961
end
6062
end;

0 commit comments

Comments
 (0)