Skip to content

Add slicing to ContResult #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/PlotsExt/RecipesPlots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ RecipesBase.@recipe function Plots(contres::AbstractBranchResult;
end

# display bifurcation points
bifpt = filter(x -> (x.type != :none) && (x.type != :endpoint) && (plotfold || x.type != :fold) && (x.idx <= length(contres)-1), contres.specialpoint)
bifpt = filter(x -> (x.type != :none) && (x.type != :endpoint) && (plotfold || x.type != :fold) && (x.idx <= length(contres)), contres.specialpoint)

if length(bifpt) >= 1 && plotspecialpoints #&& (ind1 == :param)
if filterspecialpoints == true
Expand Down
44 changes: 37 additions & 7 deletions src/Results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,37 @@ _getfirstusertype(br::AbstractBranchResult) = keys(br.branch[1])[1]
Set the parameter value `p0` according to the `::Lens` stored in `br` for the parameters of the problem `br.prob`.
"""
setparam(br::AbstractBranchResult, p0) = setparam(br.prob, p0)
Base.getindex(br::ContResult, k::Int) = (br.branch[k]..., eigenvals = haseigenvalues(br) ? br.eig[k].eigenvals : nothing, eigenvecs = haseigenvector(br) ? br.eig[k].eigenvecs : nothing)
Base.lastindex(br::ContResult) = length(br)

function Base.getindex(br::ContResult, k::Int)
idx = isnothing(br.eig) ? nothing : findfirst(x -> x.step == br.branch[k].step, br.eig)
eigenvals = haseigenvalues(br) && !isnothing(idx) ? br.eig[idx].eigenvals : nothing
eigenvecs = haseigenvector(br) && !isnothing(idx) ? br.eig[idx].eigenvecs : nothing
return (; br.branch[k]..., eigenvals, eigenvecs)
end

function Base.getindex(br0::ContResult, k::UnitRange{<:Integer})
br = deepcopy(br0)

if ~isnothing(br.branch)
@reset br.branch = br.branch[k]
end

if ~isnothing(br.eig)
@reset br.eig = [pt for pt in br.eig if pt.step in br.branch.step]
end

if ~isnothing(br.sol)
@reset br.sol = [pt for pt in br.sol if pt.step in br.branch.step]
end

if ~isnothing(br.specialpoint)
@reset br.specialpoint = [setproperties(pt; idx=pt.idx + 1 - k[1]) for pt in br.specialpoint if pt.idx in k]
end

return br
end

"""
$(SIGNATURES)

Expand Down Expand Up @@ -236,12 +264,12 @@ Function is used to initialize the composite type `ContResult` according to the
- `lens`: lens to specify the continuation parameter
- `eiginfo`: eigen-elements (eigvals, eigvecs)
"""
function _contresult(iter,
state,
printsol,
br,
x0,
contParams::ContinuationPar{T, S, E}) where {T, S, E}
function _contresult(iter,
state,
printsol,
br,
x0,
contParams::ContinuationPar{T, S, E}) where {T, S, E}
# example of bifurcation point
bif0 = SpecialPoint(x0, state.τ, T, namedprintsol(printsol))
# save full solution?
Expand Down Expand Up @@ -306,6 +334,8 @@ Base.lastindex(br::Branch) = lastindex(br.γ)
# for example, it allows to use the plot recipe for ContResult as is
Base.getproperty(br::Branch, s::Symbol) = s in (:γ, :bp) ? getfield(br, s) : getproperty(br.γ, s)
Base.getindex(br::Branch, k::Int) = getindex(br.γ, k)
Base.getindex(br::Branch, k::UnitRange{<:Integer}) = setproperties(br; γ = getindex(br.γ, k))

####################################################################################################
_reverse!(x) = reverse!(x)
_reverse!(::Nothing) = nothing
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ using Base.Threads; println("--> There are ", Threads.nthreads(), " threads")
include("test_linear.jl")
end

@testset "Results" begin
include("test_results.jl")
end

@testset "Newton" begin
include("test_newton.jl")
end
Expand Down
58 changes: 58 additions & 0 deletions test/test_results.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using Test, BifurcationKit
const BK = BifurcationKit

# Simple Test problem (Pitchfork bifurcation) to generate a ContResult and a Branch object
function f(u, p)
return p.r .* u - u .^ 3
end
p = (r=-1.0,)
u0 = [0.0]
prob = BK.BifurcationProblem(f, u0, p, (@optic _.r))

@testset "ContResult" begin
opt = BK.ContinuationPar(p_min=-1.0, p_max=1.0)
contres = BK.continuation(prob, PALC(), opt)
@assert typeof(contres) <: BK.ContResult
bp = contres.specialpoint[1] # pitchfork bifurcation

# Test slicing of ContResult object
@test contres[1:bp.step+1].specialpoint[1].step == bp.step
@test contres[bp.step+1:end].specialpoint[1].step == bp.step

# Slicing and indexing should match
@test contres[bp.step:bp.step][1].param == contres[bp.step].param

# Recursive slicing should work
@test contres[bp.step:end][1:1][1].param == contres[bp.step:end][1].param

# Slicing should still work when not evey sol/eig is saved
opt = BK.ContinuationPar(opt; detect_bifurcation=1, save_sol_every_step=2, save_eig_every_step=3)
contres = BK.continuation(prob, PALC(), opt)
@assert length(contres) != length(contres.sol) != length(contres.eig)
@test length(contres[1:end]) == length(contres)
@test length(contres[1:end].sol) == length(contres.sol)
@test length(contres[1:end].eig) == length(contres.eig)
end

@testset "Branch" begin
# Test slicing of Branch object
opt = BK.ContinuationPar(p_min=-1.0, p_max=1.0)
contres = BK.continuation(prob, PALC(), opt)
branch = BK.continuation(contres, 1)
@assert typeof(branch) <: BK.Branch
bp = branch.specialpoint[1] # pitchfork bifurcation

# Test slicing of Branch object
@test branch[1:bp.step+1].specialpoint[1].step == bp.step
@test branch[bp.step+1:end].specialpoint[1].step == bp.step

# Slicing and indexing should match
@test branch[bp.step:bp.step][1].param == branch[bp.step].param

# Recursive slicing should work
@test branch[bp.step:end][1:1][1].param == branch[bp.step:end][1].param
end