From 2e2ba5879fe29a354163dae3864f3644588b2703 Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Sun, 3 Jul 2022 13:34:43 +0100 Subject: [PATCH 01/15] Extrande algorithm --- src/JumpProcesses.jl | 2 + src/aggregators/aggregators.jl | 15 +++- src/aggregators/extrande.jl | 133 +++++++++++++++++++++++++++++++++ src/jumps.jl | 33 +++++++- src/problem.jl | 20 ++++- test/extrande.jl | 20 +++++ test/runtests.jl | 1 + 7 files changed, 217 insertions(+), 7 deletions(-) create mode 100644 src/aggregators/extrande.jl create mode 100644 test/extrande.jl diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 39f301707..af690a58e 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -49,6 +49,7 @@ include("aggregators/prioritytable.jl") include("aggregators/directcr.jl") include("aggregators/rssacr.jl") include("aggregators/rdirect.jl") +include("aggregators/extrande.jl") # spatial: include("spatial/spatial_massaction_jump.jl") @@ -82,6 +83,7 @@ export Direct, DirectFW, SortingDirect, DirectCR export BracketData, RSSA export FRM, FRMFW, NRM export RSSACR, RDirect +export Extrande export get_num_majumps, needs_depgraph, needs_vartojumps_map diff --git a/src/aggregators/aggregators.jl b/src/aggregators/aggregators.jl index bc5e354fb..a6662431d 100644 --- a/src/aggregators/aggregators.jl +++ b/src/aggregators/aggregators.jl @@ -144,8 +144,18 @@ doi: 10.1063/1.4928635 """ struct DirectCRDirect <: AbstractAggregatorAlgorithm end +""" +The Extrande method for simulating variable rate jumps with user-defined bounds +on jumps rates and validity intervals via rejection. + +Stochastic Simulation of Biomolecular Networks in Dynamic Environments, Voliotis +M, Thomas P, Grima R, Bowsher CG, PLOS Computational Biology 12(6): e1004923. +(2016); doi.org/10.1371/journal.pcbi.1004923 +""" +struct Extrande <: AbstractAggregatorAlgorithm end + const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(), - FRMFW(), NRM(), RSSACR(), RDirect()) + FRMFW(), NRM(), RSSACR(), RDirect(), Extrande()) # For JumpProblem construction without an aggregator struct NullAggregator <: AbstractAggregatorAlgorithm end @@ -167,3 +177,6 @@ needs_vartojumps_map(aggregator::RSSACR) = true is_spatial(aggregator::AbstractAggregatorAlgorithm) = false is_spatial(aggregator::NSM) = true is_spatial(aggregator::DirectCRDirect) = true + +is_ficticious(aggregator::AbstractAggregatorAlgorithm) = false +is_ficticious(aggregator::Extrande) = true diff --git a/src/aggregators/extrande.jl b/src/aggregators/extrande.jl new file mode 100644 index 000000000..b689b078b --- /dev/null +++ b/src/aggregators/extrande.jl @@ -0,0 +1,133 @@ +# Define the aggregator. +struct Extrande <: AbstractAggregatorAlgorithm end + +""" +Extrande sampling method for jumps with defined rate bounds. +""" + +nullaffect!(integrator) = nothing +const NullAffectJump = ConstantRateJump((u,p,t) -> 0.0, nullaffect!) + +mutable struct ExtrandeJumpAggregation{T,S,F1,F2,F3,F4,RNG} <: AbstractSSAJumpAggregator + next_jump::Int + prev_jump::Int + next_jump_time::T + end_time::T + cur_rates::Vector{T} + sum_rate::T + ma_jumps::S + rate_bnds::F3 + wds::F4 + rates::F1 + affects!::F2 + save_positions::Tuple{Bool,Bool} + rng::RNG +end + +function ExtrandeJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool,Bool}, rng::RNG; rate_bounds::F3, windows::F4, kwargs...) where {T,S,F1,F2,F3,F4,RNG} + + ExtrandeJumpAggregation{T,S,F1,F2,F3,F4,RNG}(nj, nj, njt, et, crs, sr, maj, rate_bounds, windows, rs, affs!, sps, rng) +end + + +############################# Required Functions ############################## +function aggregate(aggregator::Extrande, u, p, t, end_time, constant_jumps, + ma_jumps, save_positions, rng; bounded_va_jumps, kwargs...) + + rates, affects! = get_jump_info_fwrappers(u, p, t, (constant_jumps..., bounded_va_jumps..., NullAffectJump)) + rbnds, wnds = get_va_jump_bound_info_fwrapper(u, p, t, (constant_jumps..., bounded_va_jumps...,NullAffectJump)) + build_jump_aggregation(ExtrandeJumpAggregation, u, p, t, end_time, ma_jumps, + rates, affects!, save_positions, rng; u=u, rate_bounds=rbnds, windows=wnds, kwargs...) +end + +# set up a new simulation and calculate the first jump / jump time +function initialize!(p::ExtrandeJumpAggregation, integrator, u, params, t) + p.end_time = integrator.sol.prob.tspan[2] + generate_jumps!(p, integrator, u, params, t) +end + +# execute one jump, changing the system state +@inline function execute_jumps!(p::ExtrandeJumpAggregation, integrator, u, params, t) + # execute jump + u = update_state!(p, integrator, u) + nothing +end + +@fastmath function next_ma_jump(p::ExtrandeJumpAggregation, u, params, t) + ttnj = typemax(typeof(t)) + nextrx = zero(Int) + majumps = p.ma_jumps + @inbounds for i in 1:get_num_majumps(majumps) + p.cur_rates[i] = evalrxrate(u, i, majumps) + dt = randexp(p.rng) / p.cur_rates[i] + if dt < ttnj + ttnj = dt + nextrx = i + end + end + nextrx, ttnj +end + +@fastmath function next_extrande_jump(p::ExtrandeJumpAggregation, u, params, t) + ttnj = typemax(typeof(t)) + nextrx = zero(Int) + Wmin = typemax(typeof(t)) + Bmax = typemax(typeof(t)) + + # Calculate the total rate bound and the largest common validity window. + Ws = zeros(typeof(t), length(p.wds)) + Bs = zeros(typeof(t), length(p.rate_bnds)) + if !isempty(p.rate_bnds) + idx = get_num_majumps(p.ma_jumps) + 1 + @inbounds for i in 1:length(p.wds) + Ws[i] = p.wds[i](u,params,t) + Bs[i] = p.rate_bnds[i](u,params,t) + end + Wmin = minimum(Ws) + Bmax = sum(Bs) + end + + # Rejection sampling. + if !isempty(p.rates) + nextrx = length(p.rates) + idx = get_num_majumps(p.ma_jumps) + 1 + prop_ttnj = randexp(p.rng) / Bmax + if prop_ttnj < t + Wmin + fill_cur_rates(u, params, t, p.cur_rates, idx, p.rates...) + + prev_rate = zero(t) + cur_rates = p.cur_rates + @inbounds for i in idx:length(cur_rates) + cur_rates[i] = cur_rates[i] + prev_rate + prev_rate = cur_rates[i] + end + + UBmax = rand(p.rng) * Bmax + ttnj = prop_ttnj + if p.cur_rates[end] ≥ UBmax + @inbounds nextrx = findfirst(x -> x ≥ UBmax, p.cur_rates) + end + else + ttnj = Wmin + end + end + + return nextrx, ttnj +end + + +function generate_jumps!(p::ExtrandeJumpAggregation, integrator, u, params, t) + nextmaj, ttnmaj = next_ma_jump(p, u, params, t) + nextexj, ttnexj = next_extrande_jump(p, u, params, t) + + # execute reaction with minimal time + if ttnmaj < ttnexj + p.next_jump = nextmaj + p.next_jump_time = t + ttnmaj + else + p.next_jump = nextexj + p.next_jump_time = t + ttnexj + end + + nothing +end diff --git a/src/jumps.jl b/src/jumps.jl index 6c0a231af..b86cef90c 100644 --- a/src/jumps.jl +++ b/src/jumps.jl @@ -67,11 +67,15 @@ crj = VariableRateJump(rate, affect!) (5), DOI:10.1063/1.1835951 is used for calculating jump times with `VariableRateJump`s within ODE/SDE integrators. """ -struct VariableRateJump{R, F, I, T, T2} <: AbstractJump +struct VariableRateJump{R1,F,R2,R3,I,T,T2} <: AbstractJump """Function `rate(u,p,t)` that returns the jump's current rate.""" - rate::R + rate::R1 """Function `affect(integrator)` that updates the state for one occurrence of the jump.""" affect!::F + rbnd::R2 + """Function `rwnd(u,p,t)` that returns the time window length t* for which the + rate bound rbnd(u,p,t) is valid. Used for ficticious jump methods.""" + rwnd::R3 idxs::I rootfind::Bool interp_points::Int @@ -83,10 +87,12 @@ end function VariableRateJump(rate, affect!; idxs = nothing, rootfind = true, + rbnd=nothing, + rwnd=nothing, save_positions = (true, true), interp_points = 10, abstol = 1e-12, reltol = 0) - VariableRateJump(rate, affect!, idxs, + VariableRateJump(rate, affect!, rbnd, rwnd, idxs, rootfind, interp_points, save_positions, abstol, reltol) end @@ -577,3 +583,24 @@ function get_jump_info_fwrappers(u, p, t, constant_jumps) rates, affects! end + +##### helpers for splitting variable rate jumps with rate bounds and without ##### + +function split_variable_jumps(variable_jumps) + condition(v) = v.rbnd !== nothing + return filter(v -> condition(v), variable_jumps), filter(v -> !condition(v), variable_jumps) +end + +function get_va_jump_bound_info_fwrapper(u,p,t,jumps) + RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t),Tuple{typeof(u), typeof(p), typeof(t)}} + + if (jumps !== nothing) && !isempty(jumps) + rates = [j isa VariableRateJump ? RateWrapper(j.rbnd) : RateWrapper(j.rate) for j in jumps] + wnds = [j isa VariableRateJump ? RateWrapper(j.rwnd) : RateWrapper((u,p,t) -> Inf) for j in jumps] + else + rates = Vector{RateWrapper}() + wnds = Vector{RateWrapper}() + end + + rates, wnds +end diff --git a/src/problem.jl b/src/problem.jl index 6ce8bcb47..6d3b46da1 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -188,12 +188,25 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS else disc = aggregate(aggregator, u, prob.p, t, end_time, jumps.constant_jumps, maj, save_positions, rng; spatial_system = spatial_system, - hopping_constants = hopping_constants, kwargs...) + hopping_constants = hopping_constants, + bounded_va_jumps = Tuple{}(), kwargs...) constant_jump_callback = DiscreteCallback(disc) end iip = isinplace_jump(prob, jumps.regular_jump) + # Ficticious rate handling. + if !is_ficticious(aggregator) + unbnd_var_jumps = jumps.variable_jumps + else + bounded_va_jumps, unbnd_var_jumps = split_variable_jumps(jumps.variable_jumps) + disc = aggregate(aggregator, u, prob.p, t, end_time, jumps.constant_jumps, maj, + save_positions, rng; spatial_system = spatial_system, + hopping_constants = hopping_constants, + bounded_va_jumps=bounded_va_jumps, kwargs...) + constant_jump_callback = DiscreteCallback(disc) + end + ## Variable Rate Handling if typeof(jumps.variable_jumps) <: Tuple{} new_prob = prob @@ -206,11 +219,11 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS callbacks = CallbackSet(constant_jump_callback, variable_jump_callback) JumpProblem{iip, typeof(new_prob), typeof(aggregator), typeof(callbacks), - typeof(disc), typeof(jumps.variable_jumps), + typeof(disc), typeof(unbnd_var_jumps), typeof(jumps.regular_jump), typeof(maj), typeof(rng)}(new_prob, aggregator, disc, callbacks, - jumps.variable_jumps, + unbnd_var_jumps, jumps.regular_jump, maj, rng) end @@ -359,3 +372,4 @@ function Base.show(io::IO, mime::MIME"text/plain", A::JumpProblem) end TreeViews.hastreeview(x::JumpProblem) = true + diff --git a/test/extrande.jl b/test/extrande.jl new file mode 100644 index 000000000..a378eae94 --- /dev/null +++ b/test/extrande.jl @@ -0,0 +1,20 @@ +using DiffEqBase, JumpProcesses, OrdinaryDiffEq, Test +using StableRNGs +rng = StableRNG(12345) + +rate = (u,p,t) -> t +affect! = (integrator) -> (integrator.u[1] = integrator.u[1]+1) +rbound = (u,p,t) -> (t + 0.1) +rwindow = (u,p,t) -> 0.1 +jump = VariableRateJump(rate,affect!,interp_points=1000,rbnd=rbound,rwnd=rwindow) +jump2 = deepcopy(jump) + +f = function (du,u,p,t) + du[1] = 0.0 +end + +prob = ODEProblem(f,[0.0],(0.0,10.0)) +jump_prob = JumpProblem(prob,Extrande(),jump; rng=rng) + +integrator = init(jump_prob,Tsit5()) +sol = solve(jump_prob,Tsit5()) diff --git a/test/runtests.jl b/test/runtests.jl index 067b35250..3c25fa846 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,4 +25,5 @@ using JumpProcesses, DiffEqBase, SafeTestsets @time @safetestset "Spatial A + B <--> C" begin include("spatial/ABC.jl") end @time @safetestset "Spatially Varying Reaction Rates" begin include("spatial/spatial_majump.jl") end @time @safetestset "Pure diffusion" begin include("spatial/diffusion.jl") end + @time @safetestset "Ficticious Jump " begin include("extrande.jl") end end From d47e5dafa8772983a507f6163b4d495eef8f48cb Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Thu, 14 Jul 2022 17:40:36 +0100 Subject: [PATCH 02/15] if validity window not given assume bound holds everywhere --- src/jumps.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/jumps.jl b/src/jumps.jl index b86cef90c..d4e8eb6ec 100644 --- a/src/jumps.jl +++ b/src/jumps.jl @@ -591,12 +591,17 @@ function split_variable_jumps(variable_jumps) return filter(v -> condition(v), variable_jumps), filter(v -> !condition(v), variable_jumps) end +function rate_window_function(jump) + # Assumes that if no window is given the rate bound is valid for all times. + return !(jump.rwnd isa Nothing) ? jump.rwnd : (u,p,t) -> Inf +end + function get_va_jump_bound_info_fwrapper(u,p,t,jumps) RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t),Tuple{typeof(u), typeof(p), typeof(t)}} if (jumps !== nothing) && !isempty(jumps) rates = [j isa VariableRateJump ? RateWrapper(j.rbnd) : RateWrapper(j.rate) for j in jumps] - wnds = [j isa VariableRateJump ? RateWrapper(j.rwnd) : RateWrapper((u,p,t) -> Inf) for j in jumps] + wnds = [j isa VariableRateJump ? RateWrapper(rate_window_function(j)) : RateWrapper((u,p,t) -> Inf) for j in jumps] else rates = Vector{RateWrapper}() wnds = Vector{RateWrapper}() From 05532c36c311767b0fe5c805c4b26bc80ab3f7e4 Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Thu, 14 Jul 2022 17:42:55 +0100 Subject: [PATCH 03/15] bug fix. compute rates at proposed jump times --- src/aggregators/extrande.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aggregators/extrande.jl b/src/aggregators/extrande.jl index b689b078b..ce36d35d9 100644 --- a/src/aggregators/extrande.jl +++ b/src/aggregators/extrande.jl @@ -92,8 +92,8 @@ end nextrx = length(p.rates) idx = get_num_majumps(p.ma_jumps) + 1 prop_ttnj = randexp(p.rng) / Bmax - if prop_ttnj < t + Wmin - fill_cur_rates(u, params, t, p.cur_rates, idx, p.rates...) + if prop_ttnj < Wmin + fill_cur_rates(u, params, prop_ttnj + t, p.cur_rates, idx, p.rates...) prev_rate = zero(t) cur_rates = p.cur_rates From a0bd99f2de831dd610b7450ebcb5487093aa65fd Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Thu, 14 Jul 2022 17:44:56 +0100 Subject: [PATCH 04/15] check simulation with step rate function --- test/extrande.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/extrande.jl b/test/extrande.jl index a378eae94..e1a389f8b 100644 --- a/test/extrande.jl +++ b/test/extrande.jl @@ -18,3 +18,13 @@ jump_prob = JumpProblem(prob,Extrande(),jump; rng=rng) integrator = init(jump_prob,Tsit5()) sol = solve(jump_prob,Tsit5()) + +rate2 = (u,p,t) -> t < 5.0 ? 1.0 : 0.0 +rbound2 = (u,p,t) -> 1.0 +jump3 = VariableRateJump(rate2,affect2!,interp_points=1000;rbnd=rbound2) + +prob2 = ODEProblem(f,[0.0],(0.0,10.0)) +jump_prob2 = JumpProblem(prob2,Extrande(),jump3; rng=rng) + +sol2 = solve(jump_prob2,Tsit5()) +@test sol2(5.0)[1] == sol2[end][1] From 0cf67a7c91a3ecb4e08cc75f4b455e26f2781ff7 Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Tue, 19 Jul 2022 16:11:58 +0100 Subject: [PATCH 05/15] avoid duplicating bounded variable jumps --- src/problem.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/problem.jl b/src/problem.jl index 6d3b46da1..8ae0608e8 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -208,13 +208,13 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS end ## Variable Rate Handling - if typeof(jumps.variable_jumps) <: Tuple{} + if typeof(unbnd_var_jumps) <: Tuple{} new_prob = prob variable_jump_callback = CallbackSet() else new_prob = extend_problem(prob, jumps; rng = rng) variable_jump_callback = build_variable_callback(CallbackSet(), 0, - jumps.variable_jumps...; rng = rng) + unbnd_var_jumps...; rng = rng) end callbacks = CallbackSet(constant_jump_callback, variable_jump_callback) From e1df00d8512cfbd672a48a4481cd8629d24ba4ea Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Tue, 19 Jul 2022 16:28:26 +0100 Subject: [PATCH 06/15] typo corrected --- test/extrande.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/extrande.jl b/test/extrande.jl index e1a389f8b..40addaeb9 100644 --- a/test/extrande.jl +++ b/test/extrande.jl @@ -7,7 +7,6 @@ affect! = (integrator) -> (integrator.u[1] = integrator.u[1]+1) rbound = (u,p,t) -> (t + 0.1) rwindow = (u,p,t) -> 0.1 jump = VariableRateJump(rate,affect!,interp_points=1000,rbnd=rbound,rwnd=rwindow) -jump2 = deepcopy(jump) f = function (du,u,p,t) du[1] = 0.0 @@ -21,7 +20,7 @@ sol = solve(jump_prob,Tsit5()) rate2 = (u,p,t) -> t < 5.0 ? 1.0 : 0.0 rbound2 = (u,p,t) -> 1.0 -jump3 = VariableRateJump(rate2,affect2!,interp_points=1000;rbnd=rbound2) +jump3 = VariableRateJump(rate2,affect!,interp_points=1000;rbnd=rbound2) prob2 = ODEProblem(f,[0.0],(0.0,10.0)) jump_prob2 = JumpProblem(prob2,Extrande(),jump3; rng=rng) From 934858618e97a01e97b206b08c3b900a0d8dfaf3 Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Wed, 11 Jan 2023 19:20:53 +0000 Subject: [PATCH 07/15] Test time-dependent birth death process mean against ODE solution --- test/extrande.jl | 65 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/test/extrande.jl b/test/extrande.jl index 40addaeb9..a4db5f083 100644 --- a/test/extrande.jl +++ b/test/extrande.jl @@ -1,29 +1,64 @@ using DiffEqBase, JumpProcesses, OrdinaryDiffEq, Test using StableRNGs -rng = StableRNG(12345) - -rate = (u,p,t) -> t -affect! = (integrator) -> (integrator.u[1] = integrator.u[1]+1) -rbound = (u,p,t) -> (t + 0.1) -rwindow = (u,p,t) -> 0.1 -jump = VariableRateJump(rate,affect!,interp_points=1000,rbnd=rbound,rwnd=rwindow) +using Statistics +rng = StableRNG(48572) f = function (du,u,p,t) du[1] = 0.0 end +rate = (u,p,t) -> t < 5.0 ? 1.0 : 0.0 +rbound = (u,p,t) -> 1.0 +rinterval = (u,p,t) -> Inf +affect! = (integrator) -> (integrator.u[1] = integrator.u[1]+1) +jump = VariableRateJump(rate,affect!;urate=rbound,rateinterval=rinterval) + prob = ODEProblem(f,[0.0],(0.0,10.0)) jump_prob = JumpProblem(prob,Extrande(),jump; rng=rng) -integrator = init(jump_prob,Tsit5()) +# Test that process doesn't jump when rate switches to 0. sol = solve(jump_prob,Tsit5()) +@test sol(5.0)[1] == sol[end][1] + +# Birth-death process with time-varying birth rates. +Nsims = 1000000 +u0 = [10.0,] + +function runsimulations(jump_prob, testts) + Psamp = zeros(Int, length(testts), Nsims) + for i in 1:Nsims + sol_ = solve(jump_prob, Tsit5()) + Psamp[:, i] = getindex.(sol_(testts).u, 1) + end + mean(Psamp, dims=2) +end + +# Variable rate birth jumps. +rateb = (u,p,t) -> (0.1*sin(t) + 0.2) +ratebbound = (u,p,t) -> 0.3 +ratebwindow = (u,p,t) -> Inf +affectb! = (integrator) -> (integrator.u[1] = integrator.u[1] + 1) +jumpb = VariableRateJump(rateb, affectb!;urate=ratebbound, rateinterval=ratebwindow) + +# Constant rate death jumps. +rated = (u,p,t) -> u[1] * 0.08 +affectd! = (integrator) -> (integrator.u[1] = integrator.u[1] - 1) +jumpd = ConstantRateJump(rated, affectd!) -rate2 = (u,p,t) -> t < 5.0 ? 1.0 : 0.0 -rbound2 = (u,p,t) -> 1.0 -jump3 = VariableRateJump(rate2,affect!,interp_points=1000;rbnd=rbound2) +# Problem definition. +bd_prob = ODEProblem(f,u0,(0.0,2pi)) +jump_bd_prob = JumpProblem(bd_prob, Extrande(), jumpb, jumpd) + +test_times = range(1.0, stop=2pi, length=3) +means = runsimulations(jump_bd_prob, test_times) + +# ODE for the mean. +fu = function (du, u, p, t) + du[1] = (0.1*sin(t) + 0.2) - (u[1] * 0.08) +end -prob2 = ODEProblem(f,[0.0],(0.0,10.0)) -jump_prob2 = JumpProblem(prob2,Extrande(),jump3; rng=rng) +ode_prob = ODEProblem(fu,u0,(0.0,2*pi)) +ode_sol = solve(ode_prob, Tsit5()) -sol2 = solve(jump_prob2,Tsit5()) -@test sol2(5.0)[1] == sol2[end][1] +# Test extrande against the ODE mean. +@test prod(isapprox.(means, getindex.(ode_sol(test_times).u, 1), rtol=1e-3)) From 930bb078deef4f66851fb058a21fb6232db3faa2 Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Wed, 11 Jan 2023 20:18:21 +0000 Subject: [PATCH 08/15] allow extrande be used with no variable rate jumps --- src/aggregators/extrande.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aggregators/extrande.jl b/src/aggregators/extrande.jl index 4742b08be..fc8134ec7 100644 --- a/src/aggregators/extrande.jl +++ b/src/aggregators/extrande.jl @@ -32,7 +32,7 @@ end ############################# Required Functions ############################## function aggregate(aggregator::Extrande, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; variable_jumps, kwargs...) + ma_jumps, save_positions, rng; variable_jumps=(), kwargs...) rates, affects! = get_jump_info_fwrappers(u, p, t, (constant_jumps..., variable_jumps..., NullAffectJump)) rbnds, wnds = get_va_jump_bound_info_fwrapper(u, p, t, (constant_jumps..., variable_jumps...,NullAffectJump)) From 2d35fb46849f0d93b8828526788e4fad56807904 Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Thu, 12 Jan 2023 09:34:52 +0000 Subject: [PATCH 09/15] format --- src/aggregators/extrande.jl | 75 ++++++++++++++++++++----------------- src/jumps.jl | 25 +++++++------ src/problem.jl | 1 - test/extrande.jl | 46 +++++++++++------------ 4 files changed, 78 insertions(+), 69 deletions(-) diff --git a/src/aggregators/extrande.jl b/src/aggregators/extrande.jl index fc8134ec7..fc23c8cd9 100644 --- a/src/aggregators/extrande.jl +++ b/src/aggregators/extrande.jl @@ -6,9 +6,10 @@ Extrande sampling method for jumps with defined rate bounds. """ nullaffect!(integrator) = nothing -const NullAffectJump = ConstantRateJump((u,p,t) -> 0.0, nullaffect!) +const NullAffectJump = ConstantRateJump((u, p, t) -> 0.0, nullaffect!) -mutable struct ExtrandeJumpAggregation{T,S,F1,F2,F3,F4,RNG} <: AbstractSSAJumpAggregator +mutable struct ExtrandeJumpAggregation{T, S, F1, F2, F3, F4, RNG} <: + AbstractSSAJumpAggregator next_jump::Int prev_jump::Int next_jump_time::T @@ -20,24 +21,31 @@ mutable struct ExtrandeJumpAggregation{T,S,F1,F2,F3,F4,RNG} <: AbstractSSAJumpAg wds::F4 rates::F1 affects!::F2 - save_positions::Tuple{Bool,Bool} + save_positions::Tuple{Bool, Bool} rng::RNG end -function ExtrandeJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool,Bool}, rng::RNG; rate_bounds::F3, windows::F4, kwargs...) where {T,S,F1,F2,F3,F4,RNG} - - ExtrandeJumpAggregation{T,S,F1,F2,F3,F4,RNG}(nj, nj, njt, et, crs, sr, maj, rate_bounds, windows, rs, affs!, sps, rng) +function ExtrandeJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, + rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; + rate_bounds::F3, windows::F4, + kwargs...) where {T, S, F1, F2, F3, F4, RNG} + ExtrandeJumpAggregation{T, S, F1, F2, F3, F4, RNG}(nj, nj, njt, et, crs, sr, maj, + rate_bounds, windows, rs, affs!, sps, + rng) end - ############################# Required Functions ############################## function aggregate(aggregator::Extrande, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; variable_jumps=(), kwargs...) - - rates, affects! = get_jump_info_fwrappers(u, p, t, (constant_jumps..., variable_jumps..., NullAffectJump)) - rbnds, wnds = get_va_jump_bound_info_fwrapper(u, p, t, (constant_jumps..., variable_jumps...,NullAffectJump)) + ma_jumps, save_positions, rng; variable_jumps = (), kwargs...) + rates, affects! = get_jump_info_fwrappers(u, p, t, + (constant_jumps..., variable_jumps..., + NullAffectJump)) + rbnds, wnds = get_va_jump_bound_info_fwrapper(u, p, t, + (constant_jumps..., variable_jumps..., + NullAffectJump)) build_jump_aggregation(ExtrandeJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; u=u, rate_bounds=rbnds, windows=wnds, kwargs...) + rates, affects!, save_positions, rng; u = u, rate_bounds = rbnds, + windows = wnds, kwargs...) end # set up a new simulation and calculate the first jump / jump time @@ -54,25 +62,25 @@ end end @fastmath function next_ma_jump(p::ExtrandeJumpAggregation, u, params, t) - ttnj = typemax(typeof(t)) - nextrx = zero(Int) - majumps = p.ma_jumps + ttnj = typemax(typeof(t)) + nextrx = zero(Int) + majumps = p.ma_jumps @inbounds for i in 1:get_num_majumps(majumps) p.cur_rates[i] = evalrxrate(u, i, majumps) dt = randexp(p.rng) / p.cur_rates[i] if dt < ttnj - ttnj = dt + ttnj = dt nextrx = i end end nextrx, ttnj end -@fastmath function next_extrande_jump(p::ExtrandeJumpAggregation, u, params, t) - ttnj = typemax(typeof(t)) +@fastmath function next_extrande_jump(p::ExtrandeJumpAggregation, u, params, t) + ttnj = typemax(typeof(t)) nextrx = zero(Int) Wmin = typemax(typeof(t)) - Bmax = typemax(typeof(t)) + Bmax = typemax(typeof(t)) # Calculate the total rate bound and the largest common validity window. Ws = zeros(typeof(t), length(p.wds)) @@ -80,13 +88,13 @@ end if !isempty(p.rate_bnds) idx = get_num_majumps(p.ma_jumps) + 1 @inbounds for i in 1:length(p.wds) - Ws[i] = p.wds[i](u,params,t) - Bs[i] = p.rate_bnds[i](u,params,t) + Ws[i] = p.wds[i](u, params, t) + Bs[i] = p.rate_bnds[i](u, params, t) end Wmin = minimum(Ws) - Bmax = sum(Bs) + Bmax = sum(Bs) end - + # Rejection sampling. if !isempty(p.rates) nextrx = length(p.rates) @@ -98,16 +106,16 @@ end prev_rate = zero(t) cur_rates = p.cur_rates @inbounds for i in idx:length(cur_rates) - cur_rates[i] = cur_rates[i] + prev_rate - prev_rate = cur_rates[i] + cur_rates[i] = cur_rates[i] + prev_rate + prev_rate = cur_rates[i] end UBmax = rand(p.rng) * Bmax ttnj = prop_ttnj - if p.cur_rates[end] ≥ UBmax - @inbounds nextrx = findfirst(x -> x ≥ UBmax, p.cur_rates) + if p.cur_rates[end] ≥ UBmax + @inbounds nextrx = findfirst(x -> x ≥ UBmax, p.cur_rates) end - else + else ttnj = Wmin end end @@ -115,18 +123,17 @@ end return nextrx, ttnj end - function generate_jumps!(p::ExtrandeJumpAggregation, integrator, u, params, t) nextmaj, ttnmaj = next_ma_jump(p, u, params, t) nextexj, ttnexj = next_extrande_jump(p, u, params, t) - + # execute reaction with minimal time if ttnmaj < ttnexj - p.next_jump = nextmaj - p.next_jump_time = t + ttnmaj + p.next_jump = nextmaj + p.next_jump_time = t + ttnmaj else - p.next_jump = nextexj - p.next_jump_time = t + ttnexj + p.next_jump = nextexj + p.next_jump_time = t + ttnexj end nothing diff --git a/src/jumps.jl b/src/jumps.jl index 9c3c1b66c..b8b2ec825 100644 --- a/src/jumps.jl +++ b/src/jumps.jl @@ -707,19 +707,22 @@ end function rate_window_function(jump) # Assumes that if no window is given the rate bound is valid for all times. - return !(jump.rateinterval isa Nothing) ? jump.rateinterval : (u,p,t) -> Inf + return !(jump.rateinterval isa Nothing) ? jump.rateinterval : (u, p, t) -> Inf end -function get_va_jump_bound_info_fwrapper(u,p,t,jumps) - RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t),Tuple{typeof(u), typeof(p), typeof(t)}} +function get_va_jump_bound_info_fwrapper(u, p, t, jumps) + RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t), + Tuple{typeof(u), typeof(p), typeof(t)}} - if (jumps !== nothing) && !isempty(jumps) - rates = [j isa VariableRateJump ? RateWrapper(j.urate) : RateWrapper(j.rate) for j in jumps] - wnds = [j isa VariableRateJump ? RateWrapper(rate_window_function(j)) : RateWrapper((u,p,t) -> Inf) for j in jumps] - else - rates = Vector{RateWrapper}() - wnds = Vector{RateWrapper}() - end + if (jumps !== nothing) && !isempty(jumps) + rates = [j isa VariableRateJump ? RateWrapper(j.urate) : RateWrapper(j.rate) + for j in jumps] + wnds = [j isa VariableRateJump ? RateWrapper(rate_window_function(j)) : + RateWrapper((u, p, t) -> Inf) for j in jumps] + else + rates = Vector{RateWrapper}() + wnds = Vector{RateWrapper}() + end - rates, wnds + rates, wnds end diff --git a/src/problem.jl b/src/problem.jl index 3f88f0271..d84c7ff6d 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -448,4 +448,3 @@ function Base.show(io::IO, mime::MIME"text/plain", A::JumpProblem) end TreeViews.hastreeview(x::JumpProblem) = true - diff --git a/test/extrande.jl b/test/extrande.jl index a4db5f083..68fe37610 100644 --- a/test/extrande.jl +++ b/test/extrande.jl @@ -3,26 +3,26 @@ using StableRNGs using Statistics rng = StableRNG(48572) -f = function (du,u,p,t) - du[1] = 0.0 +f = function (du, u, p, t) + du[1] = 0.0 end -rate = (u,p,t) -> t < 5.0 ? 1.0 : 0.0 -rbound = (u,p,t) -> 1.0 -rinterval = (u,p,t) -> Inf -affect! = (integrator) -> (integrator.u[1] = integrator.u[1]+1) -jump = VariableRateJump(rate,affect!;urate=rbound,rateinterval=rinterval) +rate = (u, p, t) -> t < 5.0 ? 1.0 : 0.0 +rbound = (u, p, t) -> 1.0 +rinterval = (u, p, t) -> Inf +affect! = (integrator) -> (integrator.u[1] = integrator.u[1] + 1) +jump = VariableRateJump(rate, affect!; urate = rbound, rateinterval = rinterval) -prob = ODEProblem(f,[0.0],(0.0,10.0)) -jump_prob = JumpProblem(prob,Extrande(),jump; rng=rng) +prob = ODEProblem(f, [0.0], (0.0, 10.0)) +jump_prob = JumpProblem(prob, Extrande(), jump; rng = rng) # Test that process doesn't jump when rate switches to 0. -sol = solve(jump_prob,Tsit5()) +sol = solve(jump_prob, Tsit5()) @test sol(5.0)[1] == sol[end][1] # Birth-death process with time-varying birth rates. Nsims = 1000000 -u0 = [10.0,] +u0 = [10.0] function runsimulations(jump_prob, testts) Psamp = zeros(Int, length(testts), Nsims) @@ -30,35 +30,35 @@ function runsimulations(jump_prob, testts) sol_ = solve(jump_prob, Tsit5()) Psamp[:, i] = getindex.(sol_(testts).u, 1) end - mean(Psamp, dims=2) + mean(Psamp, dims = 2) end # Variable rate birth jumps. -rateb = (u,p,t) -> (0.1*sin(t) + 0.2) -ratebbound = (u,p,t) -> 0.3 -ratebwindow = (u,p,t) -> Inf +rateb = (u, p, t) -> (0.1 * sin(t) + 0.2) +ratebbound = (u, p, t) -> 0.3 +ratebwindow = (u, p, t) -> Inf affectb! = (integrator) -> (integrator.u[1] = integrator.u[1] + 1) -jumpb = VariableRateJump(rateb, affectb!;urate=ratebbound, rateinterval=ratebwindow) +jumpb = VariableRateJump(rateb, affectb!; urate = ratebbound, rateinterval = ratebwindow) # Constant rate death jumps. -rated = (u,p,t) -> u[1] * 0.08 +rated = (u, p, t) -> u[1] * 0.08 affectd! = (integrator) -> (integrator.u[1] = integrator.u[1] - 1) jumpd = ConstantRateJump(rated, affectd!) # Problem definition. -bd_prob = ODEProblem(f,u0,(0.0,2pi)) -jump_bd_prob = JumpProblem(bd_prob, Extrande(), jumpb, jumpd) +bd_prob = ODEProblem(f, u0, (0.0, 2pi)) +jump_bd_prob = JumpProblem(bd_prob, Extrande(), jumpb, jumpd) -test_times = range(1.0, stop=2pi, length=3) +test_times = range(1.0, stop = 2pi, length = 3) means = runsimulations(jump_bd_prob, test_times) # ODE for the mean. fu = function (du, u, p, t) - du[1] = (0.1*sin(t) + 0.2) - (u[1] * 0.08) + du[1] = (0.1 * sin(t) + 0.2) - (u[1] * 0.08) end -ode_prob = ODEProblem(fu,u0,(0.0,2*pi)) +ode_prob = ODEProblem(fu, u0, (0.0, 2 * pi)) ode_sol = solve(ode_prob, Tsit5()) # Test extrande against the ODE mean. -@test prod(isapprox.(means, getindex.(ode_sol(test_times).u, 1), rtol=1e-3)) +@test prod(isapprox.(means, getindex.(ode_sol(test_times).u, 1), rtol = 1e-3)) From fd59fe1ca43b9af675ecb7b445bac7e8aeb2b49c Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Tue, 7 Feb 2023 10:59:28 +0000 Subject: [PATCH 10/15] Add extrande to the tested algorithms --- test/hawkes_test.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/hawkes_test.jl b/test/hawkes_test.jl index 0de428e36..b2801705f 100644 --- a/test/hawkes_test.jl +++ b/test/hawkes_test.jl @@ -105,7 +105,7 @@ h = [Float64[]] Eλ, Varλ = expected_stats_hawkes_problem(p, tspan) -algs = (Direct(), Coevolve(), Coevolve()) +algs = (Direct(), Coevolve(), Coevolve(), Extrande()) uselrate = zeros(Bool, length(algs)) uselrate[3] = true Nsims = 250 @@ -122,7 +122,7 @@ for (i, alg) in enumerate(algs) reset_history!(h) sols[n] = solve(jump_prob, stepper) end - if typeof(alg) <: Coevolve + if typeof(alg) <: Union{Coevolve, Extrande} λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols)) else cols = length(sols[1].u[1].u) From 0d7d2e7c21ead5c2f8945246779bdc8303ebbd6b Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Tue, 7 Feb 2023 11:01:18 +0000 Subject: [PATCH 11/15] Remove unused function is_fictitious --- src/aggregators/aggregators.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/aggregators/aggregators.jl b/src/aggregators/aggregators.jl index 9359d0ad7..ed046e57d 100644 --- a/src/aggregators/aggregators.jl +++ b/src/aggregators/aggregators.jl @@ -196,6 +196,3 @@ supports_variablerates(aggregator::Extrande) = true is_spatial(aggregator::AbstractAggregatorAlgorithm) = false is_spatial(aggregator::NSM) = true is_spatial(aggregator::DirectCRDirect) = true - -is_ficticious(aggregator::AbstractAggregatorAlgorithm) = false -is_ficticious(aggregator::Extrande) = true From 80187044b8ad292be2db1bf2dff6be9ffc73a09b Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Tue, 7 Feb 2023 11:03:39 +0000 Subject: [PATCH 12/15] Fixes, ma jumps now part of the algorithm --- src/aggregators/extrande.jl | 49 ++++++++++--------------------------- 1 file changed, 13 insertions(+), 36 deletions(-) diff --git a/src/aggregators/extrande.jl b/src/aggregators/extrande.jl index fc23c8cd9..8f38c0524 100644 --- a/src/aggregators/extrande.jl +++ b/src/aggregators/extrande.jl @@ -37,11 +37,12 @@ end ############################# Required Functions ############################## function aggregate(aggregator::Extrande, u, p, t, end_time, constant_jumps, ma_jumps, save_positions, rng; variable_jumps = (), kwargs...) + ma_jumps_ = !isnothing(ma_jumps) ? ma_jumps : () rates, affects! = get_jump_info_fwrappers(u, p, t, - (constant_jumps..., variable_jumps..., + (constant_jumps..., variable_jumps..., ma_jumps_..., NullAffectJump)) rbnds, wnds = get_va_jump_bound_info_fwrapper(u, p, t, - (constant_jumps..., variable_jumps..., + (constant_jumps..., variable_jumps..., ma_jumps_..., NullAffectJump)) build_jump_aggregation(ExtrandeJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; u = u, rate_bounds = rbnds, @@ -61,21 +62,6 @@ end nothing end -@fastmath function next_ma_jump(p::ExtrandeJumpAggregation, u, params, t) - ttnj = typemax(typeof(t)) - nextrx = zero(Int) - majumps = p.ma_jumps - @inbounds for i in 1:get_num_majumps(majumps) - p.cur_rates[i] = evalrxrate(u, i, majumps) - dt = randexp(p.rng) / p.cur_rates[i] - if dt < ttnj - ttnj = dt - nextrx = i - end - end - nextrx, ttnj -end - @fastmath function next_extrande_jump(p::ExtrandeJumpAggregation, u, params, t) ttnj = typemax(typeof(t)) nextrx = zero(Int) @@ -83,22 +69,18 @@ end Bmax = typemax(typeof(t)) # Calculate the total rate bound and the largest common validity window. - Ws = zeros(typeof(t), length(p.wds)) - Bs = zeros(typeof(t), length(p.rate_bnds)) if !isempty(p.rate_bnds) - idx = get_num_majumps(p.ma_jumps) + 1 + Bmax = typeof(t)(0.) @inbounds for i in 1:length(p.wds) - Ws[i] = p.wds[i](u, params, t) - Bs[i] = p.rate_bnds[i](u, params, t) + Wmin = min(Wmin, p.wds[i](u, params, t)) + Bmax += p.rate_bnds[i](u, params, t) end - Wmin = minimum(Ws) - Bmax = sum(Bs) end # Rejection sampling. if !isempty(p.rates) nextrx = length(p.rates) - idx = get_num_majumps(p.ma_jumps) + 1 + idx = 1 prop_ttnj = randexp(p.rng) / Bmax if prop_ttnj < Wmin fill_cur_rates(u, params, prop_ttnj + t, p.cur_rates, idx, p.rates...) @@ -113,7 +95,10 @@ end UBmax = rand(p.rng) * Bmax ttnj = prop_ttnj if p.cur_rates[end] ≥ UBmax - @inbounds nextrx = findfirst(x -> x ≥ UBmax, p.cur_rates) + nextrx = 1 + @inbounds while p.cur_rates[nextrx] < UBmax + nextrx += 1 + end end else ttnj = Wmin @@ -124,17 +109,9 @@ end end function generate_jumps!(p::ExtrandeJumpAggregation, integrator, u, params, t) - nextmaj, ttnmaj = next_ma_jump(p, u, params, t) nextexj, ttnexj = next_extrande_jump(p, u, params, t) - - # execute reaction with minimal time - if ttnmaj < ttnexj - p.next_jump = nextmaj - p.next_jump_time = t + ttnmaj - else - p.next_jump = nextexj - p.next_jump_time = t + ttnexj - end + p.next_jump = nextexj + p.next_jump_time = t + ttnexj nothing end From db901fd82cb868e22c208715e52eb9a56d420ba9 Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Tue, 7 Feb 2023 19:54:37 +0000 Subject: [PATCH 13/15] format --- src/aggregators/extrande.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/aggregators/extrande.jl b/src/aggregators/extrande.jl index 8f38c0524..761884f1e 100644 --- a/src/aggregators/extrande.jl +++ b/src/aggregators/extrande.jl @@ -39,10 +39,12 @@ function aggregate(aggregator::Extrande, u, p, t, end_time, constant_jumps, ma_jumps, save_positions, rng; variable_jumps = (), kwargs...) ma_jumps_ = !isnothing(ma_jumps) ? ma_jumps : () rates, affects! = get_jump_info_fwrappers(u, p, t, - (constant_jumps..., variable_jumps..., ma_jumps_..., + (constant_jumps..., variable_jumps..., + ma_jumps_..., NullAffectJump)) rbnds, wnds = get_va_jump_bound_info_fwrapper(u, p, t, - (constant_jumps..., variable_jumps..., ma_jumps_..., + (constant_jumps..., variable_jumps..., + ma_jumps_..., NullAffectJump)) build_jump_aggregation(ExtrandeJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; u = u, rate_bounds = rbnds, @@ -70,7 +72,7 @@ end # Calculate the total rate bound and the largest common validity window. if !isempty(p.rate_bnds) - Bmax = typeof(t)(0.) + Bmax = typeof(t)(0.0) @inbounds for i in 1:length(p.wds) Wmin = min(Wmin, p.wds[i](u, params, t)) Bmax += p.rate_bnds[i](u, params, t) @@ -80,7 +82,7 @@ end # Rejection sampling. if !isempty(p.rates) nextrx = length(p.rates) - idx = 1 + idx = 1 prop_ttnj = randexp(p.rng) / Bmax if prop_ttnj < Wmin fill_cur_rates(u, params, prop_ttnj + t, p.cur_rates, idx, p.rates...) From 15a31f9759e585aa06564afce127cc0c5c0d7b61 Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Wed, 8 Feb 2023 18:59:41 +0000 Subject: [PATCH 14/15] add extrande to more test sets --- test/degenerate_rx_cases.jl | 2 +- test/linearreaction_test.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/degenerate_rx_cases.jl b/test/degenerate_rx_cases.jl index b81bb2b34..79e9fb9cf 100644 --- a/test/degenerate_rx_cases.jl +++ b/test/degenerate_rx_cases.jl @@ -13,7 +13,7 @@ doprint = false doplot = false methods = (RDirect(), RSSACR(), Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(), - NRM(), RSSA(), DirectCR(), Coevolve()) + NRM(), RSSA(), DirectCR(), Coevolve(), Extrande()) # one reaction case, mass action jump, vector of data rate = [2.0] diff --git a/test/linearreaction_test.jl b/test/linearreaction_test.jl index d169b5713..4bc6c5b63 100644 --- a/test/linearreaction_test.jl +++ b/test/linearreaction_test.jl @@ -16,7 +16,7 @@ tf = 0.1 baserate = 0.1 A0 = 100 exactmean = (t, ratevec) -> A0 * exp(-sum(ratevec) * t) -SSAalgs = [RSSACR(), Direct(), RSSA()] +SSAalgs = [RSSACR(), Direct(), RSSA(), Extrande()] spec_to_dep_jumps = [collect(1:Nrxs), []] jump_to_dep_specs = [[1, 2] for i in 1:Nrxs] From fad63ca17111c78e476a83c273b967b69f587edf Mon Sep 17 00:00:00 2001 From: Paul Piho Date: Wed, 8 Feb 2023 19:01:16 +0000 Subject: [PATCH 15/15] another attempt at correcting the mass action jump treatment --- src/aggregators/extrande.jl | 55 +++++++++++++++++++------------------ test/extrande.jl | 10 +++++++ 2 files changed, 39 insertions(+), 26 deletions(-) diff --git a/src/aggregators/extrande.jl b/src/aggregators/extrande.jl index 761884f1e..c59155403 100644 --- a/src/aggregators/extrande.jl +++ b/src/aggregators/extrande.jl @@ -37,14 +37,11 @@ end ############################# Required Functions ############################## function aggregate(aggregator::Extrande, u, p, t, end_time, constant_jumps, ma_jumps, save_positions, rng; variable_jumps = (), kwargs...) - ma_jumps_ = !isnothing(ma_jumps) ? ma_jumps : () rates, affects! = get_jump_info_fwrappers(u, p, t, (constant_jumps..., variable_jumps..., - ma_jumps_..., NullAffectJump)) rbnds, wnds = get_va_jump_bound_info_fwrapper(u, p, t, (constant_jumps..., variable_jumps..., - ma_jumps_..., NullAffectJump)) build_jump_aggregation(ExtrandeJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; u = u, rate_bounds = rbnds, @@ -66,13 +63,26 @@ end @fastmath function next_extrande_jump(p::ExtrandeJumpAggregation, u, params, t) ttnj = typemax(typeof(t)) - nextrx = zero(Int) Wmin = typemax(typeof(t)) - Bmax = typemax(typeof(t)) + Bmax = zero(t) + + prev_rate = zero(t) + new_rate = zero(t) + cur_rates = p.cur_rates + + # Mass action rates + majumps = p.ma_jumps + idx = get_num_majumps(majumps) + + @inbounds for i in 1:idx + new_rate = evalrxrate(u, i, majumps) + cur_rates[i] = add_fast(new_rate, prev_rate) + prev_rate = cur_rates[i] + Bmax += prev_rate + end # Calculate the total rate bound and the largest common validity window. if !isempty(p.rate_bnds) - Bmax = typeof(t)(0.0) @inbounds for i in 1:length(p.wds) Wmin = min(Wmin, p.wds[i](u, params, t)) Bmax += p.rate_bnds[i](u, params, t) @@ -80,31 +90,24 @@ end end # Rejection sampling. - if !isempty(p.rates) - nextrx = length(p.rates) - idx = 1 - prop_ttnj = randexp(p.rng) / Bmax - if prop_ttnj < Wmin + nextrx = length(cur_rates) + prop_ttnj = randexp(p.rng) / Bmax + if prop_ttnj < Wmin + if !isempty(p.rates) + idx += 1 fill_cur_rates(u, params, prop_ttnj + t, p.cur_rates, idx, p.rates...) - - prev_rate = zero(t) - cur_rates = p.cur_rates @inbounds for i in idx:length(cur_rates) - cur_rates[i] = cur_rates[i] + prev_rate + cur_rates[i] = add_fast(cur_rates[i], prev_rate) prev_rate = cur_rates[i] end - - UBmax = rand(p.rng) * Bmax - ttnj = prop_ttnj - if p.cur_rates[end] ≥ UBmax - nextrx = 1 - @inbounds while p.cur_rates[nextrx] < UBmax - nextrx += 1 - end - end - else - ttnj = Wmin end + UBmax = rand(p.rng) * Bmax + ttnj = prop_ttnj + if p.cur_rates[end] ≥ UBmax + nextrx = searchsortedfirst(p.cur_rates, UBmax) + end + else + ttnj = Wmin end return nextrx, ttnj diff --git a/test/extrande.jl b/test/extrande.jl index 68fe37610..444ed1ed1 100644 --- a/test/extrande.jl +++ b/test/extrande.jl @@ -62,3 +62,13 @@ ode_sol = solve(ode_prob, Tsit5()) # Test extrande against the ODE mean. @test prod(isapprox.(means, getindex.(ode_sol(test_times).u, 1), rtol = 1e-3)) + +# Make sure interfaces correctly with Mass Action Jumps. +reactant_stoich = [[1 => 1]] +net_stoich = [[1 => -1]] +majd = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1]) +bmajd_prob = ODEProblem(f, u0, (0.0, 2pi), [0.08]) +jump_bmajd_prob = JumpProblem(bmajd_prob, Extrande(), jumpb, majd) + +means_mass_action = runsimulations(jump_bmajd_prob, test_times) +@test prod(isapprox.(means_mass_action, getindex.(ode_sol(test_times).u, 1), rtol = 1e-3))