|
| 1 | +@doc raw""" |
| 2 | +A structure representing an infection process modeled by an Ordinary Differential Equation (ODE). |
| 3 | +At a high level, an `ODEProcess` struct object combines: |
| 4 | +
|
| 5 | +- An `AbstractTuringParamModel` which defines the ODE model in terms of `OrdinaryDiffEq` types, |
| 6 | + the parameters of the ODE model and a method to generate the parameters. |
| 7 | +- A technique for solving and interpreting the ODE model using the `SciML` ecosystem. This includes |
| 8 | + the solver used in the ODE solution, keyword arguments to send to the solver and a function |
| 9 | + to map the `ODESolution` solution object to latent infections. |
| 10 | +
|
| 11 | +# Constructors |
| 12 | +- `ODEProcess(prob::ODEProblem; ts, solver, sol2infs)`: Create an `ODEProcess` |
| 13 | +object with the ODE problem `prob`, time points `ts`, solver `solver`, and function `sol2infs`. |
| 14 | +
|
| 15 | +# Predefined ODE models |
| 16 | +Two basic ODE models are provided in the `EpiAware` package: `SIRParams` and `SEIRParams`. |
| 17 | +In both cases these are defined in terms of the proportions of the population in each compartment |
| 18 | +of the SIR and SEIR models respectively. |
| 19 | +
|
| 20 | +## SIR model |
| 21 | +
|
| 22 | +```math |
| 23 | +\begin{aligned} |
| 24 | +\frac{dS}{dt} &= -\beta SI \\ |
| 25 | +\frac{dI}{dt} &= \beta SI - \gamma I \\ |
| 26 | +\frac{dR}{dt} &= \gamma I |
| 27 | +\end{aligned} |
| 28 | +``` |
| 29 | +Where `S` is the proportion of the population that is susceptible, `I` is the proportion of the |
| 30 | +population that is infected and `R` is the proportion of the population that is recovered. The |
| 31 | +parameters are the infectiousness `β` and the recovery rate `γ`. |
| 32 | +
|
| 33 | +```jldoctest sirexample; output = false |
| 34 | +using EpiAware, OrdinaryDiffEq, Distributions |
| 35 | +
|
| 36 | +# Create an instance of SIRParams |
| 37 | +sirparams = SIRParams( |
| 38 | + tspan = (0.0, 100.0), |
| 39 | + infectiousness = LogNormal(log(0.3), 0.05), |
| 40 | + recovery_rate = LogNormal(log(0.1), 0.05), |
| 41 | + initial_prop_infected = Beta(1, 99) |
| 42 | +) |
| 43 | +nothing |
| 44 | +
|
| 45 | +# output |
| 46 | +
|
| 47 | +``` |
| 48 | +
|
| 49 | +## SEIR model |
| 50 | +
|
| 51 | +```math |
| 52 | +\begin{aligned} |
| 53 | +\frac{dS}{dt} &= -\beta SI \\ |
| 54 | +\frac{dE}{dt} &= \beta SI - \alpha E \\ |
| 55 | +\frac{dI}{dt} &= \alpha E - \gamma I \\ |
| 56 | +\frac{dR}{dt} &= \gamma I |
| 57 | +\end{aligned} |
| 58 | +``` |
| 59 | +Where `S` is the proportion of the population that is susceptible, `E` is the proportion of the |
| 60 | +population that is exposed, `I` is the proportion of the population that is infected and `R` is |
| 61 | +the proportion of the population that is recovered. The parameters are the infectiousness `β`, |
| 62 | +the incubation rate `α` and the recovery rate `γ`. |
| 63 | +
|
| 64 | +```jldoctest; output = false |
| 65 | +using EpiAware, OrdinaryDiffEq, Distributions, Random |
| 66 | +Random.seed!(1234) |
| 67 | +
|
| 68 | +# Create an instance of SIRParams |
| 69 | +seirparams = SEIRParams( |
| 70 | + tspan = (0.0, 100.0), |
| 71 | + infectiousness = LogNormal(log(0.3), 0.05), |
| 72 | + incubation_rate = LogNormal(log(0.1), 0.05), |
| 73 | + recovery_rate = LogNormal(log(0.1), 0.05), |
| 74 | + initial_prop_infected = Beta(1, 99) |
| 75 | +) |
| 76 | +nothing |
| 77 | +
|
| 78 | +# output |
| 79 | +
|
| 80 | +``` |
| 81 | +
|
| 82 | +# Usage example with `ODEProcess` and predefined SIR model |
| 83 | +
|
| 84 | +In this example we define an `ODEProcess` object using the predefined `SIRParams` model from |
| 85 | +above. We then generate latent infections using the `generate_latent_infs` function, and refit |
| 86 | +the model using a `Turing` model. |
| 87 | +
|
| 88 | +We assume that the latent infections are observed with a Poisson likelihood around their |
| 89 | +ODE model prediction. The population size is `N = 1000`, which we put into the `sol2infs` function, |
| 90 | +which maps the ODE solution to the number of infections. Recall that the `EpiAware` default SIR |
| 91 | +implementation assumes the model is in density/proportion form. Also, note that since the `sol2infs` |
| 92 | +function is a link function that maps the ODE solution to the expected number of infections we also |
| 93 | +apply the `LogExpFunctions.softplus` function to ensure that the expected number of infections is non-negative. |
| 94 | +Note that the `softplus` function is a smooth approximation to the ReLU function `x -> max(0, x)`. |
| 95 | +The utility of this approach is that small negative output from the ODE solver (e.g. ~ -1e-10) will be |
| 96 | +mapped to small positive values, without needing to use strict positivity constraints in the model. |
| 97 | +
|
| 98 | +First, we define the `ODEProcess` object which combines the SIR model with the `sol2infs` link |
| 99 | +function and the solver options. |
| 100 | +
|
| 101 | +```jldoctest sirexample; output = false |
| 102 | +using Turing, LogExpFunctions |
| 103 | +N = 1000.0 |
| 104 | +
|
| 105 | +sir_process = ODEProcess( |
| 106 | + params = sirparams, |
| 107 | + sol2infs = sol -> softplus.(N .* sol[2, :]), |
| 108 | + solver_options = Dict(:verbose => false, :saveat => 1.0) |
| 109 | +) |
| 110 | +nothing |
| 111 | +
|
| 112 | +# output |
| 113 | +
|
| 114 | +``` |
| 115 | +
|
| 116 | +Second, we define a `PoissionError` observation model for linking the the number of infections. |
| 117 | +
|
| 118 | +```jldoctest sirexample; output = false |
| 119 | +pois_obs = PoissonError() |
| 120 | +nothing |
| 121 | +
|
| 122 | +# output |
| 123 | +
|
| 124 | +``` |
| 125 | +
|
| 126 | +Next, we create a `Turing` model for the full generative process: this solves the ODE model for |
| 127 | +the latent infections and then samples the observed infections from a Poisson distribution with this |
| 128 | +as the average. |
| 129 | +
|
| 130 | +NB: The `nothing` argument is a dummy latent process, e.g. a log-Rt time series, that is not |
| 131 | +used in the SIR model, but might be used in other models. |
| 132 | +
|
| 133 | +```jldoctest sirexample; output = false |
| 134 | +@model function fit_ode_model(data) |
| 135 | + @submodel I_t = generate_latent_infs(sir_process, nothing) |
| 136 | + @submodel y_t = generate_observations(pois_obs, data, I_t) |
| 137 | +
|
| 138 | + return y_t |
| 139 | +end |
| 140 | +nothing |
| 141 | +
|
| 142 | +# output |
| 143 | +
|
| 144 | +``` |
| 145 | +
|
| 146 | +We can generate some test data from the model by passing `missing` as the argument to the model. |
| 147 | +This tells `Turing` that there is no data to condition on, so it will sample from the prior parameters |
| 148 | +and then generate infections. In this case, we do it in a way where we cache the sampled parameters |
| 149 | +as `θ` for later use. |
| 150 | +
|
| 151 | +```jldoctest sirexample; output = false |
| 152 | +# Sampled parameters |
| 153 | +gen_mdl = fit_ode_model(missing) |
| 154 | +θ = rand(gen_mdl) |
| 155 | +test_data = (gen_mdl | θ)() |
| 156 | +nothing |
| 157 | +
|
| 158 | +# output |
| 159 | +
|
| 160 | +``` |
| 161 | +
|
| 162 | +Now, we can refit the model but this time we condition on the test data. We suppress the |
| 163 | +output of the sampling process to keep the output clean, but you can remove the `@suppress` macro. |
| 164 | +
|
| 165 | +```jldoctest sirexample; output = false |
| 166 | +using Suppressor |
| 167 | +inference_mdl = fit_ode_model(test_data) |
| 168 | +chn = Suppressor.@suppress sample(inference_mdl, NUTS(), 2_000) |
| 169 | +summarize(chn) |
| 170 | +nothing |
| 171 | +
|
| 172 | +# output |
| 173 | +
|
| 174 | +``` |
| 175 | +
|
| 176 | +We can compare the summarized chain to the sampled parameters in `θ` to see that the model is |
| 177 | +fitting the data well and recovering a credible interval containing the true parameters. |
| 178 | +
|
| 179 | +# Custom ODE models |
| 180 | +
|
| 181 | +To define a custom ODE model, you need to define: |
| 182 | +
|
| 183 | +- Some `CustomModel <: AbstractTuringLatentModel` struct |
| 184 | + that contains the ODE problem as a field called `prob`, as well as sufficient fields to |
| 185 | + define or sample the parameters of the ODE model. |
| 186 | +- A method for `EpiAwareBase.generate_latent(params::CustomModel, Z_t)` that generates the |
| 187 | + initial condition and parameters of the ODE model, potentially conditional on a sample from a latent process `Z_t`. |
| 188 | + This method must return a `Tuple` `(u0, p)` where `u0` is the initial condition and `p` is the parameters. |
| 189 | +
|
| 190 | +Here is an example of a simple custom ODE model for _specified_ exponential growth: |
| 191 | +
|
| 192 | +```jldoctest customexample; output = false |
| 193 | +using EpiAware, Turing, OrdinaryDiffEq |
| 194 | +# Define a simple exponential growth model for testing |
| 195 | +function expgrowth(du, u, p, t) |
| 196 | + du[1] = p[1] * u[1] |
| 197 | +end |
| 198 | +
|
| 199 | +r = log(2) / 7 # Growth rate corresponding to 7 day doubling time |
| 200 | +
|
| 201 | +# Define the ODE problem using SciML |
| 202 | +prob = ODEProblem(expgrowth, [1.0], (0.0, 10.0), [r]) |
| 203 | +
|
| 204 | +# Define the custom parameters struct |
| 205 | +struct CustomModel <: AbstractTuringLatentModel |
| 206 | + prob::ODEProblem |
| 207 | + r::Float64 |
| 208 | + u0::Float64 |
| 209 | +end |
| 210 | +custom_ode = CustomModel(prob, r, 1.0) |
| 211 | +
|
| 212 | +# Define the custom generate_latent function |
| 213 | +@model function EpiAwareBase.generate_latent(params::CustomModel, n) |
| 214 | + return ([params.u0], [params.r]) |
| 215 | +end |
| 216 | +nothing |
| 217 | +
|
| 218 | +# output |
| 219 | +
|
| 220 | +``` |
| 221 | +
|
| 222 | +This model is not random! But we can still use it to generate latent infections. |
| 223 | +
|
| 224 | +```jldoctest customexample; output = false |
| 225 | +# Define the ODEProcess |
| 226 | +expgrowth_model = ODEProcess( |
| 227 | + params = custom_ode, |
| 228 | + sol2infs = sol -> sol[1, :] |
| 229 | +) |
| 230 | +infs = generate_latent_infs(expgrowth_model, nothing)() |
| 231 | +nothing |
| 232 | +
|
| 233 | +# output |
| 234 | +
|
| 235 | +``` |
| 236 | +""" |
| 237 | +@kwdef struct ODEProcess{ |
| 238 | + P <: AbstractTuringLatentModel, S, F <: Function, D <: |
| 239 | + Union{Dict, NamedTuple}} <: |
| 240 | + EpiAwareBase.AbstractTuringEpiModel |
| 241 | + "The ODE problem and parameters, where `P` is a subtype of `AbstractTuringLatentModel`." |
| 242 | + params::P |
| 243 | + "The solver used for the ODE problem. Default is `AutoVern7(Rodas5())`, which is an auto |
| 244 | + switching solver aimed at medium/low tolerances." |
| 245 | + solver::S = AutoVern7(Rodas5()) |
| 246 | + "A function that maps the solution object of the ODE to infection counts." |
| 247 | + sol2infs::F |
| 248 | + "The extra solver options for the ODE problem. Can be either a `Dict` or a `NamedTuple` |
| 249 | + containing the solver options." |
| 250 | + solver_options::D = Dict(:verbose => false, :saveat => 1.0) |
| 251 | +end |
| 252 | + |
| 253 | +@doc raw""" |
| 254 | +Implement the `generate_latent_infs` function for the `ODEProcess` model. |
| 255 | +
|
| 256 | +This function remakes the ODE problem with the provided initial conditions and parameters, |
| 257 | + solves it using the specified solver, and then transforms the solution into latent infections |
| 258 | + using the `sol2infs` function. |
| 259 | +
|
| 260 | +# Example usage with predefined SIR model |
| 261 | +
|
| 262 | +In this example we define an `ODEProcess` object using the predefined `SIRParams` model and |
| 263 | +generate an expected infection time series using SIR model parameters sampled from their priors. |
| 264 | +
|
| 265 | +```jldoctest; output = false |
| 266 | +using EpiAware, OrdinaryDiffEq, Distributions, Turing, LogExpFunctions |
| 267 | +
|
| 268 | +# Create an instance of SIRParams |
| 269 | +sirparams = SIRParams( |
| 270 | + tspan = (0.0, 100.0), |
| 271 | + infectiousness = LogNormal(log(0.3), 0.05), |
| 272 | + recovery_rate = LogNormal(log(0.1), 0.05), |
| 273 | + initial_prop_infected = Beta(1, 99) |
| 274 | +) |
| 275 | +
|
| 276 | +#Population size |
| 277 | +
|
| 278 | +N = 1000.0 |
| 279 | +
|
| 280 | +sir_process = ODEProcess( |
| 281 | + params = sirparams, |
| 282 | + sol2infs = sol -> softplus.(N .* sol[2, :]), |
| 283 | + solver_options = Dict(:verbose => false, :saveat => 1.0) |
| 284 | +) |
| 285 | +
|
| 286 | +generated_It = generate_latent_infs(sir_process, nothing)() |
| 287 | +nothing |
| 288 | +
|
| 289 | +# output |
| 290 | +
|
| 291 | +``` |
| 292 | +
|
| 293 | +""" |
| 294 | +@model function EpiAwareBase.generate_latent_infs(epi_model::ODEProcess, Z_t) |
| 295 | + prob, solver, sol2infs, solver_options = epi_model.params.prob, |
| 296 | + epi_model.solver, epi_model.sol2infs, epi_model.solver_options |
| 297 | + n = isnothing(Z_t) ? 0 : size(Z_t, 1) |
| 298 | + |
| 299 | + @submodel u0, p = generate_latent(epi_model.params, n) |
| 300 | + |
| 301 | + _prob = remake(prob; u0 = u0, p = p) |
| 302 | + sol = solve(_prob, solver; solver_options...) |
| 303 | + I_t = sol2infs(sol) |
| 304 | + |
| 305 | + return I_t |
| 306 | +end |
0 commit comments