Skip to content

Commit

Permalink
Merge pull request #189 from twildi/master
Browse files Browse the repository at this point in the history
Add slicing to ContResult
  • Loading branch information
rveltz authored Nov 27, 2024
2 parents 997cae3 + cf14395 commit 85d4273
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 8 deletions.
2 changes: 1 addition & 1 deletion ext/PlotsExt/RecipesPlots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ RecipesBase.@recipe function Plots(contres::AbstractBranchResult;
end

# display bifurcation points
bifpt = filter(x -> (x.type != :none) && (x.type != :endpoint) && (plotfold || x.type != :fold) && (x.idx <= length(contres)-1), contres.specialpoint)
bifpt = filter(x -> (x.type != :none) && (x.type != :endpoint) && (plotfold || x.type != :fold) && (x.idx <= length(contres)), contres.specialpoint)

if length(bifpt) >= 1 && plotspecialpoints #&& (ind1 == :param)
if filterspecialpoints == true
Expand Down
44 changes: 37 additions & 7 deletions src/Results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,37 @@ _getfirstusertype(br::AbstractBranchResult) = keys(br.branch[1])[1]
Set the parameter value `p0` according to the `::Lens` stored in `br` for the parameters of the problem `br.prob`.
"""
setparam(br::AbstractBranchResult, p0) = setparam(br.prob, p0)
Base.getindex(br::ContResult, k::Int) = (br.branch[k]..., eigenvals = haseigenvalues(br) ? br.eig[k].eigenvals : nothing, eigenvecs = haseigenvector(br) ? br.eig[k].eigenvecs : nothing)
Base.lastindex(br::ContResult) = length(br)

function Base.getindex(br::ContResult, k::Int)
idx = isnothing(br.eig) ? nothing : findfirst(x -> x.step == br.branch[k].step, br.eig)
eigenvals = haseigenvalues(br) && !isnothing(idx) ? br.eig[idx].eigenvals : nothing
eigenvecs = haseigenvector(br) && !isnothing(idx) ? br.eig[idx].eigenvecs : nothing
return (; br.branch[k]..., eigenvals, eigenvecs)
end

function Base.getindex(br0::ContResult, k::UnitRange{<:Integer})
br = deepcopy(br0)

if ~isnothing(br.branch)
@reset br.branch = br.branch[k]
end

if ~isnothing(br.eig)
@reset br.eig = [pt for pt in br.eig if pt.step in br.branch.step]
end

if ~isnothing(br.sol)
@reset br.sol = [pt for pt in br.sol if pt.step in br.branch.step]
end

if ~isnothing(br.specialpoint)
@reset br.specialpoint = [setproperties(pt; idx=pt.idx + 1 - k[1]) for pt in br.specialpoint if pt.idx in k]
end

return br
end

"""
$(SIGNATURES)
Expand Down Expand Up @@ -236,12 +264,12 @@ Function is used to initialize the composite type `ContResult` according to the
- `lens`: lens to specify the continuation parameter
- `eiginfo`: eigen-elements (eigvals, eigvecs)
"""
function _contresult(iter,
state,
printsol,
br,
x0,
contParams::ContinuationPar{T, S, E}) where {T, S, E}
function _contresult(iter,
state,
printsol,
br,
x0,
contParams::ContinuationPar{T, S, E}) where {T, S, E}
# example of bifurcation point
bif0 = SpecialPoint(x0, state.τ, T, namedprintsol(printsol))
# save full solution?
Expand Down Expand Up @@ -306,6 +334,8 @@ Base.lastindex(br::Branch) = lastindex(br.γ)
# for example, it allows to use the plot recipe for ContResult as is
Base.getproperty(br::Branch, s::Symbol) = s in (, :bp) ? getfield(br, s) : getproperty(br.γ, s)
Base.getindex(br::Branch, k::Int) = getindex(br.γ, k)
Base.getindex(br::Branch, k::UnitRange{<:Integer}) = setproperties(br; γ = getindex(br.γ, k))

####################################################################################################
_reverse!(x) = reverse!(x)
_reverse!(::Nothing) = nothing
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ using Base.Threads; println("--> There are ", Threads.nthreads(), " threads")
include("test_linear.jl")
end

@testset "Results" begin
include("test_results.jl")
end

@testset "Newton" begin
include("test_newton.jl")
end
Expand Down
58 changes: 58 additions & 0 deletions test/test_results.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using Test, BifurcationKit
const BK = BifurcationKit

# Simple Test problem (Pitchfork bifurcation) to generate a ContResult and a Branch object
function f(u, p)
return p.r .* u - u .^ 3
end
p = (r=-1.0,)
u0 = [0.0]
prob = BK.BifurcationProblem(f, u0, p, (@optic _.r))

@testset "ContResult" begin
opt = BK.ContinuationPar(p_min=-1.0, p_max=1.0)
contres = BK.continuation(prob, PALC(), opt)
@assert typeof(contres) <: BK.ContResult
bp = contres.specialpoint[1] # pitchfork bifurcation

# Test slicing of ContResult object
@test contres[1:bp.step+1].specialpoint[1].step == bp.step
@test contres[bp.step+1:end].specialpoint[1].step == bp.step

# Slicing and indexing should match
@test contres[bp.step:bp.step][1].param == contres[bp.step].param

# Recursive slicing should work
@test contres[bp.step:end][1:1][1].param == contres[bp.step:end][1].param

# Slicing should still work when not evey sol/eig is saved
opt = BK.ContinuationPar(opt; detect_bifurcation=1, save_sol_every_step=2, save_eig_every_step=3)
contres = BK.continuation(prob, PALC(), opt)
@assert length(contres) != length(contres.sol) != length(contres.eig)
@test length(contres[1:end]) == length(contres)
@test length(contres[1:end].sol) == length(contres.sol)
@test length(contres[1:end].eig) == length(contres.eig)
end

@testset "Branch" begin
# Test slicing of Branch object
opt = BK.ContinuationPar(p_min=-1.0, p_max=1.0)
contres = BK.continuation(prob, PALC(), opt)
branch = BK.continuation(contres, 1)
@assert typeof(branch) <: BK.Branch
bp = branch.specialpoint[1] # pitchfork bifurcation

# Test slicing of Branch object
@test branch[1:bp.step+1].specialpoint[1].step == bp.step
@test branch[bp.step+1:end].specialpoint[1].step == bp.step

# Slicing and indexing should match
@test branch[bp.step:bp.step][1].param == branch[bp.step].param

# Recursive slicing should work
@test branch[bp.step:end][1:1][1].param == branch[bp.step:end][1].param
end




0 comments on commit 85d4273

Please sign in to comment.