|
| 1 | +### A Pluto.jl notebook ### |
| 2 | +# v0.19.46 |
| 3 | + |
| 4 | +using Markdown |
| 5 | +using InteractiveUtils |
| 6 | + |
| 7 | +# ╔═╡ a2624404-48b1-4faa-abbe-6d78b8e04f2b |
| 8 | +let |
| 9 | + docs_dir = dirname(dirname(dirname(@__DIR__))) |
| 10 | + pkg_dir = dirname(docs_dir) |
| 11 | + |
| 12 | + using Pkg: Pkg |
| 13 | + Pkg.activate(docs_dir) |
| 14 | + Pkg.develop(; path = pkg_dir) |
| 15 | + Pkg.instantiate() |
| 16 | +end |
| 17 | + |
| 18 | +# ╔═╡ 5baa8d2e-bcf8-4e3b-b007-175ad3e2ca95 |
| 19 | +begin |
| 20 | + using EpiAware.EpiAwareUtils: censored_pmf, censored_cdf, ∫F |
| 21 | + using Random, Distributions, StatsBase #utilities for random events |
| 22 | + using DataFramesMeta #Data wrangling |
| 23 | + using CairoMakie, PairPlots #plotting |
| 24 | + using Turing #PPL |
| 25 | +end |
| 26 | + |
| 27 | +# ╔═╡ 8de5c5e0-6e95-11ef-1693-bfd465c8d919 |
| 28 | +md" |
| 29 | +# Fitting distributions using `EpiAware` and Turing PPL |
| 30 | +
|
| 31 | +## Introduction |
| 32 | +
|
| 33 | +### What are we going to do in this Vignette |
| 34 | +
|
| 35 | +In this vignette, we'll demonstrate how to use the CDF function for censored delay distributions `EpiAwareUtils.∫F`, which underlies `EpiAwareUtils.censored_pmf` in conjunction with the Turing PPL for Bayesian inference of epidemiological delay distributions. We'll cover the following key points: |
| 36 | +
|
| 37 | +1. Simulating censored delay distribution data |
| 38 | +2. Fitting a naive model using Turing |
| 39 | +3. Evaluating the naive model's performance |
| 40 | +4. Fitting an improved model using censored delay functionality from `EpiAware`. |
| 41 | +5. Comparing the censored delay model's performance to the naive model |
| 42 | +
|
| 43 | +### What might I need to know before starting |
| 44 | +
|
| 45 | +This note builds on the concepts introduced in the R/stan package [`primarycensoreddist`](https://github.com/epinowcast/primarycensoreddist), especially the [Fitting distributions using primarycensorseddist and cmdstan](https://primarycensoreddist.epinowcast.org/articles/fitting-dists-with-stan.html) vignette and assumes familiarity with using Turing tools as covered in the [Turing documentation](https://turinglang.org/). |
| 46 | +
|
| 47 | +This note is generated using the `EpiAware` package locally via `Pkg.develop`, in the `EpiAware/docs` environment. It is also possible to install `EpiAware` using |
| 48 | +
|
| 49 | +```julia |
| 50 | +Pkg.add(url=\"https://github.com/CDCgov/Rt-without-renewal\", subdir=\"EpiAware\") |
| 51 | +``` |
| 52 | +### Packages used in this vignette |
| 53 | +As well as `EpiAware` and `Turing` we will use `Makie` ecosystem packages for plotting and `DataFramesMeta` for data manipulation. |
| 54 | +" |
| 55 | + |
| 56 | +# ╔═╡ 30dd9af4-b64f-42b1-8439-a890752f68e3 |
| 57 | +md" |
| 58 | +The other dependencies are as follows: |
| 59 | +" |
| 60 | + |
| 61 | +# ╔═╡ c5704f67-208d-4c2e-8513-c07c6b94ca99 |
| 62 | +md" |
| 63 | +## Simulating censored and truncated delay distribution data |
| 64 | +
|
| 65 | +We'll start by simulating some censored and truncated delay distribution data. We’ll define a `rpcens` function for generating data. |
| 66 | +" |
| 67 | + |
| 68 | +# ╔═╡ aed124c7-b4ba-4c97-a01f-ff553f376c86 |
| 69 | +Random.seed!(123) # For reproducibility |
| 70 | + |
| 71 | +# ╔═╡ ec5ed3e9-6ea9-4cfe-afd2-82aabbbe8130 |
| 72 | +md"Define the true distribution parameters" |
| 73 | + |
| 74 | +# ╔═╡ 105b9594-36ce-4ae8-87a8-5c81867b1ce3 |
| 75 | +n = 2000 |
| 76 | + |
| 77 | +# ╔═╡ 8aa9f9c1-d3c4-49f3-be18-a400fc71e8f7 |
| 78 | +meanlog = 1.5 |
| 79 | + |
| 80 | +# ╔═╡ 84bb3999-9f2b-4eaa-9c2d-776a86677eaf |
| 81 | +sdlog = 0.75 |
| 82 | + |
| 83 | +# ╔═╡ 2bf6677e-ebe9-4aa8-aa91-f631e99669bb |
| 84 | +true_dist = LogNormal(meanlog, sdlog) |
| 85 | + |
| 86 | +# ╔═╡ f4083aea-8106-401a-b60f-383d0b94102a |
| 87 | +md"Generate varying pwindow, swindow, and obs_time lengths |
| 88 | +" |
| 89 | + |
| 90 | +# ╔═╡ aea8b28e-fffe-4aa6-b51e-8199a7c7975c |
| 91 | +pwindows = rand(1:2, n) |
| 92 | + |
| 93 | +# ╔═╡ 4d3a853d-0b8d-402a-8309-e9f6da2b7a8c |
| 94 | +swindows = rand(1:2, n) |
| 95 | + |
| 96 | +# ╔═╡ 7522f05b-1750-4983-8947-ef70f4298d06 |
| 97 | +obs_times = rand(8:10, n) |
| 98 | + |
| 99 | +# ╔═╡ 5eac2f60-8cec-4460-9d10-6bade7f0f406 |
| 100 | +md" |
| 101 | +We recreate the primary censored sampling function from `primarycensoreddist`, c.f. documentation [here](https://primarycensoreddist.epinowcast.org/reference/rprimarycensoreddist.html). |
| 102 | +" |
| 103 | + |
| 104 | +# ╔═╡ 9443b893-9e22-4267-9a1f-319a3adb8c0d |
| 105 | +""" |
| 106 | + function rpcens(dist; pwindow = 1, swindow = 1, D = Inf, max_tries = 1000) |
| 107 | +
|
| 108 | +Does a truncated censored sample from `dist` with a uniform primary time on `[0, pwindow]`. |
| 109 | +""" |
| 110 | +function rpcens(dist; pwindow = 1, swindow = 1, D = Inf, max_tries = 1000) |
| 111 | + T = zero(eltype(dist)) |
| 112 | + invalid_sample = true |
| 113 | + attempts = 1 |
| 114 | + while (invalid_sample && attempts <= max_tries) |
| 115 | + X = rand(dist) |
| 116 | + U = rand() * pwindow |
| 117 | + T = X + U |
| 118 | + attempts += 1 |
| 119 | + if X + U < D |
| 120 | + invalid_sample = false |
| 121 | + end |
| 122 | + end |
| 123 | + |
| 124 | + @assert !invalid_sample "censored value not found in $max_tries attempts" |
| 125 | + |
| 126 | + return (T ÷ swindow) * swindow |
| 127 | +end |
| 128 | + |
| 129 | +# ╔═╡ a4f5e9b6-ff3a-48fa-aa51-0abccb9c7bed |
| 130 | +#Sample secondary time relative to beginning of primary censor window respecting the right-truncation |
| 131 | +samples = map(pwindows, swindows, obs_times) do pw, sw, ot |
| 132 | + rpcens(true_dist; pwindow = pw, swindow = sw, D = ot) |
| 133 | +end |
| 134 | + |
| 135 | +# ╔═╡ 2a9da9e5-0925-4ae0-8b70-8db90903cb0b |
| 136 | +md" |
| 137 | +Aggregate to unique combinations and count occurrences |
| 138 | +" |
| 139 | + |
| 140 | +# ╔═╡ 0b5e96eb-9312-472e-8a88-d4509a4f25d0 |
| 141 | +delay_counts = mapreduce(vcat, pwindows, swindows, obs_times, samples) do pw, sw, ot, s |
| 142 | + DataFrame( |
| 143 | + pwindow = pw, |
| 144 | + swindow = sw, |
| 145 | + obs_time = ot, |
| 146 | + observed_delay = s, |
| 147 | + observed_delay_upper = s + sw |
| 148 | + ) |
| 149 | +end |> |
| 150 | + df -> @groupby(df, :pwindow, :swindow, :obs_time, :observed_delay, |
| 151 | + :observed_delay_upper) |> |
| 152 | + gd -> @combine(gd, :n=length(:pwindow)) |
| 153 | + |
| 154 | +# ╔═╡ c0cce80f-dec7-4a55-aefd-339ef863f854 |
| 155 | +md" |
| 156 | +Compare the samples with and without secondary censoring to the true distribution and calculate empirical CDF |
| 157 | +" |
| 158 | + |
| 159 | +# ╔═╡ a7bff47d-b61f-499e-8631-206661c2bdc0 |
| 160 | +empirical_cdf = ecdf(samples) |
| 161 | + |
| 162 | +# ╔═╡ 16bcb80a-970f-4633-aca2-261fa04172f7 |
| 163 | +empirical_cdf_obs = ecdf(delay_counts.observed_delay, weights = delay_counts.n) |
| 164 | + |
| 165 | +# ╔═╡ 60711c3c-266e-42b5-acc6-6624db294f24 |
| 166 | +x_seq = range(minimum(samples), maximum(samples), 100) |
| 167 | + |
| 168 | +# ╔═╡ 1f1bcee4-8e0d-46fb-9a6f-41998bf54957 |
| 169 | +theoretical_cdf = x_seq |> x -> cdf(true_dist, x) |
| 170 | + |
| 171 | +# ╔═╡ 59bb2a18-eaf4-438a-9359-341efadfe897 |
| 172 | +let |
| 173 | + f = Figure() |
| 174 | + ax = Axis(f[1, 1], |
| 175 | + title = "Comparison of Observed vs Theoretical CDF", |
| 176 | + ylabel = "Cumulative Probability", |
| 177 | + xlabel = "Delay" |
| 178 | + ) |
| 179 | + lines!( |
| 180 | + ax, x_seq, empirical_cdf_obs, label = "Empirical CDF", color = :blue, linewidth = 2) |
| 181 | + lines!(ax, x_seq, theoretical_cdf, label = "Theoretical CDF", |
| 182 | + color = :black, linewidth = 2) |
| 183 | + vlines!(ax, [mean(samples)], color = :blue, linestyle = :dash, |
| 184 | + label = "Empirical mean", linewidth = 2) |
| 185 | + vlines!(ax, [mean(true_dist)], linestyle = :dash, |
| 186 | + label = "Theoretical mean", color = :black, linewidth = 2) |
| 187 | + axislegend(position = :rb) |
| 188 | + |
| 189 | + f |
| 190 | +end |
| 191 | + |
| 192 | +# ╔═╡ f66d4b2e-ed66-423e-9cba-62bff712862b |
| 193 | +md" |
| 194 | +We've aggregated the data to unique combinations of `pwindow`, `swindow`, and `obs_time` and counted the number of occurrences of each `observed_delay` for each combination. This is the data we will use to fit our model. |
| 195 | +" |
| 196 | + |
| 197 | +# ╔═╡ 010ebe37-782b-4a35-bf5c-dca6dc0fee45 |
| 198 | +md" |
| 199 | +## Fitting a naive model using Turing |
| 200 | +
|
| 201 | +We'll start by fitting a naive model using NUTS from `Turing`. We define the model in the `Turing` PPL. |
| 202 | +" |
| 203 | + |
| 204 | +# ╔═╡ d9d14c48-8700-42b5-89b4-7fc51d0f577c |
| 205 | +@model function naive_model(N, y, n) |
| 206 | + mu ~ Normal(1.0, 1.0) |
| 207 | + sigma ~ truncated(Normal(0.5, 1.0); lower = 0.0) |
| 208 | + d = LogNormal(mu, sigma) |
| 209 | + |
| 210 | + for i in eachindex(y) |
| 211 | + Turing.@addlogprob! n[i] * logpdf(d, y[i]) |
| 212 | + end |
| 213 | +end |
| 214 | + |
| 215 | +# ╔═╡ 8a7cd9ec-5640-4f5f-84c3-ae3f465ca68b |
| 216 | +md" |
| 217 | +Now lets instantiate this model with data |
| 218 | +" |
| 219 | + |
| 220 | +# ╔═╡ 028ade5c-17bd-4dfc-8433-23aaff02c181 |
| 221 | +naive_mdl = naive_model( |
| 222 | + size(delay_counts, 1), |
| 223 | + delay_counts.observed_delay .+ 1e-6, # Add a small constant to avoid log(0) |
| 224 | + delay_counts.n) |
| 225 | + |
| 226 | +# ╔═╡ 04b4eefb-f0f9-4887-8db0-7cbb7f3b169b |
| 227 | +md" |
| 228 | +and now let's fit the compiled model. |
| 229 | +" |
| 230 | + |
| 231 | +# ╔═╡ 21655344-d12b-4e47-a9a9-d06bd909f6ea |
| 232 | +naive_fit = sample(naive_mdl, NUTS(), MCMCThreads(), 500, 4) |
| 233 | + |
| 234 | +# ╔═╡ 3b89fe00-6aaf-4764-8b29-e71479f1e641 |
| 235 | +summarize(naive_fit) |
| 236 | + |
| 237 | +# ╔═╡ 8e09d931-fca7-4ac2-81f7-2bc36b0174f3 |
| 238 | +let |
| 239 | + f = pairplot(naive_fit) |
| 240 | + vlines!(f[1, 1], [meanlog], linewidth = 4) |
| 241 | + vlines!(f[2, 2], [sdlog], linewidth = 4) |
| 242 | + f |
| 243 | +end |
| 244 | + |
| 245 | +# ╔═╡ 43eac8dd-8f1d-440e-b1e8-85db9e740651 |
| 246 | +md" |
| 247 | +We see that the model has converged and the diagnostics look good. However, just from the model posterior summary we see that we might not be very happy with the fit. `mu` is smaller than the target $(meanlog) and `sigma` is larger than the target $(sdlog). |
| 248 | +
|
| 249 | +" |
| 250 | + |
| 251 | +# ╔═╡ b2efafab-8849-4a7a-bb64-ac9ce126ca75 |
| 252 | +md" |
| 253 | +## Fitting an improved model using censoring utilities |
| 254 | +
|
| 255 | +We'll now fit an improved model using the `∫F` function from `EpiAware.EpiAwareUtils` for calculating the CDF of the _total delay_ from the beginning of the primary window to the secondary event time. This includes both the delay distribution we are making inference on and the time between the start of the primary censor window and the primary event. |
| 256 | +The `∫F` function underlies `censored_pmf` function from the `EpiAware.EpiAwareUtils` submodule. |
| 257 | +
|
| 258 | +Using the `∫F` function we can write a log-pmf function `primary_censored_dist_lpmf` that accounts for: |
| 259 | +- The primary and secondary censoring windows, which can vary in length. |
| 260 | +- The effect of right truncation in biasing our observations. |
| 261 | +
|
| 262 | +This is the analog function to the function of the same name in `primarycensoreddist`: it calculates the log-probability of the secondary event occurring in the secondary censoring window conditional on the primary event occurring in the primary censoring window by calculating the increase in the CDF over the secondary window and rescaling by the probability of the secondary event occuring within the maximum observation time `D`. |
| 263 | +" |
| 264 | + |
| 265 | +# ╔═╡ 348fc3b4-073b-4997-ae50-58ede5d6d0c9 |
| 266 | +function primary_censored_dist_lpmf(dist, y, pwindow, y_upper, D) |
| 267 | + if y == 0.0 |
| 268 | + return log(∫F(dist, y_upper, pwindow)) - log(∫F(dist, D, pwindow)) |
| 269 | + else |
| 270 | + return log(∫F(dist, y_upper, pwindow) - ∫F(dist, y, pwindow)) - |
| 271 | + log(∫F(dist, D, pwindow)) |
| 272 | + end |
| 273 | +end |
| 274 | + |
| 275 | +# ╔═╡ cefb5d56-fecd-4de7-bd0e-156be91c705c |
| 276 | +md" |
| 277 | +We make a new `Turing` model that now uses `primary_censored_dist_lpmf` rather than the naive uncensored and untruncated `logpdf`. |
| 278 | +" |
| 279 | + |
| 280 | +# ╔═╡ ef40112b-f23e-4d4b-8a7d-3793b786f472 |
| 281 | +@model function primarycensoreddist_model(y, y_upper, n, pws, Ds) |
| 282 | + mu ~ Normal(1.0, 1.0) |
| 283 | + sigma ~ truncated(Normal(0.5, 0.5); lower = 0.0) |
| 284 | + dist = LogNormal(mu, sigma) |
| 285 | + |
| 286 | + for i in eachindex(y) |
| 287 | + Turing.@addlogprob! n[i] * primary_censored_dist_lpmf( |
| 288 | + dist, y[i], pws[i], y_upper[i], Ds[i]) |
| 289 | + end |
| 290 | +end |
| 291 | + |
| 292 | +# ╔═╡ b823d824-419d-41e9-9ac9-2c45ef190acf |
| 293 | +md" |
| 294 | +Lets instantiate this model with data |
| 295 | +" |
| 296 | + |
| 297 | +# ╔═╡ 93bca93a-5484-47fa-8424-7315eef15e37 |
| 298 | +primarycensoreddist_mdl = primarycensoreddist_model( |
| 299 | + delay_counts.observed_delay, |
| 300 | + delay_counts.observed_delay_upper, |
| 301 | + delay_counts.n, |
| 302 | + delay_counts.pwindow, |
| 303 | + delay_counts.obs_time |
| 304 | +) |
| 305 | + |
| 306 | +# ╔═╡ d5144247-eb57-48bf-8e32-fd71167ecbc8 |
| 307 | +md"Now let’s fit the compiled model." |
| 308 | + |
| 309 | +# ╔═╡ 7ae6c61d-0e33-4af8-b8d2-e31223a15a7c |
| 310 | +primarycensoreddist_fit = sample( |
| 311 | + primarycensoreddist_mdl, NUTS(), MCMCThreads(), 1000, 4) |
| 312 | + |
| 313 | +# ╔═╡ 1210443f-480f-4e9f-b195-d557e9e1fc31 |
| 314 | +summarize(primarycensoreddist_fit) |
| 315 | + |
| 316 | +# ╔═╡ b2376beb-dd7b-442d-9ff5-ac864e75366b |
| 317 | +let |
| 318 | + f = pairplot(primarycensoreddist_fit) |
| 319 | + CairoMakie.vlines!(f[1, 1], [meanlog], linewidth = 3) |
| 320 | + CairoMakie.vlines!(f[2, 2], [sdlog], linewidth = 3) |
| 321 | + f |
| 322 | +end |
| 323 | + |
| 324 | +# ╔═╡ 673b47ec-b333-45e8-9557-9e65ad425c35 |
| 325 | +md" |
| 326 | +We see that the model has converged and the diagnostics look good. We also see that the posterior means are very near the true parameters and the 90% credible intervals include the true parameters. |
| 327 | +" |
| 328 | + |
| 329 | +# ╔═╡ Cell order: |
| 330 | +# ╟─8de5c5e0-6e95-11ef-1693-bfd465c8d919 |
| 331 | +# ╠═a2624404-48b1-4faa-abbe-6d78b8e04f2b |
| 332 | +# ╟─30dd9af4-b64f-42b1-8439-a890752f68e3 |
| 333 | +# ╠═5baa8d2e-bcf8-4e3b-b007-175ad3e2ca95 |
| 334 | +# ╟─c5704f67-208d-4c2e-8513-c07c6b94ca99 |
| 335 | +# ╠═aed124c7-b4ba-4c97-a01f-ff553f376c86 |
| 336 | +# ╟─ec5ed3e9-6ea9-4cfe-afd2-82aabbbe8130 |
| 337 | +# ╠═105b9594-36ce-4ae8-87a8-5c81867b1ce3 |
| 338 | +# ╠═8aa9f9c1-d3c4-49f3-be18-a400fc71e8f7 |
| 339 | +# ╠═84bb3999-9f2b-4eaa-9c2d-776a86677eaf |
| 340 | +# ╠═2bf6677e-ebe9-4aa8-aa91-f631e99669bb |
| 341 | +# ╟─f4083aea-8106-401a-b60f-383d0b94102a |
| 342 | +# ╠═aea8b28e-fffe-4aa6-b51e-8199a7c7975c |
| 343 | +# ╠═4d3a853d-0b8d-402a-8309-e9f6da2b7a8c |
| 344 | +# ╠═7522f05b-1750-4983-8947-ef70f4298d06 |
| 345 | +# ╟─5eac2f60-8cec-4460-9d10-6bade7f0f406 |
| 346 | +# ╠═9443b893-9e22-4267-9a1f-319a3adb8c0d |
| 347 | +# ╠═a4f5e9b6-ff3a-48fa-aa51-0abccb9c7bed |
| 348 | +# ╟─2a9da9e5-0925-4ae0-8b70-8db90903cb0b |
| 349 | +# ╠═0b5e96eb-9312-472e-8a88-d4509a4f25d0 |
| 350 | +# ╟─c0cce80f-dec7-4a55-aefd-339ef863f854 |
| 351 | +# ╠═a7bff47d-b61f-499e-8631-206661c2bdc0 |
| 352 | +# ╠═16bcb80a-970f-4633-aca2-261fa04172f7 |
| 353 | +# ╠═60711c3c-266e-42b5-acc6-6624db294f24 |
| 354 | +# ╠═1f1bcee4-8e0d-46fb-9a6f-41998bf54957 |
| 355 | +# ╠═59bb2a18-eaf4-438a-9359-341efadfe897 |
| 356 | +# ╟─f66d4b2e-ed66-423e-9cba-62bff712862b |
| 357 | +# ╠═010ebe37-782b-4a35-bf5c-dca6dc0fee45 |
| 358 | +# ╠═d9d14c48-8700-42b5-89b4-7fc51d0f577c |
| 359 | +# ╟─8a7cd9ec-5640-4f5f-84c3-ae3f465ca68b |
| 360 | +# ╠═028ade5c-17bd-4dfc-8433-23aaff02c181 |
| 361 | +# ╟─04b4eefb-f0f9-4887-8db0-7cbb7f3b169b |
| 362 | +# ╠═21655344-d12b-4e47-a9a9-d06bd909f6ea |
| 363 | +# ╠═3b89fe00-6aaf-4764-8b29-e71479f1e641 |
| 364 | +# ╠═8e09d931-fca7-4ac2-81f7-2bc36b0174f3 |
| 365 | +# ╟─43eac8dd-8f1d-440e-b1e8-85db9e740651 |
| 366 | +# ╟─b2efafab-8849-4a7a-bb64-ac9ce126ca75 |
| 367 | +# ╠═348fc3b4-073b-4997-ae50-58ede5d6d0c9 |
| 368 | +# ╟─cefb5d56-fecd-4de7-bd0e-156be91c705c |
| 369 | +# ╠═ef40112b-f23e-4d4b-8a7d-3793b786f472 |
| 370 | +# ╟─b823d824-419d-41e9-9ac9-2c45ef190acf |
| 371 | +# ╠═93bca93a-5484-47fa-8424-7315eef15e37 |
| 372 | +# ╟─d5144247-eb57-48bf-8e32-fd71167ecbc8 |
| 373 | +# ╠═7ae6c61d-0e33-4af8-b8d2-e31223a15a7c |
| 374 | +# ╠═1210443f-480f-4e9f-b195-d557e9e1fc31 |
| 375 | +# ╠═b2376beb-dd7b-442d-9ff5-ac864e75366b |
| 376 | +# ╟─673b47ec-b333-45e8-9557-9e65ad425c35 |
0 commit comments