Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 405: model specific priors #565

Merged
merged 7 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pipeline/plots/priorpredictive/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Prior predictive plots
3 changes: 2 additions & 1 deletion pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ export TruthSimulationConfig, InferenceConfig
export make_gi_params, make_inf_generating_processes, make_model_priors,
make_epiaware_name_latentmodel_pairs, make_Rt, make_truth_data_configs,
make_default_params, make_inference_configs, make_tspan, make_inference_method,
make_delay_distribution, make_delay_distribution, make_observation_model
make_delay_distribution, make_delay_distribution, make_observation_model,
remake_latent_model

# Exported functions: pipeline components
export do_truthdata, do_inference, do_pipeline
Expand Down
1 change: 1 addition & 0 deletions pipeline/src/constructors/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ include("make_tspan.jl")
include("make_default_params.jl")
include("make_delay_distribution.jl")
include("make_observation_model.jl")
include("remake_latent_model.jl")
2 changes: 1 addition & 1 deletion pipeline/src/constructors/make_model_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ deviation 1e-1.

"""
function make_model_priors(pipeline::AbstractEpiAwarePipeline)
transformed_process_init_prior = Normal(0.0, 0.25)
transformed_process_init_prior = Normal(0.0, 0.1)
seabbs marked this conversation as resolved.
Show resolved Hide resolved
std_prior = HalfNormal(0.025)
seabbs marked this conversation as resolved.
Show resolved Hide resolved
damp_param_prior = Beta(1, 9)
seabbs marked this conversation as resolved.
Show resolved Hide resolved
log_I0_prior = Normal(log(100.0), 1e-1)
Expand Down
99 changes: 99 additions & 0 deletions pipeline/src/constructors/remake_latent_model.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Constructs and returns a latent model based on the provided `inference_config` and `pipeline`.
The purpose of this function is to make adjustments to the latent model based on the
full `inference_config` provided.

The `tscale` argument is used to scale the standard deviation of the latent model based on the
idea that some processes have a variance that is (approximately) proportional to a time period (due to non-stationarity)
and some processes have a variance that is constant in time (at stationarity). The default
value is `sqrt(21.0)`, which corresponds to matching the variance of stationary processes to
the eventual variance of non-stationary process after 21 days.

The `pipeline` argument is used for dispatch purposes.

# Returns
- A latent model object which can be one of `DiffLatentModel`, `AR`, or `RandomWalk` depending on the `latent_model_name` and `igp` specified in `inference_config`.

# Details
- The function first constructs a dictionary of priors using `make_model_priors(pipeline)`.
- It then retrieves the `igp` (inference generation process) and `latent_model_name` from `inference_config`.
- Depending on the `latent_model_name` and `igp`, it constructs and returns the appropriate latent model:
- `"diff_ar"`: Constructs a `DiffLatentModel` with an `AR` model.
- `"ar"`: Constructs an `AR` model.
- `"rw"`: Constructs a `RandomWalk` model.
- The priors for the models are set based on the `prior_dict` and the `tscale` parameter.

"""
function remake_latent_model(inference_config::Dict,
pipeline::AbstractRtwithoutRenewalPipeline; tscale = sqrt(21.0))
#Baseline choices
prior_dict = make_model_priors(pipeline)
igp = inference_config["igp"]
latent_model_name = inference_config["latent_namemodels"].first

