Skip to content

Commit

Permalink
Merge branch 'main' into compathelper/new_version/2024-10-24-12-28-19…
Browse files Browse the repository at this point in the history
…-253-01404902091
  • Loading branch information
seabbs authored Oct 24, 2024
2 parents 4f920b5 + 119f244 commit 5045633
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 23 deletions.
16 changes: 16 additions & 0 deletions EpiAware/docs/src/developer/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ Tests that build example package docs from source and inspect the results (end t
located in `/test/examples`. The main entry points are `test/examples/make.jl` for building and
`test/examples/test.jl` for doing some basic checks on the generated outputs.

## Benchmarking

Benchmarking is orchestrated using `PkgBenchmark.jl` along with a GitHub action that uses `BenchmarkCI.jl` The benchmarks are located in `benchmarks/` and the main entry point is `benchmarks/runbenchmarks.jl`.

The main function in the `benchmark` environment is `make_epiaware_suite` which calls `TuringBenchmarking.make_turing_suite` on a set of `Turing` models generated by `EpiAware` benchmarking their sampling with the following autodiff backends:

- `ForwardDiff.jl`.
- `ReverseDiff.jl`: With `compile = false`.
- `ReverseDiff.jl`: With `compile = true`.

### Benchmarking "gotchas"

#### Models with no parameters

In `EpiAware` we do expose some models thats do not have parameters, for example, Poisson sampling with a transformation on a fixed mean process implemented by `TransformObservationModel(NegativeBinomialError())` has no sampleable parameters (although it does contributed log-likelihood as part of a wider model). This causes `TuringBenchmarking.make_turing_suite` to throw an error as it expects all models to have parameters.

## Pluto usage in showcase documentation

Some of the showcase examples in `EpiAware/docs/src/showcase` use [`Pluto.jl`](https://plutojl.org/) notebooks for the underlying computation. The output of the notebooks is rendered into HTML for inclusion in the documentation in two steps:
Expand Down
5 changes: 3 additions & 2 deletions EpiAware/src/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek
using ..EpiLatentModels: broadcast_rule, PrefixLatentModel, RepeatEach

using Turing, Distributions, DocStringExtensions, SparseArrays, LinearAlgebra
using LogExpFunctions: xexpy
using LogExpFunctions: xexpy, log1pexp

# Observation error models
export PoissonError, NegativeBinomialError
Expand All @@ -21,7 +21,7 @@ export generate_observation_error_priors, observation_error

# Observation model modifiers
export LatentDelay, Ascertainment, PrefixObservationModel, RecordExpectedObs
export Aggregate
export Aggregate, TransformObservationModel

# Observation model manipulators
export StackObservationModels
Expand All @@ -36,6 +36,7 @@ include("modifiers/ascertainment/helpers.jl")
include("modifiers/Aggregate.jl")
include("modifiers/PrefixObservationModel.jl")
include("modifiers/RecordExpectedObs.jl")
include("modifiers/TransformObservationModel.jl")
include("StackObservationModels.jl")
include("ObservationErrorModels/methods.jl")
include("ObservationErrorModels/NegativeBinomialError.jl")
Expand Down
56 changes: 56 additions & 0 deletions EpiAware/src/EpiObsModels/modifiers/TransformObservationModel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
@doc raw"
The `TransformObservationModel` struct represents an observation model that applies a transformation function to the expected observations before passing them to the underlying observation model.
## Fields
- `model::M`: The underlying observation model.
- `transform::F`: The transformation function applied to the expected observations.
## Constructors
- `TransformObservationModel(model::M, transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance with the specified observation model and a default transformation function.
- `TransformObservationModel(; model::M, transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance using named arguments.
- `TransformObservationModel(model::M; transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance with the specified observation model and a default transformation function.
## Example
```julia
using EpiAware, Distributions, LogExpFunctions
trans_obs = TransformObservationModel(NegativeBinomialError())
gen_obs = generate_observations(trans_obs, missing, fill(10.0, 30))
gen_obs()
```
"
@kwdef struct TransformObservationModel{
M <: AbstractTuringObservationModel, F <: Function} <: AbstractTuringObservationModel
"The underlying observation model."
model::M
"The transformation function. The default is `log1pexp` which is the softplus transformation"
transform::F = x -> log1pexp.(x)
end

function TransformObservationModel(model::M;
transform::F = x -> log1pexp.(x)) where {
M <: AbstractTuringObservationModel, F <: Function}
return TransformObservationModel(model, transform)
end

@doc raw"
Generates observations or accumulates log-likelihood based on the `TransformObservationModel`.
## Arguments
- `obs::TransformObservationModel`: The TransformObservationModel.
- `y_t`: The current state of the observations.
- `Y_t`: The expected observations.
## Returns
- `y_t`: The updated observations.
"
@model function EpiAwareBase.generate_observations(
obs::TransformObservationModel, y_t, Y_t
)
transformed_Y_t = obs.transform(Y_t)

@submodel y_t = generate_observations(obs.model, y_t, transformed_Y_t)

return y_t
end
14 changes: 7 additions & 7 deletions EpiAware/test/EpiLatentModels/models/RandomWalk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,23 @@ end
ReverseDiff
Random.seed!(1234)

rw_process = RandomWalk()
obs_nb = NegativeBinomialError()
process = RandomWalk()
obs = PoissonError()

@model function test_negbin_errors(rw, obs, y_t)
@model function test_poisson_errors(proc, obs, y_t)
n = length(y_t)
@submodel Z_t = generate_latent(rw, n)
@submodel Z_t = generate_latent(proc, n)
@submodel y_t = generate_observations(obs, y_t, exp.(Z_t))
return Z_t, y_t
end

generative_mdl = test_negbin_errors(rw_process, obs_nb, fill(missing, 40))
generative_mdl = test_poisson_errors(process, obs, fill(missing, 40))
θ_true = rand(generative_mdl)
Z_t_obs, y_t_obs = condition(generative_mdl, θ_true)()

mdl = test_negbin_errors(rw_process, obs_nb, Int.(y_t_obs))
mdl = test_poisson_errors(process, obs, Int.(y_t_obs))
chn = sample(
mdl, NUTS(adtype = AutoReverseDiff(; compile = Val(true))), 1000, progess = false)
mdl, NUTS(adtype = AutoReverseDiff(; compile = Val(true))), 1000; progess = false)

#Check that are in central 99.9% of the posterior predictive distribution
#Therefore, this should be unlikely to fail if the model is correctly implemented
Expand Down
27 changes: 13 additions & 14 deletions EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ end
end
end

@testitem "LatentDelay parameter recovery with mix of IGP + latent processes: Negative binomial errors + EpiProblem interface" begin
@testitem "LatentDelay parameter recovery with mix of IGP + latent processes: Poisson errors + EpiProblem interface" begin
using Random, Turing, Distributions, LinearAlgebra, DynamicPPL, StatsBase, ReverseDiff,
Suppressor, LogExpFunctions
# using PairPlots, CairoMakie
Expand All @@ -178,7 +178,7 @@ end
data = EpiData([0.2, 0.5, 0.3],
em_type == Renewal ? softplus : exp
),
initialisation_prior = Normal(log(100.0), 0.25)
initialisation_prior = Normal(log(100.0), 0.01)
)

latentprocess_types = [RandomWalk, AR, DiffLatentModel]
Expand All @@ -190,7 +190,7 @@ end
return (; init_prior, std_prior)
elseif epimodel isa ExpGrowthRate
init_prior = Normal(0.1, 0.025)
std_prior = HalfNormal(0.025)
std_prior = LogNormal(log(0.025), 0.01)
return (; init_prior, std_prior)
elseif epimodel isa DirectInfections
init_prior = Normal(log(100.0), 0.25)
Expand All @@ -204,11 +204,11 @@ end
if latentprocess_type == RandomWalk
return RandomWalk(init_prior, std_prior)
elseif latentprocess_type == AR
return AR(damp_priors = [Beta(8, 2; check_args = false)],
return AR(damp_priors = [Beta(2, 8; check_args = false)],
std_prior = std_prior, init_priors = [init_prior])
elseif latentprocess_type == DiffLatentModel
return DiffLatentModel(
AR(damp_priors = [Beta(8, 2; check_args = false)],
AR(damp_priors = [Beta(2, 8; check_args = false)],
std_prior = std_prior, init_priors = [Normal(0.0, 0.25)]),
init_prior; d = 1)
end
Expand All @@ -217,15 +217,14 @@ end
function test_full_process(epimodel, latentprocess, n;
ad = AutoReverseDiff(; compile = true), posterior_p_tol = 0.005)
#Fix observation model
obs = LatentDelay(
NegativeBinomialError(cluster_factor_prior = HalfNormal(0.05)), Gamma(3, 7 / 3))
obs = LatentDelay(PoissonError(), Gamma(3, 7 / 3))

#Inference method
inference_method = EpiMethod(
pre_sampler_steps = [ManyPathfinder(nruns = 4, maxiters = 100)],
pre_sampler_steps = [ManyPathfinder(nruns = 4, maxiters = 50)],
sampler = NUTSampler(adtype = ad,
ndraws = 1000,
nchains = 4,
ndraws = 2000,
nchains = 2,
mcmc_parallel = MCMCThreads())
)

Expand All @@ -237,15 +236,15 @@ end
)

#Generate data from generative model (i.e. data unconditioned)
generative_mdl = generate_epiaware(
epi_prob, (y_t = Vector{Union{Int, Missing}}(missing, n),))
generative_mdl = generate_epiaware(epi_prob, (y_t = missing,))
θ_true = rand(generative_mdl)
gen_data = condition(generative_mdl, θ_true)()

#Apply inference method to inference model (i.e. generative model conditioned on data)
inference_results = apply_method(epi_prob,
inference_method,
(y_t = gen_data.generated_y_t,)
(y_t = gen_data.generated_y_t,);
progress = false
)

chn = inference_results.samples
Expand All @@ -265,7 +264,7 @@ end
@testset "Check true parameters are within 99% central post. prob.: " begin
@testset for latentprocess_type in latentprocess_types, epimodel in epimodels
latentprocess = set_latent_process(epimodel, latentprocess_type)
@suppress _ = test_full_process(epimodel, latentprocess, 50)
@suppress _ = test_full_process(epimodel, latentprocess, 40)
end
end
end
48 changes: 48 additions & 0 deletions EpiAware/test/EpiObsModels/modifiers/TransformObservationModel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
@testitem "Test TransformObservationModel constructor" begin
using Turing, LogExpFunctions

# Test default constructor
trans_obs = TransformObservationModel(NegativeBinomialError())
@test trans_obs.model == NegativeBinomialError()
@test trans_obs.transform([1.0, 2.0, 3.0]) == log1pexp.([1.0, 2.0, 3.0])

# Test constructor with custom transform
custom_transform = x -> exp.(x)
trans_obs_custom = TransformObservationModel(NegativeBinomialError(), custom_transform)
@test trans_obs_custom.model == NegativeBinomialError()
@test trans_obs_custom.transform([1.0, 2.0, 3.0]) == exp.([1.0, 2.0, 3.0])

# Test kwarg constructor
trans_obs_kwarg = TransformObservationModel(
model = PoissonError(), transform = custom_transform)
@test trans_obs_kwarg.model == PoissonError()
@test trans_obs_kwarg.transform == custom_transform
end

@testitem "Test TransformObservationModel generate_observations" begin
using Turing, LogExpFunctions, Distributions

# Test with default log1pexp transform
trans_obs = TransformObservationModel(NegativeBinomialError())
gen_obs = generate_observations(trans_obs, missing, fill(10.0, 1))
samples = sample(gen_obs, Prior(), 1000; progress = false)["y_t[1]"]

# Reverse the transform
reversed_samples = samples .|> exp |> x -> x .- 1 .|> log
# Apply the transform again
recovered_samples = log1pexp.(reversed_samples)

@test all(isapprox.(samples, recovered_samples, rtol = 1e-6))

# Test with custom transform and Poisson distribution
custom_transform = x -> x .^ 2 # Square transform
trans_obs_custom = TransformObservationModel(PoissonError(), custom_transform)
gen_obs_custom = generate_observations(trans_obs_custom, missing, fill(5.0, 1))
samples_custom = sample(gen_obs_custom, Prior(), 1000; progress = false)
# Reverse the transform
reversed_samples_custom = sqrt.(samples_custom["y_t[1]"])
# Apply the transform again
recovered_samples_custom = custom_transform.(reversed_samples_custom)

@test all(isapprox.(samples_custom["y_t[1]"], recovered_samples_custom, rtol = 1e-6))
end
2 changes: 2 additions & 0 deletions benchmark/bench/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ include("modifiers/ascertainment/Ascertainment.jl")
include("modifiers/ascertainment/helpers.jl")
include("modifiers/LatentDelay.jl")
include("modifiers/PrefixObservationModel.jl")
include("modifiers/RecordExpectedObs.jl")
include("modifiers/TransformObservationModel.jl")
include("ObservationErrorModels/methods.jl")
include("ObservationErrorModels/NegativeBinomialError.jl")
include("ObservationErrorModels/PoissonError.jl")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
let
transform_obs = TransformObservationModel(NegativeBinomialError())
mdl = generate_observations(transform_obs, fill(10, 10), fill(9, 10))
suite["TransformObservationModel"] = make_epiaware_suite(mdl)
end

0 comments on commit 5045633

Please sign in to comment.