Skip to content

Commit d1f660b

Browse files
committed
initial commit of using censored_pmf
1 parent e660145 commit d1f660b

File tree

2 files changed

+324
-0
lines changed

2 files changed

+324
-0
lines changed

EpiAware/docs/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
99
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
1010
EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855"
11+
FFMPEG = "c87230d0-a227-11e9-1b43-d7ebe4e7570a"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
1314
PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad"
1415
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
17+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1618
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
1719
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
1820
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
21+
TuringBenchmarking = "0db1332d-5c25-4deb-809f-459bc696f94f"
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
### A Pluto.jl notebook ###
2+
# v0.19.46
3+
4+
using Markdown
5+
using InteractiveUtils
6+
7+
# ╔═╡ a2624404-48b1-4faa-abbe-6d78b8e04f2b
8+
let
9+
docs_dir = dirname(dirname(dirname(@__DIR__)))
10+
pkg_dir = dirname(docs_dir)
11+
12+
using Pkg: Pkg
13+
Pkg.activate(docs_dir)
14+
Pkg.develop(; path = pkg_dir)
15+
Pkg.add(["DataFramesMeta", "StatsBase", "TuringBenchmarking"])
16+
Pkg.instantiate()
17+
end
18+
19+
# ╔═╡ 5baa8d2e-bcf8-4e3b-b007-175ad3e2ca95
20+
begin
21+
using EpiAware.EpiAwareUtils: censored_pmf
22+
using Random, Distributions, StatsBase #utilities for random events
23+
using DataFramesMeta #Data wrangling
24+
using StatsPlots #plotting
25+
using Turing, TuringBenchmarking #PPL
26+
end
27+
28+
# ╔═╡ 8de5c5e0-6e95-11ef-1693-bfd465c8d919
29+
md"
30+
# Fitting distributions using `censored_pmf` and Turing PPL
31+
32+
## Introduction
33+
34+
### What are we going to do in this Vignette
35+
36+
In this vignette, we'll demonstrate how to use `EpiAwareUtils.censored_pmf` in conjunction with the Turing PPL for Bayesian inference of epidemiological delay distributions. We'll cover the following key points:
37+
38+
1. Simulating censored delay distribution data
39+
2. Fitting a naive model using Turing
40+
3. Evaluating the naive model's performance
41+
4. Fitting an improved model using `censored_pmf` functionality
42+
5. Comparing the `censored_pmf` model's performance to the naive model
43+
44+
### What might I need to know before starting
45+
46+
This note builds on the concepts introduced in the R/stan package [`primarycensoreddist`](https://github.com/epinowcast/primarycensoreddist), especially the [Getting Started with primarycensoreddist](https://primarycensoreddist.epinowcast.org/articles/fitting-dists-with-stan.html) vignette and assumes familiarity with using Turing tools as covered in the [Turing documentation](https://turinglang.org/).
47+
48+
This note is generated using the `EpiAware` package locally via `Pkg.develop`, in the `EpiAware/docs` environment. It is also possible to install `EpiAware` using
49+
50+
```julia
51+
Pkg.add(url=\"https://github.com/CDCgov/Rt-without-renewal\", subdir=\"EpiAware\")
52+
```
53+
54+
"
55+
56+
# ╔═╡ 30dd9af4-b64f-42b1-8439-a890752f68e3
57+
md"
58+
The other dependencies are as follows:
59+
"
60+
61+
# ╔═╡ c5704f67-208d-4c2e-8513-c07c6b94ca99
62+
md"
63+
## Simulating censored and truncated delay distribution data
64+
65+
We'll start by simulating some censored and truncated delay distribution data.
66+
"
67+
68+
# ╔═╡ aed124c7-b4ba-4c97-a01f-ff553f376c86
69+
Random.seed!(123) # For reproducibility
70+
71+
# ╔═╡ 105b9594-36ce-4ae8-87a8-5c81867b1ce3
72+
# Define the true distribution parameters
73+
n = 1000
74+
75+
# ╔═╡ 8aa9f9c1-d3c4-49f3-be18-a400fc71e8f7
76+
meanlog = 1.5
77+
78+
# ╔═╡ 84bb3999-9f2b-4eaa-9c2d-776a86677eaf
79+
sdlog = 0.75
80+
81+
# ╔═╡ 2bf6677e-ebe9-4aa8-aa91-f631e99669bb
82+
true_dist = LogNormal(meanlog, sdlog)
83+
84+
# ╔═╡ aea8b28e-fffe-4aa6-b51e-8199a7c7975c
85+
# Generate varying pwindow, swindow, and obs_time lengths
86+
pwindows = rand(1:1, n)
87+
88+
# ╔═╡ d231bd0c-165f-4973-a46f-f66991813ea7
89+
swindows = rand(1:1, n)
90+
91+
# ╔═╡ 7522f05b-1750-4983-8947-ef70f4298d06
92+
obs_times = fill(10.0,n)
93+
94+
# ╔═╡ a4f5e9b6-ff3a-48fa-aa51-0abccb9c7bed
95+
#Sample secondary time relative to beginning of primary censor window respecting the right-truncation
96+
samples = map(pwindows, swindows, obs_times) do pw, sw, ot
97+
P = rand() * pw # Primary event time
98+
T = rand(truncated(true_dist; upper= ot - P))
99+
end
100+
101+
# ╔═╡ 0b5e96eb-9312-472e-8a88-d4509a4f25d0
102+
# Generate samples
103+
delay_counts = mapreduce(vcat, samples, pwindows, swindows, obs_times) do T, pw, sw, ot
104+
DataFrame(
105+
pwindow = pw,
106+
swindow = sw,
107+
obs_time = ot,
108+
observed_delay = T ÷ sw .|> Int,
109+
observed_delay_upper = (T ÷ sw) + sw |> Int,
110+
)
111+
end |> # Aggregate to unique combinations and count occurrences
112+
df -> @groupby(df, :pwindow, :swindow, :obs_time, :observed_delay, :observed_delay_upper) |>
113+
gd -> @combine(gd, :n = length(:pwindow))
114+
115+
# ╔═╡ a7bff47d-b61f-499e-8631-206661c2bdc0
116+
empirical_cdf = ecdf(samples)
117+
118+
# ╔═╡ 16bcb80a-970f-4633-aca2-261fa04172f7
119+
empirical_cdf_obs = ecdf(delay_counts.observed_delay, weights=delay_counts.n)
120+
121+
# ╔═╡ 60711c3c-266e-42b5-acc6-6624db294f24
122+
x_seq = range(minimum(samples), maximum(samples), 100)
123+
124+
# ╔═╡ c6fe3c52-af87-4a84-b280-bc9a8532e269
125+
#plot
126+
let
127+
plot(; title = "Comparison of Observed vs Theoretical CDF",
128+
ylabel = "Cumulative Probability",
129+
xlabel = "Delay",
130+
xticks = 0:obs_times[1],
131+
xlims = (-0.1, obs_times[1] + 0.5)
132+
)
133+
plot!(x_seq, x_seq .|> x->empirical_cdf(x),
134+
lab = "Observed secondary times",
135+
c = :blue,
136+
lw = 3,
137+
)
138+
plot!(x_seq, x_seq .|> x->empirical_cdf_obs(x),
139+
lab = "Observed censored secondary times",
140+
c = :green,
141+
lw = 3,
142+
)
143+
plot!(x_seq, x_seq .|> x -> cdf(true_dist, x),
144+
lab = "Theoretical",
145+
c = :black,
146+
lw = 3,
147+
)
148+
vline!([mean(samples)], ls = :dash, c= :blue, lw = 3, lab = "")
149+
vline!([mean(true_dist)], ls = :dash, c= :black, lw = 3, lab = "")
150+
end
151+
152+
# ╔═╡ f66d4b2e-ed66-423e-9cba-62bff712862b
153+
md"
154+
We've aggregated the data to unique combinations of `pwindow`, `swindow`, and `obs_time` and counted the number of occurrences of each `observed_delay` for each combination. This is the data we will use to fit our model.
155+
"
156+
157+
# ╔═╡ 010ebe37-782b-4a35-bf5c-dca6dc0fee45
158+
md"
159+
## Fitting a naive model using Turing
160+
161+
We'll start by fitting a naive model using Turing.
162+
"
163+
164+
# ╔═╡ d9d14c48-8700-42b5-89b4-7fc51d0f577c
165+
@model function naive_model(N, y, n)
166+
mu ~ Normal(1., 1.)
167+
sigma ~ truncated(Normal(0.5, 1.0); lower= 0.0)
168+
d = LogNormal(mu, sigma)
169+
170+
for i in eachindex(y)
171+
Turing.@addlogprob! n[i] * logpdf(d, y[i])
172+
end
173+
end
174+
175+
# ╔═╡ 8a7cd9ec-5640-4f5f-84c3-ae3f465ca68b
176+
md"
177+
Now lets instantiate this model with data
178+
"
179+
180+
# ╔═╡ 028ade5c-17bd-4dfc-8433-23aaff02c181
181+
naive_mdl = naive_model(
182+
size(delay_counts,1),
183+
delay_counts.observed_delay .+ 1e-6, # Add a small constant to avoid log(0)
184+
delay_counts.n)
185+
186+
# ╔═╡ 04b4eefb-f0f9-4887-8db0-7cbb7f3b169b
187+
md"
188+
and now let's fit the compiled model.
189+
"
190+
191+
# ╔═╡ 21655344-d12b-4e47-a9a9-d06bd909f6ea
192+
naive_fit = sample(naive_mdl, NUTS(), MCMCThreads(), 500, 4)
193+
194+
# ╔═╡ 3b89fe00-6aaf-4764-8b29-e71479f1e641
195+
summarize(naive_fit)
196+
197+
# ╔═╡ 43eac8dd-8f1d-440e-b1e8-85db9e740651
198+
md"
199+
We see that the model has converged and the diagnostics look good. However, just from the model posterior summary we see that we might not be very happy with the fit. `mu` is smaller than the target $(meanlog) and `sigma` is larger than the target $(sdlog).
200+
201+
"
202+
203+
# ╔═╡ b2efafab-8849-4a7a-bb64-ac9ce126ca75
204+
md"
205+
## Fitting an improved model using primarycensoreddist
206+
207+
We'll now fit an improved model using the `censored_pmf` function from the `EpiAware.EpiAwareUtils` submodule. This accounts for the primary and secondary censoring windows as well as the truncation.
208+
209+
"
210+
211+
# ╔═╡ ef40112b-f23e-4d4b-8a7d-3793b786f472
212+
@model function primarycensoreddist_model(N, y, y_upper, n, pwindow, D)
213+
try
214+
mu ~ Normal(1., 1.)
215+
sigma ~ truncated(Normal(0.5, 0.5); lower= 0.1,)
216+
d = LogNormal(mu, sigma)
217+
log_pmf = censored_pmf(d; Δd = pwindow, D = D) .|> log
218+
219+
for i in eachindex(y)
220+
Turing.@addlogprob! n[i] * log_pmf[y[i] + 1] #0 obs is first element of array
221+
end
222+
return log_pmf
223+
catch
224+
Turing.@addlogprob! -Inf
225+
end
226+
end
227+
228+
# ╔═╡ b823d824-419d-41e9-9ac9-2c45ef190acf
229+
md"
230+
Lets instantiate this model with data
231+
"
232+
233+
# ╔═╡ 93bca93a-5484-47fa-8424-7315eef15e37
234+
primarycensoreddist_mdl = primarycensoreddist_model(
235+
size(delay_counts,1),
236+
delay_counts.observed_delay, # Add a small constant to avoid log(0)
237+
delay_counts.observed_delay_upper, # Add a small constant to avoid log(0)
238+
delay_counts.n,
239+
delay_counts.pwindow[1],
240+
delay_counts.obs_time[1]
241+
)
242+
243+
# ╔═╡ 8f1d32fd-f54b-4f69-8c93-8f0786366cef
244+
# ╠═╡ disabled = true
245+
#=╠═╡
246+
benchmark_model(
247+
primarycensoreddist_mdl;
248+
# Check correctness of computations
249+
check=true,
250+
# Automatic differentiation backends to check and benchmark
251+
adbackends=[:forwarddiff, :reversediff, :reversediff_compiled]
252+
)
253+
╠═╡ =#
254+
255+
# ╔═╡ 44132e2e-5a1a-49ad-9e57-cec24f981f52
256+
map_estimate = [maximum_a_posteriori(primarycensoreddist_mdl) for _ in 1:10] |>
257+
opts -> (opts, findmax([o.lp for o in opts])[2]) |>
258+
opts_i -> opts_i[1][opts_i[2]]
259+
260+
# ╔═╡ a34c19e8-ba9e-4276-a17e-c853bb3341cf
261+
# ╠═╡ disabled = true
262+
#=╠═╡
263+
primarycensoreddist_fit = sample(primarycensoreddist_mdl, NUTS(), MCMCThreads(), 500, 4)
264+
╠═╡ =#
265+
266+
# ╔═╡ 1210443f-480f-4e9f-b195-d557e9e1fc31
267+
summarize(primarycensoreddist_fit)
268+
269+
# ╔═╡ 46711233-f680-4962-9e3e-60c747db4d2c
270+
censored_pmf(true_dist; D = obs_times[1] )
271+
272+
# ╔═╡ 604458a6-7b6f-4b5c-b2e7-09be1908c0f9
273+
# ╠═╡ disabled = true
274+
#=╠═╡
275+
primarycensoreddist_fit = sample(primarycensoreddist_mdl, MH(), 100_000; initial_params=map_estimate.values.array) |>
276+
chn -> chn[50_000:end, :, :]
277+
╠═╡ =#
278+
279+
# ╔═╡ 7ae6c61d-0e33-4af8-b8d2-e31223a15a7c
280+
primarycensoreddist_fit = sample(primarycensoreddist_mdl, NUTS(), 1000; initial_params=map_estimate.values.array)
281+
282+
# ╔═╡ Cell order:
283+
# ╟─8de5c5e0-6e95-11ef-1693-bfd465c8d919
284+
# ╠═a2624404-48b1-4faa-abbe-6d78b8e04f2b
285+
# ╟─30dd9af4-b64f-42b1-8439-a890752f68e3
286+
# ╠═5baa8d2e-bcf8-4e3b-b007-175ad3e2ca95
287+
# ╟─c5704f67-208d-4c2e-8513-c07c6b94ca99
288+
# ╠═aed124c7-b4ba-4c97-a01f-ff553f376c86
289+
# ╠═105b9594-36ce-4ae8-87a8-5c81867b1ce3
290+
# ╠═8aa9f9c1-d3c4-49f3-be18-a400fc71e8f7
291+
# ╠═84bb3999-9f2b-4eaa-9c2d-776a86677eaf
292+
# ╠═2bf6677e-ebe9-4aa8-aa91-f631e99669bb
293+
# ╠═aea8b28e-fffe-4aa6-b51e-8199a7c7975c
294+
# ╠═d231bd0c-165f-4973-a46f-f66991813ea7
295+
# ╠═7522f05b-1750-4983-8947-ef70f4298d06
296+
# ╠═a4f5e9b6-ff3a-48fa-aa51-0abccb9c7bed
297+
# ╠═0b5e96eb-9312-472e-8a88-d4509a4f25d0
298+
# ╠═a7bff47d-b61f-499e-8631-206661c2bdc0
299+
# ╠═16bcb80a-970f-4633-aca2-261fa04172f7
300+
# ╠═60711c3c-266e-42b5-acc6-6624db294f24
301+
# ╠═c6fe3c52-af87-4a84-b280-bc9a8532e269
302+
# ╟─f66d4b2e-ed66-423e-9cba-62bff712862b
303+
# ╟─010ebe37-782b-4a35-bf5c-dca6dc0fee45
304+
# ╠═d9d14c48-8700-42b5-89b4-7fc51d0f577c
305+
# ╟─8a7cd9ec-5640-4f5f-84c3-ae3f465ca68b
306+
# ╠═028ade5c-17bd-4dfc-8433-23aaff02c181
307+
# ╟─04b4eefb-f0f9-4887-8db0-7cbb7f3b169b
308+
# ╠═21655344-d12b-4e47-a9a9-d06bd909f6ea
309+
# ╠═3b89fe00-6aaf-4764-8b29-e71479f1e641
310+
# ╟─43eac8dd-8f1d-440e-b1e8-85db9e740651
311+
# ╠═b2efafab-8849-4a7a-bb64-ac9ce126ca75
312+
# ╠═ef40112b-f23e-4d4b-8a7d-3793b786f472
313+
# ╟─b823d824-419d-41e9-9ac9-2c45ef190acf
314+
# ╠═93bca93a-5484-47fa-8424-7315eef15e37
315+
# ╠═8f1d32fd-f54b-4f69-8c93-8f0786366cef
316+
# ╠═44132e2e-5a1a-49ad-9e57-cec24f981f52
317+
# ╠═604458a6-7b6f-4b5c-b2e7-09be1908c0f9
318+
# ╠═a34c19e8-ba9e-4276-a17e-c853bb3341cf
319+
# ╠═7ae6c61d-0e33-4af8-b8d2-e31223a15a7c
320+
# ╠═1210443f-480f-4e9f-b195-d557e9e1fc31
321+
# ╠═46711233-f680-4962-9e3e-60c747db4d2c

0 commit comments

Comments
 (0)