if latent_model_name == "diff_ar"
seabbs marked this conversation as resolved.
Show resolved Hide resolved
if igp == Renewal
ar = AR(damp_priors = [prior_dict["damp_param_prior"]],
seabbs marked this conversation as resolved.
Show resolved Hide resolved
std_prior = HalfNormal(0.05 / tscale),
init_priors = [prior_dict["transformed_process_init_prior"]])
diff_ar = DiffLatentModel(;
model = ar, init_priors = [prior_dict["transformed_process_init_prior"]])
return diff_ar
elseif igp == ExpGrowthRate
ar = AR(damp_priors = [prior_dict["damp_param_prior"]],
std_prior = HalfNormal(0.005 / tscale),
init_priors = [prior_dict["transformed_process_init_prior"]])
diff_ar = DiffLatentModel(;
model = ar, init_priors = [prior_dict["transformed_process_init_prior"]])
return diff_ar
elseif igp == DirectInfections
ar = AR(damp_priors = [Beta(9, 1)],
seabbs marked this conversation as resolved.
Show resolved Hide resolved
std_prior = HalfNormal(0.05 / tscale),
init_priors = [prior_dict["transformed_process_init_prior"]])
diff_ar = DiffLatentModel(;
model = ar, init_priors = [prior_dict["transformed_process_init_prior"]])
return diff_ar
end
elseif latent_model_name == "ar"
if igp == Renewal
ar = AR(damp_priors = [Beta(2, 8)],
std_prior = HalfNormal(0.25),
seabbs marked this conversation as resolved.
Show resolved Hide resolved
init_priors = [prior_dict["transformed_process_init_prior"]])
return ar
elseif igp == ExpGrowthRate
ar = AR(damp_priors = [prior_dict["damp_param_prior"]],
std_prior = HalfNormal(0.025),
seabbs marked this conversation as resolved.
Show resolved Hide resolved
init_priors = [prior_dict["transformed_process_init_prior"]])
return ar
elseif igp == DirectInfections
ar = AR(damp_priors = [Beta(9, 1)],
std_prior = HalfNormal(0.25),
init_priors = [prior_dict["transformed_process_init_prior"]])
return ar
end
elseif latent_model_name == "rw"
if igp == Renewal
rw = RandomWalk(
seabbs marked this conversation as resolved.
Show resolved Hide resolved
std_prior = HalfNormal(0.05 / tscale),
init_prior = prior_dict["transformed_process_init_prior"])
return rw
elseif igp == ExpGrowthRate
rw = RandomWalk(
std_prior = HalfNormal(0.005 / tscale),
init_prior = prior_dict["transformed_process_init_prior"])
return rw
elseif igp == DirectInfections
rw = RandomWalk(
std_prior = HalfNormal(0.1 / tscale),
init_prior = prior_dict["transformed_process_init_prior"])
return rw
end
end
end

"""
Pass through fallback dispatch.
"""
function remake_latent_model(inference_config::Dict, pipeline::AbstractEpiAwarePipeline)
inference_config["latent_namemodels"].second
end
11 changes: 9 additions & 2 deletions pipeline/src/constructors/selector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ end

"""
Internal method for selecting from a list of items based on the pipeline type.
Example/test mode is to return a randomly selected item from the list.
Example/test mode is to return a randomly selected item from the list. Prior predictive mode
only runs on configurations with the furthest ahead horizon.
"""
function _selector(list, pipeline::AbstractRtwithoutRenewalPipeline)
return pipeline.testmode ? [rand(list)] : list
if pipeline.priorpredictive
maxT = maximum([config["T"] for config in list])
_list = filter(config -> config["T"] == maxT, list)
return pipeline.testmode ? [rand(_list)] : _list
else
return pipeline.testmode ? [rand(list)] : list
end
end
43 changes: 23 additions & 20 deletions pipeline/src/infer/InferenceConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Inference configuration struct for specifying the parameters and models used in
"""
struct InferenceConfig{
T, F, IGP, L, O, E, D <: Distribution, X <: Integer,
P <: AbstractRtwithoutRenewalPipeline}
P <: AbstractEpiAwarePipeline}
gi_mean::T
gi_std::T
igp::IGP
Expand Down Expand Up @@ -51,26 +51,29 @@ struct InferenceConfig{
case_data, truth_I_t, truth_I0, tspan, epimethod,
transformation, log_I0_prior, lookahead, latent_model_name, pipeline)
end
end

function InferenceConfig(
inference_config::Dict; case_data, truth_I_t, truth_I0, tspan, epimethod, pipeline)
InferenceConfig(
inference_config["igp"],
inference_config["latent_namemodels"].second,
inference_config["observation_model"];
gi_mean = inference_config["gi_mean"],
gi_std = inference_config["gi_std"],
case_data = case_data,
truth_I_t = truth_I_t,
truth_I0 = truth_I0,
tspan = tspan,
epimethod = epimethod,
log_I0_prior = inference_config["log_I0_prior"],
lookahead = inference_config["lookahead"],
latent_model_name = inference_config["latent_namemodels"].first,
pipeline
)
end
function InferenceConfig(
inference_config::Dict, pipeline::AbstractEpiAwarePipeline;
case_data, truth_I_t, truth_I0, tspan, epimethod)
latent_model = remake_latent_model(inference_config::Dict, pipeline)

InferenceConfig(
inference_config["igp"],
latent_model,
inference_config["observation_model"];
gi_mean = inference_config["gi_mean"],
gi_std = inference_config["gi_std"],
case_data = case_data,
truth_I_t = truth_I_t,
truth_I0 = truth_I0,
tspan = tspan,
epimethod = epimethod,
log_I0_prior = inference_config["log_I0_prior"],
lookahead = inference_config["lookahead"],
latent_model_name = inference_config["latent_namemodels"].first,
pipeline
)
end

"""
Expand Down
4 changes: 2 additions & 2 deletions pipeline/src/infer/generate_inference_results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ function generate_inference_results(
pipeline; T = inference_config["T"], lookback = inference_config["lookback"])
inference_method = make_inference_method(pipeline)
config = InferenceConfig(
inference_config; case_data = truthdata["y_t"], truth_I_t = truthdata["I_t"],
truth_I0 = truthdata["truth_I0"], tspan, epimethod = inference_method, pipeline = pipeline)
inference_config, pipeline; case_data = truthdata["y_t"], truth_I_t = truthdata["I_t"],
truth_I0 = truthdata["truth_I0"], tspan, epimethod = inference_method)

