@@ -9,5 +9,117 @@ P = pipetype(; testmode = true, nchains = 1, ndraws = 2000, priorpredictive = tr
9
9
10
10
# #
11
11
12
- inference_method = make_inference_method (P)
13
12
inference_config = make_inference_configs (P) |> first
13
+
14
+ missingdata = Dict (" y_t" => missing , " I_t" => fill (1.0 , 100 ), " truth_I0" => 1.0 ,
15
+ " truth_gi_mean" => inference_config. gi_mean)
16
+ results = generate_inference_results (missingdata, inference_config, P)
17
+
18
+ res = results[" inference_results" ]
19
+
20
+ gens = generate_quantiles_for_targets (
21
+ res, results[" epiprob" ]. epi_model. data, [0.025 , 0.25 , 0.5 , 0.75 , 0.975 ])
22
+
23
+ gens. log_I_t
24
+
25
+ # #
26
+ using CairoMakie
27
+
28
+ function _make_prior_plot_title (config)
29
+ igp_str = string (config. igp)
30
+ latent_model_str = config. latent_model_name |> uppercase
31
+ gi_mean_str = config. gi_mean |> string
32
+ T_str = config. tspan[2 ] |> string
33
+ return " Prior pred. IGP: $(igp_str) , latent model: $(latent_model_str) , truth gi mean: $(gi_mean_str) , T: $(T_str) "
34
+ end
35
+ _make_prior_plot_title (config)
36
+
37
+ function _setup_levels (ps)
38
+ n_levels = length (ps)
39
+ qs = mapreduce (vcat, ps) do percentile
40
+ [percentile / 2 , 1 - percentile / 2 ]
41
+ end |> x -> [0.5 ; x]
42
+ return qs, n_levels
43
+ end
44
+ # #
45
+ # _get_priorpred_plot_title(results["inference_config"])
46
+ # #
47
+
48
+ function prior_predictive_plot (config, output, epiprob;
49
+ ps = [0.05 , 0.1 , 0.25 ],
50
+ bottom_alpha = 0.1 ,
51
+ top_alpha = 0.5 ,
52
+ case_color = :black ,
53
+ logI_color = :purple ,
54
+ rt_color = :blue ,
55
+ Rt_color = :green ,
56
+ figsize = (750 , 600 ))
57
+ @assert all (0 .<= ps .< 0.5 ) " Percentiles must be in the range [0, 0.5)"
58
+ prior_pred_plot_title = _make_prior_plot_title (config)
59
+ qs, n_levels = _setup_levels (sort (ps))
60
+ opacity_scale = range (bottom_alpha, top_alpha, length = n_levels) |> collect
61
+
62
+ # Create the figure and axes
63
+ fig = Figure (size = figsize)
64
+ ax11 = Axis (fig[1 , 1 ]; xlabel = " t" , ylabel = " Cases" )
65
+ ax12 = Axis (fig[1 , 2 ]; xlabel = " t" , ylabel = " log(Incidence)" )
66
+ ax21 = Axis (fig[2 , 1 ]; xlabel = " t" , ylabel = " Exp. growth rate" )
67
+ ax22 = Axis (fig[2 , 2 ]; xlabel = " t" , ylabel = " Reproduction number" )
68
+ linkxaxes! (ax11, ax21)
69
+ linkxaxes! (ax12, ax22)
70
+ Label (fig[0 , :]; text = prior_pred_plot_title, fontsize = 16 )
71
+
72
+ # Quantile calculations
73
+ gen_y_t = mapreduce (hcat, output. generated) do gen
74
+ gen. generated_y_t
75
+ end |> X -> timeseries_samples_into_quantiles (X, qs)
76
+ gen_quantities = generate_quantiles_for_targets (output, epiprob. epi_model. data, qs)
77
+
78
+ # Plot the prior predictive samples
79
+ # Cases
80
+ f = findfirst (! ismissing, gen_y_t[:, 1 ])
81
+ lines! (ax11, 1 : size (gen_y_t, 1 ), gen_y_t[:, 1 ],
82
+ color = case_color, linewidth = 3 , label = " Median" )
83
+ for i in 1 : n_levels
84
+ band! (ax11, f: size (gen_y_t, 1 ), gen_y_t[f: size (gen_y_t, 1 ), (2 * i)],
85
+ gen_y_t[f: size (gen_y_t, 1 ), (2 * i) + 1 ],
86
+ color = (case_color, opacity_scale[i]),
87
+ label = " ($(ps[i]* 100 ) -$((1 - ps[i])* 100 ) )%" )
88
+ end
89
+ vlines! (ax11, [f], color = case_color, linestyle = :dash , label = " Obs. window" )
90
+ axislegend (ax11; position = :lt , framevisible = false )
91
+
92
+ # Other quantities
93
+ for (ax, target, c) in zip (
94
+ [ax12, ax21, ax22], [gen_quantities. log_I_t, gen_quantities. rt, gen_quantities. Rt],
95
+ [logI_color, rt_color, Rt_color])
96
+ lines! (ax, 1 : size (target, 1 ), target[:, 1 ],
97
+ color = logI_color, linewidth = 3 , label = " Median" )
98
+ for i in 1 : n_levels
99
+ band! (ax, 1 : size (target, 1 ), target[:, (2 * i)], target[:, (2 * i) + 1 ],
100
+ color = (c, opacity_scale[i]), label = " " )
101
+ end
102
+ end
103
+
104
+ fig
105
+ end
106
+
107
+ # #
108
+ fig = prior_predictive_plot (config, res, results[" epiprob" ]; ps = [0.025 , 0.1 , 0.25 ])
109
+ # #
110
+ gen_y_t = mapreduce (hcat, res. generated) do gen
111
+ gen. generated_y_t
112
+ end |> X -> timeseries_samples_into_quantiles (X, [0.025 , 0.25 , 0.5 , 0.75 , 0.975 ])
113
+
114
+ # #
115
+ fig = Figure ()
116
+ ax_logIt = Axis (fig[1 , 1 ];
117
+ xticks = vcat (1 , 5 : 5 : 50 ) # xlabel
118
+ )
119
+
120
+ for i in 1 : 5
121
+ lines! (ax_logIt, gen_y_t[:, i], color = :black , alpha = 0.5 )
122
+ end
123
+
124
+ # hlines!(ax, [1.0], color = :red, linestyle = :dash)
125
+ fig
0 commit comments