Skip to content

Commit 07dd19f

Browse files
committed
clean up and add arma and arima helpers
1 parent 9f09ace commit 07dd19f

File tree

6 files changed

+121
-43
lines changed

6 files changed

+121
-43
lines changed

EpiAware/src/EpiLatentModels/EpiLatentModels.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ export broadcast_rule, broadcast_dayofweek, broadcast_weekly, equal_dimensions
2828
# Export tools for modifying latent models
2929
export DiffLatentModel, TransformLatentModel, PrefixLatentModel, RecordExpectedLatent
3030

31+
# Export combinations of models and modifiers
32+
export define_arma, define_arima
33+
3134
include("docstrings.jl")
3235
include("models/Intercept.jl")
3336
include("models/IDD.jl")
@@ -44,6 +47,8 @@ include("manipulators/ConcatLatentModels.jl")
4447
include("manipulators/broadcast/LatentModel.jl")
4548
include("manipulators/broadcast/rules.jl")
4649
include("manipulators/broadcast/helpers.jl")
50+
include("combinations/define_arma.jl")
51+
include("combinations/define_arima.jl")
4752
include("utils.jl")
4853

4954
end

EpiAware/src/EpiLatentModels/combinations/arma.jl

Lines changed: 0 additions & 3 deletions
This file was deleted.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
@doc raw"""
2+
Define an ARIMA model by wrapping `define_arma` and applying differencing via `DiffLatentModel`.
3+
4+
# Arguments
5+
- `ar_init`: Prior distribution for AR initial conditions.
6+
A vector of distributions.
7+
- `d_init`: Prior distribution for differencing initial conditions.
8+
A vector of distributions.
9+
- `θ`: Prior distribution for MA coefficients.
10+
A vector of distributions.
11+
- `damp`: Prior distribution for AR damping coefficients.
12+
A vector of distributions.
13+
- `ϵ_t`: Distribution of the error term.
14+
Default is `HierarchicalNormal()`.
15+
16+
# Returns
17+
An ARIMA model consisting of AR and MA components with differencing applied.
18+
19+
# Example
20+
21+
```julia
22+
using EpiAware, Distributions
23+
24+
ARIMA = arima(
25+
ar_init = [Normal(0.0, 1.0)],
26+
d_init = [Normal()],
27+
θ = [truncated(Normal(0.0, 0.02), -1, 1)],
28+
damp = [truncated(Normal(0.0, 0.02), 0, 1)]
29+
)
30+
arma_model = generate_latent(ARIMA, 10)
31+
arma_model()
32+
```
33+
"""
34+
function arima(;
35+
ar_init = [Normal()],
36+
d_init = [Normal()],
37+
damp = [truncated(Normal(0.0, 0.05), 0, 1)],
38+
θ = [truncated(Normal(0.0, 0.05), -1, 1)],
39+
ϵ_t = HierarchicalNormal()
40+
)
41+
arma = define_arma(; init = ar_init, damp = damp, θ = θ, ϵ_t = ϵ_t)
42+
arima_model = DiffLatentModel(; model = arma, init_priors = d_init)
43+
return arima_model
44+
end
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
@doc raw"""
2+
Define an ARMA model using AR and MA components.
3+
4+
# Arguments
5+
- `init`: Prior distribution for AR initial conditions.
6+
A vector of distributions.
7+
- `θ`: Prior distribution for MA coefficients.
8+
A vector of distributions.
9+
- `damp`: Prior distribution for AR damping coefficients.
10+
A vector of distributions.
11+
- `ϵ_t`: Distribution of the error term.
12+
Default is `HierarchicalNormal()`.
13+
14+
# Returns
15+
An AR model with an MA model as its error term, effectively creating an ARMA model.
16+
17+
# Example
18+
19+
```@example
20+
using EpiAware, Distributions
21+
22+
ARMA = define_arma(;
23+
θ = [truncated(Normal(0.0, 0.02), -1, 1)],
24+
damp = [truncated(Normal(0.0, 0.02), 0, 1)]
25+
)
26+
arma = generate_latent(ARMA, 10)
27+
arma()
28+
```
29+
"""
30+
function define_arma(;
31+
init = [Normal()],
32+
damp = [truncated(Normal(0.0, 0.05), 0, 1)],
33+
θ = [truncated(Normal(0.0, 0.05), -1, 1)],
34+
ϵ_t = HierarchicalNormal())
35+
ma = MA(; θ_priors = θ, ϵ_t = ϵ_t)
36+
ar = AR(; damp_priors = damp, init_priors = init, ϵ_t = ma)
37+
return ar
38+
end

