Skip to content

Commit e4951b4

Browse files
authored
455 plotting methods of prior predictive (#462)
* add priorpredictive method to make_inference * Add pipeline priorpredictive boolean * reformat * remove inference_method kwarg because can be dispatched on * remove specialisation on forecast results add missing handling as well * move inference step into own function and give fail cover with error report * remove passing inference_method * add a latent model name to InferenceConfig * Util for setting up PI levels * prior pred plot * export prior_predictive_plot * remove dead end-to-end test in favour of direct test of prior_predictive_plot
1 parent 5d03fc8 commit e4951b4

16 files changed

+345
-348
lines changed

pipeline/src/EpiAwarePipeline.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ export make_prediction_dataframe_from_output, make_truthdata_dataframe,
5656
export figureone, figuretwo
5757

5858
# Exported functions: plot functions
59-
export plot_truth_data, plot_Rt
59+
export plot_truth_data, plot_Rt, prior_predictive_plot
6060

6161
include("docstrings.jl")
6262
include("pipeline/pipeline.jl")

pipeline/src/constructors/make_inference_method.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Constructs an inference method for the given pipeline. This is a default method.
88
- An inference method.
99
1010
"""
11-
function make_inference_method(pipeline::AbstractEpiAwarePipeline; ndraws::Integer = 2000,
11+
function make_inference_method(ndraws::Integer, pipeline::AbstractEpiAwarePipeline;
1212
mcmc_ensemble::AbstractMCMC.AbstractMCMCEnsemble = MCMCSerial(),
1313
nruns_pthf::Integer = 4, maxiters_pthf::Integer = 100, nchains::Integer = 4)
1414
return EpiMethod(
@@ -19,27 +19,28 @@ function make_inference_method(pipeline::AbstractEpiAwarePipeline; ndraws::Integ
1919
end
2020

2121
"""
22-
Method for sampling from prior predictive distribution of the model.
22+
Example pipeline.
2323
"""
24-
function make_inference_method(pipeline::RtwithoutRenewalPriorPipeline; n_samples = 2_000)
24+
function make_inference_method(
25+
pipeline::EpiAwareExamplePipeline; ndraws::Integer = 20,
26+
mcmc_ensemble::AbstractMCMC.AbstractMCMCEnsemble = MCMCThreads(),
27+
nruns_pthf::Integer = 4, maxiters_pthf::Integer = 100, nchains::Integer = 4)
2528
return EpiMethod(
26-
pre_sampler_steps = AbstractEpiOptMethod[],
27-
sampler = DirectSample(n_samples = n_samples)
29+
pre_sampler_steps = [ManyPathfinder(nruns = nruns_pthf, maxiters = maxiters_pthf)],
30+
sampler = NUTSampler(
31+
target_acceptance = 0.9, adtype = AutoReverseDiff(; compile = true),
32+
ndraws = ndraws, nchains = nchains, mcmc_parallel = mcmc_ensemble)
2833
)
2934
end
3035

3136
"""
32-
Pipeline test mode method for sampling from prior predictive distribution of the model.
37+
Method for sampling from prior predictive distribution of the model.
3338
"""
3439
function make_inference_method(
35-
pipeline::EpiAwareExamplePipeline; ndraws::Integer = 20,
36-
mcmc_ensemble::AbstractMCMC.AbstractMCMCEnsemble = MCMCThreads(),
37-
nruns_pthf::Integer = 4, maxiters_pthf::Integer = 100, nchains::Integer = 4)
40+
pipeline::AbstractRtwithoutRenewalPipeline, ::Val{:priorpredictive})
3841
return EpiMethod(
39-
pre_sampler_steps = [ManyPathfinder(nruns = nruns_pthf, maxiters = maxiters_pthf)],
40-
sampler = NUTSampler(
41-
target_acceptance = 0.9, adtype = AutoReverseDiff(; compile = true), ndraws = ndraws,
42-
nchains = nchains, mcmc_parallel = mcmc_ensemble)
42+
pre_sampler_steps = AbstractEpiOptMethod[],
43+
sampler = DirectSample(n_samples = pipeline.ndraws)
4344
)
4445
end
4546

@@ -55,7 +56,12 @@ Constructs an inference method for the Rt-without-renewal pipeline.
5556
# Examples
5657
"""
5758
function make_inference_method(pipeline::AbstractRtwithoutRenewalPipeline)
58-
return make_inference_method(pipeline; ndraws = pipeline.ndraws,
59-
mcmc_ensemble = pipeline.mcmc_ensemble, nruns_pthf = pipeline.nruns_pthf,
60-
maxiters_pthf = pipeline.maxiters_pthf, nchains = pipeline.nchains)
59+
if pipeline.priorpredictive
60+
return make_inference_method(pipeline, Val(:priorpredictive))
61+
else
62+
return make_inference_method(
63+
pipeline.ndraws, pipeline; mcmc_ensemble = pipeline.mcmc_ensemble,
64+
nruns_pthf = pipeline.nruns_pthf,
65+
maxiters_pthf = pipeline.maxiters_pthf, nchains = pipeline.nchains)
66+
end
6167
end

pipeline/src/infer/InferenceConfig.jl

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,17 @@ struct InferenceConfig{T, F, IGP, L, O, E, D <: Distribution, X <: Integer}
3232
transformation::F
3333
log_I0_prior::D
3434
lookahead::X
35+
latent_model_name::String
3536

3637
function InferenceConfig(igp, latent_model, observation_model; gi_mean, gi_std,
3738
case_data, truth_I_t, truth_I0, tspan, epimethod,
38-
transformation = exp, log_I0_prior, lookahead)
39-
new{typeof(gi_mean), typeof(transformation),
40-
typeof(igp), typeof(latent_model), typeof(observation_model),
39+
transformation = exp, log_I0_prior, lookahead, latent_model_name)
40+
new{typeof(gi_mean), typeof(transformation), typeof(igp),
41+
typeof(latent_model), typeof(observation_model),
4142
typeof(epimethod), typeof(log_I0_prior), typeof(lookahead)}(
4243
gi_mean, gi_std, igp, latent_model, observation_model,
43-
case_data, truth_I_t, truth_I0, tspan, epimethod, transformation, log_I0_prior, lookahead)
44+
case_data, truth_I_t, truth_I0, tspan, epimethod,
45+
transformation, log_I0_prior, lookahead, latent_model_name)
4446
end
4547

