-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into optimise-mvnormal-scan
- Loading branch information
Showing
15 changed files
with
490 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
## Script to make figure 1 | ||
using Pkg | ||
Pkg.activate(joinpath(@__DIR__(), "..")) | ||
|
||
using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, Plots, DataFramesMeta, | ||
Statistics, Distributions, CSV | ||
|
||
## | ||
pipelines = [ | ||
SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(), | ||
SmoothEndemicPipeline(), RoughEndemicPipeline()] | ||
|
||
## load some data and create a dataframe for the plot | ||
truth_data_files = readdir(datadir("truth_data")) |> | ||
strs -> filter(s -> occursin("jld2", s), strs) | ||
analysis_df = CSV.File(plotsdir("analysis_df.csv")) |> DataFrame | ||
truth_df = mapreduce(vcat, truth_data_files) do filename | ||
D = load(joinpath(datadir("truth_data"), filename)) | ||
make_truthdata_dataframe(filename, D, pipelines) | ||
end | ||
|
||
## Make mainfigure plots | ||
|
||
# Define scenario titles and reference times for figure 1 | ||
scenario_dict = Dict( | ||
"measures_outbreak" => (title = "Outbreak with measures", T = 28), | ||
"smooth_outbreak" => (title = "Outbreak no measures", T = 35), | ||
"smooth_endemic" => (title = "Smooth endemic", T = 35), | ||
"rough_endemic" => (title = "Rough endemic", T = 35) | ||
) | ||
|
||
fig1 = figureone(truth_df, analysis_df, scenario_dict) | ||
|
||
## Save the figure | ||
save(plotsdir("figure1.png"), fig1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
using Pkg | ||
Pkg.activate(joinpath(@__DIR__(), "..")) | ||
|
||
using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, Plots, DataFramesMeta, | ||
Statistics, Distributions, DrWatson | ||
|
||
## load some data and create a dataframe for the plot | ||
files = readdir(datadir("epiaware_observables")) |> | ||
strs -> filter(s -> occursin("jld2", s), strs) | ||
|
||
## Define scenarios | ||
pipelines = [ | ||
SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(), | ||
SmoothEndemicPipeline(), RoughEndemicPipeline()] | ||
|
||
## Set up EpiData objects: Used in the prediction dataframe for infection generating | ||
## processes that don't use directly in simulation. | ||
gi_params = make_gi_params(pipelines[1]) | ||
epi_datas = map(gi_params["gi_means"]) do μ | ||
σ = gi_params["gi_stds"][1] | ||
shape = (μ / σ)^2 | ||
scale = σ^2 / μ | ||
Gamma(shape, scale) | ||
end .|> gen_dist -> EpiData(gen_distribution = gen_dist) | ||
|
||
## Calculate the prediction dataframe | ||
prediction_df = mapreduce(vcat, files) do filename | ||
output = load(joinpath(datadir("epiaware_observables"), filename)) | ||
make_prediction_dataframe_from_output(filename, output, epi_datas, pipelines) | ||
end | ||
|
||
## Save the prediction dataframe | ||
CSV.write(plotsdir("analysis_df.csv"), prediction_df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
include("make_truthdata_dataframe.jl") | ||
include("make_prediction_dataframe_from_output.jl") |
64 changes: 64 additions & 0 deletions
64
pipeline/src/analysis/make_prediction_dataframe_from_output.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
Create a dataframe containing prediction results based on the given output and input data. | ||
# Arguments | ||
- `filename`: The name of the file. | ||
- `output`: The output data containing inference configuration, IGP model, and other information. | ||
- `epi_datas`: The input data for the epidemiological model. | ||
- `qs`: An optional array of quantiles to calculate. Default is `[0.025, 0.5, 0.975]`. | ||
# Returns | ||
A dataframe containing the prediction results. | ||
""" | ||
function make_prediction_dataframe_from_output( | ||
filename, output, epi_datas, pipelines; qs = [0.025, 0.5, 0.975]) | ||
#Get the scenario, IGP model, latent model and true mean GI | ||
inference_config = output["inference_config"] | ||
igp_model = output["inference_config"].igp |> string | ||
scenario = EpiAwarePipeline._get_scenario_from_filename(filename, pipelines) | ||
latent_model = EpiAwarePipeline._get_latent_model_from_filename(filename) | ||
true_mean_gi = EpiAwarePipeline._get_true_gi_mean_from_filename(filename) | ||
|
||
#Get the quantiles for the targets across the gi mean scenarios | ||
#if Renewal model, then we use the underlying epi model | ||
#otherwise we use the epi datas to loop over different gi mean implications | ||
used_epi_datas = igp_model == "Renewal" ? [output["epiprob"].epi_model.data] : epi_datas | ||
|
||
preds = nothing | ||
try | ||
preds = map(used_epi_datas) do epi_data | ||
generate_quantiles_for_targets(output, epi_data, qs) | ||
end | ||
used_gi_means = igp_model == "Renewal" ? | ||
[EpiAwarePipeline._get_used_gi_mean_from_filename(filename)] : | ||
make_gi_params(EpiAwareExamplePipeline())["gi_means"] | ||
|
||
#Create the dataframe columnwise | ||
df = mapreduce(vcat, preds, used_gi_means) do pred, used_gi_mean | ||
mapreduce(vcat, keys(pred)) do target | ||
target_mat = pred[target] | ||
target_times = collect(1:size(target_mat, 1)) .+ | ||
(inference_config.tspan[1] - 1) | ||
_df = DataFrame(target_times = target_times) | ||
_df[!, "Scenario"] .= scenario | ||
_df[!, "IGP_Model"] .= igp_model | ||
_df[!, "Latent_Model"] .= latent_model | ||
_df[!, "True_GI_Mean"] .= true_mean_gi | ||
_df[!, "Used_GI_Mean"] .= used_gi_mean | ||
_df[!, "Reference_Time"] .= inference_config.tspan[2] | ||
_df[!, "Target"] .= string(target) | ||
# quantile predictions | ||
for (j, q) in enumerate(qs) | ||
q_str = split(string(q), ".")[end] | ||
_df[!, "q_$(q_str)"] = target_mat[:, j] | ||
end | ||
return _df | ||
end | ||
end | ||
return df | ||
catch | ||
@warn "Error in generating quantiles for targets in file $filename" | ||
return nothing | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
|
||
""" | ||
make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0) | ||
Create a DataFrame containing truth data for analysis. | ||
# Arguments | ||
- `filename::String`: The name of the file. | ||
- `truth_data::Dict`: A dictionary containing truth data. | ||
- `pipelines::Array`: An array of pipelines. | ||
- `I_0::Float64`: Initial value for I_t (default: 100.0). | ||
# Returns | ||
- `df::DataFrame`: A DataFrame containing the truth data. | ||
""" | ||
function make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0) | ||
I_t = truth_data["I_t"] | ||
true_mean_gi = truth_data["truth_gi_mean"] | ||
log_It = _calc_log_infections(I_t) | ||
rt = _calc_rt(I_t, I_0) | ||
scenario = _get_scenario_from_filename(filename, pipelines) | ||
truth_procs = (; log_I_t = log_It, rt, Rt = truth_data["truth_process"]) | ||
|
||
df = mapreduce(vcat, keys(truth_procs)) do target | ||
proc = truth_procs[target] | ||
_df = DataFrame( | ||
target_times = 1:length(proc), | ||
target_values = proc | ||
) | ||
_df[!, "Scenario"] .= scenario | ||
_df[!, "True_GI_Mean"] .= true_mean_gi | ||
_df[!, "Target"] .= string(target) | ||
return _df | ||
end | ||
|
||
return df | ||
end |
Oops, something went wrong.