Skip to content

Commit 05dbdfd

Browse files
committed
Merge branch 'main' into issue408-brand
2 parents 2c6ef5e + 7fd98c8 commit 05dbdfd

File tree

21 files changed

+1640
-60
lines changed

21 files changed

+1640
-60
lines changed

EpiAware/Project.toml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1717
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1818
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1919
ManifoldDiff = "af67fdf4-a580-4b9f-bbec-742ef357defd"
20+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2021
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
2122
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
2223
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -33,19 +34,20 @@ AdvancedHMC = "0.6"
3334
DataFramesMeta = "0.15"
3435
Distributions = "0.25"
3536
DocStringExtensions = "0.9"
36-
DynamicPPL = "0.27, 0.28, 0.29, 0.30"
37+
DynamicPPL = "0.30"
3738
FillArrays = "1.11"
3839
HybridArrays = "0.4.16"
3940
LinearAlgebra = ">= 1.9"
4041
LogExpFunctions = "0.3"
4142
MCMCChains = "6.0"
4243
ManifoldDiff = "0.3.10"
43-
Pathfinder = "0.8"
44+
OrdinaryDiffEq = "6.89.0"
45+
Pathfinder = "0.9"
4446
QuadGK = "2.9"
45-
Random = ">= 1.9"
47+
Random = "1.11"
4648
Reexport = "1.2"
47-
SparseArrays = "1.10"
48-
Statistics = "1.10"
49+
SparseArrays = "1.11"
50+
Statistics = "1.11"
4951
Tables = "1.11"
50-
Turing = "0.32, 0.33, 0.34"
51-
julia = ">= 1.9"
52+
Turing = "0.35"
53+
julia = "1.11"