4648
function InferenceConfig(
@@ -57,41 +59,73 @@ struct InferenceConfig{T, F, IGP, L, O, E, D <: Distribution, X <: Integer}
5759
tspan = tspan,
5860
epimethod = epimethod,
5961
log_I0_prior = inference_config["log_I0_prior"],
60-
lookahead = inference_config["lookahead"]
62+
lookahead = inference_config["lookahead"],
63+
latent_model_name = inference_config["latent_namemodels"].first
6164
)
6265
end
6366
end
6467

6568
"""
66-
This method makes inference on the underlying parameters of the model specified
69+
This function makes inference on the underlying parameters of the model specified
6770
in the `InferenceConfig` object `config`.
6871
6972
# Arguments
7073
- `config::InferenceConfig`: The configuration object containing the case data
7174
to make inference on and model configuration.
75+
- `epiprob::EpiProblem`: The EpiProblem object containing the model to make inference on.
7276
7377
# Returns
7478
- `inference_results`: The results of the simulation or inference.
7579
7680
"""
77-
function infer(config::InferenceConfig)
78-
#Define the EpiProblem
79-
epiprob = define_epiprob(config)
80-
idxs = config.tspan[1]:config.tspan[2]
81-
81+
function create_inference_results(config, epiprob)
8282
#Return the sampled infections and observations
8383
y_t = ismissing(config.case_data) ? missing :
8484
Vector{Union{Missing, Int64}}(config.case_data[idxs])
85+
inference_results = apply_method(epiprob,
86+
config.epimethod,
87+
(y_t = y_t,)
88+
)
8589
inference_results = apply_method(epiprob,
8690
config.epimethod,
8791
(y_t = y_t,);
8892
)
93+
return inference_results
94+
end
95+
96+
"""
97+
This method makes inference on the underlying parameters of the model specified
98+
in the `InferenceConfig` object `config`.
99+
100+
# Arguments
101+
- `config::InferenceConfig`: The configuration object containing the case data
102+
to make inference on and model configuration.
103+
104+
# Returns
105+
- `inference_results`: The results of the simulation or inference.
106+
107+
"""
108+
function infer(config::InferenceConfig)
109+
#Define the EpiProblem
110+
epiprob = define_epiprob(config)
111+
idxs = config.tspan[1]:config.tspan[2]
89112

