Skip to content

Commit cae03e3

Browse files
committed
Merge branch 'main' into 446-add-chatzilena-et-al-as-a-replication-example
2 parents b2fae2d + f1f3571 commit cae03e3

File tree

6 files changed

+490
-25
lines changed

6 files changed

+490
-25
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,4 @@ EpiAware/docs/src/examples/*.md
389389
# benchmark ignore
390390
/.benchmarkci
391391
/benchmark/*.json
392+
EpiAware/docs/src/getting-started/tutorials/censored-obs.md

EpiAware/docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ getting_started_pages = Any[
1818
"Nowcasting" => "getting-started/tutorials/nowcasting.md",
1919
"Multiple observation models" => "getting-started/tutorials/multiple-observation-models.md",
2020
"Multiple infection processes" => "getting-started/tutorials/multiple-infection-processes.md",
21+
"Fitting distributions with censored data" => "getting-started/tutorials/censored-obs.md",
2122
"Partial pooling" => "getting-started/tutorials/partial-pooling.md"
2223
]
2324
]
Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
### A Pluto.jl notebook ###
2+
# v0.19.46
3+
4+
using Markdown
5+
using InteractiveUtils
6+
7+
# ╔═╡ a2624404-48b1-4faa-abbe-6d78b8e04f2b
8+
let
9+
docs_dir = dirname(dirname(dirname(@__DIR__)))
10+
pkg_dir = dirname(docs_dir)
11+
12+
using Pkg: Pkg
13+
Pkg.activate(docs_dir)
14+
Pkg.develop(; path = pkg_dir)
15+
Pkg.instantiate()
16+
end
17+
18+
# ╔═╡ 5baa8d2e-bcf8-4e3b-b007-175ad3e2ca95
19+
begin
20+
using EpiAware.EpiAwareUtils: censored_pmf, censored_cdf, ∫F
21+
using Random, Distributions, StatsBase #utilities for random events
22+
using DataFramesMeta #Data wrangling
23+
using CairoMakie, PairPlots #plotting
24+
using Turing #PPL
25+
end
26+
27+
# ╔═╡ 8de5c5e0-6e95-11ef-1693-bfd465c8d919
28+
md"
29+
# Fitting distributions using `EpiAware` and Turing PPL
30+
31+
## Introduction
32+
33+
### What are we going to do in this Vignette
34+
35+
In this vignette, we'll demonstrate how to use the CDF function for censored delay distributions `EpiAwareUtils.∫F`, which underlies `EpiAwareUtils.censored_pmf` in conjunction with the Turing PPL for Bayesian inference of epidemiological delay distributions. We'll cover the following key points:
36+
37+
1. Simulating censored delay distribution data
38+
2. Fitting a naive model using Turing
39+
3. Evaluating the naive model's performance
40+
4. Fitting an improved model using censored delay functionality from `EpiAware`.
41+
5. Comparing the censored delay model's performance to the naive model
42+
43+
### What might I need to know before starting
44+
45+
This note builds on the concepts introduced in the R/stan package [`primarycensoreddist`](https://github.com/epinowcast/primarycensoreddist), especially the [Fitting distributions using primarycensorseddist and cmdstan](https://primarycensoreddist.epinowcast.org/articles/fitting-dists-with-stan.html) vignette and assumes familiarity with using Turing tools as covered in the [Turing documentation](https://turinglang.org/).
46+
47+
This note is generated using the `EpiAware` package locally via `Pkg.develop`, in the `EpiAware/docs` environment. It is also possible to install `EpiAware` using
48+
49+
```julia
50+
Pkg.add(url=\"https://github.com/CDCgov/Rt-without-renewal\", subdir=\"EpiAware\")
51+
```
52+
### Packages used in this vignette
53+
As well as `EpiAware` and `Turing` we will use `Makie` ecosystem packages for plotting and `DataFramesMeta` for data manipulation.
54+
"
55+
56+
# ╔═╡ 30dd9af4-b64f-42b1-8439-a890752f68e3
57+
md"
58+
The other dependencies are as follows:
59+
"
60+
61+
# ╔═╡ c5704f67-208d-4c2e-8513-c07c6b94ca99
62+
md"
63+
## Simulating censored and truncated delay distribution data
64+
65+
We'll start by simulating some censored and truncated delay distribution data. We’ll define a `rpcens` function for generating data.
66+
"
67+
68+
# ╔═╡ aed124c7-b4ba-4c97-a01f-ff553f376c86
69+
Random.seed!(123) # For reproducibility
70+
71+
# ╔═╡ ec5ed3e9-6ea9-4cfe-afd2-82aabbbe8130
72+
md"Define the true distribution parameters"
73+
74+
# ╔═╡ 105b9594-36ce-4ae8-87a8-5c81867b1ce3
75+
n = 2000
76+
77+
# ╔═╡ 8aa9f9c1-d3c4-49f3-be18-a400fc71e8f7
78+
meanlog = 1.5
79+
80+
# ╔═╡ 84bb3999-9f2b-4eaa-9c2d-776a86677eaf
81+
sdlog = 0.75
82+
83+
# ╔═╡ 2bf6677e-ebe9-4aa8-aa91-f631e99669bb
84+
true_dist = LogNormal(meanlog, sdlog)
85+
86+
# ╔═╡ f4083aea-8106-401a-b60f-383d0b94102a
87+
md"Generate varying pwindow, swindow, and obs_time lengths
88+
"
89+
90+
# ╔═╡ aea8b28e-fffe-4aa6-b51e-8199a7c7975c
91+
pwindows = rand(1:2, n)
92+
93+
# ╔═╡ 4d3a853d-0b8d-402a-8309-e9f6da2b7a8c
94+
swindows = rand(1:2, n)
95+
96+
# ╔═╡ 7522f05b-1750-4983-8947-ef70f4298d06
97+
obs_times = rand(8:10, n)
98+
99+
# ╔═╡ 5eac2f60-8cec-4460-9d10-6bade7f0f406
100+
md"
101+
We recreate the primary censored sampling function from `primarycensoreddist`, c.f. documentation [here](https://primarycensoreddist.epinowcast.org/reference/rprimarycensoreddist.html).
102+
"
103+
104+
# ╔═╡ 9443b893-9e22-4267-9a1f-319a3adb8c0d
105+
"""
106+
function rpcens(dist; pwindow = 1, swindow = 1, D = Inf, max_tries = 1000)
107+
108+
Does a truncated censored sample from `dist` with a uniform primary time on `[0, pwindow]`.
109+
"""
110+
function rpcens(dist; pwindow = 1, swindow = 1, D = Inf, max_tries = 1000)
111+
T = zero(eltype(dist))
112+
invalid_sample = true
113+
attempts = 1
114+
while (invalid_sample && attempts <= max_tries)
115+
X = rand(dist)
116+
U = rand() * pwindow
117+
T = X + U
118+
attempts += 1
119+
if X + U < D
120+
invalid_sample = false
121+
end
122+
end
123+
124+
@assert !invalid_sample "censored value not found in $max_tries attempts"
125+
126+
return (T ÷ swindow) * swindow
127+
end
128+
129+
# ╔═╡ a4f5e9b6-ff3a-48fa-aa51-0abccb9c7bed
130+
#Sample secondary time relative to beginning of primary censor window respecting the right-truncation
131+
samples = map(pwindows, swindows, obs_times) do pw, sw, ot
132+
rpcens(true_dist; pwindow = pw, swindow = sw, D = ot)
133+
end
134+
135+
# ╔═╡ 2a9da9e5-0925-4ae0-8b70-8db90903cb0b
136+
md"
137+
Aggregate to unique combinations and count occurrences
138+
"
139+
140+
# ╔═╡ 0b5e96eb-9312-472e-8a88-d4509a4f25d0
141+
delay_counts = mapreduce(vcat, pwindows, swindows, obs_times, samples) do pw, sw, ot, s
142+
DataFrame(
143+
pwindow = pw,
144+
swindow = sw,
145+
obs_time = ot,
146+
observed_delay = s,
147+
observed_delay_upper = s + sw
148+
)
149+
end |>
150+
df -> @groupby(df, :pwindow, :swindow, :obs_time, :observed_delay,
151+
:observed_delay_upper) |>
152+
gd -> @combine(gd, :n=length(:pwindow))
153+
154+
# ╔═╡ c0cce80f-dec7-4a55-aefd-339ef863f854
155+
md"
156+
Compare the samples with and without secondary censoring to the true distribution and calculate empirical CDF
157+
"
158+
159+
# ╔═╡ a7bff47d-b61f-499e-8631-206661c2bdc0
160+
empirical_cdf = ecdf(samples)
161+
162+
# ╔═╡ 16bcb80a-970f-4633-aca2-261fa04172f7
163+
empirical_cdf_obs = ecdf(delay_counts.observed_delay, weights = delay_counts.n)
164+
165+
# ╔═╡ 60711c3c-266e-42b5-acc6-6624db294f24
166+
x_seq = range(minimum(samples), maximum(samples), 100)
167+
168+
# ╔═╡ 1f1bcee4-8e0d-46fb-9a6f-41998bf54957
169+
theoretical_cdf = x_seq |> x -> cdf(true_dist, x)
170+
171+
# ╔═╡ 59bb2a18-eaf4-438a-9359-341efadfe897
172+
let
173+
f = Figure()
174+
ax = Axis(f[1, 1],
175+
title = "Comparison of Observed vs Theoretical CDF",
176+
ylabel = "Cumulative Probability",
177+
xlabel = "Delay"
178+
)
179+
lines!(
180+
ax, x_seq, empirical_cdf_obs, label = "Empirical CDF", color = :blue, linewidth = 2)
181+
lines!(ax, x_seq, theoretical_cdf, label = "Theoretical CDF",
182+
color = :black, linewidth = 2)
183+
vlines!(ax, [mean(samples)], color = :blue, linestyle = :dash,
184+
label = "Empirical mean", linewidth = 2)
185+
vlines!(ax, [mean(true_dist)], linestyle = :dash,
186+
label = "Theoretical mean", color = :black, linewidth = 2)
187+
axislegend(position = :rb)
188+
189+
f
190+
end
191+
192+
# ╔═╡ f66d4b2e-ed66-423e-9cba-62bff712862b
193+
md"
194+
We've aggregated the data to unique combinations of `pwindow`, `swindow`, and `obs_time` and counted the number of occurrences of each `observed_delay` for each combination. This is the data we will use to fit our model.
195+
"
196+
197+
# ╔═╡ 010ebe37-782b-4a35-bf5c-dca6dc0fee45
198+
md"
199+
## Fitting a naive model using Turing
200+
201+
We'll start by fitting a naive model using NUTS from `Turing`. We define the model in the `Turing` PPL.
202+
"
203+
204+
# ╔═╡ d9d14c48-8700-42b5-89b4-7fc51d0f577c
205+
@model function naive_model(N, y, n)
206+
mu ~ Normal(1.0, 1.0)
207+
sigma ~ truncated(Normal(0.5, 1.0); lower = 0.0)
208+
d = LogNormal(mu, sigma)
209+
210+
for i in eachindex(y)
211+
Turing.@addlogprob! n[i] * logpdf(d, y[i])
212+
end
213+
end
214+
215+
# ╔═╡ 8a7cd9ec-5640-4f5f-84c3-ae3f465ca68b
216+
md"
217+
Now lets instantiate this model with data
218+
"
219+
220+
# ╔═╡ 028ade5c-17bd-4dfc-8433-23aaff02c181
221+
naive_mdl = naive_model(
222+
size(delay_counts, 1),
223+
delay_counts.observed_delay .+ 1e-6, # Add a small constant to avoid log(0)
224+
delay_counts.n)
225+
226+
# ╔═╡ 04b4eefb-f0f9-4887-8db0-7cbb7f3b169b
227+
md"
228+
and now let's fit the compiled model.
229+
"
230+
231+
# ╔═╡ 21655344-d12b-4e47-a9a9-d06bd909f6ea
232+
naive_fit = sample(naive_mdl, NUTS(), MCMCThreads(), 500, 4)
233+
234+
# ╔═╡ 3b89fe00-6aaf-4764-8b29-e71479f1e641
235+
summarize(naive_fit)
236+
237+
# ╔═╡ 8e09d931-fca7-4ac2-81f7-2bc36b0174f3
238+
let
239+
f = pairplot(naive_fit)
240+
vlines!(f[1, 1], [meanlog], linewidth = 4)
241+
vlines!(f[2, 2], [sdlog], linewidth = 4)
242+
f
243+
end
244+
245+
# ╔═╡ 43eac8dd-8f1d-440e-b1e8-85db9e740651
246+
md"
247+
We see that the model has converged and the diagnostics look good. However, just from the model posterior summary we see that we might not be very happy with the fit. `mu` is smaller than the target $(meanlog) and `sigma` is larger than the target $(sdlog).
248+
249+
"
250+
251+
# ╔═╡ b2efafab-8849-4a7a-bb64-ac9ce126ca75
252+
md"
253+
## Fitting an improved model using censoring utilities
254+
255+
We'll now fit an improved model using the `∫F` function from `EpiAware.EpiAwareUtils` for calculating the CDF of the _total delay_ from the beginning of the primary window to the secondary event time. This includes both the delay distribution we are making inference on and the time between the start of the primary censor window and the primary event.
256+
The `∫F` function underlies `censored_pmf` function from the `EpiAware.EpiAwareUtils` submodule.
257+
258+
Using the `∫F` function we can write a log-pmf function `primary_censored_dist_lpmf` that accounts for:
259+
- The primary and secondary censoring windows, which can vary in length.
260+
- The effect of right truncation in biasing our observations.
261+
262+
This is the analog function to the function of the same name in `primarycensoreddist`: it calculates the log-probability of the secondary event occurring in the secondary censoring window conditional on the primary event occurring in the primary censoring window by calculating the increase in the CDF over the secondary window and rescaling by the probability of the secondary event occuring within the maximum observation time `D`.
263+
"
264+
265+
# ╔═╡ 348fc3b4-073b-4997-ae50-58ede5d6d0c9
266+
function primary_censored_dist_lpmf(dist, y, pwindow, y_upper, D)
267+
if y == 0.0
268+
return log(∫F(dist, y_upper, pwindow)) - log(∫F(dist, D, pwindow))
269+
else
270+
return log(∫F(dist, y_upper, pwindow) - ∫F(dist, y, pwindow)) -
271+
log(∫F(dist, D, pwindow))
272+
end
273+
end
274+
275+
# ╔═╡ cefb5d56-fecd-4de7-bd0e-156be91c705c
276+
md"
277+
We make a new `Turing` model that now uses `primary_censored_dist_lpmf` rather than the naive uncensored and untruncated `logpdf`.
278+
"
279+
280+
# ╔═╡ ef40112b-f23e-4d4b-8a7d-3793b786f472
281+
@model function primarycensoreddist_model(y, y_upper, n, pws, Ds)
282+
mu ~ Normal(1.0, 1.0)
283+
sigma ~ truncated(Normal(0.5, 0.5); lower = 0.0)
284+
dist = LogNormal(mu, sigma)
285+
286+
for i in eachindex(y)
287+
Turing.@addlogprob! n[i] * primary_censored_dist_lpmf(
288+
dist, y[i], pws[i], y_upper[i], Ds[i])
289+
end
290+
end
291+
292+
# ╔═╡ b823d824-419d-41e9-9ac9-2c45ef190acf
293+
md"
294+
Lets instantiate this model with data
295+
"
296+
297+
# ╔═╡ 93bca93a-5484-47fa-8424-7315eef15e37
298+
primarycensoreddist_mdl = primarycensoreddist_model(
299+
delay_counts.observed_delay,
300+
delay_counts.observed_delay_upper,
301+
delay_counts.n,
302+
delay_counts.pwindow,
303+
delay_counts.obs_time
304+
)
305+
306+
# ╔═╡ d5144247-eb57-48bf-8e32-fd71167ecbc8
307+
md"Now let’s fit the compiled model."
308+
309+
# ╔═╡ 7ae6c61d-0e33-4af8-b8d2-e31223a15a7c
310+
primarycensoreddist_fit = sample(
311+
primarycensoreddist_mdl, NUTS(), MCMCThreads(), 1000, 4)
312+
313+
# ╔═╡ 1210443f-480f-4e9f-b195-d557e9e1fc31
314+
summarize(primarycensoreddist_fit)
315+
316+
# ╔═╡ b2376beb-dd7b-442d-9ff5-ac864e75366b
317+
let
318+
f = pairplot(primarycensoreddist_fit)
319+
CairoMakie.vlines!(f[1, 1], [meanlog], linewidth = 3)
320+
CairoMakie.vlines!(f[2, 2], [sdlog], linewidth = 3)
321+
f
322+
end
323+
324+
# ╔═╡ 673b47ec-b333-45e8-9557-9e65ad425c35
325+
md"
326+
We see that the model has converged and the diagnostics look good. We also see that the posterior means are very near the true parameters and the 90% credible intervals include the true parameters.
327+
"
328+
329+
# ╔═╡ Cell order:
330+
# ╟─8de5c5e0-6e95-11ef-1693-bfd465c8d919
331+
# ╠═a2624404-48b1-4faa-abbe-6d78b8e04f2b
332+
# ╟─30dd9af4-b64f-42b1-8439-a890752f68e3
333+
# ╠═5baa8d2e-bcf8-4e3b-b007-175ad3e2ca95
334+
# ╟─c5704f67-208d-4c2e-8513-c07c6b94ca99
335+
# ╠═aed124c7-b4ba-4c97-a01f-ff553f376c86
336+
# ╟─ec5ed3e9-6ea9-4cfe-afd2-82aabbbe8130
337+
# ╠═105b9594-36ce-4ae8-87a8-5c81867b1ce3
338+
# ╠═8aa9f9c1-d3c4-49f3-be18-a400fc71e8f7
339+
# ╠═84bb3999-9f2b-4eaa-9c2d-776a86677eaf
340+
# ╠═2bf6677e-ebe9-4aa8-aa91-f631e99669bb
341+
# ╟─f4083aea-8106-401a-b60f-383d0b94102a
342+
# ╠═aea8b28e-fffe-4aa6-b51e-8199a7c7975c
343+
# ╠═4d3a853d-0b8d-402a-8309-e9f6da2b7a8c
344+
# ╠═7522f05b-1750-4983-8947-ef70f4298d06
345+
# ╟─5eac2f60-8cec-4460-9d10-6bade7f0f406
346+
# ╠═9443b893-9e22-4267-9a1f-319a3adb8c0d
347+
# ╠═a4f5e9b6-ff3a-48fa-aa51-0abccb9c7bed
348+
# ╟─2a9da9e5-0925-4ae0-8b70-8db90903cb0b
349+
# ╠═0b5e96eb-9312-472e-8a88-d4509a4f25d0
350+
# ╟─c0cce80f-dec7-4a55-aefd-339ef863f854
351+
# ╠═a7bff47d-b61f-499e-8631-206661c2bdc0
352+
# ╠═16bcb80a-970f-4633-aca2-261fa04172f7
353+
# ╠═60711c3c-266e-42b5-acc6-6624db294f24
354+
# ╠═1f1bcee4-8e0d-46fb-9a6f-41998bf54957
355+
# ╠═59bb2a18-eaf4-438a-9359-341efadfe897
356+
# ╟─f66d4b2e-ed66-423e-9cba-62bff712862b
357+
# ╠═010ebe37-782b-4a35-bf5c-dca6dc0fee45
358+
# ╠═d9d14c48-8700-42b5-89b4-7fc51d0f577c
359+
# ╟─8a7cd9ec-5640-4f5f-84c3-ae3f465ca68b
360+
# ╠═028ade5c-17bd-4dfc-8433-23aaff02c181
361+
# ╟─04b4eefb-f0f9-4887-8db0-7cbb7f3b169b
362+
# ╠═21655344-d12b-4e47-a9a9-d06bd909f6ea
363+
# ╠═3b89fe00-6aaf-4764-8b29-e71479f1e641
364+
# ╠═8e09d931-fca7-4ac2-81f7-2bc36b0174f3
365+
# ╟─43eac8dd-8f1d-440e-b1e8-85db9e740651
366+
# ╟─b2efafab-8849-4a7a-bb64-ac9ce126ca75
367+
# ╠═348fc3b4-073b-4997-ae50-58ede5d6d0c9
368+
# ╟─cefb5d56-fecd-4de7-bd0e-156be91c705c
369+
# ╠═ef40112b-f23e-4d4b-8a7d-3793b786f472
370+
# ╟─b823d824-419d-41e9-9ac9-2c45ef190acf
371+
# ╠═93bca93a-5484-47fa-8424-7315eef15e37
372+
# ╟─d5144247-eb57-48bf-8e32-fd71167ecbc8
373+
# ╠═7ae6c61d-0e33-4af8-b8d2-e31223a15a7c
374+
# ╠═1210443f-480f-4e9f-b195-d557e9e1fc31
375+
# ╠═b2376beb-dd7b-442d-9ff5-ac864e75366b
376+
# ╟─673b47ec-b333-45e8-9557-9e65ad425c35

EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using Distributions, DocStringExtensions, QuadGK, Statistics, Turing
1717
export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial
1818

1919
#Export functions
20-
export scan, spread_draws, censored_pmf, get_param_array, prefix_submodel
20+
export scan, spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F
2121

2222
# Export accumulate tools
2323
export get_state, accumulate_scan

0 commit comments

Comments
 (0)