EpiAware/src/EpiLatentModels/models/AR.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ rand(mdl)
2727
# output
2828
```
2929
"
30-
struct AR{D <: Sampleable, S <: Sampleable, I <: Sampleable,
30+
struct AR{D <: Sampleable, I <: Sampleable,
3131
P <: Int, E <: AbstractTuringLatentModel} <: AbstractTuringLatentModel
3232
"Prior distribution for the damping coefficients."
3333
damp_prior::D
@@ -60,8 +60,7 @@ struct AR{D <: Sampleable, S <: Sampleable, I <: Sampleable,
6060
@assert p>0 "p must be greater than 0"
6161
@assert p==length(damp_prior)==length(init_prior) "p must be equal to the length of damp_prior and init_prior"
6262
new{typeof(damp_prior), typeof(init_prior), typeof(p), typeof(ϵ_t)}(
63-
damp_prior, init_prior, p, ϵ_t
64-
)
63+
damp_prior, init_prior, p, ϵ_t)
6564
end
6665
end
6766

@@ -84,12 +83,11 @@ Generate a latent AR series.
8483
p = latent_model.p
8584
@assert n>p "n must be longer than order of the autoregressive process"
8685

87-
σ_AR ~ latent_model.std_prior
8886
ar_init ~ latent_model.init_prior
8987
damp_AR ~ latent_model.damp_prior
9088
@submodel ϵ_t = generate_latent(latent_model.ϵ_t, n - p)
9189

92-
ar = accumulate_scan(ARStep(damp_AR), ar_init, σ_AR * ϵ_t)
90+
ar = accumulate_scan(ARStep(damp_AR), ar_init, ϵ_t)
9391

9492
return ar
9593
end

EpiAware/src/EpiLatentModels/models/RandomWalk.jl

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ Constructing a random walk requires specifying:
1717
1818
## Constructors
1919
20-
- `RandomWalk(; init_prior, std_prior, ϵ_t)`
20+
- `RandomWalk(init_prior::Sampleable, ϵ_t::AbstractTuringLatentModel)`: Constructs a random walk model with the specified prior distributions for the initial condition and white noise sequence.
21+
- `RandomWalk(; init_prior::Sampleable = Normal(), ϵ_t::AbstractTuringLatentModel = HierarchicalNormal())`: Constructs a random walk model with the specified prior distributions for the initial condition and white noise sequence.
2122
2223
## Example usage
2324
@@ -26,7 +27,7 @@ using Distributions, Turing, EpiAware
2627
rw = RandomWalk()
2728
rw
2829
# output
29-
RandomWalk{Normal{Float64}, HalfNormal{Float64}, IDD{Normal{Float64}}}(init_prior=Normal{Float64}(μ=0.0, σ=1.0), std_prior=HalfNormal{Float64}(σ=0.25), ϵ_t=IDD{Normal{Float64}}(ϵ_t=Normal{Float64}(μ=0.0, σ=1.0)))
30+
RandomWalk{Normal{Float64}, HierarchicalNormal{Float64}}(init_prior=Normal{Float64}(μ=0.0, σ=1.0), ϵ_t=HierarchicalNormal{Float64}(mean=0.0, std_prior=Truncated{Normal{Float64}, Continuous, Float64}(a=0.0, b=Inf, x=Normal{Float64}(μ=0.0, σ=0.1))))
3031
```
3132
3233
```jldoctest RandomWalk; filter=r\"\b\d+(\.\d+)?\b\" => \"*\"
@@ -41,51 +42,46 @@ rand(mdl)
4142
```
4243
"
4344
@kwdef struct RandomWalk{
44-
D <: Sampleable, S <: Sampleable, E <: AbstractTuringLatentModel} <:
45+
D <: Sampleable, E <: AbstractTuringLatentModel} <:
4546
AbstractTuringLatentModel
4647
init_prior::D = Normal()
47-
std_prior::S = HalfNormal(0.25)
48-
ϵ_t::E = IDD(Normal())
49-
end
50-
51-
function RandomWalk(init_prior::D, std_prior::S) where {D <: Sampleable, S <: Sampleable}
52-
return RandomWalk(; init_prior = init_prior, std_prior = std_prior)
48+
ϵ_t::E = HierarchicalNormal()
5349
end
5450

5551
@doc raw"
56-
Implement the `generate_latent` function for the `RandomWalk` model.
57-
58-
## Example usage of `generate_latent` with `RandomWalk` type of latent process model
59-
60-
```julia
61-
using Distributions, Turing, EpiAware
52+
Generate a latent RW series using accumulate_scan.
6253
63-
# Create a RandomWalk model
64-
rw = RandomWalk(init_prior = Normal(2., 1.),
65-
std_prior = HalfNormal(0.1))
66-
```
67-
68-
Then, we can use `generate_latent` to construct a Turing model for a 10 step random walk.
54+
# Arguments
6955
70-
```julia
71-
# Construct a Turing model
72-
rw_model = generate_latent(rw, 10)
73-
```
56+
- `latent_model::RandomWalk`: The RandomWalk model.
57+
- `n::Int`: The length of the RW series.
7458
75-
Now we can use the `Turing` PPL API to sample underlying parameters and generate the
76-
unobserved infections.
59+
# Returns
60+
- `rw::Vector{Float64}`: The generated RW series.
7761
78-
```julia
79-
#Sample random parameters from prior
80-
θ = rand(rw_model)
81-
#Get random walk sample path as a generated quantities from the model
82-
Z_t, _ = generated_quantities(rw_model, θ)
83-
```
62+
# Notes
63+
- `n` must be greater than 0.
8464
"
8565
@model function EpiAwareBase.generate_latent(latent_model::RandomWalk, n)
86-
σ_RW ~ latent_model.std_prior
66+
@assert n>0 "n must be greater than 0"
67+
8768
rw_init ~ latent_model.init_prior
8869
@submodel ϵ_t = generate_latent(latent_model.ϵ_t, n - 1)
89-
rw = rw_init .+ vcat(0.0, σ_RW .* cumsum(ϵ_t))
70+
71+
rw = accumulate_scan(RWStep(), rw_init, ϵ_t)
72+
9073
return rw
9174
end
75+
76+
@doc raw"
77+
The random walk (RW) step function struct
78+
"
79+
struct RWStep <: AbstractAccumulationStep end
80+
81+
@doc raw"
82+
The random walk (RW) step function for use with `accumulate_scan`.
83+
"
84+
function (rw::RWStep)(state, ϵ)
85+
new_val = state + ϵ
86+
return new_val
87+
end

0 commit comments

Comments
 (0)