-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into compathelper/new_version/2024-10-24-12-28-19…
…-253-01404902091
- Loading branch information
Showing
8 changed files
with
150 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
EpiAware/src/EpiObsModels/modifiers/TransformObservationModel.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
@doc raw" | ||
The `TransformObservationModel` struct represents an observation model that applies a transformation function to the expected observations before passing them to the underlying observation model. | ||
## Fields | ||
- `model::M`: The underlying observation model. | ||
- `transform::F`: The transformation function applied to the expected observations. | ||
## Constructors | ||
- `TransformObservationModel(model::M, transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance with the specified observation model and a default transformation function. | ||
- `TransformObservationModel(; model::M, transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance using named arguments. | ||
- `TransformObservationModel(model::M; transform::F = x -> log1pexp.(x)) where {M <: AbstractTuringObservationModel, F <: Function}`: Constructs a `TransformObservationModel` instance with the specified observation model and a default transformation function. | ||
## Example | ||
```julia | ||
using EpiAware, Distributions, LogExpFunctions | ||
trans_obs = TransformObservationModel(NegativeBinomialError()) | ||
gen_obs = generate_observations(trans_obs, missing, fill(10.0, 30)) | ||
gen_obs() | ||
``` | ||
" | ||
@kwdef struct TransformObservationModel{ | ||
M <: AbstractTuringObservationModel, F <: Function} <: AbstractTuringObservationModel | ||
"The underlying observation model." | ||
model::M | ||
"The transformation function. The default is `log1pexp` which is the softplus transformation" | ||
transform::F = x -> log1pexp.(x) | ||
end | ||
|
||
function TransformObservationModel(model::M; | ||
transform::F = x -> log1pexp.(x)) where { | ||
M <: AbstractTuringObservationModel, F <: Function} | ||
return TransformObservationModel(model, transform) | ||
end | ||
|
||
@doc raw" | ||
Generates observations or accumulates log-likelihood based on the `TransformObservationModel`. | ||
## Arguments | ||
- `obs::TransformObservationModel`: The TransformObservationModel. | ||
- `y_t`: The current state of the observations. | ||
- `Y_t`: The expected observations. | ||
## Returns | ||
- `y_t`: The updated observations. | ||
" | ||
@model function EpiAwareBase.generate_observations( | ||
obs::TransformObservationModel, y_t, Y_t | ||
) | ||
transformed_Y_t = obs.transform(Y_t) | ||
|
||
@submodel y_t = generate_observations(obs.model, y_t, transformed_Y_t) | ||
|
||
return y_t | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
48 changes: 48 additions & 0 deletions
48
EpiAware/test/EpiObsModels/modifiers/TransformObservationModel.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
@testitem "Test TransformObservationModel constructor" begin | ||
using Turing, LogExpFunctions | ||
|
||
# Test default constructor | ||
trans_obs = TransformObservationModel(NegativeBinomialError()) | ||
@test trans_obs.model == NegativeBinomialError() | ||
@test trans_obs.transform([1.0, 2.0, 3.0]) == log1pexp.([1.0, 2.0, 3.0]) | ||
|
||
# Test constructor with custom transform | ||
custom_transform = x -> exp.(x) | ||
trans_obs_custom = TransformObservationModel(NegativeBinomialError(), custom_transform) | ||
@test trans_obs_custom.model == NegativeBinomialError() | ||
@test trans_obs_custom.transform([1.0, 2.0, 3.0]) == exp.([1.0, 2.0, 3.0]) | ||
|
||
# Test kwarg constructor | ||
trans_obs_kwarg = TransformObservationModel( | ||
model = PoissonError(), transform = custom_transform) | ||
@test trans_obs_kwarg.model == PoissonError() | ||
@test trans_obs_kwarg.transform == custom_transform | ||
end | ||
|
||
@testitem "Test TransformObservationModel generate_observations" begin | ||
using Turing, LogExpFunctions, Distributions | ||
|
||
# Test with default log1pexp transform | ||
trans_obs = TransformObservationModel(NegativeBinomialError()) | ||
gen_obs = generate_observations(trans_obs, missing, fill(10.0, 1)) | ||
samples = sample(gen_obs, Prior(), 1000; progress = false)["y_t[1]"] | ||
|
||
# Reverse the transform | ||
reversed_samples = samples .|> exp |> x -> x .- 1 .|> log | ||
# Apply the transform again | ||
recovered_samples = log1pexp.(reversed_samples) | ||
|
||
@test all(isapprox.(samples, recovered_samples, rtol = 1e-6)) | ||
|
||
# Test with custom transform and Poisson distribution | ||
custom_transform = x -> x .^ 2 # Square transform | ||
trans_obs_custom = TransformObservationModel(PoissonError(), custom_transform) | ||
gen_obs_custom = generate_observations(trans_obs_custom, missing, fill(5.0, 1)) | ||
samples_custom = sample(gen_obs_custom, Prior(), 1000; progress = false) | ||
# Reverse the transform | ||
reversed_samples_custom = sqrt.(samples_custom["y_t[1]"]) | ||
# Apply the transform again | ||
recovered_samples_custom = custom_transform.(reversed_samples_custom) | ||
|
||
@test all(isapprox.(samples_custom["y_t[1]"], recovered_samples_custom, rtol = 1e-6)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 5 additions & 0 deletions
5
benchmark/bench/EpiObsModels/modifiers/TransformObservationModel.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
let | ||
transform_obs = TransformObservationModel(NegativeBinomialError()) | ||
mdl = generate_observations(transform_obs, fill(10, 10), fill(9, 10)) | ||
suite["TransformObservationModel"] = make_epiaware_suite(mdl) | ||
end |