Skip to content

Commit

Permalink
Update censored-obs.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 committed Sep 25, 2024
1 parent 191baf4 commit a1aa16f
Showing 1 changed file with 56 additions and 67 deletions.
123 changes: 56 additions & 67 deletions EpiAware/docs/src/getting-started/explainers/censored-obs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ let
using Pkg: Pkg
Pkg.activate(docs_dir)
Pkg.develop(; path = pkg_dir)
Pkg.add(["DataFramesMeta", "StatsBase", "TuringBenchmarking"])
Pkg.add(["DataFramesMeta", "StatsBase", "PairPlots", "CairoMakie"])
Pkg.instantiate()
end

Expand All @@ -21,8 +21,8 @@ begin
using EpiAware.EpiAwareUtils: censored_pmf
using Random, Distributions, StatsBase #utilities for random events
using DataFramesMeta #Data wrangling
using StatsPlots #plotting
using Turing, TuringBenchmarking #PPL
using StatsPlots, CairoMakie, PairPlots #plotting
using Turing #PPL
end

# ╔═╡ 8de5c5e0-6e95-11ef-1693-bfd465c8d919
Expand Down Expand Up @@ -83,31 +83,38 @@ true_dist = LogNormal(meanlog, sdlog)

# ╔═╡ aea8b28e-fffe-4aa6-b51e-8199a7c7975c
# Generate varying pwindow, swindow, and obs_time lengths
pwindows = rand(1:1, n)
pwindow = 1.

# ╔═╡ d231bd0c-165f-4973-a46f-f66991813ea7
swindows = rand(1:1, n)
swindow = 1.

# ╔═╡ 7522f05b-1750-4983-8947-ef70f4298d06
obs_times = fill(10.0, n)
obs_time = 10.

# ╔═╡ 60212cf7-cc3d-42d3-8260-1f067ede3c6f
nsamples_max = 2000

# ╔═╡ a4f5e9b6-ff3a-48fa-aa51-0abccb9c7bed
#Sample secondary time relative to beginning of primary censor window respecting the right-truncation
samples = map(pwindows, swindows, obs_times) do pw, sw, ot
P = rand() * pw # Primary event time
T = rand(truncated(true_dist; upper = ot - P))
end
samples = map(1:nsamples_max) do _
P = rand() * pwindow # Primary event time
T = rand(true_dist)
end |>
s -> filter(x -> x <= obs_time, s)

# ╔═╡ a627a544-4c41-4c35-83ec-be7ed7c0a737
nsamples = length(samples)

