Skip to content

Commit

Permalink
fix tests and capture forecasting failures
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 committed Oct 4, 2024
1 parent f93efeb commit 7893076
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pipeline/src/infer/InferenceConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ to make inference on and model configuration.
"""
function create_inference_results(config, epiprob)
#Return the sampled infections and observations
idxs = config.tspan[1]:config.tspan[2]
y_t = ismissing(config.case_data) ? missing :
Vector{Union{Missing, Int64}}(config.case_data[idxs])
inference_results = apply_method(epiprob,
Expand Down Expand Up @@ -108,7 +109,6 @@ to make inference on and model configuration.
function infer(config::InferenceConfig)
#Define the EpiProblem
epiprob = define_epiprob(config)
idxs = config.tspan[1]:config.tspan[2]

#Return the sampled infections and observations
inference_results = create_inference_results(config, epiprob)
Expand Down
2 changes: 2 additions & 0 deletions pipeline/test/pipeline/test_pipeline.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include("test_pipelinetypes.jl")
include("test_pipelinefunctions.jl")
36 changes: 9 additions & 27 deletions pipeline/test/pipeline/test_pipelinefunctions.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testset "do_truthdata tests" begin
using EpiAwarePipeline, Dagger
using Dagger
for pipetype in [SmoothOutbreakPipeline, MeasuresOutbreakPipeline,
SmoothEndemicPipeline, RoughEndemicPipeline]
pipeline = pipetype(; testmode = true)
Expand All @@ -13,31 +13,17 @@
end

@testset "do_inference tests" begin
using EpiAwarePipeline, Dagger, EpiAware
using Dagger

function make_inference(pipeline)
truthdata_dg_task = do_truthdata(pipeline)
truthdata = fetch.(truthdata_dg_task)
do_inference(truthdata[1], pipeline)
end

pipetype = SmoothOutbreakPipeline
pipeline = pipetype(; ndraws = 40, testmode = true)
truthdata_dg_task = do_truthdata(pipeline)
truthdata = fetch.(truthdata_dg_task)
inference_configs = make_inference_configs(pipeline)
inference_method = make_inference_method(pipeline)
map(inference_configs) do inference_config
generate_inference_results(
truthdata, inference_config, pipeline; inference_method)
end

inference_results_tsk = make_inference(pipeline)
inference_results = fetch.(inference_results_tsk)

for pipetype in [SmoothOutbreakPipeline, MeasuresOutbreakPipeline,
SmoothEndemicPipeline, RoughEndemicPipeline]
pipeline = pipetype(; ndraws = 40, testmode = true)
pipeline = pipetype(; ndraws = 20, nchains = 1, testmode = true)
inference_results_tsk = make_inference(pipeline)
inference_results = fetch.(inference_results_tsk)
@test length(inference_results) == 1
Expand All @@ -46,17 +32,13 @@ end
end
end

@testset "do_pipeline test: just run" begin
using EpiAwarePipeline
pipeline = EpiAwareExamplePipeline()
res = do_pipeline(pipeline)
fetch(res)
@test isnothing(res)
end
@testset "do_pipeline test: just run all pipeline objects" begin
using Dagger
pipelines = map([SmoothOutbreakPipeline, MeasuresOutbreakPipeline,
SmoothEndemicPipeline, RoughEndemicPipeline]) do pipetype
pipetype(; ndraws = 10, nchains = 1, testmode = true)
end

@testset "do_pipeline test: just run as a vector" begin
using EpiAwarePipeline
pipelines = fill(EpiAwareExamplePipeline(), 2)
res = do_pipeline(pipelines)
fetch(res)
@test isnothing(res)
Expand Down
6 changes: 3 additions & 3 deletions pipeline/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using DrWatson, Test
quickactivate(@__DIR__(), "EpiAwarePipeline")
using EpiAwarePipeline, EpiAware

import Random
Random.seed!(123)
# Run tests
include("pipeline/test_pipelinetypes.jl");
include("pipeline/test_pipelinefunctions.jl");
include("utils/test_utils.jl");
include("constructors/test_constructors.jl");
include("simulate/test_simulate.jl");
include("infer/test_infer.jl");
include("forecast/test_forecast.jl");
include("scoring/test_score_parameters.jl");
include("plotting/plotting_tests.jl");
include("pipeline/test_pipeline.jl");

0 comments on commit 7893076

Please sign in to comment.