90-
forecast_results = generate_forecasts(
91-
inference_results.samples, inference_results.data, epiprob, config.lookahead)
113+
#Return the sampled infections and observations
114+
inference_results = create_inference_results(config, epiprob)
115+
116+
forecast_results = try
117+
generate_forecasts(
118+
inference_results.samples, inference_results.data, epiprob, config.lookahead)
119+
catch e
120+
e
121+
end
92122

93123
epidata = epiprob.epi_model.data
94-
score_results = summarise_crps(config, inference_results, forecast_results, epidata)
124+
score_results = try
125+
summarise_crps(config, inference_results, forecast_results, epidata)
126+
catch e
127+
e
128+
end
95129

96130
return Dict("inference_results" => inference_results,
97131
"epiprob" => epiprob, "inference_config" => config,

pipeline/src/infer/generate_inference_results.jl

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ Generate inference results based on the given configuration of inference model o
1414
- `inference_results`: The generated inference results.
1515
"""
1616
function generate_inference_results(
17-
truthdata, inference_config, pipeline::AbstractEpiAwarePipeline;
18-
inference_method)
17+
truthdata, inference_config, pipeline::AbstractEpiAwarePipeline)
1918
tspan = make_tspan(
2019
pipeline; T = inference_config["T"], lookback = inference_config["lookback"])
20+
inference_method = make_inference_method(pipeline)
2121
config = InferenceConfig(
2222
inference_config; case_data = truthdata["y_t"], truth_I_t = truthdata["I_t"],
2323
truth_I0 = truthdata["truth_I0"], tspan, epimethod = inference_method)
@@ -50,9 +50,10 @@ which is deleted after the function call.
5050
- `inference_results`: The generated inference results.
5151
"""
5252
function generate_inference_results(
53-
truthdata, inference_config, pipeline::EpiAwareExamplePipeline; inference_method)
53+
truthdata, inference_config, pipeline::EpiAwareExamplePipeline)
5454
tspan = make_tspan(
5555
pipeline; T = inference_config["T"], lookback = inference_config["lookback"])
56+
inference_method = make_inference_method(pipeline)
5657
config = InferenceConfig(
5758
inference_config; case_data = truthdata["y_t"], truth_I_t = truthdata["I_t"],
5859
truth_I0 = truthdata["truth_I0"], tspan = tspan, epimethod = inference_method)
@@ -66,23 +67,3 @@ function generate_inference_results(
6667
infer, config, datadir_name; prefix = prfx)
6768
return inference_results
6869
end
69-
70-
"""
71-
Method for prior predictive modelling.
72-
"""
73-
function generate_inference_results(
74-
inference_config, pipeline::RtwithoutRenewalPriorPipeline)
75-
tspan = make_tspan(
76-
pipeline; T = inference_config["T"], lookback = inference_config["lookback"])
77-
config = InferenceConfig(
78-
inference_config; case_data = missing, tspan, epimethod = DirectSample())
79-
80-
# produce or load inference results
81-
prfx = _inference_prefix(truthdata, inference_config, pipeline)
82-
83-
datadir_name = mktempdir()
84-
85-
inference_results, inferencefile = produce_or_load(
86-
infer, config, datadir(datadir_name); prefix = prfx)
87-
return inference_results
88-
end

