Skip to content

Commit a7e184f

Browse files
committed
change definition of ∂
1 parent a9be404 commit a7e184f

File tree

5 files changed

+14
-13
lines changed

5 files changed

+14
-13
lines changed

src/Utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222
####################################################################################################
2323
# iterated derivatives
2424
(f) = x -> ForwardDiff.derivative(f, x)
25-
(f, n::Int) = n == 0 ? f : ((f), n-1)
25+
(f, ::Val{n}) where {n} = n == 0 ? f : ((f), Val(n-1))
2626
####################################################################################################
2727
function print_nonlinear_step(step, residual, itlinear = 0, lastRow = false)
2828
if lastRow

src/periodicorbit/PeriodicOrbitCollocation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,7 @@ function compute_error!(coll::PeriodicOrbitOCollProblem, x::AbstractVector{𝒯}
11561156
# sol is the piecewise polynomial approximation of y.
11571157
# However, sol is of degree m, hence ∂(sol, m+1) = 0
11581158
# we thus estimate yᵐ⁺¹ using ∂(sol, m)
1159-
dmsol = (sol, m)
1159+
dmsol = (sol, Val(m))
11601160
# we find the values of vm := ∂m(x) at the mid points
11611161
τsT = getmesh(coll) .* period
11621162
vm = [ dmsol( (τsT[i] + τsT[i+1]) / 2 ) for i = 1:Ntst ]

src/periodicorbit/PeriodicOrbitTrapeze.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ function potrap_functional_jac!(pb::AbstractPOFDProblem, out, u, par, du)
304304
end
305305

306306
(pb::PeriodicOrbitTrapProblem)(u::AbstractVector, par) = potrap_functional!(pb, similar(u), u, par)
307+
residual(pb::PeriodicOrbitTrapProblem, u::AbstractVector, par) = pb(u, par)
307308
(pb::PeriodicOrbitTrapProblem)(u::AbstractVector, par, du) = potrap_functional_jac!(pb, similar(du), u, par, du)
308309

309310
####################################################################################################
@@ -436,7 +437,7 @@ function (pb::PeriodicOrbitTrapProblem)(::Val{:JacFullSparse}, u0::AbstractVecto
436437
AγBlock = jacobian_potrap_block(pb, u0, par; γ = γ)
437438

438439
# we now set up the last line / column
439-
@views ∂TGpo = (pb(vcat(u0[begin:end-1], T + δ), par) .- pb(u0, par)) ./ δ
440+
@views ∂TGpo = (residual(pb, vcat(u0[begin:end-1], T + δ), par) .- residual(pb, u0, par)) ./ δ
440441

441442
# this is "bad" for performance. Get converted to SparseMatrix at the next line
442443
= block_to_sparse(AγBlock) # most of the computing time is here!!
@@ -493,7 +494,7 @@ This method returns the jacobian of the functional G encoded in PeriodicOrbitTra
493494
# J0[(M-1)*N+1:(M)*N, (M-1)*N+1:(M)*N] .= In
494495

495496
# we now set up the last line / column
496-
∂TGpo = (pb(vcat(u0[1:end-1], T + δ), par) .- pb(u0, par)) ./ δ
497+
∂TGpo = (residual(pb,vcat(u0[1:end-1], T + δ), par) .- residual(pb,u0, par)) ./ δ
497498
J0[:, end] .= ∂TGpo
498499

499500
# this following does not depend on u0, so it does not change. However we update it in case the caller updated the section somewhere else
@@ -547,7 +548,7 @@ end
547548

548549
if updateborder
549550
# we now set up the last line / column
550-
∂TGpo = (pb(vcat(u0[1:end-1], T + δ), par) .- pb(u0, par)) ./ δ
551+
∂TGpo = (residual(pb, vcat(u0[1:end-1], T + δ), par) .- residual(pb, u0, par)) ./ δ
551552
J0[:, end] .= ∂TGpo
552553

553554
# this following does not depend on u0, so it does not change. However we update it in case the caller updated the section somewhere else
@@ -776,7 +777,7 @@ function (J::POTrapJacobianBordered)(u0::AbstractVector, par; δ = convert(eltyp
776777
T = extract_period_fdtrap(J..prob, u0)
777778
# we compute the derivative of the problem w.r.t. the period TODO: remove this or improve!!
778779
# TODO REMOVE CE vcat!
779-
@views J.∂TGpo .= (J..prob(vcat(u0[begin:end-1], T + δ), par) .- J..prob(u0, par)) ./ δ
780+
@views J.∂TGpo .= (residual(J..prob, vcat(u0[begin:end-1], T + δ), par) .- residual(J..prob, u0, par)) ./ δ
780781

781782
J.(u0, par) # update Aγ
782783

@@ -848,10 +849,10 @@ function _newton_trap(probPO::PeriodicOrbitTrapProblem,
848849
_J = probPO(Val(:JacFullSparse), orbitguess, getparams(probPO.prob_vf)) |> Array
849850
jac = (x, p) -> probPO(Val(:JacFullSparseInplace), _J, x, p)
850851
elseif jacobianPO == :DenseAD
851-
jac = (x, p) -> ForwardDiff.jacobian(z -> probPO(z, p), x)
852+
jac = (x, p) -> ForwardDiff.jacobian(z -> residual(probPO, z, p), x)
852853
elseif jacobianPO == :FullMatrixFreeAD
853-
jac = (x, p) -> dx -> ForwardDiff.derivative(t -> probPO(x .+ t .* dx, p), 0)
854-
else
854+
jac = (x, p) -> dx -> ForwardDiff.derivative(t -> residual(probPO, x .+ t .* dx, p), 0)
855+
else
855856
jac = (x, p) -> ( dx -> probPO(x, p, dx))
856857
end
857858

@@ -984,9 +985,9 @@ function continuation_potrap(prob::PeriodicOrbitTrapProblem,
984985
_J = prob(Val(:JacFullSparse), orbitguess, getparams(prob.prob_vf)) |> Array
985986
jac = (x, p) -> (prob(Val(:JacFullSparseInplace), _J, x, p); FloquetWrapper(prob, _J, x, p));
986987
elseif jacobianPO == :DenseAD
987-
jac = (x, p) -> FloquetWrapper(prob, ForwardDiff.jacobian(z -> prob(z, p), x), x, p)
988+
jac = (x, p) -> FloquetWrapper(prob, ForwardDiff.jacobian(z -> residual(prob, z, p), x), x, p)
988989
elseif jacobianPO == :FullMatrixFreeAD
989-
jac = (x, p) -> FloquetWrapper(prob, dx -> ForwardDiff.derivative(t->prob(x .+ t .* dx, p), 0), x, p)
990+
jac = (x, p) -> FloquetWrapper(prob, dx -> ForwardDiff.derivative(t->residual(prob, x .+ t .* dx, p), 0), x, p)
990991
else
991992
jac = (x, p) -> FloquetWrapper(prob, x, p)
992993
end

src/periodicorbit/PeriodicOrbits.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ function continuation(br::AbstractResult{PeriodicOrbitCont, Tprob},
596596
# perform continuation
597597
pbnew = set_params_po(pbnew, setparam(br, newp))
598598

599-
pbnew(orbitguess, setparam(br, newp))[end] |> abs > 1 && @warn "PO constraint not satisfied"
599+
residual(pbnew, orbitguess, setparam(br, newp))[end] |> abs > 1 && @warn "PO constraint not satisfied"
600600

601601
branch = continuation( pbnew, orbitguess, alg, _contParams;
602602
kwargs..., # put this first to be overwritten just below!

test/stuartLandauCollocation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ _orbit(t) = [cos(2pi * t), 0, 0] * sqrt(par_sl.r / par_sl.c3)
5151
_ci = BK.generate_solution(prob_col, _orbit, 1.)
5252
BK.get_periodic_orbit(prob_col, _ci, par_sl)
5353
BK.getmaximum(prob_col, _ci, par_sl)
54-
@test BK.(sin, 2)(0.) == 0
54+
@test BK.(sin, Val(2))(0.) == 0
5555
prob_col(_ci, par_sl) #|> scatter
5656
BK.get_time_slices(prob_col, _ci)
5757

0 commit comments

Comments
 (0)