Skip to content

Commit

Permalink
cleanup, activate asp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Ortner committed Sep 9, 2024
1 parent 59723c5 commit 83dbee5
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 66 deletions.
37 changes: 24 additions & 13 deletions src/asp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,24 @@ function solve(solver::ASP, A, y, Aval=A, yval=y)

tracer = asp_homotopy(AP, y; solver.params...)

q = length(tracer)
every = max(1, q ÷ solver.nstore)
new_tracer = Vector{NamedTuple{(:solution, :λ), Tuple{Any, Any}}}(undef, q)
new_tracer = [(solution = solver.P \ tracer[i][1], λ = tracer[i][2]) for i in [1:every:q; q]]
q = length(tracer)
every = max(1, q ÷ solver.nstore)
istore = unique([1:every:q; q])
new_tracer = [ (solution = solver.P \ tracer[i][1], λ = tracer[i][2], σ = 0.0 )
for i in istore ]

if solver.tsvd # Post-processing if tsvd is true
post = post_asp_tsvd(new_tracer, A, y, Aval, yval)
new_post = [(solution = p.θ, λ = p.λ) for p in post]
new_post = [ (solution = p.θ, λ = p.λ, σ = p.σ) for p in post ]
else
new_post = new_tracer
end

xs, in = select_solution(new_post, solver, Aval, yval)

# println("done.")
return Dict( "C" => xs,
"path" => new_post,
"nnzs" => length((new_tracer[in][:solution]).nzind) )
"nnzs" => length( (new_tracer[in][:solution]).nzind) )
end


Expand Down Expand Up @@ -156,13 +156,24 @@ function post_asp_tsvd(path, At, yt, Av, yv)
Qt, Rt = qr(At); zt = Matrix(Qt)' * yt
Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv

post = []
for (θ, λ) in path
if isempty.nzind); push!(post,= θ, λ = λ, σ = Inf)); continue; end
function _post(θλ)
(θ, λ) = θλ
if isempty.nzind); return= θ, λ = λ, σ = Inf); end
inz = θ.nzind
θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
θ2 = copy(θ); θ2[inz] .= θ1
push!(post, (θ = θ2, λ = λ, σ = σ))
end
return identity.(post)
return= θ2, λ = λ, σ = σ)
end

return _post.(path)

# post = []
# for (θ, λ) in path
# if isempty(θ.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end
# inz = θ.nzind
# θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
# θ2 = copy(θ); θ2[inz] .= θ1
# push!(post, (θ = θ2, λ = λ, σ = σ))
# end
# return identity.(post)
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ using Test

@testset "Linear Solvers" begin include("test_linearsolvers.jl") end

@testset "ASP" begin include("test_asp.jl") end

@testset "MLJ Solvers" begin include("test_mlj.jl") end
end
54 changes: 1 addition & 53 deletions test/test_asp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Random

Random.seed!(1234)
Nobs = 10_000
Nfeat = 300
Nfeat = 100
A1 = randn(Nobs, Nfeat) / sqrt(Nobs)
U, S1, V = svd(A1)
S = 1e-4 .+ ((S1 .- S1[end]) / (S1[1] - S1[end])).^2
Expand Down Expand Up @@ -109,55 +109,3 @@ for (select, tolr, tolc) in [ (:final, 20*epsn, 1.5),
end
end


##

# Experimental Implementation of tsvd postprocessing


# using SparseArrays

# function solve_tsvd(At, yt, Av, yv)
# Ut, Σt, Vt = svd(At); zt = Ut' * yt
# Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv
# @assert issorted(Σt, rev=true)

# Rv_Vt = Rv * Vt

# θv = zeros(size(Av, 2))
# θv[1] = zt[1] / Σt[1]
# rv = Rv_Vt[:, 1] * θv[1] - zv

# tsvd_errs = Float64[]
# push!(tsvd_errs, norm(rv))

# for k = 2:length(Σt)
# θv[k] = zt[k] / Σt[k]
# rv += Rv_Vt[:, k] * θv[k]
# push!(tsvd_errs, norm(rv))
# end

# imin = argmin(tsvd_errs)
# θv[imin+1:end] .= 0
# return Vt * θv, Σt[imin]
# end

# function post_asp_tsvd(path, At, yt, Av, yv)
# Qt, Rt = qr(At); zt = Matrix(Qt)' * yt
# Qv, Rv = qr(Av); zv = Matrix(Qv)' * yv

# post = []
# for (θ, λ) in path
# if isempty(θ.nzind); push!(post, (θ = θ, λ = λ, σ = Inf)); continue; end
# inz = θ.nzind
# θ1, σ = solve_tsvd(Rt[:, inz], zt, Rv[:, inz], zv)
# θ2 = copy(θ); θ2[inz] .= θ1
# push!(post, (θ = θ2, λ = λ, σ = σ))
# end
# return identity.(post)
# end

# solver = ACEfit.ASP(P=I, select = :final, loglevel=0, traceFlag=true)
# result = ACEfit.solve(solver, At, yt);
# post = post_asp_tsvd(result["path"], At, yt, Av, yv);

0 comments on commit 83dbee5

Please sign in to comment.