EpiAware/src/EpiAwareUtils/prefix_submodel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ submodel = prefix_submodel(FixedIntercept(0.1), generate_latent, string(1), 2)
2020
2121
We can now draw a sample from the submodel.
2222
23-
```julia
23+
```@example
2424
rand(submodel)
2525
```
2626
"

EpiAware/src/EpiInfModels/EpiInfModels.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@ module EpiInfModels
66
using ..EpiAwareBase
77
using ..EpiAwareUtils
88

9-
using Turing, Distributions, DocStringExtensions, LinearAlgebra, LogExpFunctions
9+
using LogExpFunctions: xexpy
10+
11+
using Turing, Distributions, DocStringExtensions, LinearAlgebra, OrdinaryDiffEq
12+
13+
#Export parameter helpers
14+
export EpiData
1015

1116
#Export models
12-
export EpiData, DirectInfections, ExpGrowthRate, Renewal
17+
export DirectInfections, ExpGrowthRate, Renewal, ODEProcess
1318

1419
#Export functions
1520
export R_to_r, r_to_R, expected_Rt
@@ -20,6 +25,7 @@ include("DirectInfections.jl")
2025
include("ExpGrowthRate.jl")
2126
include("RenewalSteps.jl")
2227
include("Renewal.jl")
28+
include("ODEProcess.jl")
2329
include("utils.jl")
2430

2531
end
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
@doc raw"""
2+
A structure representing an infection process modeled by an Ordinary Differential Equation (ODE).
3+
At a high level, an `ODEProcess` struct object combines:
4+
5+
- An `AbstractTuringParamModel` which defines the ODE model in terms of `OrdinaryDiffEq` types,
6+
the parameters of the ODE model and a method to generate the parameters.
7+
- A technique for solving and interpreting the ODE model using the `SciML` ecosystem. This includes
8+
the solver used in the ODE solution, keyword arguments to send to the solver and a function
9+
to map the `ODESolution` solution object to latent infections.
10+
11+
# Constructors
12+
- `ODEProcess(prob::ODEProblem; ts, solver, sol2infs)`: Create an `ODEProcess`
13+
object with the ODE problem `prob`, time points `ts`, solver `solver`, and function `sol2infs`.
14+
15+
# Predefined ODE models
16+
Two basic ODE models are provided in the `EpiAware` package: `SIRParams` and `SEIRParams`.
17+
In both cases these are defined in terms of the proportions of the population in each compartment
18+
of the SIR and SEIR models respectively.
19+
20+
## SIR model
21+
22+
```math
23+
\begin{aligned}
24+
\frac{dS}{dt} &= -\beta SI \\
25+
\frac{dI}{dt} &= \beta SI - \gamma I \\
26+
\frac{dR}{dt} &= \gamma I
27+
\end{aligned}
28+
```
29+
Where `S` is the proportion of the population that is susceptible, `I` is the proportion of the
30+
population that is infected and `R` is the proportion of the population that is recovered. The
31+
parameters are the infectiousness `β` and the recovery rate `γ`.
32+
33+
```jldoctest sirexample; output = false
34+
using EpiAware, OrdinaryDiffEq, Distributions
35+
36+
# Create an instance of SIRParams
37+
sirparams = SIRParams(
38+
tspan = (0.0, 100.0),
39+
infectiousness = LogNormal(log(0.3), 0.05),
40+
recovery_rate = LogNormal(log(0.1), 0.05),
41+
initial_prop_infected = Beta(1, 99)
42+
)
43+
nothing
44+
45+
# output
46+
47+
```
48+
49+
## SEIR model
50+
51+
```math
52+
\begin{aligned}
53+
\frac{dS}{dt} &= -\beta SI \\
54+
\frac{dE}{dt} &= \beta SI - \alpha E \\
55+
\frac{dI}{dt} &= \alpha E - \gamma I \\
56+
\frac{dR}{dt} &= \gamma I
57+
\end{aligned}
58+
```
59+
Where `S` is the proportion of the population that is susceptible, `E` is the proportion of the
60+
population that is exposed, `I` is the proportion of the population that is infected and `R` is
61+
the proportion of the population that is recovered. The parameters are the infectiousness `β`,
62+
the incubation rate `α` and the recovery rate `γ`.
63+
64+
```jldoctest; output = false
65+
using EpiAware, OrdinaryDiffEq, Distributions, Random
66+
Random.seed!(1234)
67+
68+
# Create an instance of SIRParams
69+
seirparams = SEIRParams(
70+
tspan = (0.0, 100.0),
71+
infectiousness = LogNormal(log(0.3), 0.05),
72+
incubation_rate = LogNormal(log(0.1), 0.05),
73+
recovery_rate = LogNormal(log(0.1), 0.05),
74+
initial_prop_infected = Beta(1, 99)
75+
)
76+
nothing
77+
78+
# output
79+
80+
```
81+
82+
# Usage example with `ODEProcess` and predefined SIR model
83+
84+
In this example we define an `ODEProcess` object using the predefined `SIRParams` model from
85+
above. We then generate latent infections using the `generate_latent_infs` function, and refit
86+
the model using a `Turing` model.
87+
88+
We assume that the latent infections are observed with a Poisson likelihood around their
89+
ODE model prediction. The population size is `N = 1000`, which we put into the `sol2infs` function,
90+
which maps the ODE solution to the number of infections. Recall that the `EpiAware` default SIR
91+
implementation assumes the model is in density/proportion form. Also, note that since the `sol2infs`
92+
function is a link function that maps the ODE solution to the expected number of infections we also
93+
apply the `LogExpFunctions.softplus` function to ensure that the expected number of infections is non-negative.
94+
Note that the `softplus` function is a smooth approximation to the ReLU function `x -> max(0, x)`.
95+
The utility of this approach is that small negative output from the ODE solver (e.g. ~ -1e-10) will be
96+
mapped to small positive values, without needing to use strict positivity constraints in the model.
97+
98+
First, we define the `ODEProcess` object which combines the SIR model with the `sol2infs` link
99+
function and the solver options.
100+
101+
```jldoctest sirexample; output = false
102+
using Turing, LogExpFunctions
103+
N = 1000.0
104+
105+
sir_process = ODEProcess(
106+
params = sirparams,
107+
sol2infs = sol -> softplus.(N .* sol[2, :]),
108+
solver_options = Dict(:verbose => false, :saveat => 1.0)
109+
)
110+
nothing
111+
112+
# output
113+
114+
```
115+
116+
Second, we define a `PoissionError` observation model for linking the the number of infections.
117+
118+
```jldoctest sirexample; output = false
119+
pois_obs = PoissonError()
120+
nothing
121+
122+
# output
123+
124+
```
125+
126+
Next, we create a `Turing` model for the full generative process: this solves the ODE model for
127+
the latent infections and then samples the observed infections from a Poisson distribution with this
128+
as the average.
129+
130+
NB: The `nothing` argument is a dummy latent process, e.g. a log-Rt time series, that is not
131+
used in the SIR model, but might be used in other models.
132+
133+
```jldoctest sirexample; output = false
134+
@model function fit_ode_model(data)
135+
@submodel I_t = generate_latent_infs(sir_process, nothing)
136+
@submodel y_t = generate_observations(pois_obs, data, I_t)
137+
138+
return y_t
139+
end
140+
nothing
141+
142+
# output
143+
144+
```
145+
146+
We can generate some test data from the model by passing `missing` as the argument to the model.
147+
This tells `Turing` that there is no data to condition on, so it will sample from the prior parameters
148+
and then generate infections. In this case, we do it in a way where we cache the sampled parameters
149+
as `θ` for later use.
150+
151+
```jldoctest sirexample; output = false
152+
# Sampled parameters
153+
gen_mdl = fit_ode_model(missing)
154+
θ = rand(gen_mdl)
155+
test_data = (gen_mdl | θ)()
156+
nothing
157+
158+
# output
159+
160+
```
161+
162+
Now, we can refit the model but this time we condition on the test data. We suppress the
163+
output of the sampling process to keep the output clean, but you can remove the `@suppress` macro.
164+
165+
```jldoctest sirexample; output = false
166+
using Suppressor
167+
inference_mdl = fit_ode_model(test_data)
168+
chn = Suppressor.@suppress sample(inference_mdl, NUTS(), 2_000)
169+
summarize(chn)
170+
nothing
171+
172+
# output
173+
174+
```
175+
176+
We can compare the summarized chain to the sampled parameters in `θ` to see that the model is
177+
fitting the data well and recovering a credible interval containing the true parameters.
178+
179+
# Custom ODE models
180+
181+
To define a custom ODE model, you need to define:
182+
183+
- Some `CustomModel <: AbstractTuringLatentModel` struct
184+
that contains the ODE problem as a field called `prob`, as well as sufficient fields to
185+
define or sample the parameters of the ODE model.
186+
- A method for `EpiAwareBase.generate_latent(params::CustomModel, Z_t)` that generates the
187+
initial condition and parameters of the ODE model, potentially conditional on a sample from a latent process `Z_t`.
188+
This method must return a `Tuple` `(u0, p)` where `u0` is the initial condition and `p` is the parameters.
189+
190+
Here is an example of a simple custom ODE model for _specified_ exponential growth:
191+
192+
```jldoctest customexample; output = false
193+
using EpiAware, Turing, OrdinaryDiffEq
194+
# Define a simple exponential growth model for testing
195+
function expgrowth(du, u, p, t)
196+
du[1] = p[1] * u[1]
197+
end
198+
199+
r = log(2) / 7 # Growth rate corresponding to 7 day doubling time
200+
201+
# Define the ODE problem using SciML
202+
prob = ODEProblem(expgrowth, [1.0], (0.0, 10.0), [r])
203+
204+
# Define the custom parameters struct
205+
struct CustomModel <: AbstractTuringLatentModel
206+
prob::ODEProblem
207+
r::Float64
208+
u0::Float64
209+
end
210+
custom_ode = CustomModel(prob, r, 1.0)
211+
212+
# Define the custom generate_latent function
213+
@model function EpiAwareBase.generate_latent(params::CustomModel, n)
214+
return ([params.u0], [params.r])
215+
end
216+
nothing
217+
218+
# output
219+
220+
```
221+
222+
This model is not random! But we can still use it to generate latent infections.
223+
224+
```jldoctest customexample; output = false
225+
# Define the ODEProcess
226+
expgrowth_model = ODEProcess(
227+
params = custom_ode,
228+
sol2infs = sol -> sol[1, :]
229+
)
230+
infs = generate_latent_infs(expgrowth_model, nothing)()
231+
nothing
232+
233+
# output
234+
235+
```
236+
"""
237+
@kwdef struct ODEProcess{
238+
P <: AbstractTuringLatentModel, S, F <: Function, D <:
239+
Union{Dict, NamedTuple}} <:
240+
EpiAwareBase.AbstractTuringEpiModel
241+
"The ODE problem and parameters, where `P` is a subtype of `AbstractTuringLatentModel`."
242+
params::P
243+
"The solver used for the ODE problem. Default is `AutoVern7(Rodas5())`, which is an auto
244+
switching solver aimed at medium/low tolerances."
245+
solver::S = AutoVern7(Rodas5())
246+
"A function that maps the solution object of the ODE to infection counts."
247+
sol2infs::F
248+
"The extra solver options for the ODE problem. Can be either a `Dict` or a `NamedTuple`
249+
containing the solver options."
250+
solver_options::D = Dict(:verbose => false, :saveat => 1.0)
251+
end
252+
253+
@doc raw"""
254+
Implement the `generate_latent_infs` function for the `ODEProcess` model.
255+
256+
This function remakes the ODE problem with the provided initial conditions and parameters,
257+
solves it using the specified solver, and then transforms the solution into latent infections
258+
using the `sol2infs` function.
259+
260+
# Example usage with predefined SIR model
261+
262+
In this example we define an `ODEProcess` object using the predefined `SIRParams` model and
263+
generate an expected infection time series using SIR model parameters sampled from their priors.
264+
265+
```jldoctest; output = false
266+
using EpiAware, OrdinaryDiffEq, Distributions, Turing, LogExpFunctions
267+
268+
# Create an instance of SIRParams
269+
sirparams = SIRParams(
270+
tspan = (0.0, 100.0),
271+
infectiousness = LogNormal(log(0.3), 0.05),
272+
recovery_rate = LogNormal(log(0.1), 0.05),
273+
initial_prop_infected = Beta(1, 99)
274+
)
275+
276+
#Population size
277+
278+
N = 1000.0
279+
280+
sir_process = ODEProcess(
281+
params = sirparams,
282+
sol2infs = sol -> softplus.(N .* sol[2, :]),
283+
solver_options = Dict(:verbose => false, :saveat => 1.0)
284+
)
285+
286+
generated_It = generate_latent_infs(sir_process, nothing)()
287+
nothing
288+
289+
# output
290+
291+
```
292+
293+
"""
294+
@model function EpiAwareBase.generate_latent_infs(epi_model::ODEProcess, Z_t)
295+
prob, solver, sol2infs, solver_options = epi_model.params.prob,
296+
epi_model.solver, epi_model.sol2infs, epi_model.solver_options
297+
n = isnothing(Z_t) ? 0 : size(Z_t, 1)
298+
299+
@submodel u0, p = generate_latent(epi_model.params, n)
300+
301+
_prob = remake(prob; u0 = u0, p = p)
302+
sol = solve(_prob, solver; solver_options...)
303+
I_t = sol2infs(sol)
304+
305+
return I_t
306+
end

0 commit comments

Comments
 (0)