Skip to content

Commit 67cc102

Browse files
SamuelBrand1seabbs
authored andcommitted
remove dead end-to-end test
in favour of direct test of prior_predictive_plot
1 parent 1dcda88 commit 67cc102

File tree

4 files changed

+132
-78
lines changed

4 files changed

+132
-78
lines changed

pipeline/test/end-to-end/test_full_inference.jl

Lines changed: 0 additions & 77 deletions
This file was deleted.

pipeline/test/end-to-end/test_prior_predictive.jl

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,117 @@ P = pipetype(; testmode = true, nchains = 1, ndraws = 2000, priorpredictive = tr
99

1010
##
1111

12-
inference_method = make_inference_method(P)
1312
inference_config = make_inference_configs(P) |> first
13+
14+
missingdata = Dict("y_t" => missing, "I_t" => fill(1.0, 100), "truth_I0" => 1.0,
15+
"truth_gi_mean" => inference_config.gi_mean)
16+
results = generate_inference_results(missingdata, inference_config, P)
17+
18+
res = results["inference_results"]
19+
20+
gens = generate_quantiles_for_targets(
21+
res, results["epiprob"].epi_model.data, [0.025, 0.25, 0.5, 0.75, 0.975])
22+
23+
gens.log_I_t
24+
25+
##
26+
using CairoMakie
27+
28+
function _make_prior_plot_title(config)
29+
igp_str = string(config.igp)
30+
latent_model_str = config.latent_model_name |> uppercase
31+
gi_mean_str = config.gi_mean |> string
32+
T_str = config.tspan[2] |> string
33+
return "Prior pred. IGP: $(igp_str), latent model: $(latent_model_str), truth gi mean: $(gi_mean_str), T: $(T_str)"
34+
end
35+
_make_prior_plot_title(config)
36+
37+
function _setup_levels(ps)
38+
n_levels = length(ps)
39+
qs = mapreduce(vcat, ps) do percentile
40+
[percentile / 2, 1 - percentile / 2]
41+
end |> x -> [0.5; x]
42+
return qs, n_levels
43+
end
44+
##
45+
# _get_priorpred_plot_title(results["inference_config"])
46+
##
47+
48+
function prior_predictive_plot(config, output, epiprob;
49+
ps = [0.05, 0.1, 0.25],
50+
bottom_alpha = 0.1,
51+
top_alpha = 0.5,
52+
case_color = :black,
53+
logI_color = :purple,
54+
rt_color = :blue,
55+
Rt_color = :green,
56+
figsize = (750, 600))
57+
@assert all(0 .<= ps .< 0.5) "Percentiles must be in the range [0, 0.5)"
58+
prior_pred_plot_title = _make_prior_plot_title(config)
59+
qs, n_levels = _setup_levels(sort(ps))
60+
opacity_scale = range(bottom_alpha, top_alpha, length = n_levels) |> collect
61+
62+
# Create the figure and axes
63+
fig = Figure(size = figsize)
64+
ax11 = Axis(fig[1, 1]; xlabel = "t", ylabel = "Cases")
65+
ax12 = Axis(fig[1, 2]; xlabel = "t", ylabel = "log(Incidence)")
66+
ax21 = Axis(fig[2, 1]; xlabel = "t", ylabel = "Exp. growth rate")
67+
ax22 = Axis(fig[2, 2]; xlabel = "t", ylabel = "Reproduction number")
68+
linkxaxes!(ax11, ax21)
69+
linkxaxes!(ax12, ax22)
70+
Label(fig[0, :]; text = prior_pred_plot_title, fontsize = 16)
71+
72+
# Quantile calculations
73+
gen_y_t = mapreduce(hcat, output.generated) do gen
74+
gen.generated_y_t
75+
end |> X -> timeseries_samples_into_quantiles(X, qs)
76+
gen_quantities = generate_quantiles_for_targets(output, epiprob.epi_model.data, qs)
77+
78+
# Plot the prior predictive samples
79+
# Cases
80+
f = findfirst(!ismissing, gen_y_t[:, 1])
81+
lines!(ax11, 1:size(gen_y_t, 1), gen_y_t[:, 1],
82+
color = case_color, linewidth = 3, label = "Median")
83+
for i in 1:n_levels
84+
band!(ax11, f:size(gen_y_t, 1), gen_y_t[f:size(gen_y_t, 1), (2 * i)],
85+
gen_y_t[f:size(gen_y_t, 1), (2 * i) + 1],
86+
color = (case_color, opacity_scale[i]),
87+
label = "($(ps[i]*100)-$((1 - ps[i])*100))%")
88+
end
89+
vlines!(ax11, [f], color = case_color, linestyle = :dash, label = "Obs. window")
90+
axislegend(ax11; position = :lt, framevisible = false)
91+
92+
# Other quantities
93+
for (ax, target, c) in zip(
94+
[ax12, ax21, ax22], [gen_quantities.log_I_t, gen_quantities.rt, gen_quantities.Rt],
95+
[logI_color, rt_color, Rt_color])
96+
lines!(ax, 1:size(target, 1), target[:, 1],
97+
color = logI_color, linewidth = 3, label = "Median")
98+
for i in 1:n_levels
99+
band!(ax, 1:size(target, 1), target[:, (2 * i)], target[:, (2 * i) + 1],
100+
color = (c, opacity_scale[i]), label = "")
101+
end
102+
end
103+
104+
fig
105+
end
106+
107+
##
108+
fig = prior_predictive_plot(config, res, results["epiprob"]; ps = [0.025, 0.1, 0.25])
109+
##
110+
gen_y_t = mapreduce(hcat, res.generated) do gen
111+
gen.generated_y_t
112+
end |> X -> timeseries_samples_into_quantiles(X, [0.025, 0.25, 0.5, 0.75, 0.975])
113+
114+
##
115+
fig = Figure()
116+
ax_logIt = Axis(fig[1, 1];
117+
xticks = vcat(1, 5:5:50) # xlabel
118+
)
119+
120+
for i in 1:5
121+
lines!(ax_logIt, gen_y_t[:, i], color = :black, alpha = 0.5)
122+
end
123+
124+
# hlines!(ax, [1.0], color = :red, linestyle = :dash)
125+
fig
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
include("test_utils.jl")
22
include("test_plot_funcs.jl")
3+
include("prior_pred_plot.jl")
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
@testset "Prior pred plotting" begin
2+
using CairoMakie
3+
#pick a random scenario
4+
pipetype = [SmoothOutbreakPipeline, MeasuresOutbreakPipeline,
5+
SmoothEndemicPipeline, RoughEndemicPipeline] |> rand
6+
P = pipetype(; testmode = true, nchains = 1, ndraws = 2000, priorpredictive = true)
7+
inference_config = make_inference_configs(P) |> rand
8+
9+
#Add missing data
10+
missingdata = Dict("y_t" => missing, "I_t" => fill(1.0, 100), "truth_I0" => 1.0,
11+
"truth_gi_mean" => inference_config["gi_mean"])
12+
results = generate_inference_results(missingdata, inference_config, P)
13+
14+
fig = prior_predictive_plot(results["inference_config"], results["inference_results"],
15+
results["epiprob"]; ps = [0.025, 0.1, 0.25])
16+
17+
@test fig isa Figure
18+
end

0 commit comments

Comments
 (0)