diff --git a/ext/PlotsExt/RecipesPlots.jl b/ext/PlotsExt/RecipesPlots.jl index fcc29bc9..57f4c430 100644 --- a/ext/PlotsExt/RecipesPlots.jl +++ b/ext/PlotsExt/RecipesPlots.jl @@ -28,7 +28,7 @@ RecipesBase.@recipe function Plots(contres::AbstractBranchResult; end # display bifurcation points - bifpt = filter(x -> (x.type != :none) && (x.type != :endpoint) && (plotfold || x.type != :fold) && (x.idx <= length(contres)-1), contres.specialpoint) + bifpt = filter(x -> (x.type != :none) && (x.type != :endpoint) && (plotfold || x.type != :fold) && (x.idx <= length(contres)), contres.specialpoint) if length(bifpt) >= 1 && plotspecialpoints #&& (ind1 == :param) if filterspecialpoints == true diff --git a/src/Results.jl b/src/Results.jl index d746e9c4..a440fe0a 100644 --- a/src/Results.jl +++ b/src/Results.jl @@ -131,9 +131,37 @@ _getfirstusertype(br::AbstractBranchResult) = keys(br.branch[1])[1] Set the parameter value `p0` according to the `::Lens` stored in `br` for the parameters of the problem `br.prob`. """ setparam(br::AbstractBranchResult, p0) = setparam(br.prob, p0) -Base.getindex(br::ContResult, k::Int) = (br.branch[k]..., eigenvals = haseigenvalues(br) ? br.eig[k].eigenvals : nothing, eigenvecs = haseigenvector(br) ? br.eig[k].eigenvecs : nothing) Base.lastindex(br::ContResult) = length(br) +function Base.getindex(br::ContResult, k::Int) + idx = isnothing(br.eig) ? nothing : findfirst(x -> x.step == br.branch[k].step, br.eig) + eigenvals = haseigenvalues(br) && !isnothing(idx) ? br.eig[idx].eigenvals : nothing + eigenvecs = haseigenvector(br) && !isnothing(idx) ? br.eig[idx].eigenvecs : nothing + return (; br.branch[k]..., eigenvals, eigenvecs) +end + +function Base.getindex(br0::ContResult, k::UnitRange{<:Integer}) + br = deepcopy(br0) + + if ~isnothing(br.branch) + @reset br.branch = br.branch[k] + end + + if ~isnothing(br.eig) + @reset br.eig = [pt for pt in br.eig if pt.step in br.branch.step] + end + + if ~isnothing(br.sol) + @reset br.sol = [pt for pt in br.sol if pt.step in br.branch.step] + end + + if ~isnothing(br.specialpoint) + @reset br.specialpoint = [setproperties(pt; idx=pt.idx + 1 - k[1]) for pt in br.specialpoint if pt.idx in k] + end + + return br +end + """ $(SIGNATURES) @@ -236,12 +264,12 @@ Function is used to initialize the composite type `ContResult` according to the - `lens`: lens to specify the continuation parameter - `eiginfo`: eigen-elements (eigvals, eigvecs) """ - function _contresult(iter, - state, - printsol, - br, - x0, - contParams::ContinuationPar{T, S, E}) where {T, S, E} +function _contresult(iter, + state, + printsol, + br, + x0, + contParams::ContinuationPar{T, S, E}) where {T, S, E} # example of bifurcation point bif0 = SpecialPoint(x0, state.τ, T, namedprintsol(printsol)) # save full solution? @@ -306,6 +334,8 @@ Base.lastindex(br::Branch) = lastindex(br.γ) # for example, it allows to use the plot recipe for ContResult as is Base.getproperty(br::Branch, s::Symbol) = s in (:γ, :bp) ? getfield(br, s) : getproperty(br.γ, s) Base.getindex(br::Branch, k::Int) = getindex(br.γ, k) +Base.getindex(br::Branch, k::UnitRange{<:Integer}) = setproperties(br; γ = getindex(br.γ, k)) + #################################################################################################### _reverse!(x) = reverse!(x) _reverse!(::Nothing) = nothing diff --git a/test/runtests.jl b/test/runtests.jl index 80ddedf3..2f90de0c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,10 @@ using Base.Threads; println("--> There are ", Threads.nthreads(), " threads") include("test_linear.jl") end + @testset "Results" begin + include("test_results.jl") + end + @testset "Newton" begin include("test_newton.jl") end diff --git a/test/test_results.jl b/test/test_results.jl new file mode 100644 index 00000000..f67b859a --- /dev/null +++ b/test/test_results.jl @@ -0,0 +1,58 @@ +using Test, BifurcationKit +const BK = BifurcationKit + +# Simple Test problem (Pitchfork bifurcation) to generate a ContResult and a Branch object +function f(u, p) + return p.r .* u - u .^ 3 +end +p = (r=-1.0,) +u0 = [0.0] +prob = BK.BifurcationProblem(f, u0, p, (@optic _.r)) + +@testset "ContResult" begin + opt = BK.ContinuationPar(p_min=-1.0, p_max=1.0) + contres = BK.continuation(prob, PALC(), opt) + @assert typeof(contres) <: BK.ContResult + bp = contres.specialpoint[1] # pitchfork bifurcation + + # Test slicing of ContResult object + @test contres[1:bp.step+1].specialpoint[1].step == bp.step + @test contres[bp.step+1:end].specialpoint[1].step == bp.step + + # Slicing and indexing should match + @test contres[bp.step:bp.step][1].param == contres[bp.step].param + + # Recursive slicing should work + @test contres[bp.step:end][1:1][1].param == contres[bp.step:end][1].param + + # Slicing should still work when not evey sol/eig is saved + opt = BK.ContinuationPar(opt; detect_bifurcation=1, save_sol_every_step=2, save_eig_every_step=3) + contres = BK.continuation(prob, PALC(), opt) + @assert length(contres) != length(contres.sol) != length(contres.eig) + @test length(contres[1:end]) == length(contres) + @test length(contres[1:end].sol) == length(contres.sol) + @test length(contres[1:end].eig) == length(contres.eig) +end + +@testset "Branch" begin + # Test slicing of Branch object + opt = BK.ContinuationPar(p_min=-1.0, p_max=1.0) + contres = BK.continuation(prob, PALC(), opt) + branch = BK.continuation(contres, 1) + @assert typeof(branch) <: BK.Branch + bp = branch.specialpoint[1] # pitchfork bifurcation + + # Test slicing of Branch object + @test branch[1:bp.step+1].specialpoint[1].step == bp.step + @test branch[bp.step+1:end].specialpoint[1].step == bp.step + + # Slicing and indexing should match + @test branch[bp.step:bp.step][1].param == branch[bp.step].param + + # Recursive slicing should work + @test branch[bp.step:end][1:1][1].param == branch[bp.step:end][1].param +end + + + +