Skip to content

Commit 94b4643

Browse files
authored
Issue 405: model specific priors (#565)
1 parent 0d4095e commit 94b4643

14 files changed

+411
-207
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Prior predictive plots

pipeline/src/EpiAwarePipeline.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ export TruthSimulationConfig, InferenceConfig
2727
export make_gi_params, make_inf_generating_processes, make_model_priors,
2828
make_epiaware_name_latentmodel_pairs, make_Rt, make_truth_data_configs,
2929
make_default_params, make_inference_configs, make_tspan, make_inference_method,
30-
make_delay_distribution, make_delay_distribution, make_observation_model
30+
make_delay_distribution, make_delay_distribution, make_observation_model,
31+
remake_latent_model
3132

3233
# Exported functions: pipeline components
3334
export do_truthdata, do_inference, do_pipeline

pipeline/src/constructors/constructors.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ include("make_tspan.jl")
1111
include("make_default_params.jl")
1212
include("make_delay_distribution.jl")
1313
include("make_observation_model.jl")
14+
include("remake_latent_model.jl")

pipeline/src/constructors/make_model_priors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ deviation 1e-1.
1616
1717
"""
1818
function make_model_priors(pipeline::AbstractEpiAwarePipeline)
19-
transformed_process_init_prior = Normal(0.0, 0.25)
19+
transformed_process_init_prior = Normal(0.0, 0.1)
2020
std_prior = HalfNormal(0.025)
2121
damp_param_prior = Beta(1, 9)
2222
log_I0_prior = Normal(log(100.0), 1e-1)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
Constructs and returns a latent model based on the provided `inference_config` and `pipeline`.
3+
The purpose of this function is to make adjustments to the latent model based on the
4+
full `inference_config` provided.
5+
6+
The `pipeline` argument is used for dispatch purposes.
7+
8+
The prior decisions are based on the target standard deviation and autocorrelation of the latent process,
9+
which are determined by the infection generating process (igp) and whether the latent process is stationary or non-stationary
10+
via the `_make_target_std_and_autocorr` function.
11+
12+
13+
# Returns
14+
- A latent model object which can be one of `DiffLatentModel`, `AR`, or `RandomWalk` depending on the `latent_model_name` and `igp` specified in `inference_config`.
15+
"""
16+
function remake_latent_model(
17+
inference_config::Dict, pipeline::AbstractRtwithoutRenewalPipeline)
18+
#Baseline choices
19+
prior_dict = make_model_priors(pipeline)
20+
igp = inference_config["igp"]
21+
default_latent_model = inference_config["latent_namemodels"].second
22+
target_std, target_autocorr = default_latent_model isa AR ?
23+
_make_target_std_and_autocorr(igp; stationary = true) :
24+
_make_target_std_and_autocorr(igp; stationary = false)
25+
26+
return _implement_latent_process(
27+
target_std, target_autocorr, default_latent_model, pipeline)
28+
end
29+
30+
"""
31+
This function sets the target standard deviation for an infection generating process (igp)
32+
based on whether the latent process representation of its dynamics are stationary or non-stationary.
33+
34+
## Stationary Processes
35+
36+
- For Renewal process `log(R_t)` in the long run a fluctuation of 0.75 (e.g. ~ 75% of the mean) is not unexpected.
37+
- For Exponential Growth Rate process `r_t` in the long run a fluctuation of 0.2 is not unexpected e.g. going from
38+
`rt = 0.1` (7 day doubling time) to `rt = -0.1` (7 day halving time) is a 0.2 time-to-time fluctuation.
39+
- For Direct Infections process `log(I_t)` in the long run a fluctuation of 2.0 (i.e a couple of orders of magnitude) is not unexpected.
40+
41+
For stationary latent processes Direct Infections and rt processes the autocorrelation is expected to be high at 0.9,
42+
because persistence in residual away from mean is expected. Otherwise, the autocorrelation is expected to be 0.1.
43+
44+
## Non-Stationary Processes
45+
46+
For Renewal process `log(R_t)` in a single time step a fluctuation of 0.025 (e.g. ~ 2.5% of the mean) is not unexpected.
47+
For Exponential Growth Rate process `r_t` in a single time step a fluctuation of 0.005 is not unexpected.
48+
For Direct Infections process `log(I_t)` in a single time step a fluctuation of 0.025 is not unexpected.
49+
50+
The autocorrelation is expected to be 0.1.
51+
"""
52+
function _make_target_std_and_autocorr(::Type{Renewal}; stationary::Bool)
53+
return stationary ? (0.75, 0.1) : (0.025, 0.1)
54+
end
55+
56+
function _make_target_std_and_autocorr(::Type{ExpGrowthRate}; stationary::Bool)
57+
return stationary ? (0.2, 0.9) : (0.005, 0.1)
58+
end
59+
60+
function _make_target_std_and_autocorr(::Type{DirectInfections}; stationary::Bool)
61+
return stationary ? (2.0, 0.9) : (0.25, 0.1)
62+
end
63+
64+
function _make_new_prior_dict(target_std, target_autocorr,
65+
pipeline::AbstractRtwithoutRenewalPipeline; beta_eff_sample_size)
66+
#Get default priors
67+
prior_dict = make_model_priors(pipeline)
68+
#Adjust priors based on target autocorrelation and standard deviation
69+
damp_prior = Beta(target_autocorr * beta_eff_sample_size,
70+
(1 - target_autocorr) * beta_eff_sample_size)
71+
corr_corrected_noise_prior = HalfNormal(target_std * sqrt(1 - target_autocorr^2))
72+
noise_prior = HalfNormal(target_std)
73+
init_prior = prior_dict["transformed_process_init_prior"]
74+
return Dict(
75+
"transformed_process_init_prior" => init_prior,
76+
"corr_corrected_noise_prior" => corr_corrected_noise_prior,
77+
"noise_prior" => noise_prior,
78+
"damp_param_prior" => damp_prior
79+
)
80+
end
81+
82+
"""
83+
Constructs and returns a latent model based on an approximation to the specified target standard deviation and autocorrelation.
84+
85+
NB: The stationary variance of an AR(1) process is given by `σ² = σ²_ε / (1 - ρ²)` where `σ²_ε` is the variance of the noise and `ρ` is the autocorrelation.
86+
The approximation here are based on `E[1/(1 - ρ²)`] ≈ 1 / (1 - E[ρ²])` which only holds for fairly tight distributions of `ρ`.
87+
However, for priors this should get the expected order of magnitude.
88+
89+
# Models
90+
- `"diff_ar"`: Constructs a `DiffLatentModel` with an autoregressive (AR) process.
91+
- `"ar"`: Constructs an autoregressive (AR) process.
92+
- `"rw"`: Constructs a random walk (RW) process.
93+
94+
"""
95+
function _implement_latent_process(
96+
target_std, target_autocorr, default_latent_model, pipeline; beta_eff_sample_size = 10)
97+
prior_dict = make_model_priors(pipeline)
98+
new_priors = _make_new_prior_dict(
99+
target_std, target_autocorr, pipeline; beta_eff_sample_size)
100+
101+
return _make_latent(default_latent_model, new_priors)
102+
end
103+
104+
function _make_latent(::AR, new_priors)
105+
damp_prior = new_priors["damp_param_prior"]
106+
corr_corrected_noise_std = new_priors["corr_corrected_noise_prior"]
107+
init_prior = new_priors["transformed_process_init_prior"]
108+
return AR(damp_priors = [damp_prior],
109+
std_prior = corr_corrected_noise_std,
110+
init_priors = [init_prior])
111+
end
112+
113+
function _make_latent(::DiffLatentModel, new_priors)
114+
init_prior = new_priors["transformed_process_init_prior"]
115+
ar = _make_latent(AR(), new_priors)
116+
return DiffLatentModel(; model = ar, init_priors = [init_prior])
117+
end
118+
119+
function _make_latent(::RandomWalk, new_priors)
120+
noise_std = new_priors["noise_prior"]
121+
init_prior = new_priors["transformed_process_init_prior"]
122+
return RandomWalk(std_prior = noise_std, init_prior = init_prior)
123+
end
124+
125+
"""
126+
Pass through fallback dispatch.
127+
"""
128+
function remake_latent_model(inference_config::Dict, pipeline::AbstractEpiAwarePipeline)
129+
inference_config["latent_namemodels"].second
130+
end

pipeline/src/constructors/selector.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,15 @@ end
1616

1717
"""
1818
Internal method for selecting from a list of items based on the pipeline type.
19-
Example/test mode is to return a randomly selected item from the list.
19+
Example/test mode is to return a randomly selected item from the list. Prior predictive mode
20+
only runs on configurations with the furthest ahead horizon.
2021
"""
2122
function _selector(list, pipeline::AbstractRtwithoutRenewalPipeline)
22-
return pipeline.testmode ? [rand(list)] : list
23+
if pipeline.priorpredictive
24+
maxT = maximum([config["T"] for config in list])
25+
_list = filter(config -> config["T"] == maxT, list)
26+
return pipeline.testmode ? [rand(_list)] : _list
27+
else
28+
return pipeline.testmode ? [rand(list)] : list
29+
end
2330
end

pipeline/src/infer/InferenceConfig.jl

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Inference configuration struct for specifying the parameters and models used in
2323
"""
2424
struct InferenceConfig{
2525
T, F, IGP, L, O, E, D <: Distribution, X <: Integer,
26-
P <: AbstractRtwithoutRenewalPipeline}
26+
P <: AbstractEpiAwarePipeline}
2727
gi_mean::T
2828
gi_std::T
2929
igp::IGP
@@ -51,26 +51,29 @@ struct InferenceConfig{
5151
case_data, truth_I_t, truth_I0, tspan, epimethod,
5252
transformation, log_I0_prior, lookahead, latent_model_name, pipeline)
5353
end
54+
end
5455

55-
function InferenceConfig(
56-
inference_config::Dict; case_data, truth_I_t, truth_I0, tspan, epimethod, pipeline)
57-
InferenceConfig(
58-
inference_config["igp"],
59-
inference_config["latent_namemodels"].second,
60-
inference_config["observation_model"];
61-
gi_mean = inference_config["gi_mean"],
62-
gi_std = inference_config["gi_std"],
63-
case_data = case_data,
64-
truth_I_t = truth_I_t,
65-
truth_I0 = truth_I0,
66-
tspan = tspan,
67-
epimethod = epimethod,
68-
log_I0_prior = inference_config["log_I0_prior"],
69-
lookahead = inference_config["lookahead"],
70-
latent_model_name = inference_config["latent_namemodels"].first,
71-
pipeline
72-
)
73-
end
56+
function InferenceConfig(
57+
inference_config::Dict, pipeline::AbstractEpiAwarePipeline;
58+
case_data, truth_I_t, truth_I0, tspan, epimethod)
59+
latent_model = remake_latent_model(inference_config::Dict, pipeline)
60+
61+
InferenceConfig(
62+
inference_config["igp"],
63+
latent_model,
64+
inference_config["observation_model"];
65+
gi_mean = inference_config["gi_mean"],
66+
gi_std = inference_config["gi_std"],
67+
case_data = case_data,
68+
truth_I_t = truth_I_t,
69+
truth_I0 = truth_I0,
70+
tspan = tspan,
71+
epimethod = epimethod,
72+
log_I0_prior = inference_config["log_I0_prior"],
73+
lookahead = inference_config["lookahead"],
74+
latent_model_name = inference_config["latent_namemodels"].first,
75+
pipeline
76+
)
7477
end
7578

7679
"""

pipeline/src/infer/generate_inference_results.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ function generate_inference_results(
1919
pipeline; T = inference_config["T"], lookback = inference_config["lookback"])
2020
inference_method = make_inference_method(pipeline)
2121
config = InferenceConfig(
22-
inference_config; case_data = truthdata["y_t"], truth_I_t = truthdata["I_t"],
23-
truth_I0 = truthdata["truth_I0"], tspan, epimethod = inference_method, pipeline = pipeline)
22+
inference_config, pipeline; case_data = truthdata["y_t"], truth_I_t = truthdata["I_t"],
23+
truth_I0 = truthdata["truth_I0"], tspan, epimethod = inference_method)
2424

2525
# produce or load inference results
2626
prfx = _inference_prefix(truthdata, inference_config, pipeline)

0 commit comments

Comments
 (0)