# ╔═╡ 0b5e96eb-9312-472e-8a88-d4509a4f25d0
# Generate samples
delay_counts = mapreduce(vcat, samples, pwindows, swindows, obs_times) do T, pw, sw, ot
delay_counts = mapreduce(vcat, samples) do T
DataFrame(
pwindow = pw,
swindow = sw,
obs_time = ot,
observed_delay = (T ÷ sw) * sw,
observed_delay_upper = (T ÷ sw) * (sw + 1),
observed_delay_step = Int(T ÷ sw) + 1
pwindow = pwindow,
swindow = swindow,
obs_time = obs_time,
observed_delay = (T ÷ swindow) * swindow,
observed_delay_upper = (T ÷ swindow) * (swindow) + swindow,
observed_delay_step = Int(T ÷ swindow) + 1
)
end |> # Aggregate to unique combinations and count occurrences
df -> @groupby(df, :pwindow, :swindow, :obs_time, :observed_delay,
Expand All @@ -126,29 +133,29 @@ x_seq = range(minimum(samples), maximum(samples), 100)
# ╔═╡ c6fe3c52-af87-4a84-b280-bc9a8532e269
#plot
let
plot(; title = "Comparison of Observed vs Theoretical CDF",
StatsPlots.plot(; title = "Comparison of Observed vs Theoretical CDF",
ylabel = "Cumulative Probability",
xlabel = "Delay",
xticks = 0:obs_times[1],
xlims = (-0.1, obs_times[1] + 0.5)
xticks = 0:obs_time,
xlims = (-0.1, obs_time + 0.5)
)
plot!(x_seq, x_seq .|> x -> empirical_cdf(x),
StatsPlots.plot!(x_seq, x_seq .|> x -> empirical_cdf(x),
lab = "Observed secondary times",
c = :blue,
lw = 3
)
plot!(x_seq, x_seq .|> x -> empirical_cdf_obs(x),
StatsPlots.plot!(x_seq, x_seq .|> x -> empirical_cdf_obs(x),
lab = "Observed censored secondary times",
c = :green,
lw = 3
)
plot!(x_seq, x_seq .|> x -> cdf(true_dist, x),
StatsPlots.plot!(x_seq, x_seq .|> x -> cdf(true_dist, x),
lab = "Theoretical",
c = :black,
lw = 3
)
vline!([mean(samples)], ls = :dash, c = :blue, lw = 3, lab = "")
vline!([mean(true_dist)], ls = :dash, c = :black, lw = 3, lab = "")
StatsPlots.vline!([mean(samples)], ls = :dash, c = :blue, lw = 3, lab = "")
StatsPlots.vline!([mean(true_dist)], ls = :dash, c = :black, lw = 3, lab = "")
end

# ╔═╡ f66d4b2e-ed66-423e-9cba-62bff712862b
Expand Down Expand Up @@ -196,6 +203,14 @@ naive_fit = sample(naive_mdl, NUTS(), MCMCThreads(), 500, 4)
# ╔═╡ 3b89fe00-6aaf-4764-8b29-e71479f1e641
summarize(naive_fit)

# ╔═╡ 8e09d931-fca7-4ac2-81f7-2bc36b0174f3
let
f = pairplot(naive_fit)
CairoMakie.vlines!(f[1,1], [meanlog])
CairoMakie.vlines!(f[2,2], [sdlog])
f
end

# ╔═╡ 43eac8dd-8f1d-440e-b1e8-85db9e740651
md"
We see that the model has converged and the diagnostics look good. However, just from the model posterior summary we see that we might not be very happy with the fit. `mu` is smaller than the target $(meanlog) and `sigma` is larger than the target $(sdlog).
Expand All @@ -214,7 +229,7 @@ We'll now fit an improved model using the `censored_pmf` function from the `EpiA
@model function primarycensoreddist_model(N, y, n, pwindow, D)
try
mu ~ Normal(1.0, 1.0)
sigma ~ truncated(Normal(0.5, 0.5); lower = 0.1)
sigma ~ truncated(Normal(0.5, 0.5); lower = 0.)
d = LogNormal(mu, sigma)
log_pmf = censored_pmf(d; Δd = pwindow, D = D) .|> log

Expand All @@ -237,49 +252,24 @@ primarycensoreddist_mdl = primarycensoreddist_model(
size(delay_counts, 1),
delay_counts.observed_delay_step,
delay_counts.n,
delay_counts.pwindow[1],
delay_counts.obs_time[1]
pwindow,
obs_time,
)

# ╔═╡ 8f1d32fd-f54b-4f69-8c93-8f0786366cef
# ╠═╡ disabled = true
#=╠═╡
benchmark_model(
primarycensoreddist_mdl;
# Check correctness of computations
check=true,
# Automatic differentiation backends to check and benchmark
adbackends=[:forwarddiff, :reversediff, :reversediff_compiled]
)
╠═╡ =#

# ╔═╡ 44132e2e-5a1a-49ad-9e57-cec24f981f52
map_estimate = [maximum_a_posteriori(primarycensoreddist_mdl) for _ in 1:10] |>
opts -> (opts, findmax([o.lp for o in opts])[2]) |>
opts_i -> opts_i[1][opts_i[2]]

# ╔═╡ 604458a6-7b6f-4b5c-b2e7-09be1908c0f9
# ╠═╡ disabled = true
#=╠═╡
primarycensoreddist_fit = sample(primarycensoreddist_mdl, MH(), 100_000; initial_params=map_estimate.values.array) |>
chn -> chn[50_000:end, :, :]
╠═╡ =#

# ╔═╡ a34c19e8-ba9e-4276-a17e-c853bb3341cf
# ╠═╡ disabled = true
#=╠═╡
primarycensoreddist_fit = sample(primarycensoreddist_mdl, NUTS(), MCMCThreads(), 500, 4)
╠═╡ =#

# ╔═╡ 7ae6c61d-0e33-4af8-b8d2-e31223a15a7c
primarycensoreddist_fit = sample(
primarycensoreddist_mdl, NUTS(), 1000; initial_params = map_estimate.values.array)
primarycensoreddist_mdl, NUTS(), MCMCThreads(), 1000, 4)

# ╔═╡ 1210443f-480f-4e9f-b195-d557e9e1fc31
summarize(primarycensoreddist_fit)

# ╔═╡ 46711233-f680-4962-9e3e-60c747db4d2c
censored_pmf(true_dist; D = obs_times[1])
# ╔═╡ b2376beb-dd7b-442d-9ff5-ac864e75366b
let
f = pairplot(primarycensoreddist_fit)
CairoMakie.vlines!(f[1,1], [meanlog], linewidth = 3)
CairoMakie.vlines!(f[2,2], [sdlog], linewidth = 3)
f
end

# ╔═╡ Cell order:
# ╟─8de5c5e0-6e95-11ef-1693-bfd465c8d919
Expand All @@ -295,7 +285,9 @@ censored_pmf(true_dist; D = obs_times[1])
# ╠═aea8b28e-fffe-4aa6-b51e-8199a7c7975c
# ╠═d231bd0c-165f-4973-a46f-f66991813ea7
# ╠═7522f05b-1750-4983-8947-ef70f4298d06
# ╠═60212cf7-cc3d-42d3-8260-1f067ede3c6f
# ╠═a4f5e9b6-ff3a-48fa-aa51-0abccb9c7bed
# ╠═a627a544-4c41-4c35-83ec-be7ed7c0a737
# ╠═0b5e96eb-9312-472e-8a88-d4509a4f25d0
# ╠═a7bff47d-b61f-499e-8631-206661c2bdc0
# ╠═16bcb80a-970f-4633-aca2-261fa04172f7
Expand All @@ -309,15 +301,12 @@ censored_pmf(true_dist; D = obs_times[1])
# ╟─04b4eefb-f0f9-4887-8db0-7cbb7f3b169b
# ╠═21655344-d12b-4e47-a9a9-d06bd909f6ea
# ╠═3b89fe00-6aaf-4764-8b29-e71479f1e641
# ╠═8e09d931-fca7-4ac2-81f7-2bc36b0174f3
# ╟─43eac8dd-8f1d-440e-b1e8-85db9e740651
# ╠═b2efafab-8849-4a7a-bb64-ac9ce126ca75
# ╟─b2efafab-8849-4a7a-bb64-ac9ce126ca75
# ╠═ef40112b-f23e-4d4b-8a7d-3793b786f472
# ╟─b823d824-419d-41e9-9ac9-2c45ef190acf
# ╠═93bca93a-5484-47fa-8424-7315eef15e37
# ╠═8f1d32fd-f54b-4f69-8c93-8f0786366cef
# ╠═44132e2e-5a1a-49ad-9e57-cec24f981f52
# ╠═604458a6-7b6f-4b5c-b2e7-09be1908c0f9
# ╠═a34c19e8-ba9e-4276-a17e-c853bb3341cf
# ╠═7ae6c61d-0e33-4af8-b8d2-e31223a15a7c
# ╠═1210443f-480f-4e9f-b195-d557e9e1fc31
# ╠═46711233-f680-4962-9e3e-60c747db4d2c
# ╠═b2376beb-dd7b-442d-9ff5-ac864e75366b

0 comments on commit a1aa16f

Please sign in to comment.