Skip to content

Commit af6dfb6

Browse files
authored
Refactor prior construction and make clearer the logic (#566)
1 parent 1b3d397 commit af6dfb6

File tree

2 files changed

+120
-87
lines changed

2 files changed

+120
-87
lines changed

pipeline/src/constructors/remake_latent_model.jl

Lines changed: 109 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -3,92 +3,123 @@ Constructs and returns a latent model based on the provided `inference_config` a
33
The purpose of this function is to make adjustments to the latent model based on the
44
full `inference_config` provided.
55
6-
The `tscale` argument is used to scale the standard deviation of the latent model based on the
7-
idea that some processes have a variance that is (approximately) proportional to a time period (due to non-stationarity)
8-
and some processes have a variance that is constant in time (at stationarity). The default
9-
value is `sqrt(21.0)`, which corresponds to matching the variance of stationary processes to
10-
the eventual variance of non-stationary process after 21 days.
11-
126
The `pipeline` argument is used for dispatch purposes.
137
14-
# Returns
15-
- A latent model object which can be one of `DiffLatentModel`, `AR`, or `RandomWalk` depending on the `latent_model_name` and `igp` specified in `inference_config`.
8+
The prior decisions are based on the target standard deviation and autocorrelation of the latent process,
9+
which are determined by the infection generating process (igp) and whether the latent process is stationary or non-stationary
10+
via the `_make_target_std_and_autocorr` function.
1611
17-
# Details
18-
- The function first constructs a dictionary of priors using `make_model_priors(pipeline)`.
19-
- It then retrieves the `igp` (inference generation process) and `latent_model_name` from `inference_config`.
20-
- Depending on the `latent_model_name` and `igp`, it constructs and returns the appropriate latent model:
21-
- `"diff_ar"`: Constructs a `DiffLatentModel` with an `AR` model.
22-
- `"ar"`: Constructs an `AR` model.
23-
- `"rw"`: Constructs a `RandomWalk` model.
24-
- The priors for the models are set based on the `prior_dict` and the `tscale` parameter.
2512
13+
# Returns
14+
- A latent model object which can be one of `DiffLatentModel`, `AR`, or `RandomWalk` depending on the `latent_model_name` and `igp` specified in `inference_config`.
2615
"""
27-
function remake_latent_model(inference_config::Dict,
28-
pipeline::AbstractRtwithoutRenewalPipeline; tscale = sqrt(21.0))
16+
function remake_latent_model(
17+
inference_config::Dict, pipeline::AbstractRtwithoutRenewalPipeline)
2918
#Baseline choices
3019
prior_dict = make_model_priors(pipeline)
3120
igp = inference_config["igp"]
32-
latent_model_name = inference_config["latent_namemodels"].first
33-
34-
if latent_model_name == "diff_ar"
35-
if igp == Renewal
36-
ar = AR(damp_priors = [prior_dict["damp_param_prior"]],
37-
std_prior = HalfNormal(0.05 / tscale),
38-
init_priors = [prior_dict["transformed_process_init_prior"]])
39-
diff_ar = DiffLatentModel(;
40-
model = ar, init_priors = [prior_dict["transformed_process_init_prior"]])
41-
return diff_ar
42-
elseif igp == ExpGrowthRate
43-
ar = AR(damp_priors = [prior_dict["damp_param_prior"]],
44-
std_prior = HalfNormal(0.005 / tscale),
45-
init_priors = [prior_dict["transformed_process_init_prior"]])
46-
diff_ar = DiffLatentModel(;
47-
model = ar, init_priors = [prior_dict["transformed_process_init_prior"]])
48-
return diff_ar
49-
elseif igp == DirectInfections
50-
ar = AR(damp_priors = [Beta(9, 1)],
51-
std_prior = HalfNormal(0.05 / tscale),
52-
init_priors = [prior_dict["transformed_process_init_prior"]])
53-
diff_ar = DiffLatentModel(;
54-
model = ar, init_priors = [prior_dict["transformed_process_init_prior"]])
55-
return diff_ar
56-
end
57-
elseif latent_model_name == "ar"
58-
if igp == Renewal
59-
ar = AR(damp_priors = [Beta(2, 8)],
60-
std_prior = HalfNormal(0.25),
61-
init_priors = [prior_dict["transformed_process_init_prior"]])
62-
return ar
63-
elseif igp == ExpGrowthRate
64-
ar = AR(damp_priors = [prior_dict["damp_param_prior"]],
65-
std_prior = HalfNormal(0.025),
66-
init_priors = [prior_dict["transformed_process_init_prior"]])
67-
return ar
68-
elseif igp == DirectInfections
69-
ar = AR(damp_priors = [Beta(9, 1)],
70-
std_prior = HalfNormal(0.25),
71-
init_priors = [prior_dict["transformed_process_init_prior"]])
72-
return ar
73-
end
74-
elseif latent_model_name == "rw"
75-
if igp == Renewal
76-
rw = RandomWalk(
77-
std_prior = HalfNormal(0.05 / tscale),
78-
init_prior = prior_dict["transformed_process_init_prior"])
79-
return rw
80-
elseif igp == ExpGrowthRate
81-
rw = RandomWalk(
82-
std_prior = HalfNormal(0.005 / tscale),
83-
init_prior = prior_dict["transformed_process_init_prior"])
84-
return rw
85-
elseif igp == DirectInfections
86-
rw = RandomWalk(
87-
std_prior = HalfNormal(0.1 / tscale),
88-
init_prior = prior_dict["transformed_process_init_prior"])
89-
return rw
90-
end
91-
end
21+
default_latent_model = inference_config["latent_namemodels"].second
22+
target_std, target_autocorr = default_latent_model isa AR ?
23+
_make_target_std_and_autocorr(igp; stationary = true) :
24+
_make_target_std_and_autocorr(igp; stationary = false)
25+
26+
return _implement_latent_process(
27+
target_std, target_autocorr, default_latent_model, pipeline)
28+
end
29+
30+
"""
31+
This function sets the target standard deviation for an infection generating process (igp)
32+
based on whether the latent process representation of its dynamics are stationary or non-stationary.
33+
34+
## Stationary Processes
35+
36+
- For Renewal process `log(R_t)` in the long run a fluctuation of 0.75 (e.g. ~ 75% of the mean) is not unexpected.
37+
- For Exponential Growth Rate process `r_t` in the long run a fluctuation of 0.2 is not unexpected e.g. going from
38+
`rt = 0.1` (7 day doubling time) to `rt = -0.1` (7 day halving time) is a 0.2 time-to-time fluctuation.
39+
- For Direct Infections process `log(I_t)` in the long run a fluctuation of 2.0 (i.e a couple of orders of magnitude) is not unexpected.
40+
41+
For stationary latent processes Direct Infections and rt processes the autocorrelation is expected to be high at 0.9,
42+
because persistence in residual away from mean is expected. Otherwise, the autocorrelation is expected to be 0.1.
43+
44+
## Non-Stationary Processes
45+
46+
For Renewal process `log(R_t)` in a single time step a fluctuation of 0.025 (e.g. ~ 2.5% of the mean) is not unexpected.
47+
For Exponential Growth Rate process `r_t` in a single time step a fluctuation of 0.005 is not unexpected.
48+
For Direct Infections process `log(I_t)` in a single time step a fluctuation of 0.025 is not unexpected.
49+
50+
The autocorrelation is expected to be 0.1.
51+
"""
52+
function _make_target_std_and_autocorr(::Type{Renewal}; stationary::Bool)
53+
return stationary ? (0.75, 0.1) : (0.025, 0.1)
54+
end
55+
56+
function _make_target_std_and_autocorr(::Type{ExpGrowthRate}; stationary::Bool)
57+
return stationary ? (0.2, 0.9) : (0.005, 0.1)
58+
end
59+
60+
function _make_target_std_and_autocorr(::Type{DirectInfections}; stationary::Bool)
61+
return stationary ? (2.0, 0.9) : (0.25, 0.1)
62+
end
63+
64+
function _make_new_prior_dict(target_std, target_autocorr,
65+
pipeline::AbstractRtwithoutRenewalPipeline; beta_eff_sample_size)
66+
#Get default priors
67+
prior_dict = make_model_priors(pipeline)
68+
#Adjust priors based on target autocorrelation and standard deviation
69+
damp_prior = Beta(target_autocorr * beta_eff_sample_size,
70+
(1 - target_autocorr) * beta_eff_sample_size)
71+
corr_corrected_noise_prior = HalfNormal(target_std * sqrt(1 - target_autocorr^2))
72+
noise_prior = HalfNormal(target_std)
73+
init_prior = prior_dict["transformed_process_init_prior"]
74+
return Dict(
75+
"transformed_process_init_prior" => init_prior,
76+
"corr_corrected_noise_prior" => corr_corrected_noise_prior,
77+
"noise_prior" => noise_prior,
78+
"damp_param_prior" => damp_prior
79+
)
80+
end
81+
82+
"""
83+
Constructs and returns a latent model based on an approximation to the specified target standard deviation and autocorrelation.
84+
85+
NB: The stationary variance of an AR(1) process is given by `σ² = σ²_ε / (1 - ρ²)` where `σ²_ε` is the variance of the noise and `ρ` is the autocorrelation.
86+
The approximation here are based on `E[1/(1 - ρ²)`] ≈ 1 / (1 - E[ρ²])` which only holds for fairly tight distributions of `ρ`.
87+
However, for priors this should get the expected order of magnitude.
88+
89+
# Models
90+
- `"diff_ar"`: Constructs a `DiffLatentModel` with an autoregressive (AR) process.
91+
- `"ar"`: Constructs an autoregressive (AR) process.
92+
- `"rw"`: Constructs a random walk (RW) process.
93+
94+
"""
95+
function _implement_latent_process(
96+
target_std, target_autocorr, default_latent_model, pipeline; beta_eff_sample_size = 10)
97+
prior_dict = make_model_priors(pipeline)
98+
new_priors = _make_new_prior_dict(
99+
target_std, target_autocorr, pipeline; beta_eff_sample_size)
100+
101+
return _make_latent(default_latent_model, new_priors)
102+
end
103+
104+
function _make_latent(::AR, new_priors)
105+
damp_prior = new_priors["damp_param_prior"]
106+
corr_corrected_noise_std = new_priors["corr_corrected_noise_prior"]
107+
init_prior = new_priors["transformed_process_init_prior"]
108+
return AR(damp_priors = [damp_prior],
109+
std_prior = corr_corrected_noise_std,
110+
init_priors = [init_prior])
111+
end
112+
113+
function _make_latent(::DiffLatentModel, new_priors)
114+
init_prior = new_priors["transformed_process_init_prior"]
115+
ar = _make_latent(AR(), new_priors)
116+
return DiffLatentModel(; model = ar, init_priors = [init_prior])
117+
end
118+
119+
function _make_latent(::RandomWalk, new_priors)
120+
noise_std = new_priors["noise_prior"]
121+
init_prior = new_priors["transformed_process_init_prior"]
122+
return RandomWalk(std_prior = noise_std, init_prior = init_prior)
92123
end
93124

94125
"""

pipeline/test/constructors/remake_latent_model.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,51 @@
77
)
88
end
99
pipeline = MockPipeline()
10-
10+
ar = AR()
11+
diff_ar = DiffLatentModel(model = ar)
12+
rw = RandomWalk()
1113
@testset "diff_ar model" begin
1214
inference_config = Dict(
13-
"igp" => ExpGrowthRate, "latent_namemodels" => ("diff_ar" => "diff_ar"))
15+
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("diff_ar", diff_ar))
1416
model = remake_latent_model(inference_config, pipeline)
1517
@test model isa DiffLatentModel
1618
@test model.model isa AR
1719

1820
inference_config = Dict(
19-
"igp" => DirectInfections, "latent_namemodels" => ("diff_ar" => "diff_ar"))
21+
"igp" => DirectInfections, "latent_namemodels" => Pair("diff_ar", diff_ar))
2022
model = remake_latent_model(inference_config, pipeline)
2123
@test model isa DiffLatentModel
2224
@test model.model isa AR
2325
end
2426

2527
@testset "ar model" begin
26-
inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("ar", "ar"))
28+
inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("ar", ar))
2729
model = remake_latent_model(inference_config, pipeline)
2830
@test model isa AR
2931

3032
inference_config = Dict(
31-
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("ar", "ar"))
33+
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("ar", ar))
3234
model = remake_latent_model(inference_config, pipeline)
3335
@test model isa AR
3436

3537
inference_config = Dict(
36-
"igp" => DirectInfections, "latent_namemodels" => Pair("ar", "ar"))
38+
"igp" => DirectInfections, "latent_namemodels" => Pair("ar", ar))
3739
model = remake_latent_model(inference_config, pipeline)
3840
@test model isa AR
3941
end
4042

4143
@testset "rw model" begin
42-
inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("rw", "rw"))
44+
inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("rw", rw))
4345
model = remake_latent_model(inference_config, pipeline)
4446
@test model isa RandomWalk
4547

4648
inference_config = Dict(
47-
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("rw", "rw"))
49+
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("rw", rw))
4850
model = remake_latent_model(inference_config, pipeline)
4951
@test model isa RandomWalk
5052

5153
inference_config = Dict(
52-
"igp" => DirectInfections, "latent_namemodels" => Pair("rw", "rw"))
54+
"igp" => DirectInfections, "latent_namemodels" => Pair("rw", rw))
5355
model = remake_latent_model(inference_config, pipeline)
5456
@test model isa RandomWalk
5557
end

0 commit comments

Comments
 (0)