From 71c6808fde9a9d1b49fcbb8477538ff5b34afda2 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 12 Sep 2025 06:38:47 -0400 Subject: [PATCH] Replace internal ITP implementation with SimpleNonlinearSolve.jl's ITP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR removes the internal `InternalITP` implementation from DiffEqBase and replaces it with the more robust and feature-complete `ITP` algorithm from SimpleNonlinearSolve.jl. ## Changes: 1. **Removed internal ITP implementation:** - Deleted `src/internal_itp.jl` containing `InternalITP` struct and solver - Removed helper functions like `prevfloat_tdir`, `nextfloat_tdir`, `max_tdir` 2. **Added SimpleNonlinearSolve dependency:** - Added SimpleNonlinearSolve as a dependency in Project.toml - Set minimum version bound to 2.7 to ensure compatibility 3. **Updated usage throughout codebase:** - Modified `src/callbacks.jl` to use `ITP()` instead of `InternalITP()` - Updated `src/DiffEqBase.jl` to import `ITP` from SimpleNonlinearSolve - Updated `ext/DiffEqBaseForwardDiffExt.jl` to handle the new ITP algorithm - Updated `test/internal_rootfinder.jl` to test with the new ITP implementation ## Benefits: - **Better maintenance:** Eliminates duplicate code and reduces maintenance burden - **Improved reliability:** Uses the well-tested and optimized ITP implementation from SimpleNonlinearSolve - **Better performance:** SimpleNonlinearSolve's ITP has additional optimizations and tuning parameters - **Consistency:** Aligns with the broader SciML ecosystem's approach of using specialized packages ## Testing: All existing tests pass, including the specific internal rootfinder tests that verify the ITP algorithm works correctly for various edge cases. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- Project.toml | 2 + ext/DiffEqBaseForwardDiffExt.jl | 10 ++-- src/DiffEqBase.jl | 3 +- src/callbacks.jl | 6 +-- src/internal_itp.jl | 87 --------------------------------- test/internal_rootfinder.jl | 5 +- 6 files changed, 15 insertions(+), 98 deletions(-) delete mode 100644 src/internal_itp.jl diff --git a/Project.toml b/Project.toml index fb1f3a5ed..017bced67 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" +SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -97,6 +98,7 @@ ReverseDiff = "1" SciMLBase = "2.115.0" SciMLOperators = "1" SciMLStructures = "1.5" +SimpleNonlinearSolve = "2.7" Setfield = "1" SparseArrays = "1.9" Static = "1" diff --git a/ext/DiffEqBaseForwardDiffExt.jl b/ext/DiffEqBaseForwardDiffExt.jl index c5c0b4e26..f2ad064cf 100644 --- a/ext/DiffEqBaseForwardDiffExt.jl +++ b/ext/DiffEqBaseForwardDiffExt.jl @@ -1,13 +1,13 @@ module DiffEqBaseForwardDiffExt using DiffEqBase, ForwardDiff +using SimpleNonlinearSolve: ITP using DiffEqBase.ArrayInterface using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag, AbstractTimeseriesSolution, RecursiveArrayTools, reduce_tup, _promote_tspan, has_continuous_callback import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin, - promote_tspan, ODE_DEFAULT_NORM, - InternalITP, nextfloat_tdir + promote_tspan, ODE_DEFAULT_NORM import SciMLBase: isdualtype, DualEltypeChecker, sse, __sum const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1} @@ -153,7 +153,7 @@ end # Differentiation of internal solver -function scalar_nlsolve_ad(prob, alg::InternalITP, args...; kwargs...) +function scalar_nlsolve_ad(prob, alg::ITP, args...; kwargs...) f = prob.f p = value(prob.p) @@ -186,7 +186,7 @@ end function SciMLBase.solve( prob::IntervalNonlinearProblem{uType, iip, <:ForwardDiff.Dual{T, V, P}}, - alg::InternalITP, args...; + alg::ITP, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials), @@ -202,7 +202,7 @@ function SciMLBase.solve( V, P}, }}, - alg::InternalITP, args...; + alg::ITP, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index 1a0466ef3..645b04922 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -44,6 +44,8 @@ using SciMLBase using SciMLOperators: AbstractSciMLOperator, AbstractSciMLScalarOperator +using SimpleNonlinearSolve: ITP + using SciMLBase: @def, DEIntegrator, AbstractDEProblem, AbstractDiffEqInterpolation, DECallback, AbstractDEOptions, DECache, AbstractContinuousCallback, @@ -140,7 +142,6 @@ include("utils.jl") include("stats.jl") include("calculate_residuals.jl") include("tableaus.jl") -include("internal_itp.jl") include("callbacks.jl") include("common_defaults.jl") diff --git a/src/callbacks.jl b/src/callbacks.jl index a84d11e4c..a714cccac 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -358,17 +358,17 @@ end # rough implementation, needs multiple type handling # always ensures that if r = bisection(f, (x0, x1)) # then either f(nextfloat(r)) == 0 or f(nextfloat(r)) * f(r) < 0 -# note: not really using bisection - uses the ITP method +# note: not really using bisection - uses the ITP method function bisection( f, tup, t_forward::Bool, rootfind::SciMLBase.RootfindOpt, abstol, reltol; maxiters = 1000) if rootfind == SciMLBase.LeftRootFind solve(IntervalNonlinearProblem{false}(f, tup), - InternalITP(), abstol = abstol, + ITP(), abstol = abstol, reltol = reltol).left else solve(IntervalNonlinearProblem{false}(f, tup), - InternalITP(), abstol = abstol, + ITP(), abstol = abstol, reltol = reltol).right end end diff --git a/src/internal_itp.jl b/src/internal_itp.jl deleted file mode 100644 index 06f685045..000000000 --- a/src/internal_itp.jl +++ /dev/null @@ -1,87 +0,0 @@ -""" - prevfloat_tdir(x, x0, x1) - -Move `x` one floating point towards x0. -""" -function prevfloat_tdir(x, x0, x1) - x1 > x0 ? prevfloat(x) : nextfloat(x) -end - -function nextfloat_tdir(x, x0, x1) - x1 > x0 ? nextfloat(x) : prevfloat(x) -end - -function max_tdir(a, b, x0, x1) - x1 > x0 ? max(a, b) : min(a, b) -end - -""" -`InternalITP`: A non-allocating ITP method, internal to DiffEqBase for -simpler dependencies. -""" -struct InternalITP - scaled_k1::Float64 - n0::Int -end - -InternalITP() = InternalITP(0.2, 10) - -function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T, T}}, alg::InternalITP, - args...; - maxiters = 1000, kwargs...) where {IP, T} - f = Base.Fix2(prob.f, prob.p) - left, right = minmax(prob.tspan...) # a and b - fl, fr = f(left), f(right) - ϵ = eps(T) - if iszero(fl) - return SciMLBase.build_solution(prob, alg, left, fl; - retcode = ReturnCode.ExactSolutionLeft, left, right) - elseif iszero(fr) - return SciMLBase.build_solution(prob, alg, right, fr; - retcode = ReturnCode.ExactSolutionRight, left, right) - end - span = right - left - k1 = T(alg.scaled_k1) / span - n0 = T(alg.n0) - n_h = exponent(span / (2 * ϵ)) - ϵ_s = ϵ * exp2(n_h + n0) - T0 = zero(fl) - - i = 1 - while i ≤ maxiters - span = right - left - mid = (left + right) / 2 - r = ϵ_s - (span / 2) - - x_f = left + span * (fl / (fl - fr)) # Interpolation Step - - δ = max(k1 * span^2, eps(x_f)) - diff = mid - x_f - - xt = ifelse(δ ≤ abs(diff), x_f + copysign(δ, diff), mid) # Truncation Step - - xp = ifelse(abs(xt - mid) ≤ r, xt, mid - copysign(r, diff)) # Projection Step - yp = f(xp) - yps = yp * sign(fr) - if yps > T0 - right, fr = xp, yp - elseif yps < T0 - left, fl = xp, yp - else - return SciMLBase.build_solution( - prob, alg, xp, yps; retcode = ReturnCode.Success, left = xp, right = xp - ) - end - - i += 1 - ϵ_s /= 2 - - if nextfloat_tdir(left, left, right) == right - return SciMLBase.build_solution( - prob, alg, right, fr; retcode = ReturnCode.FloatingPointLimit, left, right - ) - end - end - return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters, - left = left, right = right) -end diff --git a/test/internal_rootfinder.jl b/test/internal_rootfinder.jl index 82b0632e7..29947cecd 100644 --- a/test/internal_rootfinder.jl +++ b/test/internal_rootfinder.jl @@ -1,8 +1,9 @@ using DiffEqBase -using DiffEqBase: InternalITP, IntervalNonlinearProblem +using DiffEqBase: IntervalNonlinearProblem +using SimpleNonlinearSolve: ITP using ForwardDiff -for Rootfinder in (InternalITP,) +for Rootfinder in (ITP,) rf = Rootfinder() # From SimpleNonlinearSolve f = (u, p) -> u * u - p