From 11b90342451e1abc91ac4dfd19eb470c6f662776 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 27 Sep 2024 13:04:25 +0100 Subject: [PATCH 01/17] Plots to Makie Remove Statsplots and convert all plots to CairoMakie in Mishra replication --- EpiAware/docs/Project.toml | 4 +- .../replications/mishra-2020/index.jl | 220 +++++++++--------- 2 files changed, 116 insertions(+), 108 deletions(-) diff --git a/EpiAware/docs/Project.toml b/EpiAware/docs/Project.toml index f3fd8f2b7..819b6412f 100644 --- a/EpiAware/docs/Project.toml +++ b/EpiAware/docs/Project.toml @@ -2,6 +2,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Changelog = "5217a498-cd5d-4ec6-b8c2-9b85a09b6e3e" DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -9,10 +10,11 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da" Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781" PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" +TimeSeries = "9e3dc215-6440-5c97-bce1-76c03772f85e" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl b/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl index bac2c2820..b745e99c7 100644 --- a/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl +++ b/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl @@ -28,10 +28,7 @@ using Distributions, Statistics #Statistics packages using CSV, DataFramesMeta #Data wrangling # ╔═╡ 9eb03a0b-c6ca-4e23-8109-fb68f87d7fdf -begin #Plotting backend - using StatsPlots - using StatsPlots.PlotMeasures -end +using CairoMakie, PairPlots, TimeSeries #Plotting backend # ╔═╡ 97b5374e-7653-4b3b-98eb-d8f73aa30580 using ReverseDiff #Automatic differentiation backend @@ -155,17 +152,19 @@ We can spaghetti plot generative samples from the AR(2) process with the priors plt_ar_sample = let n_samples = 100 ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _ - ar_mdl() #Sample Z_t trajectories for the model + ar_mdl() .|> exp #Sample Z_t trajectories for the model end - plot(ar_mdl_samples .|> exp, #R_t = exp(Z_t) - lab = "", - c = :grey, - alpha = 0.25, - title = "$(n_samples) draws from the prior Rₜ model", + fig = Figure() + ax = Axis(fig[1, 1]; + yscale = log10, ylabel = "Time varying Rₜ", - yticks = [10.0^n for n in -4:4], - yscale = :log10) + title = "$(n_samples) draws from the prior Rₜ model" + ) + for col in eachcol(ar_mdl_samples) + lines!(ax, col, color = (:grey, 0.1)) + end + fig end # ╔═╡ 9f84dec1-70f1-442e-8bef-a9494921549e @@ -208,14 +207,21 @@ We can compare the discretized generation interval with the continuous estimate, # ╔═╡ 71d08f7e-c409-4fbe-b154-b21d09010683 let - bar(model_data.gen_int, - fillalpha = 0.5, - lw = 0, - lab = "Discretized next gen pmf", + fig = Figure() + ax = Axis(fig[1, 1]; xticks = 0:14, xlabel = "Days", - title = "Continuous and discrete generation intervals") - plot!(truth_GI, lab = "Continuous serial interval") + title = "Continuous and discrete generation intervals" + ) + barplot!(ax, model_data.gen_int; + label = "Discretized next gen pmf" + ) + lines!(truth_GI; + label = "Continuous serial interval", + color = :green + ) + axislegend(ax) + fig end # ╔═╡ 4a2b5cf1-623c-4fe7-8365-49fb7972af5a @@ -259,24 +265,27 @@ latent_inf_mdl = generate_latent_infs(epi, log.(R_t_fixed)) # ╔═╡ 7a6d4b14-58d3-40c1-81f2-713c830f875f plt_epi = let n_samples = 100 + #Sample unconditionally the underlying parameters of the model epi_mdl_samples = mapreduce(hcat, 1:n_samples) do _ - latent_inf_mdl() #Sample unconditionally the underlying parameters of the model + latent_inf_mdl() end - - p1 = plot(epi_mdl_samples, - lab = "", - c = :grey, - alpha = 0.25, + fig = Figure() + ax1 = Axis(fig[1, 1]; title = "$(n_samples) draws from renewal model with chosen Rt", ylabel = "Latent infections" ) - p2 = plot(R_t_fixed, - lab = "", - lw = 2, + ax2 = Axis(fig[2, 1]; ylabel = "Rt" ) - - plot(p1, p2, layout = (2, 1)) + for col in eachcol(epi_mdl_samples) + lines!(ax1, col; + color = (:grey, 0.1) + ) + end + lines!(ax2, R_t_fixed; + linewidth = 2 + ) + fig end # ╔═╡ c8ef8a60-d087-4ae9-ae92-abeea5afc7ae @@ -296,7 +305,6 @@ A prior for $\phi$ was not specified in _Mishra et al_, we select one below but # ╔═╡ 714908a1-dc85-476f-a99f-ec5c95a78b60 obs = NegativeBinomialError(cluster_factor_prior = HalfNormal(0.1)) -# obs = PoissonError() # ╔═╡ dacb8094-89a4-404a-8243-525c0dbfa482 md" @@ -324,17 +332,23 @@ plt_obs = let obs_mdl_samples = mapreduce(hcat, 1:n_samples) do _ θ = obs_mdl() #Sample unconditionally the underlying parameters of the model end - scatter(obs_mdl_samples, - lab = "", - c = :grey, - alpha = 0.25, + fig = Figure() + ax = Axis(fig[1, 1]; title = "$(n_samples) draws from neg. bin. obs model", ylabel = "Observed cases" ) - plot!(expected_cases, - c = :red, - lw = 3, - lab = "Expected cases") + for col in eachcol(obs_mdl_samples) + scatter!(ax, col; + color = (:grey, 0.2) + ) + end + lines!(ax, expected_cases; + color = :red, + linewidth = 3, + label = "Expected cases" + ) + axislegend(ax) + fig end # ╔═╡ a06065e1-0e20-4cf8-8d5a-2d588da20bee @@ -473,9 +487,10 @@ let C = south_korea_data.y_t D = south_korea_data.dates - #Unconditional model for posterior predictive sampling - mdl_unconditional = generate_epiaware(epi_prob, (y_t = fill(missing, length(C)),)) | - (var"obs.cluster_factor" = fixed_cluster_factor,) + #Case unconditional model for posterior predictive sampling + mdl_unconditional = generate_epiaware(epi_prob, + (y_t = fill(missing, length(C)),) + ) | (var"obs.cluster_factor" = fixed_cluster_factor,) posterior_gens = generated_quantities(mdl_unconditional, inference_results.samples) #plotting quantiles @@ -486,30 +501,55 @@ let predicted_R_t = generated_quantiles( posterior_gens, :Z_t, qs; transformation = x -> exp.(x)) - #Plots - p1 = plot(D, predicted_y_t[:, 3], lw = 2, lab = "post. median", c = :purple) - plot!(p1, D, predicted_y_t[:, 2], fillrange = predicted_y_t[:, 4], - fillalpha = 0.5, lw = 0, c = :purple, lab = "50%") - plot!(p1, D, predicted_y_t[:, 1], fillrange = predicted_y_t[:, 5], - fillalpha = 0.2, lw = 0, c = :purple, lab = "95%") - - scatter!(p1, D, C, - lab = "Actual cases", - ylabel = "Daily Cases", - title = "Posterior predictive: Cases", - ylims = (-50, maximum(C) * 2), - c = :black + ts = D .|> d -> d - minimum(D) .|> d -> d.value + 1 + t_ticks = string.(D) + fig = Figure() + ax1 = Axis(fig[1, 1]; + ylabel = "Daily cases", + xticks = (ts[1:14:end], t_ticks[1:14:end]), + title = "Posterior predictive: Cases" + ) + ax2 = Axis(fig[2, 1]; + yscale = log10, + title = "Prediction: Reproduction number", + xticks = (ts[1:14:end], t_ticks[1:14:end]) ) + linkxaxes!(ax1, ax2) - p2 = plot(D, predicted_R_t[:, 3], lw = 2, lab = "post. median", c = :green, - yscale = :log10, title = "Prediction: Reproduction number") - plot!(p2, D, predicted_R_t[:, 2], fillrange = predicted_R_t[:, 4], - fillalpha = 0.5, lw = 0, c = :green, lab = "50%") - plot!(p2, D, predicted_R_t[:, 1], fillrange = predicted_R_t[:, 5], - fillalpha = 0.2, lw = 0, c = :green, lab = "95%") - hline!(p2, [1.0], lab = "Rt = 1", lw = 2, c = :blue) + lines!(ax1, ts, predicted_y_t[:, 3]; + color = :purple, + linewidth = 2, + label = "Post. median" + ) + band!(ax1, 1:size(predicted_y_t, 1), predicted_y_t[:, 2], predicted_y_t[:, 4]; + color = (:purple, 0.4), + label = "50%" + ) + band!(ax1, 1:size(predicted_y_t, 1), predicted_y_t[:, 1], predicted_y_t[:, 5]; + color = (:purple, 0.2), + label = "95%" + ) + scatter!(ax1, C; + color = :black, + label = "Actual cases") + axislegend(ax1) + + lines!(ax2, ts, predicted_R_t[:, 3]; + color = :green, + linewidth = 2, + label = "Post. median" + ) + band!(ax2, 1:size(predicted_R_t, 1), predicted_R_t[:, 2], predicted_R_t[:, 4]; + color = (:green, 0.4), + label = "50%" + ) + band!(ax2, 1:size(predicted_R_t, 1), predicted_R_t[:, 1], predicted_R_t[:, 5]; + color = (:green, 0.2), + label = "95%" + ) + axislegend(ax2) - plot(p1, p2, layout = (2, 1), size = (500, 700), left_margin = 5mm) + fig end # ╔═╡ c05ed977-7a89-4ac8-97be-7078d69fce9f @@ -521,51 +561,17 @@ We can interrogate the sampled chains directly from the `samples` field of the ` # ╔═╡ ff21c9ec-1581-405f-8db1-0f522b5bc296 let - p1 = histogram(inference_results.samples["latent.σ_AR"], - lab = "chain " .* string.([1 2 3 4]), - fillalpha = 0.4, - lw = 0, - norm = :pdf, - title = "Posterior dist: AR noise std") - plot!(p1, ar.std_prior, - lw = 3, - c = :black, - lab = "prior") - - p2 = histogram(inference_results.samples[:init_incidence], - lab = "chain " .* string.([1 2 3 4]), - fillalpha = 0.4, - lw = 0, - norm = :pdf, - title = "Posterior dist: log-initial incidence") - plot!(p2, epi.initialisation_prior, - lw = 3, - c = :black, - lab = "prior") - - p3 = histogram(inference_results.samples["latent.damp_AR[2]"], - lab = "chain " .* string.([1 2 3 4]), - fillalpha = 0.4, - lw = 0, - norm = :pdf, - title = "Posterior dist: rho_1") - plot!(p3, ar.damp_prior.v[2], - lw = 3, - c = :black, - lab = "prior") - - p4 = histogram(inference_results.samples["latent.damp_AR[1]"], - lab = "chain " .* string.([1 2 3 4]), - fillalpha = 0.4, - lw = 0, - norm = :pdf, - title = "Posterior dist: rho_2") - plot!(p4, ar.damp_prior.v[1], - lw = 3, - c = :black, - lab = "prior") - - plot(p1, p2, p3, p4, layout = (2, 2), size = (800, 600)) + sub_chn = inference_results.samples[inference_results.samples.name_map.parameters[[1:5; + end]]] + fig = pairplot(sub_chn) + lines!(fig[1, 1], ar.std_prior, label = "Prior") + lines!(fig[2, 2], ar.init_prior.v[1], label = "Prior") + lines!(fig[3, 3], ar.init_prior.v[2], label = "Prior") + lines!(fig[4, 4], ar.damp_prior.v[1], label = "Prior") + lines!(fig[5, 5], ar.damp_prior.v[2], label = "Prior") + lines!(fig[6, 6], epi.initialisation_prior, label = "Prior") + + fig end # ╔═╡ Cell order: From a1763f1a55612c387bbbd97e672f395c056234e4 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 27 Sep 2024 13:40:21 +0100 Subject: [PATCH 02/17] get_data script for English boarding school flu --- .../replications/chatzilena-2019/get_data.R | 8 ++++++++ .../influenza_england_1978_school.csv2 | 15 +++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 EpiAware/docs/src/showcase/replications/chatzilena-2019/get_data.R create mode 100644 EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2 diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/get_data.R b/EpiAware/docs/src/showcase/replications/chatzilena-2019/get_data.R new file mode 100644 index 000000000..0bc8b334f --- /dev/null +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/get_data.R @@ -0,0 +1,8 @@ +install.packages("outbreaks") +library(outbreaks) + +# Get data + +data <- outbreaks::influenza_england_1978_school +write.csv(data, + "EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv") diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2 b/EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2 new file mode 100644 index 000000000..1eb88145c --- /dev/null +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2 @@ -0,0 +1,15 @@ +"","date","in_bed","convalescent" +"1",1978-01-22,3,0 +"2",1978-01-23,8,0 +"3",1978-01-24,26,0 +"4",1978-01-25,76,0 +"5",1978-01-26,225,9 +"6",1978-01-27,298,17 +"7",1978-01-28,258,105 +"8",1978-01-29,233,162 +"9",1978-01-30,189,176 +"10",1978-01-31,128,166 +"11",1978-02-01,68,150 +"12",1978-02-02,29,85 +"13",1978-02-03,14,47 +"14",1978-02-04,4,20 From 422b8d97560408a6de98298de4e8a9a12631857f Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 1 Oct 2024 14:45:28 +0100 Subject: [PATCH 03/17] fixed broken inference --- EpiAware/docs/Project.toml | 2 + .../replications/chatzilena-2019/get_data.R | 4 +- .../replications/chatzilena-2019/index.jl | 410 ++++++++++++++++++ 3 files changed, 414 insertions(+), 2 deletions(-) create mode 100644 EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl diff --git a/EpiAware/docs/Project.toml b/EpiAware/docs/Project.toml index 819b6412f..8b9e975a6 100644 --- a/EpiAware/docs/Project.toml +++ b/EpiAware/docs/Project.toml @@ -10,10 +10,12 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da" Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781" PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TimeSeries = "9e3dc215-6440-5c97-bce1-76c03772f85e" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/get_data.R b/EpiAware/docs/src/showcase/replications/chatzilena-2019/get_data.R index 0bc8b334f..0bd32d59d 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/get_data.R +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/get_data.R @@ -4,5 +4,5 @@ library(outbreaks) # Get data data <- outbreaks::influenza_england_1978_school -write.csv(data, - "EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv") +write.csv(data, + "EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2") diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl new file mode 100644 index 000000000..e40bb2a44 --- /dev/null +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -0,0 +1,410 @@ +### A Pluto.jl notebook ### +# v0.19.46 + +using Markdown +using InteractiveUtils + +# ╔═╡ e34cec5a-a173-4e92-a860-340c7a9e9c72 +let + docs_dir = dirname(dirname(dirname(dirname(@__DIR__)))) + pkg_dir = dirname(docs_dir) + + using Pkg: Pkg + Pkg.activate(docs_dir) + Pkg.develop(; path = pkg_dir) + Pkg.instantiate() +end; + +# ╔═╡ b1468db3-7ab0-468c-8e27-70013a8f512f +using EpiAware + +# ╔═╡ a4710701-6315-459d-b677-f24b77ff3e80 +using Turing + +# ╔═╡ 7263d714-2ce4-4d57-8881-6b60db018dd5 +using OrdinaryDiffEq, SciMLSensitivity #ODE solvers and adjoint methods + +# ╔═╡ 261420cd-4650-402b-b126-7a431f93f37e +using Distributions, Statistics #Statistics packages + +# ╔═╡ 9c19a98b-a08b-4560-966d-61ff0ece2ad5 +using CSV, DataFramesMeta #Data wrangling + +# ╔═╡ 3897e773-ed07-4860-bb62-35605d0dacb0 +using CairoMakie, PairPlots + +# ╔═╡ 14641441-dbea-4fdf-88e0-64a57da60ef7 +using ReverseDiff #Automatic differentiation backend + +# ╔═╡ a0d91258-8ab5-4adc-98f2-8f17b4bd685c +begin #Date utility and set Random seed + using Dates + using Random + Random.seed!(1234) +end + +# ╔═╡ 33384fc6-7cca-11ef-3567-ab7df9200cde +md" +# Example: Contemporary statistical inference for infectious disease models +# Introduction +## What are we going to do in this Vignette +In this vignette, we'll demonstrate how to use `EpiAware` in conjunction with [SciML ecosystem](https://sciml.ai/) for Bayesian inference of infectious disease dynamics. The model and data is heavily based on [Contemporary statistical inference for infectious disease models using Stan _Chatzilena et al. 2019_](https://www.sciencedirect.com/science/article/pii/S1755436519300325). + +We'll cover the following key points: + +1. Defining the deterministic ODE model from Chatzilena et al section 2.2.2 using SciML ODE functionality and an `EpiAware` observation model. +2. Build on this to define the stochastic ODE model from Chatzilena et al section 2.2.3 using an `EpiAware` observation model. +3. Fitting the deterministic ODE model to data from an Influenza outbreak in an English boarding school. +4. Fitting the stochastic ODE model to data from an Influenza outbreak in an English boarding school. + +## What might I need to know before starting + +This vignette builds on concepts from `EpiAware` observation models and a familarity with the `SciML` and `Turing` ecosystems would be useful but not essential. + +## Packages used in this vignette + +Alongside the `EpiAware` package we will use the `OrdinaryDiffEq` package for interfacing with `SciML` ecosystem; this is a lower dependency usage of `DifferentialEquations.jl` that only exposes ODE solvers. Bayesian inference will be done with `NUTS` from the `Turing` ecosystem. We will also use the `CairoMakie` package for plotting and `DataFramesMeta` for data manipulation. +" + +# ╔═╡ 943b82ec-b4dc-4537-8183-d6c73cd74a37 +md" +# SIR models from _Chatzilena et al_ + +As mentioned in _Chatzilena et al_ disease spread is frequently modelled in terms +of ODE-based models. The study population is divided into compartments representing a specific stage of the epidemic status. In this case, susceptible, infected, and recovered individuals. + +```math +\begin{aligned} +{dS \over dt} &= - \beta \frac{I(t)}{N} S(t) \\ +{dI \over dt} &= \beta \frac{I(t)}{N} S(t) - \gamma I(t) \\ +{dR \over dt} &= \gamma I(t). \\ +\end{aligned} +``` +where S(t) represents the number of susceptible, I(t) the number of +infected and R(t) the number of recovered individuals at time t. The +total population size is denoted by N (with N = S(t) + I(t) + R(t)), β +denotes the transmission rate and γ denotes the recovery rate. +" + +# ╔═╡ ab4269b1-e292-466f-8bfb-713d917c18f9 +function sir!(du, u, p, t) + S, I, R = u + β, γ = p + du[1] = -β * I * S + du[2] = β * I * S - γ * I + du[3] = γ * I + + return nothing +end + +# ╔═╡ bb07a580-6d86-48b3-a79f-d2ed9306e87c +sir_prob = ODEProblem( + sir!, + [0.99, 0.01, 0.0], + (0.0, (Date(1978, 2, 4) - Date(1978, 1, 22)).value + 1), + [3.0, 2.0] +) + +# ╔═╡ aba3f1db-c290-409c-9b9e-6065935ede54 +N = 763 + +# ╔═╡ 7c9cbbc1-71ef-4d81-b93a-c2b3a8683d53 +url = "https://raw.githubusercontent.com/CDCgov/Rt-without-renewal/refs/heads/446-add-chatzilena-et-al-as-a-replication-example/EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2" + +# ╔═╡ eb247c93-1512-4927-9f39-ae408be0dc89 +data = CSV.read(download(url), DataFrame) + +# ╔═╡ 3f54bb44-76c4-4744-885a-46dedfaffeca +md" +## Deterministic SIR model + +" + +# ╔═╡ 87509792-e28d-4618-9bf5-e06b2e5dbe8b +obs = PoissonError() + +# ╔═╡ 1d287c8e-7000-4b23-ae7e-f7008c3e53bd +@model function deterministic_ode_mdl(Yt, obs, prob, N) + nobs = length(Yt) + + β ~ LogNormal(0.0, 1.0) + γ ~ Gamma(0.004, 1 / 0.002) + S₀ ~ Beta(0.5, 0.5) + + # try + _prob = remake(prob; + u0 = [S₀, 1 - S₀, 0.0], + p = [β, γ] + ) + + sol = solve(_prob, AutoTsit5(Rosenbrock23()); saveat = 1.0:nobs, verbose = false, sensealg = ForwardDiffSensitivity()) + λt = N * sol[2, :] .+ 1e-3 + + @submodel obsYt = generate_observations(obs, Yt, λt) + + return (; sol, obsYt, R0 = β / γ) + # catch + # Turing.@addlogprob! -Inf + # return + # end +end + +# ╔═╡ dbc1b453-1c29-4f82-bec9-098d67f9e63f +mdl = deterministic_ode_mdl(data.in_bed, obs, sir_prob, N) + +# ╔═╡ e795c2bf-0861-4e96-9921-db47f41af206 +uncond_mdl = deterministic_ode_mdl(fill(missing,length(data.in_bed)), obs, sir_prob, N) + +# ╔═╡ ba35cebd-0d29-43c5-8db7-f550d7f821bc +map_fit = map(1:10) do _ + fit = maximum_a_posteriori(mdl; + initial_params=[1, 0.1, 0.99], + ) +end |> +fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> +min_and_fits -> min_and_fits[2][min_and_fits[1]] + +# ╔═╡ 0be912c1-22dc-4978-b86a-84273062f5da +map_fit.optim_result.retcode + +# ╔═╡ 2cf64ba3-ff8d-40b0-9bd8-9e80393156f5 +chn = sample(mdl, NUTS(), MCMCThreads(), 1000, 4; initial_params=fill(map_fit.values.array,4)) + +# ╔═╡ 6d8a1903-ffcf-47a9-a02a-4ef77525f133 +map_fit.values + +# ╔═╡ b2429b68-dd75-499f-a4e1-1b7d72e209c7 +describe(chn) + +# ╔═╡ 1e7f37c5-4cb4-4d06-8f68-55d80f7a00ad +pairplot(chn) + +# ╔═╡ 03d1ecf8-543d-444d-b1a3-7a19acd88499 +let +ts = 1:size(data, 1) +gens = generated_quantities(uncond_mdl, chn) +fig = Figure() +ax = Axis(fig[1,1]; + title = "Fitted deterministic model", + xticks = (ts[1:3:end], data.date[1:3:end] .|> string), + ylabel = "Number of Infected students" + ) +pred_Yt = mapreduce(hcat, gens) do gen + gen.obsYt +end |> X -> mapreduce(vcat, eachrow(X)) do row + quantile(row, [0.5, 0.025, 0.975])' +end + +lines!(ax,ts, pred_Yt[:,1]; linewidth = 3, label = "Fitted deterministic model", color = :green) +band!(ax, ts, pred_Yt[:,2], pred_Yt[:,3], color = (:green, 0.5)) +scatter!(ax, data.in_bed) + +fig +end + +# ╔═╡ 506855ac-57f1-40cf-9ee1-c3097b9b554a + + +# ╔═╡ e023770d-25f7-4b7a-b509-8a4372f42b76 +md" +## Stochastic model +" + +# ╔═╡ 71a26408-1c26-46cf-bc72-c6ba528dfadd +ar = AR(HalfNormal(0.01), + HalfNormal(0.3), + Normal(0, 0.001) +) + +# ╔═╡ 178e0048-069a-4953-bb24-5116eb81cc41 +ϕs = rand(truncated(Normal(0,100), lower = 0.), 1000) + +# ╔═╡ e6bcf0c0-3cc4-41f3-ad20-fa11bf2ca37b +σs = rand(InverseGamma(0.1,0.1), 1000) .|> x -> 1/x + +# ╔═╡ f9c1bcd4-bfb4-45d4-ae06-f114a0923bd7 +mean(InverseGamma(0.1,0.1)) + +# ╔═╡ 4f07e8ba-30d0-411f-8c3e-b6d5bc1bb5fa +AR_damps = ϕs .|> ϕ -> exp(-ϕ) + +# ╔═╡ 7235289e-28f0-43c2-986b-81b96c42d9fe +mean(AR_damps) + +# ╔═╡ 48032d21-53fa-4c0a-85cb-c22327b55073 +AR_stds = zip(ϕs, σs) .|> ϕ_σ -> (1 - exp(-2*ϕ_σ[1])) * ϕ_σ[2] / (2 * ϕ_σ[1]) + +# ╔═╡ 4089aea2-3946-48b0-bf7c-dcdc73fe87fa +mean(AR_stds) + +# ╔═╡ ec63fd4b-4323-4a9e-9aa7-46ba4115ec4f + + +# ╔═╡ 2dcb4034-b138-4c3e-b65f-ba13f230439c +hist(AR_stds) + +# ╔═╡ 7271886d-2f87-4dc1-833b-182f4b726738 +# xs = rand(truncated(Normal(0,100), lower = 0.), 1000) .|> x -> exp(-x) +xs = rand(InverseGamma(1/0.1,1/0.1), 1000) + +# ╔═╡ 68b75d5b-2b45-44bd-a973-12cba31d0e53 + + +# ╔═╡ f0f02012-e0fe-4d11-a60a-dc27b6dd510c +density(xs) + +# ╔═╡ e15d0532-0c8a-4cd2-a576-567fc0c625c5 +gmdl = generate_latent(ar, 10) + +# ╔═╡ 0be4b20e-5f16-43dc-90f6-84a6f29ae8cc +gmdl() + +# ╔═╡ 9309f7f8-0896-4686-8bfc-b9f82d91bc0f +@model function stochastic_ode_mdl(Yt, logobsprob, obs, prob, N) + nobs = length(Yt) + + β ~ LogNormal(0.0, 1.0) + γ ~ Gamma(0.004, 1 / 0.002) + S₀ ~ Beta(0.5, 0.5) + + + # try + _prob = remake(prob; + u0 = [S₀, 1 - S₀, 0.0], + p = [β, γ] + ) + + sol = solve(_prob, AutoTsit5(Rosenbrock23()); + sensealg = ForwardDiffSensitivity(), + saveat = 1.0:nobs, verbose = false) + # μ = log.(N * sol[2, :]) + @submodel κ = generate_latent(logobsprob, nobs) + λt = @. N * sol[2, :] * exp(κ) + 0.1 + + @submodel obsYt = generate_observations(obs, Yt, λt) + + return (; sol, obsYt, R0 = β / γ) + # catch + # Turing.@addlogprob! -Inf + # return + # end +end + +# ╔═╡ 6dbd3935-dada-4cac-903e-2dec1a197304 + + +# ╔═╡ 4330c83f-de39-44c7-bdab-87e5f5830145 +mdl2 = stochastic_ode_mdl(data.in_bed, ar, obs, sir_prob, N) + +# ╔═╡ 8071c92f-9fe8-48cf-b1a0-79d1e34ec7e7 +uncond_mdl2 = stochastic_ode_mdl(fill(missing,length(data.in_bed)), ar, obs, sir_prob, N) + +# ╔═╡ bbe9a87a-a212-4d9d-9c75-8a863d6fb0be +rand(mdl2) + +# ╔═╡ d4502528-d058-4899-b3dd-576316116c18 +map_fit2 = map(1:10) do _ + fit = maximum_a_posteriori(mdl2; + initial_params=vcat([1, 0.1, 0.99, 0.01, 0., 0.01], zeros(13)), + adtype=AutoReverseDiff() + ) +end |> +fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> +min_and_fits -> min_and_fits[2][min_and_fits[1]] + +# ╔═╡ 6a246854-601b-4d5a-9fb8-52b0e1620e7d +mdl2() + +# ╔═╡ 156272d7-56c4-4ac4-bf3e-7882f4edc144 +chn2 = sample(mdl2, NUTS(; adtype = AutoReverseDiff(true)), MCMCThreads(), 1000, 4; initial_params = fill(map_fit2.values.array,4)) + +# ╔═╡ 00b90e6d-732f-41c9-a603-cabe9740e329 +describe(chn2) + +# ╔═╡ 37a016d8-8384-41c9-abdd-23e88b1f988d +pairplot(chn2[[:β, :γ, :S₀]]) + +# ╔═╡ 0e7bbf13-9187-41ea-8b46-294b93be4c6d +let +ts = 1:size(data, 1) +gens = generated_quantities(uncond_mdl2, chn2) +fig = Figure() +ax = Axis(fig[1,1]; + title = "Fitted Stochastic model", + xticks = (ts[1:3:end], data.date[1:3:end] .|> string), + ylabel = "Number of Infected students" + ) +pred_Yt = mapreduce(hcat, gens) do gen + gen.obsYt +end |> X -> mapreduce(vcat, eachrow(X)) do row + quantile(row, [0.5, 0.025, 0.975])' +end + +lines!(ax,ts, pred_Yt[:,1]; linewidth = 3, label = "Fitted deterministic model", color = :green) +band!(ax, ts, pred_Yt[:,2], pred_Yt[:,3], color = (:green, 0.5)) +scatter!(ax, data.in_bed) + +fig +end + +# ╔═╡ 36efe6e0-643f-42e6-9d64-de2f5a76b764 + + +# ╔═╡ Cell order: +# ╟─e34cec5a-a173-4e92-a860-340c7a9e9c72 +# ╠═33384fc6-7cca-11ef-3567-ab7df9200cde +# ╠═b1468db3-7ab0-468c-8e27-70013a8f512f +# ╠═a4710701-6315-459d-b677-f24b77ff3e80 +# ╠═7263d714-2ce4-4d57-8881-6b60db018dd5 +# ╠═261420cd-4650-402b-b126-7a431f93f37e +# ╠═9c19a98b-a08b-4560-966d-61ff0ece2ad5 +# ╠═3897e773-ed07-4860-bb62-35605d0dacb0 +# ╠═14641441-dbea-4fdf-88e0-64a57da60ef7 +# ╠═a0d91258-8ab5-4adc-98f2-8f17b4bd685c +# ╠═943b82ec-b4dc-4537-8183-d6c73cd74a37 +# ╠═ab4269b1-e292-466f-8bfb-713d917c18f9 +# ╠═bb07a580-6d86-48b3-a79f-d2ed9306e87c +# ╠═aba3f1db-c290-409c-9b9e-6065935ede54 +# ╠═7c9cbbc1-71ef-4d81-b93a-c2b3a8683d53 +# ╠═eb247c93-1512-4927-9f39-ae408be0dc89 +# ╠═3f54bb44-76c4-4744-885a-46dedfaffeca +# ╠═87509792-e28d-4618-9bf5-e06b2e5dbe8b +# ╠═1d287c8e-7000-4b23-ae7e-f7008c3e53bd +# ╠═dbc1b453-1c29-4f82-bec9-098d67f9e63f +# ╠═e795c2bf-0861-4e96-9921-db47f41af206 +# ╠═ba35cebd-0d29-43c5-8db7-f550d7f821bc +# ╠═0be912c1-22dc-4978-b86a-84273062f5da +# ╠═2cf64ba3-ff8d-40b0-9bd8-9e80393156f5 +# ╠═6d8a1903-ffcf-47a9-a02a-4ef77525f133 +# ╠═b2429b68-dd75-499f-a4e1-1b7d72e209c7 +# ╠═1e7f37c5-4cb4-4d06-8f68-55d80f7a00ad +# ╠═03d1ecf8-543d-444d-b1a3-7a19acd88499 +# ╠═506855ac-57f1-40cf-9ee1-c3097b9b554a +# ╠═e023770d-25f7-4b7a-b509-8a4372f42b76 +# ╠═71a26408-1c26-46cf-bc72-c6ba528dfadd +# ╠═178e0048-069a-4953-bb24-5116eb81cc41 +# ╠═e6bcf0c0-3cc4-41f3-ad20-fa11bf2ca37b +# ╠═f9c1bcd4-bfb4-45d4-ae06-f114a0923bd7 +# ╠═4f07e8ba-30d0-411f-8c3e-b6d5bc1bb5fa +# ╠═7235289e-28f0-43c2-986b-81b96c42d9fe +# ╠═48032d21-53fa-4c0a-85cb-c22327b55073 +# ╠═4089aea2-3946-48b0-bf7c-dcdc73fe87fa +# ╠═ec63fd4b-4323-4a9e-9aa7-46ba4115ec4f +# ╠═2dcb4034-b138-4c3e-b65f-ba13f230439c +# ╠═7271886d-2f87-4dc1-833b-182f4b726738 +# ╠═68b75d5b-2b45-44bd-a973-12cba31d0e53 +# ╠═f0f02012-e0fe-4d11-a60a-dc27b6dd510c +# ╠═e15d0532-0c8a-4cd2-a576-567fc0c625c5 +# ╠═0be4b20e-5f16-43dc-90f6-84a6f29ae8cc +# ╠═9309f7f8-0896-4686-8bfc-b9f82d91bc0f +# ╠═6dbd3935-dada-4cac-903e-2dec1a197304 +# ╠═4330c83f-de39-44c7-bdab-87e5f5830145 +# ╠═8071c92f-9fe8-48cf-b1a0-79d1e34ec7e7 +# ╠═bbe9a87a-a212-4d9d-9c75-8a863d6fb0be +# ╠═d4502528-d058-4899-b3dd-576316116c18 +# ╠═6a246854-601b-4d5a-9fb8-52b0e1620e7d +# ╠═156272d7-56c4-4ac4-bf3e-7882f4edc144 +# ╠═00b90e6d-732f-41c9-a603-cabe9740e329 +# ╠═37a016d8-8384-41c9-abdd-23e88b1f988d +# ╠═0e7bbf13-9187-41ea-8b46-294b93be4c6d +# ╠═36efe6e0-643f-42e6-9d64-de2f5a76b764 From 0aaba4be8016bedba6703aede94bcbed8b5dce2c Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Tue, 1 Oct 2024 19:41:48 +0100 Subject: [PATCH 04/17] Update index.jl --- .../replications/chatzilena-2019/index.jl | 227 +++++++++++++----- 1 file changed, 164 insertions(+), 63 deletions(-) diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl index e40bb2a44..94c2d3fbc 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -45,7 +45,7 @@ end # ╔═╡ 33384fc6-7cca-11ef-3567-ab7df9200cde md" -# Example: Contemporary statistical inference for infectious disease models +# Example: Statistical inference for ODE-based infectious disease models # Introduction ## What are we going to do in this Vignette In this vignette, we'll demonstrate how to use `EpiAware` in conjunction with [SciML ecosystem](https://sciml.ai/) for Bayesian inference of infectious disease dynamics. The model and data is heavily based on [Contemporary statistical inference for infectious disease models using Stan _Chatzilena et al. 2019_](https://www.sciencedirect.com/science/article/pii/S1755436519300325). @@ -63,12 +63,12 @@ This vignette builds on concepts from `EpiAware` observation models and a famila ## Packages used in this vignette -Alongside the `EpiAware` package we will use the `OrdinaryDiffEq` package for interfacing with `SciML` ecosystem; this is a lower dependency usage of `DifferentialEquations.jl` that only exposes ODE solvers. Bayesian inference will be done with `NUTS` from the `Turing` ecosystem. We will also use the `CairoMakie` package for plotting and `DataFramesMeta` for data manipulation. +Alongside the `EpiAware` package we will use the `OrdinaryDiffEq` and `SciMLSensitivity` packages for interfacing with `SciML` ecosystem; this is a lower dependency usage of `DifferentialEquations.jl` that, respectively, exposes ODE solvers and adjoint methods for ODE solvees; that is the method of propagating parameter derivatives through functions containing ODE solutions. Bayesian inference will be done with `NUTS` from the `Turing` ecosystem. We will also use the `CairoMakie` package for plotting and `DataFramesMeta` for data manipulation. " # ╔═╡ 943b82ec-b4dc-4537-8183-d6c73cd74a37 md" -# SIR models from _Chatzilena et al_ +# Single population SIR model As mentioned in _Chatzilena et al_ disease spread is frequently modelled in terms of ODE-based models. The study population is divided into compartments representing a specific stage of the epidemic status. In this case, susceptible, infected, and recovered individuals. @@ -84,6 +84,20 @@ where S(t) represents the number of susceptible, I(t) the number of infected and R(t) the number of recovered individuals at time t. The total population size is denoted by N (with N = S(t) + I(t) + R(t)), β denotes the transmission rate and γ denotes the recovery rate. + +" + +# ╔═╡ 0e78285c-d2e8-4c3c-848a-14dae6ead0a4 +md" +We can interface to the `SciML` ecosystem by writing a function with the signature: + +> `(du, u, p, t) -> nothing` + +Where: +- `du` is the _vector field_ of the ODE problem, e.g. ${dS \over dt}$, ${dI \over dt}$ etc. This is calculated _in-place_. +- `u` is the _state_ of the ODE problem, e.g. $S$, $I$, etc. +- `p` is an object that represents the parameters of the ODE problem, e.g. $\beta$, $\gamma$. +- `t` is the time of the ODE problem. " # ╔═╡ ab4269b1-e292-466f-8bfb-713d917c18f9 @@ -97,22 +111,35 @@ function sir!(du, u, p, t) return nothing end +# ╔═╡ f16eb00b-2d77-45df-b767-757fe2f5674c +md" +We combine the function defining the vector field with a initial condition `u0` and the integration period `tspan` to make an `ODEProblem`. We do not define the parameters, these will be defined within an inference approach. +" + # ╔═╡ bb07a580-6d86-48b3-a79f-d2ed9306e87c sir_prob = ODEProblem( sir!, [0.99, 0.01, 0.0], - (0.0, (Date(1978, 2, 4) - Date(1978, 1, 22)).value + 1), - [3.0, 2.0] + (0.0, (Date(1978, 2, 4) - Date(1978, 1, 22)).value + 1) ) -# ╔═╡ aba3f1db-c290-409c-9b9e-6065935ede54 -N = 763 +# ╔═╡ d64388f9-6edd-414d-a191-316f75b35b2c +md" + +## Data for inference + +There was a brief, but intense, outbreak of Influenza within the (semi-) closed community of a boarding school reported to the British medical journal in 1978. The outbreak lasted from 22nd January to 4th February and it is reported that one infected child started the epidemic and then it spread rapidly. Of the 763 children at the boarding scholl, 512 became ill. + +We downloaded the data of this outbreak using the R package `outbreaks` which is maintained as part of the [R Epidemics Consortium(RECON)](http://www. repidemicsconsortium.org). + +" # ╔═╡ 7c9cbbc1-71ef-4d81-b93a-c2b3a8683d53 -url = "https://raw.githubusercontent.com/CDCgov/Rt-without-renewal/refs/heads/446-add-chatzilena-et-al-as-a-replication-example/EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2" +data = "https://raw.githubusercontent.com/CDCgov/Rt-without-renewal/refs/heads/446-add-chatzilena-et-al-as-a-replication-example/EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2" |> + url -> CSV.read(download(url), DataFrame) -# ╔═╡ eb247c93-1512-4927-9f39-ae408be0dc89 -data = CSV.read(download(url), DataFrame) +# ╔═╡ aba3f1db-c290-409c-9b9e-6065935ede54 +N = 763 # ╔═╡ 3f54bb44-76c4-4744-885a-46dedfaffeca md" @@ -132,17 +159,18 @@ obs = PoissonError() S₀ ~ Beta(0.5, 0.5) # try - _prob = remake(prob; - u0 = [S₀, 1 - S₀, 0.0], - p = [β, γ] - ) + _prob = remake(prob; + u0 = [S₀, 1 - S₀, 0.0], + p = [β, γ] + ) - sol = solve(_prob, AutoTsit5(Rosenbrock23()); saveat = 1.0:nobs, verbose = false, sensealg = ForwardDiffSensitivity()) - λt = N * sol[2, :] .+ 1e-3 + sol = solve(_prob, AutoTsit5(Rosenbrock23()); saveat = 1.0:nobs, + verbose = false, sensealg = ForwardDiffSensitivity()) + λt = N * sol[2, :] .+ 1e-3 - @submodel obsYt = generate_observations(obs, Yt, λt) + @submodel obsYt = generate_observations(obs, Yt, λt) - return (; sol, obsYt, R0 = β / γ) + return (; sol, obsYt, R0 = β / γ) # catch # Turing.@addlogprob! -Inf # return @@ -153,22 +181,23 @@ end mdl = deterministic_ode_mdl(data.in_bed, obs, sir_prob, N) # ╔═╡ e795c2bf-0861-4e96-9921-db47f41af206 -uncond_mdl = deterministic_ode_mdl(fill(missing,length(data.in_bed)), obs, sir_prob, N) +uncond_mdl = deterministic_ode_mdl(fill(missing, length(data.in_bed)), obs, sir_prob, N) # ╔═╡ ba35cebd-0d29-43c5-8db7-f550d7f821bc map_fit = map(1:10) do _ - fit = maximum_a_posteriori(mdl; - initial_params=[1, 0.1, 0.99], - ) + fit = maximum_a_posteriori(mdl; + initial_params = [1, 0.1, 0.99] + ) end |> -fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> -min_and_fits -> min_and_fits[2][min_and_fits[1]] + fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> + min_and_fits -> min_and_fits[2][min_and_fits[1]] # ╔═╡ 0be912c1-22dc-4978-b86a-84273062f5da map_fit.optim_result.retcode # ╔═╡ 2cf64ba3-ff8d-40b0-9bd8-9e80393156f5 -chn = sample(mdl, NUTS(), MCMCThreads(), 1000, 4; initial_params=fill(map_fit.values.array,4)) +chn = sample( + mdl, NUTS(), MCMCThreads(), 1000, 4; initial_params = fill(map_fit.values.array, 4)) # ╔═╡ 6d8a1903-ffcf-47a9-a02a-4ef77525f133 map_fit.values @@ -181,29 +210,90 @@ pairplot(chn) # ╔═╡ 03d1ecf8-543d-444d-b1a3-7a19acd88499 let -ts = 1:size(data, 1) -gens = generated_quantities(uncond_mdl, chn) -fig = Figure() -ax = Axis(fig[1,1]; - title = "Fitted deterministic model", - xticks = (ts[1:3:end], data.date[1:3:end] .|> string), - ylabel = "Number of Infected students" - ) -pred_Yt = mapreduce(hcat, gens) do gen - gen.obsYt -end |> X -> mapreduce(vcat, eachrow(X)) do row - quantile(row, [0.5, 0.025, 0.975])' + ts = 1:size(data, 1) + gens = generated_quantities(uncond_mdl, chn) + fig = Figure() + ax = Axis(fig[1, 1]; + title = "Fitted deterministic model", + xticks = (ts[1:3:end], data.date[1:3:end] .|> string), + ylabel = "Number of Infected students" + ) + pred_Yt = mapreduce(hcat, gens) do gen + gen.obsYt + end |> X -> mapreduce(vcat, eachrow(X)) do row + quantile(row, [0.5, 0.025, 0.975])' + end + + lines!(ax, ts, pred_Yt[:, 1]; linewidth = 3, + label = "Fitted deterministic model", color = :green) + band!(ax, ts, pred_Yt[:, 2], pred_Yt[:, 3], color = (:green, 0.5)) + scatter!(ax, data.in_bed) + + fig end -lines!(ax,ts, pred_Yt[:,1]; linewidth = 3, label = "Fitted deterministic model", color = :green) -band!(ax, ts, pred_Yt[:,2], pred_Yt[:,3], color = (:green, 0.5)) -scatter!(ax, data.in_bed) +# ╔═╡ 506855ac-57f1-40cf-9ee1-c3097b9b554a +@model function deterministic_ode_mdl2(Yt, obs, prob, N) + nobs = length(Yt) -fig + β ~ LogNormal(0.0, 1.0) + γ ~ Gamma(0.004, 1 / 0.002) + S₀ ~ Beta(0.5, 0.5) + + # try + _prob = remake(prob; + u0 = [S₀, 1 / N, 1 - S₀], + p = [β, γ] + ) + + sol = solve(_prob, AutoTsit5(Rosenbrock23()); saveat = 1.0:nobs, + verbose = false, sensealg = ForwardDiffSensitivity()) + λt = N * sol[2, :] .+ 1e-3 + + @submodel obsYt = generate_observations(obs, Yt, λt) + + return (; sol, obsYt, R0 = β / γ) + # catch + # Turing.@addlogprob! -Inf + # return + # end end -# ╔═╡ 506855ac-57f1-40cf-9ee1-c3097b9b554a +# ╔═╡ 019f0d55-d1e6-43a0-88fc-d1d5a7d9334b +mdl3 = deterministic_ode_mdl2(data.in_bed, obs, sir_prob, N) + +# ╔═╡ 466ac63c-9d79-4906-8b97-f7c70eee66d9 +uncond_mdl3 = deterministic_ode_mdl2(fill(missing, length(data.in_bed)), obs, sir_prob, N) + +# ╔═╡ 62cf0ec1-0e5f-457e-882e-9282553680e8 +chn3 = sample(mdl3, NUTS(), MCMCThreads(), 1000, 4) + +# ╔═╡ 7a68937a-afe4-48f6-8e69-6e318bb03887 +let + ts = 1:size(data, 1) + gens = generated_quantities(uncond_mdl3, chn3) + fig = Figure() + ax = Axis(fig[1, 1]; + title = "Fitted deterministic model", + xticks = (ts[1:3:end], data.date[1:3:end] .|> string), + ylabel = "Number of Infected students" + ) + pred_Yt = mapreduce(hcat, gens) do gen + gen.obsYt + end |> X -> mapreduce(vcat, eachrow(X)) do row + quantile(row, [0.5, 0.025, 0.975])' + end + + lines!(ax, ts, pred_Yt[:, 1]; linewidth = 3, + label = "Fitted deterministic model", color = :green) + band!(ax, ts, pred_Yt[:, 2], pred_Yt[:, 3], color = (:green, 0.5)) + scatter!(ax, data.in_bed) + + fig +end +# ╔═╡ 19be6d10-342f-4fbd-a1ad-053ed9d1f039 +describe(chn3) # ╔═╡ e023770d-25f7-4b7a-b509-8a4372f42b76 md" @@ -217,13 +307,13 @@ ar = AR(HalfNormal(0.01), ) # ╔═╡ 178e0048-069a-4953-bb24-5116eb81cc41 -ϕs = rand(truncated(Normal(0,100), lower = 0.), 1000) +ϕs = rand(truncated(Normal(0, 100), lower = 0.0), 1000) # ╔═╡ e6bcf0c0-3cc4-41f3-ad20-fa11bf2ca37b -σs = rand(InverseGamma(0.1,0.1), 1000) .|> x -> 1/x +σs = rand(InverseGamma(0.1, 0.1), 1000) .|> x -> 1 / x # ╔═╡ f9c1bcd4-bfb4-45d4-ae06-f114a0923bd7 -mean(InverseGamma(0.1,0.1)) +mean(InverseGamma(0.1, 0.1)) # ╔═╡ 4f07e8ba-30d0-411f-8c3e-b6d5bc1bb5fa AR_damps = ϕs .|> ϕ -> exp(-ϕ) @@ -232,24 +322,22 @@ AR_damps = ϕs .|> ϕ -> exp(-ϕ) mean(AR_damps) # ╔═╡ 48032d21-53fa-4c0a-85cb-c22327b55073 -AR_stds = zip(ϕs, σs) .|> ϕ_σ -> (1 - exp(-2*ϕ_σ[1])) * ϕ_σ[2] / (2 * ϕ_σ[1]) +AR_stds = zip(ϕs, σs) .|> ϕ_σ -> (1 - exp(-2 * ϕ_σ[1])) * ϕ_σ[2] / (2 * ϕ_σ[1]) # ╔═╡ 4089aea2-3946-48b0-bf7c-dcdc73fe87fa mean(AR_stds) # ╔═╡ ec63fd4b-4323-4a9e-9aa7-46ba4115ec4f - # ╔═╡ 2dcb4034-b138-4c3e-b65f-ba13f230439c hist(AR_stds) # ╔═╡ 7271886d-2f87-4dc1-833b-182f4b726738 # xs = rand(truncated(Normal(0,100), lower = 0.), 1000) .|> x -> exp(-x) -xs = rand(InverseGamma(1/0.1,1/0.1), 1000) +xs = rand(InverseGamma(1 / 0.1, 1 / 0.1), 1000) # ╔═╡ 68b75d5b-2b45-44bd-a973-12cba31d0e53 - # ╔═╡ f0f02012-e0fe-4d11-a60a-dc27b6dd510c density(xs) @@ -267,7 +355,6 @@ gmdl() γ ~ Gamma(0.004, 1 / 0.002) S₀ ~ Beta(0.5, 0.5) - # try _prob = remake(prob; u0 = [S₀, 1 - S₀, 0.0], @@ -275,8 +362,8 @@ gmdl() ) sol = solve(_prob, AutoTsit5(Rosenbrock23()); - sensealg = ForwardDiffSensitivity(), - saveat = 1.0:nobs, verbose = false) + sensealg = ForwardDiffSensitivity(), + saveat = 1.0:nobs, verbose = false) # μ = log.(N * sol[2, :]) @submodel κ = generate_latent(logobsprob, nobs) λt = @. N * sol[2, :] * exp(κ) + 0.1 @@ -292,39 +379,46 @@ end # ╔═╡ 6dbd3935-dada-4cac-903e-2dec1a197304 - # ╔═╡ 4330c83f-de39-44c7-bdab-87e5f5830145 mdl2 = stochastic_ode_mdl(data.in_bed, ar, obs, sir_prob, N) # ╔═╡ 8071c92f-9fe8-48cf-b1a0-79d1e34ec7e7 -uncond_mdl2 = stochastic_ode_mdl(fill(missing,length(data.in_bed)), ar, obs, sir_prob, N) +uncond_mdl2 = stochastic_ode_mdl(fill(missing, length(data.in_bed)), ar, obs, sir_prob, N) # ╔═╡ bbe9a87a-a212-4d9d-9c75-8a863d6fb0be rand(mdl2) # ╔═╡ d4502528-d058-4899-b3dd-576316116c18 map_fit2 = map(1:10) do _ - fit = maximum_a_posteriori(mdl2; - initial_params=vcat([1, 0.1, 0.99, 0.01, 0., 0.01], zeros(13)), - adtype=AutoReverseDiff() - ) + fit = maximum_a_posteriori(mdl2; + initial_params = vcat([1, 0.1, 0.99, 0.01, 0.0, 0.01], zeros(13)), + adtype = AutoReverseDiff() + ) end |> -fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> -min_and_fits -> min_and_fits[2][min_and_fits[1]] + fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> + min_and_fits -> min_and_fits[2][min_and_fits[1]] # ╔═╡ 6a246854-601b-4d5a-9fb8-52b0e1620e7d mdl2() # ╔═╡ 156272d7-56c4-4ac4-bf3e-7882f4edc144 +# ╠═╡ disabled = true +#=╠═╡ chn2 = sample(mdl2, NUTS(; adtype = AutoReverseDiff(true)), MCMCThreads(), 1000, 4; initial_params = fill(map_fit2.values.array,4)) + ╠═╡ =# # ╔═╡ 00b90e6d-732f-41c9-a603-cabe9740e329 +#=╠═╡ describe(chn2) + ╠═╡ =# # ╔═╡ 37a016d8-8384-41c9-abdd-23e88b1f988d +#=╠═╡ pairplot(chn2[[:β, :γ, :S₀]]) + ╠═╡ =# # ╔═╡ 0e7bbf13-9187-41ea-8b46-294b93be4c6d +#=╠═╡ let ts = 1:size(data, 1) gens = generated_quantities(uncond_mdl2, chn2) @@ -346,10 +440,10 @@ scatter!(ax, data.in_bed) fig end + ╠═╡ =# # ╔═╡ 36efe6e0-643f-42e6-9d64-de2f5a76b764 - # ╔═╡ Cell order: # ╟─e34cec5a-a173-4e92-a860-340c7a9e9c72 # ╠═33384fc6-7cca-11ef-3567-ab7df9200cde @@ -362,11 +456,13 @@ end # ╠═14641441-dbea-4fdf-88e0-64a57da60ef7 # ╠═a0d91258-8ab5-4adc-98f2-8f17b4bd685c # ╠═943b82ec-b4dc-4537-8183-d6c73cd74a37 +# ╟─0e78285c-d2e8-4c3c-848a-14dae6ead0a4 # ╠═ab4269b1-e292-466f-8bfb-713d917c18f9 +# ╟─f16eb00b-2d77-45df-b767-757fe2f5674c # ╠═bb07a580-6d86-48b3-a79f-d2ed9306e87c -# ╠═aba3f1db-c290-409c-9b9e-6065935ede54 +# ╟─d64388f9-6edd-414d-a191-316f75b35b2c # ╠═7c9cbbc1-71ef-4d81-b93a-c2b3a8683d53 -# ╠═eb247c93-1512-4927-9f39-ae408be0dc89 +# ╠═aba3f1db-c290-409c-9b9e-6065935ede54 # ╠═3f54bb44-76c4-4744-885a-46dedfaffeca # ╠═87509792-e28d-4618-9bf5-e06b2e5dbe8b # ╠═1d287c8e-7000-4b23-ae7e-f7008c3e53bd @@ -380,6 +476,11 @@ end # ╠═1e7f37c5-4cb4-4d06-8f68-55d80f7a00ad # ╠═03d1ecf8-543d-444d-b1a3-7a19acd88499 # ╠═506855ac-57f1-40cf-9ee1-c3097b9b554a +# ╠═019f0d55-d1e6-43a0-88fc-d1d5a7d9334b +# ╠═466ac63c-9d79-4906-8b97-f7c70eee66d9 +# ╠═62cf0ec1-0e5f-457e-882e-9282553680e8 +# ╠═7a68937a-afe4-48f6-8e69-6e318bb03887 +# ╠═19be6d10-342f-4fbd-a1ad-053ed9d1f039 # ╠═e023770d-25f7-4b7a-b509-8a4372f42b76 # ╠═71a26408-1c26-46cf-bc72-c6ba528dfadd # ╠═178e0048-069a-4953-bb24-5116eb81cc41 From ff9c3dcac9198096dd36f4ba08b12f4507b2eedb Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 2 Oct 2024 01:56:10 +0100 Subject: [PATCH 05/17] Update index.jl --- .../replications/chatzilena-2019/index.jl | 412 ++++++++++-------- 1 file changed, 232 insertions(+), 180 deletions(-) diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl index 94c2d3fbc..8c993d400 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -98,6 +98,8 @@ Where: - `u` is the _state_ of the ODE problem, e.g. $S$, $I$, etc. - `p` is an object that represents the parameters of the ODE problem, e.g. $\beta$, $\gamma$. - `t` is the time of the ODE problem. + +We do this for the SIR model described above in a function called `sir!`: " # ╔═╡ ab4269b1-e292-466f-8bfb-713d917c18f9 @@ -113,16 +115,9 @@ end # ╔═╡ f16eb00b-2d77-45df-b767-757fe2f5674c md" -We combine the function defining the vector field with a initial condition `u0` and the integration period `tspan` to make an `ODEProblem`. We do not define the parameters, these will be defined within an inference approach. +We combine vector field function `sir!` with a initial condition `u0` and the integration period `tspan` to make an `ODEProblem`. We do not define the parameters, these will be defined within an inference approach. " -# ╔═╡ bb07a580-6d86-48b3-a79f-d2ed9306e87c -sir_prob = ODEProblem( - sir!, - [0.99, 0.01, 0.0], - (0.0, (Date(1978, 2, 4) - Date(1978, 1, 22)).value + 1) -) - # ╔═╡ d64388f9-6edd-414d-a191-316f75b35b2c md" @@ -136,71 +131,186 @@ We downloaded the data of this outbreak using the R package `outbreaks` which is # ╔═╡ 7c9cbbc1-71ef-4d81-b93a-c2b3a8683d53 data = "https://raw.githubusercontent.com/CDCgov/Rt-without-renewal/refs/heads/446-add-chatzilena-et-al-as-a-replication-example/EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2" |> - url -> CSV.read(download(url), DataFrame) + url -> CSV.read(download(url), DataFrame) |> + df -> @transform(df, + :ts = (:date .- minimum(:date)) .|> d -> d.value + 1.0, + ) # ╔═╡ aba3f1db-c290-409c-9b9e-6065935ede54 -N = 763 +N = 763; + +# ╔═╡ bb07a580-6d86-48b3-a79f-d2ed9306e87c +sir_prob = ODEProblem( + sir!, + N .* [0.99, 0.01, 0.0], + (0.0, (Date(1978, 2, 4) - Date(1978, 1, 22)).value + 1) +) # ╔═╡ 3f54bb44-76c4-4744-885a-46dedfaffeca md" -## Deterministic SIR model +## Inference for the deterministic SIR model + +The boarding school data gives the number of children \"in bed\" and \"convalescent\" on each of 14 days from 22nd Jan to 4th Feb 1978. We follow _Chatzilena et al_ and treat the number \"in bed\" as a proxy for the number of children in the infectious (I) compartment in the ODE model. +The full observation model is: + +```math +\begin{aligned} +Y_t &\sim \text{Poisson}(\lambda_t)\\ +\lambda_t &= I(t)\\ +\beta &\sim \text{LogNormal}(\text{logmean}=0,\text{logstd}=1) \\ +\gamma & \sim \text{Gamma}(\text{shape} = 0.004, \text{scale} = 50)\\ +S(0) /N &\sim \text{Beta}(0.5, 0.5). +\end{aligned} +``` + +**NB: Chatzilena et al give $\lambda_t = \int_0^t \beta \frac{I(s)}{N} S(s) - \gamma I(s)ds = I(t) - I(0).$ However, this doesn't match their underlying stan code.** +" + +# ╔═╡ ea1be94b-d722-47ee-8465-982c83dc6838 +md" +From `EpiAware`, we have the `PoissonError` struct which defines the probabilistic structure of this observation error model. " # ╔═╡ 87509792-e28d-4618-9bf5-e06b2e5dbe8b obs = PoissonError() -# ╔═╡ 1d287c8e-7000-4b23-ae7e-f7008c3e53bd -@model function deterministic_ode_mdl(Yt, obs, prob, N) - nobs = length(Yt) +# ╔═╡ 81501c84-5e1f-4829-a26d-52fe00503958 +md" +Now we can write the observation model using the `Turing` PPL. +" +# ╔═╡ 1d287c8e-7000-4b23-ae7e-f7008c3e53bd +@model function deterministic_ode_mdl(Yt, ts, obs, prob, N; + solver = AutoTsit5(Rosenbrock23()), + upjitter = 1e-3 +) + ##Priors## β ~ LogNormal(0.0, 1.0) γ ~ Gamma(0.004, 1 / 0.002) S₀ ~ Beta(0.5, 0.5) - # try + ##remake ODE model## _prob = remake(prob; u0 = [S₀, 1 - S₀, 0.0], p = [β, γ] ) - sol = solve(_prob, AutoTsit5(Rosenbrock23()); saveat = 1.0:nobs, - verbose = false, sensealg = ForwardDiffSensitivity()) - λt = N * sol[2, :] .+ 1e-3 + ##Solve remade ODE model## + + sol = solve(_prob, solver; + saveat = ts, + verbose = false) + ##log-like accumulation using obs## + λt = N * sol[2, :] .+ upjitter #expected It @submodel obsYt = generate_observations(obs, Yt, λt) + ##Generated quantities## return (; sol, obsYt, R0 = β / γ) - # catch - # Turing.@addlogprob! -Inf - # return - # end end +# ╔═╡ e7383885-fa6a-4240-a252-44ae82cae713 +md" +We instantiate the model in two ways: + +1. `deterministic_mdl`: This conditions the generative model on the data observation. We can sample from this model to find the posterior distribution of the parameters. +2. `deterministic_uncond_mdl`: This _doesn't_ condition on the data. This is useful for prior and posterior predictive modelling. +" + # ╔═╡ dbc1b453-1c29-4f82-bec9-098d67f9e63f -mdl = deterministic_ode_mdl(data.in_bed, obs, sir_prob, N) +deterministic_mdl = deterministic_ode_mdl(data.in_bed, data.ts, obs, sir_prob, N); # ╔═╡ e795c2bf-0861-4e96-9921-db47f41af206 -uncond_mdl = deterministic_ode_mdl(fill(missing, length(data.in_bed)), obs, sir_prob, N) +deterministic_uncond_mdl = deterministic_ode_mdl(fill(missing, length(data.in_bed)), data.ts, obs, sir_prob, N); + +# ╔═╡ e848434c-2543-43d1-ae22-5c4241f138bb +md" +We add a useful plotting utility. +" + +# ╔═╡ ab8c98d1-d357-4c49-9f5a-f069e05c45f5 +function plot_predYt(data, gens; title::String, ylabel::String) + fig = Figure() + ga = fig[1, 1:2] = GridLayout() + + ax = Axis(ga[1, 1]; + title = title, + xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string), + ylabel = ylabel, + ) + pred_Yt = mapreduce(hcat, gens) do gen + gen.obsYt + end |> X -> mapreduce(vcat, eachrow(X)) do row + quantile(row, [0.5, 0.025, 0.975, 0.1, 0.9, 0.25, 0.75])' + end + + lines!(ax, data.ts, pred_Yt[:, 1]; linewidth = 3, color = :green, label = "Median") + band!(ax, data.ts, pred_Yt[:, 2], pred_Yt[:, 3], color = (:green, 0.2), label = "95% CI") + band!(ax, data.ts, pred_Yt[:, 4], pred_Yt[:, 5], color = (:green, 0.4), label = "80% CI") + band!(ax, data.ts, pred_Yt[:, 6], pred_Yt[:, 7], color = (:green, 0.6), label = "50% CI") + scatter!(ax, data.in_bed, label = "data") + leg = Legend(ga[1, 2], ax; framevisible = false) + hidespines!(ax) + + fig +end + +# ╔═╡ 2c6ac235-e331-4189-8c8c-74de5f98b2c4 +md" +**Prior predictive sampling** +" + +# ╔═╡ a729f1cd-404c-4a33-a8f9-b2ea6f0adb62 +let + prior_chn = sample(deterministic_uncond_mdl, Prior(), 2000) + gens = generated_quantities(deterministic_uncond_mdl, prior_chn) + plot_predYt(data, gens; + title = "Prior predictive: deterministic model", + ylabel = "Number of Infected students", + ) +end + +# ╔═╡ 4c0759fb-76e9-4de5-9206-89e8bfb6c3bb +md" +The prior predictive checking suggests that _a priori_ our parameter beliefs are very far from the data. Approaching the inference naively can lead to poor fits. + +We do three things to mitigate this: + +1. We choose a switching ODE solver which switches between explicit (`Tsit5`) and implicit (`Rosenbrock23`) solvers. This helps avoid the ODE solver failing when the sampler tries extreme parameter values. This is the default `solver = AutoTsit5(Rosenbrock23())` above. +2. To avoid the effect of numerically negative small values of `λt` we add a small `upjitter`. +3. We locate the maximum likelihood point, that is we ignore the influence of the priors, as a useful starting point for `NUTS`. +" + +# ╔═╡ 8d96db67-de3b-4704-9f54-f4ed50a4ecff +nmle_tries = 100 # ╔═╡ ba35cebd-0d29-43c5-8db7-f550d7f821bc -map_fit = map(1:10) do _ - fit = maximum_a_posteriori(mdl; - initial_params = [1, 0.1, 0.99] +mle_fit = map(1:nmle_tries) do _ + fit = try + maximum_likelihood(deterministic_mdl; ) + catch + (lp = -Inf,) + end end |> fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> min_and_fits -> min_and_fits[2][min_and_fits[1]] # ╔═╡ 0be912c1-22dc-4978-b86a-84273062f5da -map_fit.optim_result.retcode +mle_fit.optim_result.retcode + +# ╔═╡ a1a34b67-ff4e-4fee-aa30-4c2add3ea8a0 +md" +Note that we choose the best out of $nmle_tries tries for the MLE estimators. + +Now, we sample aiming at 1000 samples for each of 4 chains. +" # ╔═╡ 2cf64ba3-ff8d-40b0-9bd8-9e80393156f5 chn = sample( - mdl, NUTS(), MCMCThreads(), 1000, 4; initial_params = fill(map_fit.values.array, 4)) - -# ╔═╡ 6d8a1903-ffcf-47a9-a02a-4ef77525f133 -map_fit.values + deterministic_mdl, NUTS(), MCMCThreads(), 1000, 4; + initial_params = fill(mle_fit.values.array, 4)) # ╔═╡ b2429b68-dd75-499f-a4e1-1b7d72e209c7 describe(chn) @@ -208,144 +318,95 @@ describe(chn) # ╔═╡ 1e7f37c5-4cb4-4d06-8f68-55d80f7a00ad pairplot(chn) +# ╔═╡ c16b81a0-2d36-4012-aed4-a035af31b4c3 +md" +**Posterior predictive plotting** +" + # ╔═╡ 03d1ecf8-543d-444d-b1a3-7a19acd88499 let - ts = 1:size(data, 1) - gens = generated_quantities(uncond_mdl, chn) - fig = Figure() - ax = Axis(fig[1, 1]; - title = "Fitted deterministic model", - xticks = (ts[1:3:end], data.date[1:3:end] .|> string), - ylabel = "Number of Infected students" - ) - pred_Yt = mapreduce(hcat, gens) do gen - gen.obsYt - end |> X -> mapreduce(vcat, eachrow(X)) do row - quantile(row, [0.5, 0.025, 0.975])' - end - - lines!(ax, ts, pred_Yt[:, 1]; linewidth = 3, - label = "Fitted deterministic model", color = :green) - band!(ax, ts, pred_Yt[:, 2], pred_Yt[:, 3], color = (:green, 0.5)) - scatter!(ax, data.in_bed) - - fig + gens = generated_quantities(deterministic_uncond_mdl, chn) + plot_predYt(data, gens; + title = "Fitted deterministic model", + ylabel = "Number of Infected students", + ) end -# ╔═╡ 506855ac-57f1-40cf-9ee1-c3097b9b554a -@model function deterministic_ode_mdl2(Yt, obs, prob, N) - nobs = length(Yt) +# ╔═╡ e023770d-25f7-4b7a-b509-8a4372f42b76 +md" +## Inference for the Stochastic SIR model - β ~ LogNormal(0.0, 1.0) - γ ~ Gamma(0.004, 1 / 0.002) - S₀ ~ Beta(0.5, 0.5) +In _Chatzilena et al_, they present an auto-regressive model for connecting the outcome of the ODE model to illness observations. The argument is that the stochastic component of the model can absorb the noise +generated by a possible mis-specification of the model. - # try - _prob = remake(prob; - u0 = [S₀, 1 / N, 1 - S₀], - p = [β, γ] - ) +In their approach they consider $\kappa_t = \log \lambda_t$ where $\kappa_t$ evolves according to an Ornstein-Uhlenbeck process: - sol = solve(_prob, AutoTsit5(Rosenbrock23()); saveat = 1.0:nobs, - verbose = false, sensealg = ForwardDiffSensitivity()) - λt = N * sol[2, :] .+ 1e-3 +```math +d\kappa_t = \phi(\mu_t - \kappa_t) dt + \sigma dB_t. +``` +Which has transition density: +```math +\kappa_{t+1} | \kappa_t \sim N\Big(\mu_t + \left(\kappa_t - \mu_t\right)e^{-\phi}, {\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)\Big). +``` +Where $\mu_t = \log(I(t))$. - @submodel obsYt = generate_observations(obs, Yt, λt) +We modify this approach since it implies that the $\mu_t$ is treated as constant between observation times. - return (; sol, obsYt, R0 = β / γ) - # catch - # Turing.@addlogprob! -Inf - # return - # end -end +Instead we redefine $\kappa_t$ as the log-residual: -# ╔═╡ 019f0d55-d1e6-43a0-88fc-d1d5a7d9334b -mdl3 = deterministic_ode_mdl2(data.in_bed, obs, sir_prob, N) +$\kappa_t = \log(\lambda_t / I(t)).$ -# ╔═╡ 466ac63c-9d79-4906-8b97-f7c70eee66d9 -uncond_mdl3 = deterministic_ode_mdl2(fill(missing, length(data.in_bed)), obs, sir_prob, N) +With the transition density: -# ╔═╡ 62cf0ec1-0e5f-457e-882e-9282553680e8 -chn3 = sample(mdl3, NUTS(), MCMCThreads(), 1000, 4) +```math +\kappa_{t+1} | \kappa_t \sim N\Big(\kappa_te^{-\phi}, {\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)\Big). +``` -# ╔═╡ 7a68937a-afe4-48f6-8e69-6e318bb03887 -let - ts = 1:size(data, 1) - gens = generated_quantities(uncond_mdl3, chn3) - fig = Figure() - ax = Axis(fig[1, 1]; - title = "Fitted deterministic model", - xticks = (ts[1:3:end], data.date[1:3:end] .|> string), - ylabel = "Number of Infected students" - ) - pred_Yt = mapreduce(hcat, gens) do gen - gen.obsYt - end |> X -> mapreduce(vcat, eachrow(X)) do row - quantile(row, [0.5, 0.025, 0.975])' - end - - lines!(ax, ts, pred_Yt[:, 1]; linewidth = 3, - label = "Fitted deterministic model", color = :green) - band!(ax, ts, pred_Yt[:, 2], pred_Yt[:, 3], color = (:green, 0.5)) - scatter!(ax, data.in_bed) - - fig -end +This is an AR(1) process. -# ╔═╡ 19be6d10-342f-4fbd-a1ad-053ed9d1f039 -describe(chn3) +The stochastic model is completed: + +```math +\begin{aligned} +Y_t &\sim \text{Poisson}(\lambda_t)\\ +\lambda_t &= I(t)\exp(\kappa_t)\\ +\beta &\sim \text{LogNormal}(\text{logmean}=0,\text{logstd}=1) \\ +\gamma & \sim \text{Gamma}(\text{shape} = 0.004, \text{scale} = 50)\\ +S(0) /N &\sim \text{Beta}(0.5, 0.5)\\ +\phi & \sim \text{HalfNormal}(0, 100) \\ +1 / \sigma^2 & \sim \text{InvGamma}(0.1,0.1). +\end{aligned} +``` -# ╔═╡ e023770d-25f7-4b7a-b509-8a4372f42b76 -md" -## Stochastic model " -# ╔═╡ 71a26408-1c26-46cf-bc72-c6ba528dfadd -ar = AR(HalfNormal(0.01), - HalfNormal(0.3), - Normal(0, 0.001) -) +# ╔═╡ 69ba59d1-2221-463f-8853-ae172739e512 +md" +We will using the `AR` struct from `EpiAware` to define the auto-regressive process in this model which has a direct parameterisation of the `AR` model. + +To convert from the formulation above we sample from the priors, and define `HalfNormal` priors based on the sampled prior means of $e^{-\phi}$ and ${\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)$. We also add a strong prior that $\kappa_1 \approx 0$. +" # ╔═╡ 178e0048-069a-4953-bb24-5116eb81cc41 ϕs = rand(truncated(Normal(0, 100), lower = 0.0), 1000) # ╔═╡ e6bcf0c0-3cc4-41f3-ad20-fa11bf2ca37b -σs = rand(InverseGamma(0.1, 0.1), 1000) .|> x -> 1 / x - -# ╔═╡ f9c1bcd4-bfb4-45d4-ae06-f114a0923bd7 -mean(InverseGamma(0.1, 0.1)) +σ²s = rand(InverseGamma(0.1, 0.1), 1000) .|> x -> 1 / x # ╔═╡ 4f07e8ba-30d0-411f-8c3e-b6d5bc1bb5fa -AR_damps = ϕs .|> ϕ -> exp(-ϕ) - -# ╔═╡ 7235289e-28f0-43c2-986b-81b96c42d9fe -mean(AR_damps) +sampled_AR_damps = ϕs .|> ϕ -> exp(-ϕ) # ╔═╡ 48032d21-53fa-4c0a-85cb-c22327b55073 -AR_stds = zip(ϕs, σs) .|> ϕ_σ -> (1 - exp(-2 * ϕ_σ[1])) * ϕ_σ[2] / (2 * ϕ_σ[1]) - -# ╔═╡ 4089aea2-3946-48b0-bf7c-dcdc73fe87fa -mean(AR_stds) - -# ╔═╡ ec63fd4b-4323-4a9e-9aa7-46ba4115ec4f - -# ╔═╡ 2dcb4034-b138-4c3e-b65f-ba13f230439c -hist(AR_stds) - -# ╔═╡ 7271886d-2f87-4dc1-833b-182f4b726738 -# xs = rand(truncated(Normal(0,100), lower = 0.), 1000) .|> x -> exp(-x) -xs = rand(InverseGamma(1 / 0.1, 1 / 0.1), 1000) - -# ╔═╡ 68b75d5b-2b45-44bd-a973-12cba31d0e53 - -# ╔═╡ f0f02012-e0fe-4d11-a60a-dc27b6dd510c -density(xs) - -# ╔═╡ e15d0532-0c8a-4cd2-a576-567fc0c625c5 -gmdl = generate_latent(ar, 10) +sampled_AR_stds = map(ϕs, σ²s) do ϕ, σ² + (1 - exp(-2 * ϕ)) * σ² / (2 * ϕ) +end -# ╔═╡ 0be4b20e-5f16-43dc-90f6-84a6f29ae8cc -gmdl() +# ╔═╡ 71a26408-1c26-46cf-bc72-c6ba528dfadd +ar = AR( + damp_priors = [HalfNormal(mean(sampled_AR_damps))], + std_prior = HalfNormal(mean(sampled_AR_stds)), + init_priors = [Normal(0, 0.001)] +) # ╔═╡ 9309f7f8-0896-4686-8bfc-b9f82d91bc0f @model function stochastic_ode_mdl(Yt, logobsprob, obs, prob, N) @@ -377,26 +438,23 @@ gmdl() # end end -# ╔═╡ 6dbd3935-dada-4cac-903e-2dec1a197304 - # ╔═╡ 4330c83f-de39-44c7-bdab-87e5f5830145 -mdl2 = stochastic_ode_mdl(data.in_bed, ar, obs, sir_prob, N) +stochastic_mdl = stochastic_ode_mdl(data.in_bed, ar, obs, sir_prob, N) # ╔═╡ 8071c92f-9fe8-48cf-b1a0-79d1e34ec7e7 -uncond_mdl2 = stochastic_ode_mdl(fill(missing, length(data.in_bed)), ar, obs, sir_prob, N) - -# ╔═╡ bbe9a87a-a212-4d9d-9c75-8a863d6fb0be -rand(mdl2) +stochastic_uncond_mdl = stochastic_ode_mdl(fill(missing, length(data.in_bed)), ar, obs, sir_prob, N) # ╔═╡ d4502528-d058-4899-b3dd-576316116c18 -map_fit2 = map(1:10) do _ - fit = maximum_a_posteriori(mdl2; - initial_params = vcat([1, 0.1, 0.99, 0.01, 0.0, 0.01], zeros(13)), - adtype = AutoReverseDiff() +mle_fit2 = map(1:nmle_tries) do _ + fit = try + maximum_likelihood(stochastic_mdl; ) + catch + (lp = -Inf,) + end end |> - fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> - min_and_fits -> min_and_fits[2][min_and_fits[1]] + fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> + min_and_fits -> min_and_fits[2][min_and_fits[1]] # ╔═╡ 6a246854-601b-4d5a-9fb8-52b0e1620e7d mdl2() @@ -444,9 +502,10 @@ end # ╔═╡ 36efe6e0-643f-42e6-9d64-de2f5a76b764 + # ╔═╡ Cell order: # ╟─e34cec5a-a173-4e92-a860-340c7a9e9c72 -# ╠═33384fc6-7cca-11ef-3567-ab7df9200cde +# ╟─33384fc6-7cca-11ef-3567-ab7df9200cde # ╠═b1468db3-7ab0-468c-8e27-70013a8f512f # ╠═a4710701-6315-459d-b677-f24b77ff3e80 # ╠═7263d714-2ce4-4d57-8881-6b60db018dd5 @@ -455,7 +514,7 @@ end # ╠═3897e773-ed07-4860-bb62-35605d0dacb0 # ╠═14641441-dbea-4fdf-88e0-64a57da60ef7 # ╠═a0d91258-8ab5-4adc-98f2-8f17b4bd685c -# ╠═943b82ec-b4dc-4537-8183-d6c73cd74a37 +# ╟─943b82ec-b4dc-4537-8183-d6c73cd74a37 # ╟─0e78285c-d2e8-4c3c-848a-14dae6ead0a4 # ╠═ab4269b1-e292-466f-8bfb-713d917c18f9 # ╟─f16eb00b-2d77-45df-b767-757fe2f5674c @@ -463,45 +522,38 @@ end # ╟─d64388f9-6edd-414d-a191-316f75b35b2c # ╠═7c9cbbc1-71ef-4d81-b93a-c2b3a8683d53 # ╠═aba3f1db-c290-409c-9b9e-6065935ede54 -# ╠═3f54bb44-76c4-4744-885a-46dedfaffeca +# ╟─3f54bb44-76c4-4744-885a-46dedfaffeca +# ╟─ea1be94b-d722-47ee-8465-982c83dc6838 # ╠═87509792-e28d-4618-9bf5-e06b2e5dbe8b +# ╠═81501c84-5e1f-4829-a26d-52fe00503958 # ╠═1d287c8e-7000-4b23-ae7e-f7008c3e53bd +# ╟─e7383885-fa6a-4240-a252-44ae82cae713 # ╠═dbc1b453-1c29-4f82-bec9-098d67f9e63f # ╠═e795c2bf-0861-4e96-9921-db47f41af206 +# ╟─e848434c-2543-43d1-ae22-5c4241f138bb +# ╠═ab8c98d1-d357-4c49-9f5a-f069e05c45f5 +# ╟─2c6ac235-e331-4189-8c8c-74de5f98b2c4 +# ╠═a729f1cd-404c-4a33-a8f9-b2ea6f0adb62 +# ╟─4c0759fb-76e9-4de5-9206-89e8bfb6c3bb +# ╠═8d96db67-de3b-4704-9f54-f4ed50a4ecff # ╠═ba35cebd-0d29-43c5-8db7-f550d7f821bc # ╠═0be912c1-22dc-4978-b86a-84273062f5da +# ╟─a1a34b67-ff4e-4fee-aa30-4c2add3ea8a0 # ╠═2cf64ba3-ff8d-40b0-9bd8-9e80393156f5 -# ╠═6d8a1903-ffcf-47a9-a02a-4ef77525f133 # ╠═b2429b68-dd75-499f-a4e1-1b7d72e209c7 # ╠═1e7f37c5-4cb4-4d06-8f68-55d80f7a00ad +# ╟─c16b81a0-2d36-4012-aed4-a035af31b4c3 # ╠═03d1ecf8-543d-444d-b1a3-7a19acd88499 -# ╠═506855ac-57f1-40cf-9ee1-c3097b9b554a -# ╠═019f0d55-d1e6-43a0-88fc-d1d5a7d9334b -# ╠═466ac63c-9d79-4906-8b97-f7c70eee66d9 -# ╠═62cf0ec1-0e5f-457e-882e-9282553680e8 -# ╠═7a68937a-afe4-48f6-8e69-6e318bb03887 -# ╠═19be6d10-342f-4fbd-a1ad-053ed9d1f039 -# ╠═e023770d-25f7-4b7a-b509-8a4372f42b76 -# ╠═71a26408-1c26-46cf-bc72-c6ba528dfadd +# ╟─e023770d-25f7-4b7a-b509-8a4372f42b76 +# ╟─69ba59d1-2221-463f-8853-ae172739e512 # ╠═178e0048-069a-4953-bb24-5116eb81cc41 # ╠═e6bcf0c0-3cc4-41f3-ad20-fa11bf2ca37b -# ╠═f9c1bcd4-bfb4-45d4-ae06-f114a0923bd7 # ╠═4f07e8ba-30d0-411f-8c3e-b6d5bc1bb5fa -# ╠═7235289e-28f0-43c2-986b-81b96c42d9fe # ╠═48032d21-53fa-4c0a-85cb-c22327b55073 -# ╠═4089aea2-3946-48b0-bf7c-dcdc73fe87fa -# ╠═ec63fd4b-4323-4a9e-9aa7-46ba4115ec4f -# ╠═2dcb4034-b138-4c3e-b65f-ba13f230439c -# ╠═7271886d-2f87-4dc1-833b-182f4b726738 -# ╠═68b75d5b-2b45-44bd-a973-12cba31d0e53 -# ╠═f0f02012-e0fe-4d11-a60a-dc27b6dd510c -# ╠═e15d0532-0c8a-4cd2-a576-567fc0c625c5 -# ╠═0be4b20e-5f16-43dc-90f6-84a6f29ae8cc +# ╠═71a26408-1c26-46cf-bc72-c6ba528dfadd # ╠═9309f7f8-0896-4686-8bfc-b9f82d91bc0f -# ╠═6dbd3935-dada-4cac-903e-2dec1a197304 # ╠═4330c83f-de39-44c7-bdab-87e5f5830145 # ╠═8071c92f-9fe8-48cf-b1a0-79d1e34ec7e7 -# ╠═bbe9a87a-a212-4d9d-9c75-8a863d6fb0be # ╠═d4502528-d058-4899-b3dd-576316116c18 # ╠═6a246854-601b-4d5a-9fb8-52b0e1620e7d # ╠═156272d7-56c4-4ac4-bf3e-7882f4edc144 From 32b176b20a91decd4db6e6ccf9d801511b35d758 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 2 Oct 2024 08:50:41 +0100 Subject: [PATCH 06/17] Update index.jl --- .../replications/chatzilena-2019/index.jl | 55 ++++++++++++------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl index 8c993d400..2907bfbe6 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -295,7 +295,7 @@ mle_fit = map(1:nmle_tries) do _ end end |> fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> - min_and_fits -> min_and_fits[2][min_and_fits[1]] + max_and_fits -> max_and_fits[2][max_and_fits[1]] # ╔═╡ 0be912c1-22dc-4978-b86a-84273062f5da mle_fit.optim_result.retcode @@ -416,7 +416,6 @@ ar = AR( γ ~ Gamma(0.004, 1 / 0.002) S₀ ~ Beta(0.5, 0.5) - # try _prob = remake(prob; u0 = [S₀, 1 - S₀, 0.0], p = [β, γ] @@ -432,10 +431,6 @@ ar = AR( @submodel obsYt = generate_observations(obs, Yt, λt) return (; sol, obsYt, R0 = β / γ) - # catch - # Turing.@addlogprob! -Inf - # return - # end end # ╔═╡ 4330c83f-de39-44c7-bdab-87e5f5830145 @@ -444,39 +439,58 @@ stochastic_mdl = stochastic_ode_mdl(data.in_bed, ar, obs, sir_prob, N) # ╔═╡ 8071c92f-9fe8-48cf-b1a0-79d1e34ec7e7 stochastic_uncond_mdl = stochastic_ode_mdl(fill(missing, length(data.in_bed)), ar, obs, sir_prob, N) +# ╔═╡ adb9d0ac-d412-4dbc-a601-59fcc33adf43 +md" +**Prior predictive checking** +" + +# ╔═╡ b44286f9-ba88-4e2b-9a34-f14c0a78824d +let + prior_chn = sample(stochastic_uncond_mdl, Prior(), 2000) + gens = generated_quantities(stochastic_uncond_mdl, prior_chn) + plot_predYt(data, gens; + title = "Prior predictive: stochastic model", + ylabel = "Number of Infected students", + ) +end + # ╔═╡ d4502528-d058-4899-b3dd-576316116c18 mle_fit2 = map(1:nmle_tries) do _ fit = try maximum_likelihood(stochastic_mdl; + adtype = AutoReverseDiff(true), ) catch (lp = -Inf,) end end |> fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> - min_and_fits -> min_and_fits[2][min_and_fits[1]] + max_and_fits -> max_and_fits[2][max_and_fits[1]] -# ╔═╡ 6a246854-601b-4d5a-9fb8-52b0e1620e7d -mdl2() +# ╔═╡ 78a732ab-4915-43d9-af55-b01bd84eb364 +map_fit2 = map(1:nmle_tries) do _ + fit = + maximum_likelihood(stochastic_mdl; + adtype = AutoReverseDiff(true), + initial_params = mle_fit2.values.array, + ) + # catch + # (lp = -Inf,) + # end +end |> + fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> + max_and_fits -> max_and_fits[2][max_and_fits[1]] # ╔═╡ 156272d7-56c4-4ac4-bf3e-7882f4edc144 -# ╠═╡ disabled = true -#=╠═╡ -chn2 = sample(mdl2, NUTS(; adtype = AutoReverseDiff(true)), MCMCThreads(), 1000, 4; initial_params = fill(map_fit2.values.array,4)) - ╠═╡ =# +chn2 = sample(stochastic_mdl, NUTS(; adtype = AutoReverseDiff(true)), MCMCThreads(), 1000, 4; initial_params = fill(map_fit2.values.array,4)) # ╔═╡ 00b90e6d-732f-41c9-a603-cabe9740e329 -#=╠═╡ describe(chn2) - ╠═╡ =# # ╔═╡ 37a016d8-8384-41c9-abdd-23e88b1f988d -#=╠═╡ pairplot(chn2[[:β, :γ, :S₀]]) - ╠═╡ =# # ╔═╡ 0e7bbf13-9187-41ea-8b46-294b93be4c6d -#=╠═╡ let ts = 1:size(data, 1) gens = generated_quantities(uncond_mdl2, chn2) @@ -498,7 +512,6 @@ scatter!(ax, data.in_bed) fig end - ╠═╡ =# # ╔═╡ 36efe6e0-643f-42e6-9d64-de2f5a76b764 @@ -554,8 +567,10 @@ end # ╠═9309f7f8-0896-4686-8bfc-b9f82d91bc0f # ╠═4330c83f-de39-44c7-bdab-87e5f5830145 # ╠═8071c92f-9fe8-48cf-b1a0-79d1e34ec7e7 +# ╠═adb9d0ac-d412-4dbc-a601-59fcc33adf43 +# ╠═b44286f9-ba88-4e2b-9a34-f14c0a78824d # ╠═d4502528-d058-4899-b3dd-576316116c18 -# ╠═6a246854-601b-4d5a-9fb8-52b0e1620e7d +# ╠═78a732ab-4915-43d9-af55-b01bd84eb364 # ╠═156272d7-56c4-4ac4-bf3e-7882f4edc144 # ╠═00b90e6d-732f-41c9-a603-cabe9740e329 # ╠═37a016d8-8384-41c9-abdd-23e88b1f988d From b2fae2dc8f7dc4b5794d6bd25a0ec4145096ade8 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Wed, 2 Oct 2024 12:53:10 +0100 Subject: [PATCH 07/17] fix initialisation process for stochastic model --- .../replications/chatzilena-2019/index.jl | 296 ++++++++++-------- 1 file changed, 170 insertions(+), 126 deletions(-) diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl index 2907bfbe6..d7c9a68e3 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -132,9 +132,8 @@ We downloaded the data of this outbreak using the R package `outbreaks` which is # ╔═╡ 7c9cbbc1-71ef-4d81-b93a-c2b3a8683d53 data = "https://raw.githubusercontent.com/CDCgov/Rt-without-renewal/refs/heads/446-add-chatzilena-et-al-as-a-replication-example/EpiAware/docs/src/showcase/replications/chatzilena-2019/influenza_england_1978_school.csv2" |> url -> CSV.read(download(url), DataFrame) |> - df -> @transform(df, - :ts = (:date .- minimum(:date)) .|> d -> d.value + 1.0, - ) + df -> @transform(df, + :ts=(:date .- minimum(:date)) .|> d -> d.value + 1.0,) # ╔═╡ aba3f1db-c290-409c-9b9e-6065935ede54 N = 763; @@ -150,7 +149,7 @@ sir_prob = ODEProblem( md" ## Inference for the deterministic SIR model -The boarding school data gives the number of children \"in bed\" and \"convalescent\" on each of 14 days from 22nd Jan to 4th Feb 1978. We follow _Chatzilena et al_ and treat the number \"in bed\" as a proxy for the number of children in the infectious (I) compartment in the ODE model. +The boarding school data gives the number of children \"in bed\" and \"convalescent\" on each of 14 days from 22nd Jan to 4th Feb 1978. We follow _Chatzilena et al_ and treat the number \"in bed\" as a proxy for the number of children in the infectious (I) compartment in the ODE model. The full observation model is: @@ -181,32 +180,32 @@ Now we can write the observation model using the `Turing` PPL. " # ╔═╡ 1d287c8e-7000-4b23-ae7e-f7008c3e53bd -@model function deterministic_ode_mdl(Yt, ts, obs, prob, N; - solver = AutoTsit5(Rosenbrock23()), - upjitter = 1e-3 +@model function deterministic_ode_mdl(Yt, ts, obs, prob, N; + solver = AutoTsit5(Rosenbrock23()), + upjitter = 1e-3 ) - ##Priors## + ##Priors## β ~ LogNormal(0.0, 1.0) γ ~ Gamma(0.004, 1 / 0.002) S₀ ~ Beta(0.5, 0.5) - ##remake ODE model## + ##remake ODE model## _prob = remake(prob; u0 = [S₀, 1 - S₀, 0.0], p = [β, γ] ) - ##Solve remade ODE model## - - sol = solve(_prob, solver; - saveat = ts, + ##Solve remade ODE model## + + sol = solve(_prob, solver; + saveat = ts, verbose = false) - ##log-like accumulation using obs## + ##log-like accumulation using obs## λt = N * sol[2, :] .+ upjitter #expected It @submodel obsYt = generate_observations(obs, Yt, λt) - ##Generated quantities## + ##Generated quantities## return (; sol, obsYt, R0 = β / γ) end @@ -222,7 +221,8 @@ We instantiate the model in two ways: deterministic_mdl = deterministic_ode_mdl(data.in_bed, data.ts, obs, sir_prob, N); # ╔═╡ e795c2bf-0861-4e96-9921-db47f41af206 -deterministic_uncond_mdl = deterministic_ode_mdl(fill(missing, length(data.in_bed)), data.ts, obs, sir_prob, N); +deterministic_uncond_mdl = deterministic_ode_mdl( + fill(missing, length(data.in_bed)), data.ts, obs, sir_prob, N); # ╔═╡ e848434c-2543-43d1-ae22-5c4241f138bb md" @@ -231,29 +231,32 @@ We add a useful plotting utility. # ╔═╡ ab8c98d1-d357-4c49-9f5a-f069e05c45f5 function plot_predYt(data, gens; title::String, ylabel::String) - fig = Figure() - ga = fig[1, 1:2] = GridLayout() - - ax = Axis(ga[1, 1]; - title = title, - xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string), - ylabel = ylabel, - ) - pred_Yt = mapreduce(hcat, gens) do gen - gen.obsYt - end |> X -> mapreduce(vcat, eachrow(X)) do row - quantile(row, [0.5, 0.025, 0.975, 0.1, 0.9, 0.25, 0.75])' - end - - lines!(ax, data.ts, pred_Yt[:, 1]; linewidth = 3, color = :green, label = "Median") - band!(ax, data.ts, pred_Yt[:, 2], pred_Yt[:, 3], color = (:green, 0.2), label = "95% CI") - band!(ax, data.ts, pred_Yt[:, 4], pred_Yt[:, 5], color = (:green, 0.4), label = "80% CI") - band!(ax, data.ts, pred_Yt[:, 6], pred_Yt[:, 7], color = (:green, 0.6), label = "50% CI") - scatter!(ax, data.in_bed, label = "data") - leg = Legend(ga[1, 2], ax; framevisible = false) - hidespines!(ax) - - fig + fig = Figure() + ga = fig[1, 1:2] = GridLayout() + + ax = Axis(ga[1, 1]; + title = title, + xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string), + ylabel = ylabel + ) + pred_Yt = mapreduce(hcat, gens) do gen + gen.obsYt + end |> X -> mapreduce(vcat, eachrow(X)) do row + quantile(row, [0.5, 0.025, 0.975, 0.1, 0.9, 0.25, 0.75])' + end + + lines!(ax, data.ts, pred_Yt[:, 1]; linewidth = 3, color = :green, label = "Median") + band!( + ax, data.ts, pred_Yt[:, 2], pred_Yt[:, 3], color = (:green, 0.2), label = "95% CI") + band!( + ax, data.ts, pred_Yt[:, 4], pred_Yt[:, 5], color = (:green, 0.4), label = "80% CI") + band!( + ax, data.ts, pred_Yt[:, 6], pred_Yt[:, 7], color = (:green, 0.6), label = "50% CI") + scatter!(ax, data.in_bed, label = "data") + leg = Legend(ga[1, 2], ax; framevisible = false) + hidespines!(ax) + + fig end # ╔═╡ 2c6ac235-e331-4189-8c8c-74de5f98b2c4 @@ -263,12 +266,12 @@ md" # ╔═╡ a729f1cd-404c-4a33-a8f9-b2ea6f0adb62 let - prior_chn = sample(deterministic_uncond_mdl, Prior(), 2000) + prior_chn = sample(deterministic_uncond_mdl, Prior(), 2000) gens = generated_quantities(deterministic_uncond_mdl, prior_chn) - plot_predYt(data, gens; - title = "Prior predictive: deterministic model", - ylabel = "Number of Infected students", - ) + plot_predYt(data, gens; + title = "Prior predictive: deterministic model", + ylabel = "Number of Infected students" + ) end # ╔═╡ 4c0759fb-76e9-4de5-9206-89e8bfb6c3bb @@ -287,12 +290,11 @@ nmle_tries = 100 # ╔═╡ ba35cebd-0d29-43c5-8db7-f550d7f821bc mle_fit = map(1:nmle_tries) do _ - fit = try - maximum_likelihood(deterministic_mdl; - ) - catch - (lp = -Inf,) - end + fit = try + maximum_likelihood(deterministic_mdl) + catch + (lp = -Inf,) + end end |> fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> max_and_fits -> max_and_fits[2][max_and_fits[1]] @@ -304,13 +306,13 @@ mle_fit.optim_result.retcode md" Note that we choose the best out of $nmle_tries tries for the MLE estimators. -Now, we sample aiming at 1000 samples for each of 4 chains. +Now, we sample aiming at 1000 samples for each of 4 chains. " # ╔═╡ 2cf64ba3-ff8d-40b0-9bd8-9e80393156f5 chn = sample( - deterministic_mdl, NUTS(), MCMCThreads(), 1000, 4; - initial_params = fill(mle_fit.values.array, 4)) + deterministic_mdl, NUTS(), MCMCThreads(), 1000, 4; + initial_params = fill(mle_fit.values.array, 4)) # ╔═╡ b2429b68-dd75-499f-a4e1-1b7d72e209c7 describe(chn) @@ -326,10 +328,10 @@ md" # ╔═╡ 03d1ecf8-543d-444d-b1a3-7a19acd88499 let gens = generated_quantities(deterministic_uncond_mdl, chn) - plot_predYt(data, gens; - title = "Fitted deterministic model", - ylabel = "Number of Infected students", - ) + plot_predYt(data, gens; + title = "Fitted deterministic model", + ylabel = "Number of Infected students" + ) end # ╔═╡ e023770d-25f7-4b7a-b509-8a4372f42b76 @@ -398,46 +400,75 @@ sampled_AR_damps = ϕs .|> ϕ -> exp(-ϕ) # ╔═╡ 48032d21-53fa-4c0a-85cb-c22327b55073 sampled_AR_stds = map(ϕs, σ²s) do ϕ, σ² - (1 - exp(-2 * ϕ)) * σ² / (2 * ϕ) + (1 - exp(-2 * ϕ)) * σ² / (2 * ϕ) end +# ╔═╡ 89c767b8-97a0-45bb-9e9f-821879ddd38b +md" +We define the AR(1) process by matching means of `HalfNormal` prior distributions for the damp parameters and std deviation parameter to the calculated the prior means from the _Chatzilena et al_ definition. +" + # ╔═╡ 71a26408-1c26-46cf-bc72-c6ba528dfadd ar = AR( - damp_priors = [HalfNormal(mean(sampled_AR_damps))], + damp_priors = [HalfNormal(mean(sampled_AR_damps))], std_prior = HalfNormal(mean(sampled_AR_stds)), init_priors = [Normal(0, 0.001)] ) # ╔═╡ 9309f7f8-0896-4686-8bfc-b9f82d91bc0f -@model function stochastic_ode_mdl(Yt, logobsprob, obs, prob, N) - nobs = length(Yt) +@model function stochastic_ode_mdl(Yt, ts, logobsprob, obs, prob, N; + solver = AutoTsit5(Rosenbrock23()), + upjitter = 1e-2#0.1, +) + ##Priors## β ~ LogNormal(0.0, 1.0) γ ~ Gamma(0.004, 1 / 0.002) S₀ ~ Beta(0.5, 0.5) + ##Remake ODE model## _prob = remake(prob; u0 = [S₀, 1 - S₀, 0.0], p = [β, γ] ) - sol = solve(_prob, AutoTsit5(Rosenbrock23()); - sensealg = ForwardDiffSensitivity(), - saveat = 1.0:nobs, verbose = false) - # μ = log.(N * sol[2, :]) - @submodel κ = generate_latent(logobsprob, nobs) - λt = @. N * sol[2, :] * exp(κ) + 0.1 + ##Solve ODE model## + sol = solve(_prob, solver; + saveat = ts, + verbose = false + ) + ##Sample the log-residual AR process## + nobs = length(Yt) + @submodel κₜ = generate_latent(logobsprob, nobs) + λt = @. N * sol[2, :] * exp(κₜ) + upjitter + + ##log-like accumulation using obs## @submodel obsYt = generate_observations(obs, Yt, λt) + ##Generated quantities## return (; sol, obsYt, R0 = β / γ) end # ╔═╡ 4330c83f-de39-44c7-bdab-87e5f5830145 -stochastic_mdl = stochastic_ode_mdl(data.in_bed, ar, obs, sir_prob, N) +stochastic_mdl = stochastic_ode_mdl( + data.in_bed, + data.ts, + ar, + obs, + sir_prob, + N +) # ╔═╡ 8071c92f-9fe8-48cf-b1a0-79d1e34ec7e7 -stochastic_uncond_mdl = stochastic_ode_mdl(fill(missing, length(data.in_bed)), ar, obs, sir_prob, N) +stochastic_uncond_mdl = stochastic_ode_mdl( + fill(missing, length(data.in_bed)), + data.ts, + ar, + obs, + sir_prob, + N +) # ╔═╡ adb9d0ac-d412-4dbc-a601-59fcc33adf43 md" @@ -446,76 +477,84 @@ md" # ╔═╡ b44286f9-ba88-4e2b-9a34-f14c0a78824d let - prior_chn = sample(stochastic_uncond_mdl, Prior(), 2000) + prior_chn = sample(stochastic_uncond_mdl, Prior(), 2000) gens = generated_quantities(stochastic_uncond_mdl, prior_chn) - plot_predYt(data, gens; - title = "Prior predictive: stochastic model", - ylabel = "Number of Infected students", - ) + plot_predYt(data, gens; + title = "Prior predictive: stochastic model", + ylabel = "Number of Infected students" + ) end -# ╔═╡ d4502528-d058-4899-b3dd-576316116c18 -mle_fit2 = map(1:nmle_tries) do _ - fit = try - maximum_likelihood(stochastic_mdl; - adtype = AutoReverseDiff(true), - ) - catch - (lp = -Inf,) - end -end |> - fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> - max_and_fits -> max_and_fits[2][max_and_fits[1]] +# ╔═╡ f690114f-4dca-4451-8a93-57c9d8d7c20c +md" +The prior predictive checking again shows misaligned prior beliefs; for example _a priori_ without data we would not expect the median prediction of number of ill children as about 600 out of $N after 1 day. -# ╔═╡ 78a732ab-4915-43d9-af55-b01bd84eb364 -map_fit2 = map(1:nmle_tries) do _ - fit = - maximum_likelihood(stochastic_mdl; - adtype = AutoReverseDiff(true), - initial_params = mle_fit2.values.array, - ) - # catch - # (lp = -Inf,) - # end -end |> - fits -> (findmax(fit -> fit.lp, fits)[2], fits) |> - max_and_fits -> max_and_fits[2][max_and_fits[1]] +The latent process for the log-residuals $\kappa_t$ doesn't make much sense without priors, so we look for a reasonable MAP point to start NUTS from. We do this by first making an initial guess which is a mixture of: + +1. The posterior averages from the deterministic model. +2. The prior averages of the structure parameters of the AR(1) process. +3. Zero for the time-varying noise underlying the AR(1) process. +" + +# ╔═╡ 3491e7e5-6c82-4c50-8feb-d730d1fbe457 +rand(stochastic_mdl) + +# ╔═╡ e1d54935-4305-42cf-98c6-ccee9b0813ea +initial_guess = [[mean(chn[:β]), + mean(chn[:γ]), + mean(chn[:S₀]), + mean(ar.std_prior), + mean(ar.init_prior)[1], + mean(ar.damp_prior)[1]] + zeros(13)] + +# ╔═╡ 685221ea-f268-4ddc-937f-e7620d065c28 +md" +Starting from the initial guess, the MAP point is calculated rapidly in one pass. +" + +# ╔═╡ 6796ae76-bc2d-4895-ba0a-5e2c23c50dfb +map_fit_stoch_mdl = maximum_a_posteriori(stochastic_mdl; + adtype = AutoReverseDiff(), + initial_params = initial_guess +) + +# ╔═╡ 62080cc2-3cab-4a22-9b2e-2bff640a17a4 +md" +Now we can run NUTS, sampling 1000 posterior draws per chain for 4 chains. +" # ╔═╡ 156272d7-56c4-4ac4-bf3e-7882f4edc144 -chn2 = sample(stochastic_mdl, NUTS(; adtype = AutoReverseDiff(true)), MCMCThreads(), 1000, 4; initial_params = fill(map_fit2.values.array,4)) +chn2 = sample( + stochastic_mdl, + NUTS(; adtype = AutoReverseDiff(true)), + MCMCThreads(), 1000, 4; + initial_params = fill(map_fit_stoch_mdl.values.array, 4) +) # ╔═╡ 00b90e6d-732f-41c9-a603-cabe9740e329 describe(chn2) # ╔═╡ 37a016d8-8384-41c9-abdd-23e88b1f988d -pairplot(chn2[[:β, :γ, :S₀]]) +pairplot(chn2[[:β, :γ, :S₀, :σ_AR, Symbol("ar_init[1]"), Symbol("damp_AR[1]")]]) -# ╔═╡ 0e7bbf13-9187-41ea-8b46-294b93be4c6d +# ╔═╡ 7df5d669-d3a2-4a66-83c3-f8618e39bec6 let -ts = 1:size(data, 1) -gens = generated_quantities(uncond_mdl2, chn2) -fig = Figure() -ax = Axis(fig[1,1]; - title = "Fitted Stochastic model", - xticks = (ts[1:3:end], data.date[1:3:end] .|> string), - ylabel = "Number of Infected students" - ) -pred_Yt = mapreduce(hcat, gens) do gen - gen.obsYt -end |> X -> mapreduce(vcat, eachrow(X)) do row - quantile(row, [0.5, 0.025, 0.975])' + vars = mapreduce(vcat, 1:13) do i + Symbol("ϵ_t[$i]") + end + pairplot(chn2[vars]) end -lines!(ax,ts, pred_Yt[:,1]; linewidth = 3, label = "Fitted deterministic model", color = :green) -band!(ax, ts, pred_Yt[:,2], pred_Yt[:,3], color = (:green, 0.5)) -scatter!(ax, data.in_bed) - -fig +# ╔═╡ 0e7bbf13-9187-41ea-8b46-294b93be4c6d +let + gens = generated_quantities(stochastic_uncond_mdl, chn2) + plot_predYt(data, gens; + title = "Fitted stochastic model", + ylabel = "Number of Infected students" + ) end -# ╔═╡ 36efe6e0-643f-42e6-9d64-de2f5a76b764 - - # ╔═╡ Cell order: # ╟─e34cec5a-a173-4e92-a860-340c7a9e9c72 # ╟─33384fc6-7cca-11ef-3567-ab7df9200cde @@ -563,16 +602,21 @@ end # ╠═e6bcf0c0-3cc4-41f3-ad20-fa11bf2ca37b # ╠═4f07e8ba-30d0-411f-8c3e-b6d5bc1bb5fa # ╠═48032d21-53fa-4c0a-85cb-c22327b55073 +# ╟─89c767b8-97a0-45bb-9e9f-821879ddd38b # ╠═71a26408-1c26-46cf-bc72-c6ba528dfadd # ╠═9309f7f8-0896-4686-8bfc-b9f82d91bc0f # ╠═4330c83f-de39-44c7-bdab-87e5f5830145 # ╠═8071c92f-9fe8-48cf-b1a0-79d1e34ec7e7 -# ╠═adb9d0ac-d412-4dbc-a601-59fcc33adf43 +# ╟─adb9d0ac-d412-4dbc-a601-59fcc33adf43 # ╠═b44286f9-ba88-4e2b-9a34-f14c0a78824d -# ╠═d4502528-d058-4899-b3dd-576316116c18 -# ╠═78a732ab-4915-43d9-af55-b01bd84eb364 +# ╟─f690114f-4dca-4451-8a93-57c9d8d7c20c +# ╠═3491e7e5-6c82-4c50-8feb-d730d1fbe457 +# ╠═e1d54935-4305-42cf-98c6-ccee9b0813ea +# ╟─685221ea-f268-4ddc-937f-e7620d065c28 +# ╠═6796ae76-bc2d-4895-ba0a-5e2c23c50dfb +# ╟─62080cc2-3cab-4a22-9b2e-2bff640a17a4 # ╠═156272d7-56c4-4ac4-bf3e-7882f4edc144 # ╠═00b90e6d-732f-41c9-a603-cabe9740e329 # ╠═37a016d8-8384-41c9-abdd-23e88b1f988d +# ╠═7df5d669-d3a2-4a66-83c3-f8618e39bec6 # ╠═0e7bbf13-9187-41ea-8b46-294b93be4c6d -# ╠═36efe6e0-643f-42e6-9d64-de2f5a76b764 From 615d7e99e4c3d8e2f356e27006fac4cffe0e532a Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Wed, 2 Oct 2024 15:05:35 +0100 Subject: [PATCH 08/17] Update EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl Co-authored-by: Sam Abbott --- .../docs/src/showcase/replications/chatzilena-2019/index.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl index d7c9a68e3..cace1ceb1 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -94,7 +94,7 @@ We can interface to the `SciML` ecosystem by writing a function with the signatu > `(du, u, p, t) -> nothing` Where: -- `du` is the _vector field_ of the ODE problem, e.g. ${dS \over dt}$, ${dI \over dt}$ etc. This is calculated _in-place_. +- `du` is the _vector field_ of the ODE problem, e.g. ${dS \over dt}$, ${dI \over dt}$ etc. This is calculated _in-place_ (commonly denoted using ! in function names in Julia). - `u` is the _state_ of the ODE problem, e.g. $S$, $I$, etc. - `p` is an object that represents the parameters of the ODE problem, e.g. $\beta$, $\gamma$. - `t` is the time of the ODE problem. From 813a166208da2a50c64f7e81816c3f2b29497e54 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 2 Oct 2024 16:10:09 +0100 Subject: [PATCH 09/17] add StatsBase to doc env --- EpiAware/docs/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/EpiAware/docs/Project.toml b/EpiAware/docs/Project.toml index 8b9e975a6..44538ff67 100644 --- a/EpiAware/docs/Project.toml +++ b/EpiAware/docs/Project.toml @@ -17,6 +17,7 @@ PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TimeSeries = "9e3dc215-6440-5c97-bce1-76c03772f85e" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" From 3ba8106b2e46d5ac37669982918781162c738602 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 2 Oct 2024 16:10:36 +0100 Subject: [PATCH 10/17] add chatzilena replication to pages and build step --- EpiAware/docs/make.jl | 1 + EpiAware/docs/pages.jl | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/EpiAware/docs/make.jl b/EpiAware/docs/make.jl index 5def2a8c8..eb42ea0e3 100644 --- a/EpiAware/docs/make.jl +++ b/EpiAware/docs/make.jl @@ -15,6 +15,7 @@ include("build.jl") build("getting-started") build("getting-started/tutorials") +build("showcase/replications/chatzilena-2019") build("showcase/replications/mishra-2020") DocMeta.setdocmeta!(EpiAware, :DocTestSetup, :(using EpiAware); recursive = true) diff --git a/EpiAware/docs/pages.jl b/EpiAware/docs/pages.jl index 34db2263d..71268159d 100644 --- a/EpiAware/docs/pages.jl +++ b/EpiAware/docs/pages.jl @@ -26,7 +26,8 @@ getting_started_pages = Any[ showcase_pages = Any[ "Overview" => "showcase/index.md", "Replication" => [ - "On the derivation of the renewal equation from an age-dependent branching process: an epidemic modelling perspective" => "showcase/replications/mishra-2020/index.md" + "On the derivation of the renewal equation from an age-dependent branching process: an epidemic modelling perspective" => "showcase/replications/mishra-2020/index.md", + "Statistical inference for ODE-based infectious disease models" => "showcase/replications/chatzilena-2019/index.md", ] ] From 1b50d1d788a9a14e57010b9a725676c91e3be868 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 2 Oct 2024 16:11:15 +0100 Subject: [PATCH 11/17] reformat --- EpiAware/docs/pages.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/EpiAware/docs/pages.jl b/EpiAware/docs/pages.jl index 71268159d..ec7c4380d 100644 --- a/EpiAware/docs/pages.jl +++ b/EpiAware/docs/pages.jl @@ -27,7 +27,7 @@ showcase_pages = Any[ "Overview" => "showcase/index.md", "Replication" => [ "On the derivation of the renewal equation from an age-dependent branching process: an epidemic modelling perspective" => "showcase/replications/mishra-2020/index.md", - "Statistical inference for ODE-based infectious disease models" => "showcase/replications/chatzilena-2019/index.md", + "Statistical inference for ODE-based infectious disease models" => "showcase/replications/chatzilena-2019/index.md" ] ] From 86479f0fa8004f81c42f768aefc3126397e3afb2 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 2 Oct 2024 17:03:51 +0100 Subject: [PATCH 12/17] redo lines --- .../replications/chatzilena-2019/index.jl | 60 ++++++++++++------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl index cace1ceb1..f011784f2 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -48,7 +48,8 @@ md" # Example: Statistical inference for ODE-based infectious disease models # Introduction ## What are we going to do in this Vignette -In this vignette, we'll demonstrate how to use `EpiAware` in conjunction with [SciML ecosystem](https://sciml.ai/) for Bayesian inference of infectious disease dynamics. The model and data is heavily based on [Contemporary statistical inference for infectious disease models using Stan _Chatzilena et al. 2019_](https://www.sciencedirect.com/science/article/pii/S1755436519300325). +In this vignette, we'll demonstrate how to use `EpiAware` in conjunction with [SciML ecosystem](https://sciml.ai/) for Bayesian inference of infectious disease dynamics. +The model and data is heavily based on [Contemporary statistical inference for infectious disease models using Stan _Chatzilena et al. 2019_](https://www.sciencedirect.com/science/article/pii/S1755436519300325). We'll cover the following key points: @@ -63,15 +64,17 @@ This vignette builds on concepts from `EpiAware` observation models and a famila ## Packages used in this vignette -Alongside the `EpiAware` package we will use the `OrdinaryDiffEq` and `SciMLSensitivity` packages for interfacing with `SciML` ecosystem; this is a lower dependency usage of `DifferentialEquations.jl` that, respectively, exposes ODE solvers and adjoint methods for ODE solvees; that is the method of propagating parameter derivatives through functions containing ODE solutions. Bayesian inference will be done with `NUTS` from the `Turing` ecosystem. We will also use the `CairoMakie` package for plotting and `DataFramesMeta` for data manipulation. +Alongside the `EpiAware` package we will use the `OrdinaryDiffEq` and `SciMLSensitivity` packages for interfacing with `SciML` ecosystem; this is a lower dependency usage of `DifferentialEquations.jl` that, respectively, exposes ODE solvers and adjoint methods for ODE solvees; that is the method of propagating parameter derivatives through functions containing ODE solutions. +Bayesian inference will be done with `NUTS` from the `Turing` ecosystem. We will also use the `CairoMakie` package for plotting and `DataFramesMeta` for data manipulation. " # ╔═╡ 943b82ec-b4dc-4537-8183-d6c73cd74a37 md" # Single population SIR model -As mentioned in _Chatzilena et al_ disease spread is frequently modelled in terms -of ODE-based models. The study population is divided into compartments representing a specific stage of the epidemic status. In this case, susceptible, infected, and recovered individuals. +As mentioned in _Chatzilena et al_ disease spread is frequently modelled in terms of ODE-based models. +The study population is divided into compartments representing a specific stage of the epidemic status. +In this case, susceptible, infected, and recovered individuals. ```math \begin{aligned} @@ -81,9 +84,8 @@ of ODE-based models. The study population is divided into compartments represent \end{aligned} ``` where S(t) represents the number of susceptible, I(t) the number of -infected and R(t) the number of recovered individuals at time t. The -total population size is denoted by N (with N = S(t) + I(t) + R(t)), β -denotes the transmission rate and γ denotes the recovery rate. +infected and R(t) the number of recovered individuals at time t. +The total population size is denoted by N (with N = S(t) + I(t) + R(t)), β denotes the transmission rate and γ denotes the recovery rate. " @@ -115,7 +117,13 @@ end # ╔═╡ f16eb00b-2d77-45df-b767-757fe2f5674c md" -We combine vector field function `sir!` with a initial condition `u0` and the integration period `tspan` to make an `ODEProblem`. We do not define the parameters, these will be defined within an inference approach. +We combine vector field function `sir!` with a initial condition `u0` and the integration period `tspan` to make an `ODEProblem`. +We do not define the parameters, these will be defined within an inference approach. +" + +# ╔═╡ b5ff95d1-8a6f-4d48-adf2-60d91b3ebebe +md" +Note that this is analogous " # ╔═╡ d64388f9-6edd-414d-a191-316f75b35b2c @@ -123,7 +131,9 @@ md" ## Data for inference -There was a brief, but intense, outbreak of Influenza within the (semi-) closed community of a boarding school reported to the British medical journal in 1978. The outbreak lasted from 22nd January to 4th February and it is reported that one infected child started the epidemic and then it spread rapidly. Of the 763 children at the boarding scholl, 512 became ill. +There was a brief, but intense, outbreak of Influenza within the (semi-) closed community of a boarding school reported to the British medical journal in 1978. +The outbreak lasted from 22nd January to 4th February and it is reported that one infected child started the epidemic and then it spread rapidly. +Of the 763 children at the boarding scholl, 512 became ill. We downloaded the data of this outbreak using the R package `outbreaks` which is maintained as part of the [R Epidemics Consortium(RECON)](http://www. repidemicsconsortium.org). @@ -149,7 +159,8 @@ sir_prob = ODEProblem( md" ## Inference for the deterministic SIR model -The boarding school data gives the number of children \"in bed\" and \"convalescent\" on each of 14 days from 22nd Jan to 4th Feb 1978. We follow _Chatzilena et al_ and treat the number \"in bed\" as a proxy for the number of children in the infectious (I) compartment in the ODE model. +The boarding school data gives the number of children \"in bed\" and \"convalescent\" on each of 14 days from 22nd Jan to 4th Feb 1978. +We follow _Chatzilena et al_ and treat the number \"in bed\" as a proxy for the number of children in the infectious (I) compartment in the ODE model. The full observation model is: @@ -163,7 +174,8 @@ S(0) /N &\sim \text{Beta}(0.5, 0.5). \end{aligned} ``` -**NB: Chatzilena et al give $\lambda_t = \int_0^t \beta \frac{I(s)}{N} S(s) - \gamma I(s)ds = I(t) - I(0).$ However, this doesn't match their underlying stan code.** +**NB: Chatzilena et al give $\lambda_t = \int_0^t \beta \frac{I(s)}{N} S(s) - \gamma I(s)ds = I(t) - I(0).$ +However, this doesn't match their underlying stan code.** " # ╔═╡ ea1be94b-d722-47ee-8465-982c83dc6838 @@ -213,8 +225,10 @@ end md" We instantiate the model in two ways: -1. `deterministic_mdl`: This conditions the generative model on the data observation. We can sample from this model to find the posterior distribution of the parameters. -2. `deterministic_uncond_mdl`: This _doesn't_ condition on the data. This is useful for prior and posterior predictive modelling. +1. `deterministic_mdl`: This conditions the generative model on the data observation. +We can sample from this model to find the posterior distribution of the parameters. +2. `deterministic_uncond_mdl`: This _doesn't_ condition on the data. +This is useful for prior and posterior predictive modelling. " # ╔═╡ dbc1b453-1c29-4f82-bec9-098d67f9e63f @@ -276,7 +290,8 @@ end # ╔═╡ 4c0759fb-76e9-4de5-9206-89e8bfb6c3bb md" -The prior predictive checking suggests that _a priori_ our parameter beliefs are very far from the data. Approaching the inference naively can lead to poor fits. +The prior predictive checking suggests that _a priori_ our parameter beliefs are very far from the data. +Approaching the inference naively can lead to poor fits. We do three things to mitigate this: @@ -338,8 +353,8 @@ end md" ## Inference for the Stochastic SIR model -In _Chatzilena et al_, they present an auto-regressive model for connecting the outcome of the ODE model to illness observations. The argument is that the stochastic component of the model can absorb the noise -generated by a possible mis-specification of the model. +In _Chatzilena et al_, they present an auto-regressive model for connecting the outcome of the ODE model to illness observations. +The argument is that the stochastic component of the model can absorb the noise generated by a possible mis-specification of the model. In their approach they consider $\kappa_t = \log \lambda_t$ where $\kappa_t$ evolves according to an Ornstein-Uhlenbeck process: @@ -386,7 +401,8 @@ S(0) /N &\sim \text{Beta}(0.5, 0.5)\\ md" We will using the `AR` struct from `EpiAware` to define the auto-regressive process in this model which has a direct parameterisation of the `AR` model. -To convert from the formulation above we sample from the priors, and define `HalfNormal` priors based on the sampled prior means of $e^{-\phi}$ and ${\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)$. We also add a strong prior that $\kappa_1 \approx 0$. +To convert from the formulation above we sample from the priors, and define `HalfNormal` priors based on the sampled prior means of $e^{-\phi}$ and ${\sigma^2 \over 2 \phi} \left(1 - e^{-2\phi} \right)$. +We also add a strong prior that $\kappa_1 \approx 0$. " # ╔═╡ 178e0048-069a-4953-bb24-5116eb81cc41 @@ -418,7 +434,7 @@ ar = AR( # ╔═╡ 9309f7f8-0896-4686-8bfc-b9f82d91bc0f @model function stochastic_ode_mdl(Yt, ts, logobsprob, obs, prob, N; solver = AutoTsit5(Rosenbrock23()), - upjitter = 1e-2#0.1, + upjitter = 1e-2, ) ##Priors## @@ -489,7 +505,8 @@ end md" The prior predictive checking again shows misaligned prior beliefs; for example _a priori_ without data we would not expect the median prediction of number of ill children as about 600 out of $N after 1 day. -The latent process for the log-residuals $\kappa_t$ doesn't make much sense without priors, so we look for a reasonable MAP point to start NUTS from. We do this by first making an initial guess which is a mixture of: +The latent process for the log-residuals $\kappa_t$ doesn't make much sense without priors, so we look for a reasonable MAP point to start NUTS from. +We do this by first making an initial guess which is a mixture of: 1. The posterior averages from the deterministic model. 2. The prior averages of the structure parameters of the AR(1) process. @@ -571,13 +588,14 @@ end # ╠═ab4269b1-e292-466f-8bfb-713d917c18f9 # ╟─f16eb00b-2d77-45df-b767-757fe2f5674c # ╠═bb07a580-6d86-48b3-a79f-d2ed9306e87c +# ╠═b5ff95d1-8a6f-4d48-adf2-60d91b3ebebe # ╟─d64388f9-6edd-414d-a191-316f75b35b2c # ╠═7c9cbbc1-71ef-4d81-b93a-c2b3a8683d53 # ╠═aba3f1db-c290-409c-9b9e-6065935ede54 # ╟─3f54bb44-76c4-4744-885a-46dedfaffeca # ╟─ea1be94b-d722-47ee-8465-982c83dc6838 # ╠═87509792-e28d-4618-9bf5-e06b2e5dbe8b -# ╠═81501c84-5e1f-4829-a26d-52fe00503958 +# ╟─81501c84-5e1f-4829-a26d-52fe00503958 # ╠═1d287c8e-7000-4b23-ae7e-f7008c3e53bd # ╟─e7383885-fa6a-4240-a252-44ae82cae713 # ╠═dbc1b453-1c29-4f82-bec9-098d67f9e63f @@ -586,7 +604,7 @@ end # ╠═ab8c98d1-d357-4c49-9f5a-f069e05c45f5 # ╟─2c6ac235-e331-4189-8c8c-74de5f98b2c4 # ╠═a729f1cd-404c-4a33-a8f9-b2ea6f0adb62 -# ╟─4c0759fb-76e9-4de5-9206-89e8bfb6c3bb +# ╠═4c0759fb-76e9-4de5-9206-89e8bfb6c3bb # ╠═8d96db67-de3b-4704-9f54-f4ed50a4ecff # ╠═ba35cebd-0d29-43c5-8db7-f550d7f821bc # ╠═0be912c1-22dc-4978-b86a-84273062f5da From 3b75af5e967db34fb9758a00eb5c23dce4bdb336 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 2 Oct 2024 17:08:12 +0100 Subject: [PATCH 13/17] link to `EpiProblem` --- .../docs/src/showcase/replications/chatzilena-2019/index.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl index f011784f2..9c60fbcab 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -123,7 +123,9 @@ We do not define the parameters, these will be defined within an inference appro # ╔═╡ b5ff95d1-8a6f-4d48-adf2-60d91b3ebebe md" -Note that this is analogous +Note that this is analogous to the `EpiProblem` approach we expose from `EpiAware`, as used in the [Mishra et al replication](https://cdcgov.github.io/Rt-without-renewal/dev/showcase/replications/mishra-2020/). +The difference is that here we are going to use ODE solvers from the `SciML` ecosystem to generate the dynamics of the underlying infections. +In the linked example, we use latent process generation exposed by `EpiAware` as the underlying generative process for underlying dynamics. " # ╔═╡ d64388f9-6edd-414d-a191-316f75b35b2c @@ -588,7 +590,7 @@ end # ╠═ab4269b1-e292-466f-8bfb-713d917c18f9 # ╟─f16eb00b-2d77-45df-b767-757fe2f5674c # ╠═bb07a580-6d86-48b3-a79f-d2ed9306e87c -# ╠═b5ff95d1-8a6f-4d48-adf2-60d91b3ebebe +# ╟─b5ff95d1-8a6f-4d48-adf2-60d91b3ebebe # ╟─d64388f9-6edd-414d-a191-316f75b35b2c # ╠═7c9cbbc1-71ef-4d81-b93a-c2b3a8683d53 # ╠═aba3f1db-c290-409c-9b9e-6065935ede54 From 1b9b943ff87eae9ad934268475defa0519841854 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 3 Oct 2024 11:21:18 +0100 Subject: [PATCH 14/17] Small changes and prior pred plot for AR --- .../replications/chatzilena-2019/index.jl | 62 +++++++++++++++---- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl index 9c60fbcab..0a328951e 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -190,11 +190,11 @@ obs = PoissonError() # ╔═╡ 81501c84-5e1f-4829-a26d-52fe00503958 md" -Now we can write the observation model using the `Turing` PPL. +Now we can write the probabilistic model using the `Turing` PPL. " # ╔═╡ 1d287c8e-7000-4b23-ae7e-f7008c3e53bd -@model function deterministic_ode_mdl(Yt, ts, obs, prob, N; +@model function deterministic_ode_mdl(y_t, ts, obs, prob, N; solver = AutoTsit5(Rosenbrock23()), upjitter = 1e-3 ) @@ -217,20 +217,21 @@ Now we can write the observation model using the `Turing` PPL. ##log-like accumulation using obs## λt = N * sol[2, :] .+ upjitter #expected It - @submodel obsYt = generate_observations(obs, Yt, λt) + @submodel generated_y_t = generate_observations(obs, y_t, λt) ##Generated quantities## - return (; sol, obsYt, R0 = β / γ) + return (; sol, generated_y_t, R0 = β / γ) end # ╔═╡ e7383885-fa6a-4240-a252-44ae82cae713 md" We instantiate the model in two ways: -1. `deterministic_mdl`: This conditions the generative model on the data observation. -We can sample from this model to find the posterior distribution of the parameters. -2. `deterministic_uncond_mdl`: This _doesn't_ condition on the data. -This is useful for prior and posterior predictive modelling. +1. `deterministic_mdl`: This conditions the generative model on the data observation. We can sample from this model to find the posterior distribution of the parameters. +2. `deterministic_uncond_mdl`: This _doesn't_ condition on the data. This is useful for prior and posterior predictive modelling. + +Here we construct the `Turing` model directly, in the [Mishra et al replication](https://cdcgov.github.io/Rt-without-renewal/dev/showcase/replications/mishra-2020/) we using the `EpiProblem` functionality to build a `Turing` model under the hood. +Because in this note we are using a mix of functionality from `SciML` and `EpiAware`, we construct the model to sample from directly. " # ╔═╡ dbc1b453-1c29-4f82-bec9-098d67f9e63f @@ -256,7 +257,7 @@ function plot_predYt(data, gens; title::String, ylabel::String) ylabel = ylabel ) pred_Yt = mapreduce(hcat, gens) do gen - gen.obsYt + gen.generated_y_t end |> X -> mapreduce(vcat, eachrow(X)) do row quantile(row, [0.5, 0.025, 0.975, 0.1, 0.9, 0.25, 0.75])' end @@ -433,8 +434,39 @@ ar = AR( init_priors = [Normal(0, 0.001)] ) +# ╔═╡ e1ffdaf6-ca2e-405d-8355-0d8848d005b0 +md" +We can sample directly from the behaviour specified by the `ar` struct to do prior predictive checking on the `AR(1)` process. +" + +# ╔═╡ de1498fa-8502-40ba-9708-2add74368e73 +let +nobs = size(data, 1) +ar_mdl = generate_latent(ar, nobs) +fig = Figure() +ax = Axis(fig[1,1], + xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string), + ylabel = "exp(kt)", + title = "Prior predictive sampling for relative residual in mean pred." +) +for i = 1:500 +lines!(ax, ar_mdl() .|> exp, color = (:grey, 0.15)) +end +fig +end + +# ╔═╡ 9a82c75a-6ea4-48bb-af06-fabaca4c45ee +md" +We see that the choice of priors implies an _a priori_ belief that the extra observation noise on the mean prediction of the ODE model is fairly small, approximately 10% relative to the mean prediction. +" + +# ╔═╡ b693a942-c6c7-40f8-997c-0dc8e5548132 +md" +We can now define the probabilistic model. +" + # ╔═╡ 9309f7f8-0896-4686-8bfc-b9f82d91bc0f -@model function stochastic_ode_mdl(Yt, ts, logobsprob, obs, prob, N; +@model function stochastic_ode_mdl(y_t, ts, logobsprob, obs, prob, N; solver = AutoTsit5(Rosenbrock23()), upjitter = 1e-2, ) @@ -457,15 +489,15 @@ ar = AR( ) ##Sample the log-residual AR process## - nobs = length(Yt) + nobs = length(y_t) @submodel κₜ = generate_latent(logobsprob, nobs) λt = @. N * sol[2, :] * exp(κₜ) + upjitter ##log-like accumulation using obs## - @submodel obsYt = generate_observations(obs, Yt, λt) + @submodel generated_y_t = generate_observations(obs, y_t, λt) ##Generated quantities## - return (; sol, obsYt, R0 = β / γ) + return (; sol, generated_y_t, R0 = β / γ) end # ╔═╡ 4330c83f-de39-44c7-bdab-87e5f5830145 @@ -624,6 +656,10 @@ end # ╠═48032d21-53fa-4c0a-85cb-c22327b55073 # ╟─89c767b8-97a0-45bb-9e9f-821879ddd38b # ╠═71a26408-1c26-46cf-bc72-c6ba528dfadd +# ╟─e1ffdaf6-ca2e-405d-8355-0d8848d005b0 +# ╠═de1498fa-8502-40ba-9708-2add74368e73 +# ╟─9a82c75a-6ea4-48bb-af06-fabaca4c45ee +# ╟─b693a942-c6c7-40f8-997c-0dc8e5548132 # ╠═9309f7f8-0896-4686-8bfc-b9f82d91bc0f # ╠═4330c83f-de39-44c7-bdab-87e5f5830145 # ╠═8071c92f-9fe8-48cf-b1a0-79d1e34ec7e7 From 8c35483f21989cc076f83a26da48b8c9b5132d5b Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 3 Oct 2024 11:21:43 +0100 Subject: [PATCH 15/17] reformat --- .../replications/chatzilena-2019/index.jl | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl index 0a328951e..c18d9dc3c 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -441,18 +441,18 @@ We can sample directly from the behaviour specified by the `ar` struct to do pri # ╔═╡ de1498fa-8502-40ba-9708-2add74368e73 let -nobs = size(data, 1) -ar_mdl = generate_latent(ar, nobs) -fig = Figure() -ax = Axis(fig[1,1], - xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string), - ylabel = "exp(kt)", - title = "Prior predictive sampling for relative residual in mean pred." -) -for i = 1:500 -lines!(ax, ar_mdl() .|> exp, color = (:grey, 0.15)) -end -fig + nobs = size(data, 1) + ar_mdl = generate_latent(ar, nobs) + fig = Figure() + ax = Axis(fig[1, 1], + xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string), + ylabel = "exp(kt)", + title = "Prior predictive sampling for relative residual in mean pred." + ) + for i in 1:500 + lines!(ax, ar_mdl() .|> exp, color = (:grey, 0.15)) + end + fig end # ╔═╡ 9a82c75a-6ea4-48bb-af06-fabaca4c45ee @@ -468,7 +468,7 @@ We can now define the probabilistic model. # ╔═╡ 9309f7f8-0896-4686-8bfc-b9f82d91bc0f @model function stochastic_ode_mdl(y_t, ts, logobsprob, obs, prob, N; solver = AutoTsit5(Rosenbrock23()), - upjitter = 1e-2, + upjitter = 1e-2 ) ##Priors## From e5cb99e8a9371e2aee1f23bbf132cf88f2c2c791 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 3 Oct 2024 11:57:57 +0100 Subject: [PATCH 16/17] implement xexpy and log1pexp --- EpiAware/docs/Project.toml | 1 + .../replications/chatzilena-2019/index.jl | 45 ++++++++++--------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/EpiAware/docs/Project.toml b/EpiAware/docs/Project.toml index 44538ff67..6e42c5742 100644 --- a/EpiAware/docs/Project.toml +++ b/EpiAware/docs/Project.toml @@ -10,6 +10,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da" Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781" diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl index c18d9dc3c..71df7acf0 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -25,7 +25,7 @@ using Turing using OrdinaryDiffEq, SciMLSensitivity #ODE solvers and adjoint methods # ╔═╡ 261420cd-4650-402b-b126-7a431f93f37e -using Distributions, Statistics #Statistics packages +using Distributions, Statistics, LogExpFunctions #Statistics and special func packages # ╔═╡ 9c19a98b-a08b-4560-966d-61ff0ece2ad5 using CSV, DataFramesMeta #Data wrangling @@ -191,12 +191,13 @@ obs = PoissonError() # ╔═╡ 81501c84-5e1f-4829-a26d-52fe00503958 md" Now we can write the probabilistic model using the `Turing` PPL. +Note that instead of using $I(t)$ directly we do the [softplus](https://en.wikipedia.org/wiki/Softplus) transform on $I(t)$ implemented by `LogExpFunctions.log1pexp`. +The reason is that the solver can return small negative numbers, the soft plus transform smoothly maintains positivity which being very close to $I(t)$ when $I(t) > 2$. " # ╔═╡ 1d287c8e-7000-4b23-ae7e-f7008c3e53bd @model function deterministic_ode_mdl(y_t, ts, obs, prob, N; - solver = AutoTsit5(Rosenbrock23()), - upjitter = 1e-3 + solver = AutoTsit5(Rosenbrock23()) ) ##Priors## β ~ LogNormal(0.0, 1.0) @@ -216,7 +217,7 @@ Now we can write the probabilistic model using the `Turing` PPL. verbose = false) ##log-like accumulation using obs## - λt = N * sol[2, :] .+ upjitter #expected It + λt = log1pexp.(N * sol[2, :] ) # #expected It @submodel generated_y_t = generate_observations(obs, y_t, λt) ##Generated quantities## @@ -299,8 +300,7 @@ Approaching the inference naively can lead to poor fits. We do three things to mitigate this: 1. We choose a switching ODE solver which switches between explicit (`Tsit5`) and implicit (`Rosenbrock23`) solvers. This helps avoid the ODE solver failing when the sampler tries extreme parameter values. This is the default `solver = AutoTsit5(Rosenbrock23())` above. -2. To avoid the effect of numerically negative small values of `λt` we add a small `upjitter`. -3. We locate the maximum likelihood point, that is we ignore the influence of the priors, as a useful starting point for `NUTS`. +2. We locate the maximum likelihood point, that is we ignore the influence of the priors, as a useful starting point for `NUTS`. " # ╔═╡ 8d96db67-de3b-4704-9f54-f4ed50a4ecff @@ -330,7 +330,8 @@ Now, we sample aiming at 1000 samples for each of 4 chains. # ╔═╡ 2cf64ba3-ff8d-40b0-9bd8-9e80393156f5 chn = sample( deterministic_mdl, NUTS(), MCMCThreads(), 1000, 4; - initial_params = fill(mle_fit.values.array, 4)) + initial_params = fill(mle_fit.values.array, 4), +) # ╔═╡ b2429b68-dd75-499f-a4e1-1b7d72e209c7 describe(chn) @@ -441,18 +442,18 @@ We can sample directly from the behaviour specified by the `ar` struct to do pri # ╔═╡ de1498fa-8502-40ba-9708-2add74368e73 let - nobs = size(data, 1) - ar_mdl = generate_latent(ar, nobs) - fig = Figure() - ax = Axis(fig[1, 1], - xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string), - ylabel = "exp(kt)", - title = "Prior predictive sampling for relative residual in mean pred." - ) - for i in 1:500 - lines!(ax, ar_mdl() .|> exp, color = (:grey, 0.15)) - end - fig +nobs = size(data, 1) +ar_mdl = generate_latent(ar, nobs) +fig = Figure() +ax = Axis(fig[1,1], + xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string), + ylabel = "exp(kt)", + title = "Prior predictive sampling for relative residual in mean pred." +) +for i = 1:500 +lines!(ax, ar_mdl() .|> exp, color = (:grey, 0.15)) +end +fig end # ╔═╡ 9a82c75a-6ea4-48bb-af06-fabaca4c45ee @@ -463,12 +464,12 @@ We see that the choice of priors implies an _a priori_ belief that the extra obs # ╔═╡ b693a942-c6c7-40f8-997c-0dc8e5548132 md" We can now define the probabilistic model. +Note that instead of implementing `exp.(κₜ)` directly, which can be unstable for large primal values, we use the `LogExpFunctions.xexpy` function which implements $x\exp(y)$ stabily for a wide range of values. " # ╔═╡ 9309f7f8-0896-4686-8bfc-b9f82d91bc0f @model function stochastic_ode_mdl(y_t, ts, logobsprob, obs, prob, N; solver = AutoTsit5(Rosenbrock23()), - upjitter = 1e-2 ) ##Priors## @@ -491,7 +492,7 @@ We can now define the probabilistic model. ##Sample the log-residual AR process## nobs = length(y_t) @submodel κₜ = generate_latent(logobsprob, nobs) - λt = @. N * sol[2, :] * exp(κₜ) + upjitter + λt = xexpy.(log1pexp.(N * sol[2, :]), κₜ) ##log-like accumulation using obs## @submodel generated_y_t = generate_observations(obs, y_t, λt) @@ -638,7 +639,7 @@ end # ╠═ab8c98d1-d357-4c49-9f5a-f069e05c45f5 # ╟─2c6ac235-e331-4189-8c8c-74de5f98b2c4 # ╠═a729f1cd-404c-4a33-a8f9-b2ea6f0adb62 -# ╠═4c0759fb-76e9-4de5-9206-89e8bfb6c3bb +# ╟─4c0759fb-76e9-4de5-9206-89e8bfb6c3bb # ╠═8d96db67-de3b-4704-9f54-f4ed50a4ecff # ╠═ba35cebd-0d29-43c5-8db7-f550d7f821bc # ╠═0be912c1-22dc-4978-b86a-84273062f5da From e27927f7a8fc41b87a0baa120bbf93bd739ec4ee Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 3 Oct 2024 11:58:14 +0100 Subject: [PATCH 17/17] reformat --- .../replications/chatzilena-2019/index.jl | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl index 71df7acf0..a07ff40aa 100644 --- a/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl +++ b/EpiAware/docs/src/showcase/replications/chatzilena-2019/index.jl @@ -217,7 +217,7 @@ The reason is that the solver can return small negative numbers, the soft plus t verbose = false) ##log-like accumulation using obs## - λt = log1pexp.(N * sol[2, :] ) # #expected It + λt = log1pexp.(N * sol[2, :]) # #expected It @submodel generated_y_t = generate_observations(obs, y_t, λt) ##Generated quantities## @@ -330,7 +330,7 @@ Now, we sample aiming at 1000 samples for each of 4 chains. # ╔═╡ 2cf64ba3-ff8d-40b0-9bd8-9e80393156f5 chn = sample( deterministic_mdl, NUTS(), MCMCThreads(), 1000, 4; - initial_params = fill(mle_fit.values.array, 4), + initial_params = fill(mle_fit.values.array, 4) ) # ╔═╡ b2429b68-dd75-499f-a4e1-1b7d72e209c7 @@ -442,18 +442,18 @@ We can sample directly from the behaviour specified by the `ar` struct to do pri # ╔═╡ de1498fa-8502-40ba-9708-2add74368e73 let -nobs = size(data, 1) -ar_mdl = generate_latent(ar, nobs) -fig = Figure() -ax = Axis(fig[1,1], - xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string), - ylabel = "exp(kt)", - title = "Prior predictive sampling for relative residual in mean pred." -) -for i = 1:500 -lines!(ax, ar_mdl() .|> exp, color = (:grey, 0.15)) -end -fig + nobs = size(data, 1) + ar_mdl = generate_latent(ar, nobs) + fig = Figure() + ax = Axis(fig[1, 1], + xticks = (data.ts[1:3:end], data.date[1:3:end] .|> string), + ylabel = "exp(kt)", + title = "Prior predictive sampling for relative residual in mean pred." + ) + for i in 1:500 + lines!(ax, ar_mdl() .|> exp, color = (:grey, 0.15)) + end + fig end # ╔═╡ 9a82c75a-6ea4-48bb-af06-fabaca4c45ee @@ -469,7 +469,7 @@ Note that instead of implementing `exp.(κₜ)` directly, which can be unstable # ╔═╡ 9309f7f8-0896-4686-8bfc-b9f82d91bc0f @model function stochastic_ode_mdl(y_t, ts, logobsprob, obs, prob, N; - solver = AutoTsit5(Rosenbrock23()), + solver = AutoTsit5(Rosenbrock23()) ) ##Priors##