pipeline/src/infer/map_inference_results.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ tasks from `Dagger.@spawn`.
1515
1616
"""
1717
function map_inference_results(
18-
truthdata, inference_configs, pipeline::AbstractEpiAwarePipeline; inference_method)
18+
truthdata, inference_configs, pipeline::AbstractEpiAwarePipeline)
1919
map(inference_configs) do inference_config
2020
Dagger.@spawn generate_inference_results(
21-
truthdata, inference_config, pipeline; inference_method)
21+
truthdata, inference_config, pipeline)
2222
end
2323
end

pipeline/src/pipeline/do_inference.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ An array of inference results.
1111
"""
1212
function do_inference(truthdata, pipeline::AbstractEpiAwarePipeline)
1313
inference_configs = make_inference_configs(pipeline)
14-
inference_method = make_inference_method(pipeline)
1514
inference_results = map_inference_results(
16-
truthdata, inference_configs, pipeline; inference_method)
15+
truthdata, inference_configs, pipeline)
1716
return inference_results
1817
end

pipeline/src/pipeline/pipelinetypes.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ Rt = make_Rt(pipeline) |> Rt -> plot(Rt,
4545
nchains::Integer = 4
4646
prefix::String = "smooth_outbreak"
4747
testmode::Bool = false
48+
priorpredictive::Bool = false
4849
end
4950

5051
"""
@@ -59,6 +60,7 @@ The pipeline type for the Rt pipeline for an outbreak scenario where Rt has
5960
nchains::Integer = 4
6061
prefix::String = "measures_outbreak"
6162
testmode::Bool = false
63+
priorpredictive::Bool = false
6264
end
6365

6466
"""
@@ -73,6 +75,7 @@ The pipeline type for the Rt pipeline for an endemic scenario where Rt changes i
7375
nchains::Integer = 4
7476
prefix::String = "smooth_endemic"
7577
testmode::Bool = false
78+
priorpredictive::Bool = false
7679
end
7780

7881
"""
@@ -87,4 +90,5 @@ The pipeline type for the Rt pipeline for an endemic scenario where Rt changes i
8790
nchains::Integer = 4
8891
prefix::String = "rough_endemic"
8992
testmode::Bool = false
93+
priorpredictive::Bool = false
9094
end

pipeline/src/plotting/plotting.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ include("basicplots.jl")
22
include("df_checking.jl")
33
include("figureone.jl")
44
include("figuretwo.jl")
5+
include("prior_predictive_plot.jl")
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""
2+
Internal method for generating a title string for a prior predictive plot based on the provided
3+
configuration.
4+
5+
# Arguments
6+
- `config::InferenceConfig`: `InferenceConfig` object containing the configuration for the
7+
prior predictive plot.
8+
# Returns
9+
- `String`: A formatted title string for the prior predictive plot.
10+
11+
"""
12+
function _make_prior_plot_title(config::InferenceConfig)
13+
igp_str = string(config.igp)
14+
latent_model_str = config.latent_model_name |> uppercase
15+
gi_mean_str = config.gi_mean |> string
16+
T_str = config.tspan[2] |> string
17+
return "Prior pred. IGP: $(igp_str), latent model: $(latent_model_str), truth gi mean: $(gi_mean_str), T: $(T_str)"
18+
end
19+
20+
"""
21+
Generate a prior predictive plot for the given configuration, output, and epidemiological probabilities.
22+
23+
# Arguments
24+
- `config`: Configuration object containing settings for the plot.
25+
- `output`: Output object containing generated data for plotting.
26+
- `epiprob`: Epidemiological probabilities object.
27+
- `ps`: Array of percentiles for quantile calculations (default: [0.05, 0.1, 0.25]).
28+
- `bottom_alpha`: Opacity for the lowest percentile band (default: 0.1).
29+
- `top_alpha`: Opacity for the highest percentile band (default: 0.5).
30+
- `case_color`: Color for the cases plot (default: :black).
31+
- `logI_color`: Color for the log(Incidence) plot (default: :purple).
32+
- `rt_color`: Color for the exponential growth rate plot (default: :blue).
33+
- `Rt_color`: Color for the reproduction number plot (default: :green).
34+
- `figsize`: Tuple specifying the size of the figure (default: (750, 600)).
35+
36+
# Returns
37+
- `fig`: A Figure object containing the prior predictive plots.
38+
39+
# Notes
40+
- The function asserts that all percentiles in `ps` are in the range [0, 0.5).
41+
- The function creates a 2x2 grid of subplots with linked x-axes for the top and bottom rows.
42+
- The function plots the median and percentile bands for cases, log(Incidence), exponential growth rate, and reproduction number.
43+
"""
44+
function prior_predictive_plot(config, output, epiprob;
45+
ps = [0.05, 0.1, 0.25],
46+
bottom_alpha = 0.1,
47+
top_alpha = 0.5,
48+
case_color = :black,
49+
logI_color = :purple,
50+
rt_color = :blue,
51+
Rt_color = :green,
52+
figsize = (750, 600))
53+
@assert all(0 .<= ps .< 0.5) "Percentiles must be in the range [0, 0.5)"
54+
prior_pred_plot_title = _make_prior_plot_title(config)
55+
qs, n_levels = _setup_levels(sort(ps))
56+
opacity_scale = range(bottom_alpha, top_alpha, length = n_levels) |> collect
57+
58+
# Create the figure and axes
59+
fig = Figure(size = figsize)
60+
ax11 = Axis(fig[1, 1]; xlabel = "t", ylabel = "Cases")
61+
ax12 = Axis(fig[1, 2]; xlabel = "t", ylabel = "log(Incidence)")
62+
ax21 = Axis(fig[2, 1]; xlabel = "t", ylabel = "Exp. growth rate")
63+
ax22 = Axis(fig[2, 2]; xlabel = "t", ylabel = "Reproduction number")
64+
linkxaxes!(ax11, ax21)
65+
linkxaxes!(ax12, ax22)
66+
Label(fig[0, :]; text = prior_pred_plot_title, fontsize = 16)
67+
68+
# Quantile calculations
69+
gen_y_t = mapreduce(hcat, output.generated) do gen
70+
gen.generated_y_t
71+
end |> X -> timeseries_samples_into_quantiles(X, qs)
72+
gen_quantities = generate_quantiles_for_targets(output, epiprob.epi_model.data, qs)
73+
74+
# Plot the prior predictive samples
75+
# Cases
76+
f = findfirst(!ismissing, gen_y_t[:, 1])
77+
lines!(ax11, 1:size(gen_y_t, 1), gen_y_t[:, 1],
78+
color = case_color, linewidth = 3, label = "Median")
79+
for i in 1:n_levels
80+
band!(ax11, f:size(gen_y_t, 1), gen_y_t[f:size(gen_y_t, 1), (2 * i)],
81+
gen_y_t[f:size(gen_y_t, 1), (2 * i) + 1],
82+
color = (case_color, opacity_scale[i]),
83+
label = "($(ps[i]*100)-$((1 - ps[i])*100))%")
84+
end
85+
vlines!(ax11, [f], color = case_color, linestyle = :dash, label = "Obs. window")
86+
axislegend(ax11; position = :lt, framevisible = false)
87+
88+
# Other quantities
89+
for (ax, target, c) in zip(
90+
[ax12, ax21, ax22], [gen_quantities.log_I_t, gen_quantities.rt, gen_quantities.Rt],
91+
[logI_color, rt_color, Rt_color])
92+
lines!(ax, 1:size(target, 1), target[:, 1],
93+
color = logI_color, linewidth = 3, label = "Median")
94+
for i in 1:n_levels
95+
band!(ax, 1:size(target, 1), target[:, (2 * i)], target[:, (2 * i) + 1],
96+
color = (c, opacity_scale[i]), label = "")
97+
end
98+
end
99+
100+
fig
101+
end

0 commit comments

Comments
 (0)