From 78930766f3fef2c8363841268796e66ae238a472 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 4 Oct 2024 12:43:40 +0100 Subject: [PATCH] fix tests and capture forecasting failures --- pipeline/src/infer/InferenceConfig.jl | 2 +- pipeline/test/pipeline/test_pipeline.jl | 2 ++ .../test/pipeline/test_pipelinefunctions.jl | 36 +++++-------------- pipeline/test/runtests.jl | 6 ++-- 4 files changed, 15 insertions(+), 31 deletions(-) create mode 100644 pipeline/test/pipeline/test_pipeline.jl diff --git a/pipeline/src/infer/InferenceConfig.jl b/pipeline/src/infer/InferenceConfig.jl index e7832d9fb..aa28a4964 100644 --- a/pipeline/src/infer/InferenceConfig.jl +++ b/pipeline/src/infer/InferenceConfig.jl @@ -80,6 +80,7 @@ to make inference on and model configuration. """ function create_inference_results(config, epiprob) #Return the sampled infections and observations + idxs = config.tspan[1]:config.tspan[2] y_t = ismissing(config.case_data) ? missing : Vector{Union{Missing, Int64}}(config.case_data[idxs]) inference_results = apply_method(epiprob, @@ -108,7 +109,6 @@ to make inference on and model configuration. function infer(config::InferenceConfig) #Define the EpiProblem epiprob = define_epiprob(config) - idxs = config.tspan[1]:config.tspan[2] #Return the sampled infections and observations inference_results = create_inference_results(config, epiprob) diff --git a/pipeline/test/pipeline/test_pipeline.jl b/pipeline/test/pipeline/test_pipeline.jl new file mode 100644 index 000000000..ad5f71815 --- /dev/null +++ b/pipeline/test/pipeline/test_pipeline.jl @@ -0,0 +1,2 @@ +include("test_pipelinetypes.jl") +include("test_pipelinefunctions.jl") diff --git a/pipeline/test/pipeline/test_pipelinefunctions.jl b/pipeline/test/pipeline/test_pipelinefunctions.jl index 19f9a6bd2..0a49944f3 100644 --- a/pipeline/test/pipeline/test_pipelinefunctions.jl +++ b/pipeline/test/pipeline/test_pipelinefunctions.jl @@ -1,5 +1,5 @@ @testset "do_truthdata tests" begin - using EpiAwarePipeline, Dagger + using Dagger for pipetype in [SmoothOutbreakPipeline, MeasuresOutbreakPipeline, SmoothEndemicPipeline, RoughEndemicPipeline] pipeline = pipetype(; testmode = true) @@ -13,7 +13,7 @@ end @testset "do_inference tests" begin - using EpiAwarePipeline, Dagger, EpiAware + using Dagger function make_inference(pipeline) truthdata_dg_task = do_truthdata(pipeline) @@ -21,23 +21,9 @@ end do_inference(truthdata[1], pipeline) end - pipetype = SmoothOutbreakPipeline - pipeline = pipetype(; ndraws = 40, testmode = true) - truthdata_dg_task = do_truthdata(pipeline) - truthdata = fetch.(truthdata_dg_task) - inference_configs = make_inference_configs(pipeline) - inference_method = make_inference_method(pipeline) - map(inference_configs) do inference_config - generate_inference_results( - truthdata, inference_config, pipeline; inference_method) - end - - inference_results_tsk = make_inference(pipeline) - inference_results = fetch.(inference_results_tsk) - for pipetype in [SmoothOutbreakPipeline, MeasuresOutbreakPipeline, SmoothEndemicPipeline, RoughEndemicPipeline] - pipeline = pipetype(; ndraws = 40, testmode = true) + pipeline = pipetype(; ndraws = 20, nchains = 1, testmode = true) inference_results_tsk = make_inference(pipeline) inference_results = fetch.(inference_results_tsk) @test length(inference_results) == 1 @@ -46,17 +32,13 @@ end end end -@testset "do_pipeline test: just run" begin - using EpiAwarePipeline - pipeline = EpiAwareExamplePipeline() - res = do_pipeline(pipeline) - fetch(res) - @test isnothing(res) -end +@testset "do_pipeline test: just run all pipeline objects" begin + using Dagger + pipelines = map([SmoothOutbreakPipeline, MeasuresOutbreakPipeline, + SmoothEndemicPipeline, RoughEndemicPipeline]) do pipetype + pipetype(; ndraws = 10, nchains = 1, testmode = true) + end -@testset "do_pipeline test: just run as a vector" begin - using EpiAwarePipeline - pipelines = fill(EpiAwareExamplePipeline(), 2) res = do_pipeline(pipelines) fetch(res) @test isnothing(res) diff --git a/pipeline/test/runtests.jl b/pipeline/test/runtests.jl index d565ec9a7..2e111aed5 100644 --- a/pipeline/test/runtests.jl +++ b/pipeline/test/runtests.jl @@ -1,10 +1,9 @@ using DrWatson, Test quickactivate(@__DIR__(), "EpiAwarePipeline") using EpiAwarePipeline, EpiAware - +import Random +Random.seed!(123) # Run tests -include("pipeline/test_pipelinetypes.jl"); -include("pipeline/test_pipelinefunctions.jl"); include("utils/test_utils.jl"); include("constructors/test_constructors.jl"); include("simulate/test_simulate.jl"); @@ -12,3 +11,4 @@ include("infer/test_infer.jl"); include("forecast/test_forecast.jl"); include("scoring/test_score_parameters.jl"); include("plotting/plotting_tests.jl"); +include("pipeline/test_pipeline.jl");