# produce or load inference results
prfx = _inference_prefix(truthdata, inference_config, pipeline)
Expand Down
175 changes: 175 additions & 0 deletions pipeline/test/constructors/constructors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
@testset "make_gi_params: returns a dictionary with correct keys" begin
pipeline = EpiAwareExamplePipeline()
params = make_gi_params(pipeline)

@test params isa Dict
@test haskey(params, "gi_means")
@test haskey(params, "gi_stds")
end

@testset "make_inf_generating_processes" begin
pipeline = EpiAwareExamplePipeline()
igps = make_inf_generating_processes(pipeline)
@test igps == [DirectInfections, ExpGrowthRate, Renewal]
end

@testset "make_Rt: returns an array" begin
map([EpiAwareExamplePipeline(), SmoothOutbreakPipeline(),
MeasuresOutbreakPipeline(), SmoothEndemicPipeline(), RoughEndemicPipeline()]) do pipeline
Rt = make_Rt(pipeline)
@test Rt isa Array
end
end

@testset "default_tspan: returns an Tuple{Integer, Integer}" begin
pipeline = EpiAwareExamplePipeline()

tspan = make_tspan(pipeline; lookback = 90)
@test tspan isa Tuple{Integer, Integer}
end

@testset "make_model_priors: generates a dict with correct keys and distributions" begin
using Distributions
pipeline = EpiAwareExamplePipeline()

priors_dict = make_model_priors(pipeline)

# Check if the priors dictionary is constructed correctly
@test haskey(priors_dict, "transformed_process_init_prior")
@test haskey(priors_dict, "std_prior")
@test haskey(priors_dict, "damp_param_prior")

# Check if the values are all distributions
@test valtype(priors_dict) <: Distribution
end

@testset "make_epiaware_name_latentmodel_pairs: generates a vector of Pairs with correct keys and latent models" begin
pipeline = EpiAwareExamplePipeline()

namemodel_vect = make_epiaware_name_latentmodel_pairs(pipeline)

@test first.(namemodel_vect) == ["ar", "rw", "diff_ar"]
@test all([model isa AbstractTuringLatentModel for model in last.(namemodel_vect)])
end

@testset "make_inference_method: constructor and defaults" begin
using ADTypes, AbstractMCMC
pipeline = EpiAwareExamplePipeline()

method = make_inference_method(pipeline)

@test length(method.pre_sampler_steps) == 1
@test method.pre_sampler_steps[1] isa ManyPathfinder
@test method.pre_sampler_steps[1].nruns == 4
@test method.pre_sampler_steps[1].maxiters == 100
@test method.sampler isa NUTSampler
@test method.sampler.adtype isa AbstractADType
@test method.sampler.ndraws == 20
@test method.sampler.nchains == 4
@test method.sampler.mcmc_parallel == MCMCThreads()
end

