Skip to content

Commit 3fa15bf

Browse files
authored
Merge branch 'main' into compathelper/new_version/2024-10-24-12-28-19-253-01404902091
2 parents 74cec0c + 0ddac54 commit 3fa15bf

File tree

21 files changed

+1634
-54
lines changed

21 files changed

+1634
-54
lines changed

EpiAware/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1515
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1616
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1717
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
18+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1819
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
1920
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
2021
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -37,6 +38,8 @@ LinearAlgebra = ">= 1.9"
3738
LogExpFunctions = "0.3"
3839
MCMCChains = "6.0"
3940
Pathfinder = "0.9"
41+
OrdinaryDiffEq = "6.89.0"
42+
Pathfinder = "0.8, 0.9"
4043
QuadGK = "2.9"
4144
Random = "1.11"
4245
Reexport = "1.2"

EpiAware/src/EpiAwareUtils/prefix_submodel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ submodel = prefix_submodel(FixedIntercept(0.1), generate_latent, string(1), 2)
2020
2121
We can now draw a sample from the submodel.
2222
23-
```julia
23+
```@example
2424
rand(submodel)
2525
```
2626
"

EpiAware/src/EpiInfModels/EpiInfModels.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@ module EpiInfModels
66
using ..EpiAwareBase
77
using ..EpiAwareUtils
88

9-
using Turing, Distributions, DocStringExtensions, LinearAlgebra, LogExpFunctions
9+
using LogExpFunctions: xexpy
10+
11+
using Turing, Distributions, DocStringExtensions, LinearAlgebra, OrdinaryDiffEq
12+
13+
#Export parameter helpers
14+
export EpiData
1015

1116
#Export models
12-
export EpiData, DirectInfections, ExpGrowthRate, Renewal
17+
export DirectInfections, ExpGrowthRate, Renewal, ODEProcess
1318

1419
#Export functions
1520
export R_to_r, r_to_R, expected_Rt
@@ -20,6 +25,7 @@ include("DirectInfections.jl")
2025
include("ExpGrowthRate.jl")
2126
include("RenewalSteps.jl")
2227
include("Renewal.jl")
28+
include("ODEProcess.jl")
2329
include("utils.jl")
2430

2531
end
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
@doc raw"""
2+
A structure representing an infection process modeled by an Ordinary Differential Equation (ODE).
3+
At a high level, an `ODEProcess` struct object combines:
4+
5+
- An `AbstractTuringParamModel` which defines the ODE model in terms of `OrdinaryDiffEq` types,
6+
the parameters of the ODE model and a method to generate the parameters.
7+
- A technique for solving and interpreting the ODE model using the `SciML` ecosystem. This includes
8+
the solver used in the ODE solution, keyword arguments to send to the solver and a function
9+
to map the `ODESolution` solution object to latent infections.
10+
11+
# Constructors
12+
- `ODEProcess(prob::ODEProblem; ts, solver, sol2infs)`: Create an `ODEProcess`
13+
object with the ODE problem `prob`, time points `ts`, solver `solver`, and function `sol2infs`.
14+
15+
# Predefined ODE models
16+
Two basic ODE models are provided in the `EpiAware` package: `SIRParams` and `SEIRParams`.
17+
In both cases these are defined in terms of the proportions of the population in each compartment
18+
of the SIR and SEIR models respectively.
19+
20+
## SIR model
21+
22+
```math
23+
\begin{aligned}
24+
\frac{dS}{dt} &= -\beta SI \\
25+
\frac{dI}{dt} &= \beta SI - \gamma I \\
26+
\frac{dR}{dt} &= \gamma I
27+
\end{aligned}
28+
```
29+
Where `S` is the proportion of the population that is susceptible, `I` is the proportion of the
30+
population that is infected and `R` is the proportion of the population that is recovered. The
31+
parameters are the infectiousness `β` and the recovery rate `γ`.
32+
33+
```jldoctest sirexample; output = false
34+
using EpiAware, OrdinaryDiffEq, Distributions
35+
36+
# Create an instance of SIRParams
37+
sirparams = SIRParams(
38+
tspan = (0.0, 100.0),
39+
infectiousness = LogNormal(log(0.3), 0.05),
40+
recovery_rate = LogNormal(log(0.1), 0.05),
41+
initial_prop_infected = Beta(1, 99)
42+
)
43+
nothing
44+
45+
# output
46+
47+
```
48+
49+
## SEIR model
50+
51+
```math
52+
\begin{aligned}
53+
\frac{dS}{dt} &= -\beta SI \\
54+
\frac{dE}{dt} &= \beta SI - \alpha E \\
55+
\frac{dI}{dt} &= \alpha E - \gamma I \\
56+
\frac{dR}{dt} &= \gamma I
57+
\end{aligned}
58+
```
59+
Where `S` is the proportion of the population that is susceptible, `E` is the proportion of the
60+
population that is exposed, `I` is the proportion of the population that is infected and `R` is
61+
the proportion of the population that is recovered. The parameters are the infectiousness `β`,
62+
the incubation rate `α` and the recovery rate `γ`.
63+
64+
```jldoctest; output = false
65+
using EpiAware, OrdinaryDiffEq, Distributions, Random
66+
Random.seed!(1234)
67+
68+
# Create an instance of SIRParams
69+
seirparams = SEIRParams(
70+
tspan = (0.0, 100.0),
71+
infectiousness = LogNormal(log(0.3), 0.05),
72+
incubation_rate = LogNormal(log(0.1), 0.05),
73+
recovery_rate = LogNormal(log(0.1), 0.05),
74+
initial_prop_infected = Beta(1, 99)
75+
)
76+
nothing
77+
78+
# output
79+
80+
```
81+
82+
# Usage example with `ODEProcess` and predefined SIR model
83+
84+
In this example we define an `ODEProcess` object using the predefined `SIRParams` model from
85+
above. We then generate latent infections using the `generate_latent_infs` function, and refit
86+
the model using a `Turing` model.
87+
88+
We assume that the latent infections are observed with a Poisson likelihood around their
89+
ODE model prediction. The population size is `N = 1000`, which we put into the `sol2infs` function,
90+
which maps the ODE solution to the number of infections. Recall that the `EpiAware` default SIR
91+
implementation assumes the model is in density/proportion form. Also, note that since the `sol2infs`
92+
function is a link function that maps the ODE solution to the expected number of infections we also
93+
apply the `LogExpFunctions.softplus` function to ensure that the expected number of infections is non-negative.
94+
Note that the `softplus` function is a smooth approximation to the ReLU function `x -> max(0, x)`.
95+
The utility of this approach is that small negative output from the ODE solver (e.g. ~ -1e-10) will be
96+
mapped to small positive values, without needing to use strict positivity constraints in the model.
97+
98+
First, we define the `ODEProcess` object which combines the SIR model with the `sol2infs` link
99+
function and the solver options.
100+
101+
```jldoctest sirexample; output = false
102+
using Turing, LogExpFunctions
103+
N = 1000.0
104+
105+
sir_process = ODEProcess(
106+
params = sirparams,
107+
sol2infs = sol -> softplus.(N .* sol[2, :]),
108+
solver_options = Dict(:verbose => false, :saveat => 1.0)
109+
)
110+
nothing
111+
112+
# output
113+
114+
```
115+
116+
Second, we define a `PoissionError` observation model for linking the the number of infections.
117+
118+
```jldoctest sirexample; output = false
119+
pois_obs = PoissonError()
120+
nothing
121+
122+
# output
123+
124+
```
125+
126+
Next, we create a `Turing` model for the full generative process: this solves the ODE model for
127+
the latent infections and then samples the observed infections from a Poisson distribution with this
128+
as the average.
129+
130+
NB: The `nothing` argument is a dummy latent process, e.g. a log-Rt time series, that is not
131+
used in the SIR model, but might be used in other models.
132+
133+
```jldoctest sirexample; output = false
134+
@model function fit_ode_model(data)
135+
@submodel I_t = generate_latent_infs(sir_process, nothing)
136+
@submodel y_t = generate_observations(pois_obs, data, I_t)
137+
138+
return y_t
139+
end
140+
nothing
141+
142+
# output
143+
144+
```
145+
146+
We can generate some test data from the model by passing `missing` as the argument to the model.
147+
This tells `Turing` that there is no data to condition on, so it will sample from the prior parameters
148+
and then generate infections. In this case, we do it in a way where we cache the sampled parameters
149+
as `θ` for later use.
150+
151+
```jldoctest sirexample; output = false
152+
# Sampled parameters
153+
gen_mdl = fit_ode_model(missing)
154+
θ = rand(gen_mdl)
155+
test_data = (gen_mdl | θ)()
156+
nothing
157+
158+
# output
159+
160+
```
161+
162+
Now, we can refit the model but this time we condition on the test data. We suppress the
163+
output of the sampling process to keep the output clean, but you can remove the `@suppress` macro.
164+
165+
```jldoctest sirexample; output = false
166+
using Suppressor
167+
inference_mdl = fit_ode_model(test_data)
168+
chn = Suppressor.@suppress sample(inference_mdl, NUTS(), 2_000)
169+
summarize(chn)
170+
nothing
171+
172+
# output
173+
174+
```
175+
176+
We can compare the summarized chain to the sampled parameters in `θ` to see that the model is
177+
fitting the data well and recovering a credible interval containing the true parameters.
178+
179+
# Custom ODE models
180+
181+
To define a custom ODE model, you need to define:
182+
183+
- Some `CustomModel <: AbstractTuringLatentModel` struct
184+
that contains the ODE problem as a field called `prob`, as well as sufficient fields to
185+
define or sample the parameters of the ODE model.
186+
- A method for `EpiAwareBase.generate_latent(params::CustomModel, Z_t)` that generates the
187+
initial condition and parameters of the ODE model, potentially conditional on a sample from a latent process `Z_t`.
188+
This method must return a `Tuple` `(u0, p)` where `u0` is the initial condition and `p` is the parameters.
189+
190+
Here is an example of a simple custom ODE model for _specified_ exponential growth:
191+
192+
```jldoctest customexample; output = false
193+
using EpiAware, Turing, OrdinaryDiffEq
194+
# Define a simple exponential growth model for testing
195+
function expgrowth(du, u, p, t)
196+
du[1] = p[1] * u[1]
197+
end
198+
199+
r = log(2) / 7 # Growth rate corresponding to 7 day doubling time
200+
201+
# Define the ODE problem using SciML
202+
prob = ODEProblem(expgrowth, [1.0], (0.0, 10.0), [r])
203+
204+
# Define the custom parameters struct
205+
struct CustomModel <: AbstractTuringLatentModel
206+
prob::ODEProblem
207+
r::Float64
208+
u0::Float64
209+
end
210+
custom_ode = CustomModel(prob, r, 1.0)
211+
212+
# Define the custom generate_latent function
213+
@model function EpiAwareBase.generate_latent(params::CustomModel, n)
214+
return ([params.u0], [params.r])
215+
end
216+
nothing
217+
218+
# output
219+
220+
```
221+
222+
This model is not random! But we can still use it to generate latent infections.
223+
224+
```jldoctest customexample; output = false
225+
# Define the ODEProcess
226+
expgrowth_model = ODEProcess(
227+
params = custom_ode,
228+
sol2infs = sol -> sol[1, :]
229+
)
230+
infs = generate_latent_infs(expgrowth_model, nothing)()
231+
nothing
232+
233+
# output
234+
235+
```
236+
"""
237+
@kwdef struct ODEProcess{
238+
P <: AbstractTuringLatentModel, S, F <: Function, D <:
239+
Union{Dict, NamedTuple}} <:
240+
EpiAwareBase.AbstractTuringEpiModel
241+
"The ODE problem and parameters, where `P` is a subtype of `AbstractTuringLatentModel`."
242+
params::P
243+
"The solver used for the ODE problem. Default is `AutoVern7(Rodas5())`, which is an auto
244+
switching solver aimed at medium/low tolerances."
245+
solver::S = AutoVern7(Rodas5())
246+
"A function that maps the solution object of the ODE to infection counts."
247+
sol2infs::F
248+
"The extra solver options for the ODE problem. Can be either a `Dict` or a `NamedTuple`
249+
containing the solver options."
250+
solver_options::D = Dict(:verbose => false, :saveat => 1.0)
251+
end
252+
253+
@doc raw"""
254+
Implement the `generate_latent_infs` function for the `ODEProcess` model.
255+
256+
This function remakes the ODE problem with the provided initial conditions and parameters,
257+
solves it using the specified solver, and then transforms the solution into latent infections
258+
using the `sol2infs` function.
259+
260+
# Example usage with predefined SIR model
261+
262+
In this example we define an `ODEProcess` object using the predefined `SIRParams` model and
263+
generate an expected infection time series using SIR model parameters sampled from their priors.
264+
265+
```jldoctest; output = false
266+
using EpiAware, OrdinaryDiffEq, Distributions, Turing, LogExpFunctions
267+
268+
# Create an instance of SIRParams
269+
sirparams = SIRParams(
270+
tspan = (0.0, 100.0),
271+
infectiousness = LogNormal(log(0.3), 0.05),
272+
recovery_rate = LogNormal(log(0.1), 0.05),
273+
initial_prop_infected = Beta(1, 99)
274+
)
275+
276+
#Population size
277+
278+
N = 1000.0
279+
280+
sir_process = ODEProcess(
281+
params = sirparams,
282+
sol2infs = sol -> softplus.(N .* sol[2, :]),
283+
solver_options = Dict(:verbose => false, :saveat => 1.0)
284+
)
285+
286+
generated_It = generate_latent_infs(sir_process, nothing)()
287+
nothing
288+
289+
# output
290+
291+
```
292+
293+
"""
294+
@model function EpiAwareBase.generate_latent_infs(epi_model::ODEProcess, Z_t)
295+
prob, solver, sol2infs, solver_options = epi_model.params.prob,
296+
epi_model.solver, epi_model.sol2infs, epi_model.solver_options
297+
n = isnothing(Z_t) ? 0 : size(Z_t, 1)
298+
299+
@submodel u0, p = generate_latent(epi_model.params, n)
300+
301+
_prob = remake(prob; u0 = u0, p = p)
302+
sol = solve(_prob, solver; solver_options...)
303+
I_t = sol2infs(sol)
304+
305+
return I_t
306+
end

EpiAware/src/EpiLatentModels/EpiLatentModels.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@ using LogExpFunctions: softmax
1111

1212
using FillArrays: Fill
1313

14-
using Turing, Distributions, DocStringExtensions, LinearAlgebra
14+
using Turing, Distributions, DocStringExtensions, LinearAlgebra, SparseArrays,
15+
OrdinaryDiffEq
1516

1617
#Export models
1718
export FixedIntercept, Intercept, RandomWalk, AR, HierarchicalNormal
1819

20+
#Export ODE definitions
21+
export SIRParams, SEIRParams
22+
1923
# Export tools for manipulating latent models
2024
export CombineLatentModels, ConcatLatentModels, BroadcastLatentModel
2125

@@ -29,10 +33,13 @@ export broadcast_rule, broadcast_dayofweek, broadcast_weekly, equal_dimensions
2933
export DiffLatentModel, TransformLatentModel, PrefixLatentModel, RecordExpectedLatent
3034

3135
include("docstrings.jl")
36+
include("utils.jl")
3237
include("models/Intercept.jl")
3338
include("models/RandomWalk.jl")
3439
include("models/AR.jl")
3540
include("models/HierarchicalNormal.jl")
41+
include("odemodels/SIRParams.jl")
42+
include("odemodels/SEIRParams.jl")
3643
include("modifiers/DiffLatentModel.jl")
3744
include("modifiers/TransformLatentModel.jl")
3845
include("modifiers/PrefixLatentModel.jl")
@@ -42,6 +49,5 @@ include("manipulators/ConcatLatentModels.jl")
4249
include("manipulators/broadcast/LatentModel.jl")
4350
include("manipulators/broadcast/rules.jl")
4451
include("manipulators/broadcast/helpers.jl")
45-
include("utils.jl")
4652

4753
end

0 commit comments

Comments
 (0)