@testset "make_inference_method: for prior predictive checking" begin
using EpiAwarePipeline, EpiAware, ADTypes, AbstractMCMC
pipetype = [SmoothOutbreakPipeline, MeasuresOutbreakPipeline,
SmoothEndemicPipeline, RoughEndemicPipeline] |> rand
pipeline = pipetype(; ndraws = 100, testmode = true, priorpredictive = true)

method = make_inference_method(pipeline)

@test length(method.pre_sampler_steps) == 0
@test method.sampler isa DirectSample
end

@testset "make_truth_data_configs" begin
pipeline = SmoothOutbreakPipeline()
example_pipeline = EpiAwareExamplePipeline()
@testset "make_truth_data_configs should return a dictionary" begin
config_dicts = make_truth_data_configs(pipeline)
@test eltype(config_dicts) <: Dict
end

@testset "make_truth_data_configs should contain gi_mean and gi_std keys" begin
config_dicts = make_truth_data_configs(pipeline)
@test all(config_dicts .|> config -> haskey(config, "gi_mean"))
@test all(config_dicts .|> config -> haskey(config, "gi_std"))
end

@testset "make_truth_data_configs should return a vector of length 1 for EpiAwareExamplePipeline" begin
config_dicts = make_truth_data_configs(example_pipeline)
@test length(config_dicts) == 1
end
end

@testset "default inference configurations" begin
pipeline = SmoothOutbreakPipeline()
example_pipeline = EpiAwareExamplePipeline()

@testset "make_inference_configs should return a vector of dictionaries" begin
inference_configs = make_inference_configs(pipeline)
@test eltype(inference_configs) <: Dict
end

@testset "make_inference_configs should contain igp, latent_namemodels, observation_model, gi_mean, gi_std, and log_I0_prior keys" begin
inference_configs = make_inference_configs(pipeline)
@test inference_configs .|> (config -> haskey(config, "igp")) |> all
@test inference_configs .|> (config -> haskey(config, "latent_namemodels")) |> all
@test inference_configs .|> (config -> haskey(config, "observation_model")) |> all
@test inference_configs .|> (config -> haskey(config, "gi_mean")) |> all
@test inference_configs .|> (config -> haskey(config, "gi_std")) |> all
@test inference_configs .|> (config -> haskey(config, "log_I0_prior")) |> all
end

@testset "make_inference_configs should return a vector of length 1 for EpiAwareExamplePipeline" begin
inference_configs = make_inference_configs(example_pipeline)
@test length(inference_configs) == 1
end
end

@testset "make_default_params" begin
pipeline = SmoothOutbreakPipeline()

# Expected default parameters
expected_params = Dict(
"Rt" => make_Rt(pipeline),
"logit_daily_ascertainment" => [zeros(5); -0.5 * ones(2)],
"cluster_factor" => 0.05,
"I0" => 100.0,
"α_delay" => 4.0,
"θ_delay" => 5.0 / 4.0,
"lookahead" => 21,
"lookback" => 90,
"stride" => 7
)

# Test the make_default_params function
@test make_default_params(pipeline) == expected_params
end

@testset "make_delay_distribution" begin
using Distributions
pipeline = SmoothOutbreakPipeline()
delay_distribution = make_delay_distribution(pipeline)
@test delay_distribution isa Distribution
@test delay_distribution isa Gamma
@test delay_distribution.α == 4.0
@test delay_distribution.θ == 5.0 / 4.0
end

@testset "make_observation_model" begin
# Mock pipeline object
pipeline = SmoothOutbreakPipeline()
default_params = make_default_params(pipeline)
obs = make_observation_model(pipeline)

# Test case 1: Check if the returned object is of type LatentDelay
@testset "Returned object type" begin
@test obs isa LatentDelay
end

# Test case 2: Check if the default parameters are correctly passed to ascertainment_dayofweek
@testset "Default parameters" begin
@test obs.model.model.cluster_factor_prior ==
HalfNormal(default_params["cluster_factor"])
end
end
Loading
Loading