From 2880bd3b8bde3cd2739164d5ed181a44032f592e Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Mon, 15 Apr 2024 14:52:56 +0100 Subject: [PATCH] SSMProblems integration (#97) * SSMProblems - Draft * Tests * Format, GP-SSM * Add levy-ssm * Align timesteps * Update Project.toml --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 5 +- examples/gaussian-process/Project.toml | 1 + examples/gaussian-process/script.jl | 35 ++-- examples/gaussian-ssm/Project.toml | 1 + examples/gaussian-ssm/script.jl | 42 ++-- examples/levy-ssm/Project.toml | 7 + examples/levy-ssm/script.jl | 256 +++++++++++++++++++++++++ examples/particle-gibbs/Project.toml | 1 + examples/particle-gibbs/script.jl | 32 +++- src/AdvancedPS.jl | 1 + src/container.jl | 2 +- src/model.jl | 9 +- src/pgas.jl | 48 +---- src/smc.jl | 11 +- test/Project.toml | 1 + test/container.jl | 11 +- test/pgas.jl | 21 +- test/runtests.jl | 1 + 18 files changed, 395 insertions(+), 90 deletions(-) create mode 100644 examples/levy-ssm/Project.toml create mode 100644 examples/levy-ssm/script.jl diff --git a/Project.toml b/Project.toml index 4524aa40..4c286d59 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AdvancedPS" uuid = "576499cb-2369-40b2-a588-c64705576edc" authors = ["TuringLang"] -version = "0.5.4" +version = "0.6" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [weakdeps] @@ -21,10 +22,10 @@ AdvancedPSLibtaskExt = "Libtask" AbstractMCMC = "2, 3, 4, 5" Distributions = "0.23, 0.24, 0.25" Libtask = "0.8" +Random = "1.6" Random123 = "1.3" Requires = "1.0" StatsFuns = "0.9, 1" -Random = "1.6" julia = "1.6" [extras] diff --git a/examples/gaussian-process/Project.toml b/examples/gaussian-process/Project.toml index e2796748..9a45895d 100644 --- a/examples/gaussian-process/Project.toml +++ b/examples/gaussian-process/Project.toml @@ -5,3 +5,4 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" diff --git a/examples/gaussian-process/script.jl b/examples/gaussian-process/script.jl index d70a9cda..1a58bde3 100644 --- a/examples/gaussian-process/script.jl +++ b/examples/gaussian-process/script.jl @@ -6,6 +6,7 @@ using AbstractGPs using Plots using Distributions using Libtask +using SSMProblems Parameters = @NamedTuple begin a::Float64 @@ -13,11 +14,13 @@ Parameters = @NamedTuple begin kernel end -mutable struct GPSSM <: AdvancedPS.AbstractStateSpaceModel +mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel X::Vector{Float64} + observations::Vector{Float64} θ::Parameters GPSSM(params::Parameters) = new(Vector{Float64}(), params) + GPSSM(y::Vector{Float64}, params::Parameters) = new(Vector{Float64}(), y, params) end seed = 1 @@ -29,21 +32,20 @@ q = 0.5 params = Parameters((a, q, SqExponentialKernel())) -f(model::GPSSM, x, t) = Normal(model.θ.a * x, model.θ.q) -h(model::GPSSM) = Normal(0, model.θ.q) -g(model::GPSSM, x, t) = Normal(0, exp(0.5 * x)^2) +f(θ::Parameters, x, t) = Normal(θ.a * x, θ.q) +h(θ::Parameters) = Normal(0, θ.q) +g(θ::Parameters, x, t) = Normal(0, exp(0.5 * x)^2) rng = Random.MersenneTwister(seed) -ref_model = GPSSM(params) x = zeros(T) y = similar(x) -x[1] = rand(rng, h(ref_model)) +x[1] = rand(rng, h(params)) for t in 1:T if t < T - x[t + 1] = rand(rng, f(ref_model, x[t], t)) + x[t + 1] = rand(rng, f(params, x[t], t)) end - y[t] = rand(rng, g(ref_model, x[t], t)) + y[t] = rand(rng, g(params, x[t], t)) end function gp_update(model::GPSSM, state, step) @@ -54,12 +56,21 @@ function gp_update(model::GPSSM, state, step) return Normal(μ[1], σ[1]) end -AdvancedPS.initialization(::GPSSM) = h(model) -AdvancedPS.transition(model::GPSSM, state, step) = gp_update(model, state, step) -AdvancedPS.observation(model::GPSSM, state, step) = logpdf(g(model, state, step), y[step]) +SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM) = rand(rng, h(model.θ)) +function SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM, state, step) + return rand(rng, gp_update(model, state, step)) +end + +function SSMProblems.emission_logdensity(model::GPSSM, state, step) + return logpdf(g(model.θ, state, step), model.observations[step]) +end +function SSMProblems.transition_logdensity(model::GPSSM, prev_state, current_state, step) + return logpdf(gp_update(model, prev_state, step), current_state) +end + AdvancedPS.isdone(::GPSSM, step) = step > T -model = GPSSM(params) +model = GPSSM(y, params) pg = AdvancedPS.PGAS(Nₚ) chains = sample(rng, model, pg, Nₛ) diff --git a/examples/gaussian-ssm/Project.toml b/examples/gaussian-ssm/Project.toml index 1093db09..e87cc369 100644 --- a/examples/gaussian-ssm/Project.toml +++ b/examples/gaussian-ssm/Project.toml @@ -5,4 +5,5 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" diff --git a/examples/gaussian-ssm/script.jl b/examples/gaussian-ssm/script.jl index debd9867..d7a5889c 100644 --- a/examples/gaussian-ssm/script.jl +++ b/examples/gaussian-ssm/script.jl @@ -3,6 +3,7 @@ using AdvancedPS using Random using Distributions using Plots +using SSMProblems # We consider the following linear state-space model with Gaussian innovations. The latent state is a simple gaussian random walk # and the observation is linear in the latent states, namely: @@ -33,16 +34,18 @@ Parameters = @NamedTuple begin r::Float64 end -mutable struct LinearSSM <: AdvancedPS.AbstractStateSpaceModel +mutable struct LinearSSM <: SSMProblems.AbstractStateSpaceModel X::Vector{Float64} + observations::Vector{Float64} θ::Parameters LinearSSM(θ::Parameters) = new(Vector{Float64}(), θ) + LinearSSM(y::Vector, θ::Parameters) = new(Vector{Float64}(), y, θ) end # and the densities defined above. -f(m::LinearSSM, state, t) = Normal(m.θ.a * state, m.θ.q) # Transition density -g(m::LinearSSM, state, t) = Normal(state, m.θ.r) # Observation density -f₀(m::LinearSSM) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state density +f(θ::Parameters, state, t) = Normal(θ.a * state, θ.q) # Transition density +g(θ::Parameters, state, t) = Normal(state, θ.r) # Observation density +f₀(θ::Parameters) = Normal(0, θ.q^2 / (1 - θ.a^2)) # Initial state density #md nothing #hide # We also need to specify the dynamics of the system through the transition equations: @@ -50,15 +53,26 @@ f₀(m::LinearSSM) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state den # - `AdvancedPS.transition`: the state transition density # - `AdvancedPS.observation`: the observation score given the observed data # - `AdvancedPS.isdone`: signals the end of the execution for the model -AdvancedPS.initialization(model::LinearSSM) = f₀(model) -AdvancedPS.transition(model::LinearSSM, state, step) = f(model, state, step) -function AdvancedPS.observation(model::LinearSSM, state, step) - return logpdf(g(model, state, step), y[step]) +SSMProblems.transition!!(rng::AbstractRNG, model::LinearSSM) = rand(rng, f₀(model.θ)) +function SSMProblems.transition!!( + rng::AbstractRNG, model::LinearSSM, state::Float64, step::Int +) + return rand(rng, f(model.θ, state, step)) +end + +function SSMProblems.emission_logdensity(modeL::LinearSSM, state::Float64, step::Int) + return logpdf(g(model.θ, state, step), model.observations[step]) +end +function SSMProblems.transition_logdensity( + model::LinearSSM, prev_state, current_state, step +) + return logpdf(f(model.θ, prev_state, step), current_state) end + +# We need to think seriously about how the data is handled AdvancedPS.isdone(::LinearSSM, step) = step > Tₘ # Everything is now ready to simulate some data. - a = 0.9 # Scale q = 0.32 # State variance r = 1 # Observation variance @@ -72,14 +86,12 @@ rng = Random.MersenneTwister(seed) x = zeros(Tₘ) y = zeros(Tₘ) - -reference = LinearSSM(θ₀) -x[1] = rand(rng, f₀(reference)) +x[1] = rand(rng, f₀(θ₀)) for t in 1:Tₘ if t < Tₘ - x[t + 1] = rand(rng, f(reference, x[t], t)) + x[t + 1] = rand(rng, f(θ₀, x[t], t)) end - y[t] = rand(rng, g(reference, x[t], t)) + y[t] = rand(rng, g(θ₀, x[t], t)) end # Here are the latent and obseravation timeseries @@ -88,7 +100,7 @@ plot!(y; seriestype=:scatter, label="y", xlabel="t", mc=:red, ms=2, ma=0.5) # `AdvancedPS` subscribes to the `AbstractMCMC` API. To sample we just need to define a Particle Gibbs kernel # and a model interface. -model = LinearSSM(θ₀) +model = LinearSSM(y, θ₀) pgas = AdvancedPS.PGAS(Nₚ) chains = sample(rng, model, pgas, Nₛ; progress=false); #md nothing #hide diff --git a/examples/levy-ssm/Project.toml b/examples/levy-ssm/Project.toml new file mode 100644 index 00000000..572ec481 --- /dev/null +++ b/examples/levy-ssm/Project.toml @@ -0,0 +1,7 @@ +[deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" diff --git a/examples/levy-ssm/script.jl b/examples/levy-ssm/script.jl new file mode 100644 index 00000000..2b35858f --- /dev/null +++ b/examples/levy-ssm/script.jl @@ -0,0 +1,256 @@ +# # Levy-SSM latent state inference +using AdvancedPS: SSMProblems +using AdvancedPS +using Random +using Plots +using Distributions +using AdvancedPS +using LinearAlgebra +using SSMProblems + +struct GammaProcess + C::Float64 + β::Float64 + tol::Float64 +end + +struct GammaPath{T} + jumps::Vector{T} + times::Vector{T} +end + +struct LangevinDynamics{T} + A::Matrix{T} + L::Vector{T} + θ::T + H::Vector{T} + σe::T +end + +struct NormalMeanVariance{T} + μ::T + σ::T +end + +function simulate( + rng::AbstractRNG, + process::GammaProcess, + rate::Float64, + start::Float64, + finish::Float64, + t0::Float64=0.0, +) + let β = process.β, C = process.C, tolerance = process.tol + jumps = Float64[] + last_jump = Inf + t = t0 + truncated = last_jump < tolerance + while !truncated + t += rand(rng, Exponential(1.0 / rate)) + xi = 1.0 / (β * (exp(t / C) - 1)) + prob = (1.0 + β * xi) * exp(-β * xi) + if rand(rng) < prob + push!(jumps, xi) + last_jump = xi + end + truncated = last_jump < tolerance + end + times = rand(rng, Uniform(start, finish), length(jumps)) + return GammaPath(jumps, times) + end +end + +function integral(times::Array{Float64}, path::GammaPath) + let jumps = path.jumps, jump_times = path.times + return [sum(jumps[jump_times .<= t]) for t in times] + end +end + +# Gamma Process +C = 1.0 +β = 1.0 +ϵ = 1e-10 +process = GammaProcess(C, β, ϵ) + +# Normal Mean-Variance representation +μw = 0.0 +σw = 1.0 +nvm = NormalMeanVariance(μw, σw) + +# Levy SSM with Langevin dynamics +# dx(t) = A x(t) dt + L dW(t) +# y(t) = H x(t) + ϵ(t) +θ = -0.5 +A = [ + 0.0 1.0 + 0.0 θ +] +L = [0.0; 1.0] +σe = 1.0 +H = [1.0, 0] +dyn = LangevinDynamics(A, L, θ, H, σe) + +# Simulation parameters +start, finish = 0, 100 +N = 200 +ts = range(start, finish; length=N) +seed = 4 +rng = Random.MersenneTwister(seed) +Np = 50 +Ns = 100 + +f(dt, θ) = exp(θ * dt) +function Base.exp(dyn::LangevinDynamics, dt::Real) + let θ = dyn.θ + f_val = f(dt, θ) + return [1.0 (f_val - 1)/θ; 0 f_val] + end +end + +function meancov( + t::T, dyn::LangevinDynamics, path::GammaPath, nvm::NormalMeanVariance +) where {T<:Real} + μ = zeros(T, 2) + Σ = zeros(T, (2, 2)) + let times = path.times, jumps = path.jumps, μw = nvm.μ, σw = nvm.σ + for (v, z) in zip(times, jumps) + ft = exp(dyn, (t - v)) * dyn.L + μ += ft .* μw .* z + Σ += ft * transpose(ft) .* σw^2 .* z + end + return μ, Σ + end +end + +X = zeros(Float64, (N, 2)) +Y = zeros(Float64, N) +for (i, t) in enumerate(ts) + if i > 1 + s = ts[i - 1] + dt = t - s + path = simulate(rng, process, dt, s, t, ϵ) + μ, Σ = meancov(t, dyn, path, nvm) + X[i, :] .= rand(rng, MultivariateNormal(exp(dyn, dt) * X[i - 1, :] + μ, Σ)) + end + + let H = dyn.H, σe = dyn.σe + Y[i] = transpose(H) * X[i, :] + rand(rng, Normal(0, σe)) + end +end + +# AdvancedPS +Parameters = @NamedTuple begin + dyn::LangevinDynamics + process::GammaProcess + nvm::NormalMeanVariance + times::Vector{Float64} +end + +struct MixedState{T} + x::Vector{T} + path::GammaPath{T} +end + +mutable struct LevyLangevin <: SSMProblems.AbstractStateSpaceModel + X::Vector{MixedState{Float64}} + observations::Vector{Float64} + θ::Parameters + LevyLangevin(θ::Parameters) = new(Vector{MixedState{Float64}}(), θ) + function LevyLangevin(y::Vector{Float64}, θ::Parameters) + return new(Vector{MixedState{Float64}}(), y, θ) + end +end + +function SSMProblems.transition!!(rng::AbstractRNG, model::LevyLangevin) + return MixedState( + rand(rng, MultivariateNormal([0, 0], I)), GammaPath(Float64[], Float64[]) + ) +end + +function SSMProblems.transition!!( + rng::AbstractRNG, model::LevyLangevin, state::MixedState, step +) + times = model.θ.times + s = times[step - 1] + t = times[step] + dt = t - s + path = simulate(rng, model.θ.process, dt, s, t) + μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm) + Σ += 1e-6 * I + return MixedState(rand(rng, MultivariateNormal(exp(dyn, dt) * state.x + μ, Σ)), path) +end + +function SSMProblems.transition_logdensity( + model::LevyLangevin, prev_state::MixedState, current_state::MixedState, step +) + times = model.θ.times + s = times[step - 1] + t = times[step] + dt = t - s + path = simulate(rng, model.θ.process, dt, s, t) + μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm) + Σ += 1e-6 * I + return logpdf(MultivariateNormal(exp(dyn, dt) * prev_state.x + μ, Σ), current_state.x) +end + +function SSMProblems.emission_logdensity(model::LevyLangevin, state::MixedState, step) + return logpdf(Normal(transpose(H) * state.x, σe), model.observations[step]) +end + +AdvancedPS.isdone(model::LevyLangevin, step) = step > length(model.θ.times) + +θ₀ = Parameters((dyn, process, nvm, ts)) +model = LevyLangevin(Y, θ₀) +pg = AdvancedPS.PGAS(Np) +chains = sample(rng, model, pg, Ns; progress=false); + +# Concat all sampled states +particles = hcat([chain.trajectory.model.X for chain in chains]...) +marginal_states = map(s -> s.x, particles); +jump_times = map(s -> s.path.times, particles); +jump_intensities = map(s -> s.path.jumps, particles); + +# Plot marginal state and jump intensities for one trajectory +p1 = plot( + ts, + [state[1] for state in marginal_states[:, end]]; + color=:darkorange, + label="Marginal State (x1)", +) +plot!( + ts, + [state[2] for state in marginal_states[:, end]]; + color=:dodgerblue, + label="Marginal State (x2)", +) + +p2 = scatter( + vcat([t for t in jump_times[:, end]]...), + vcat([j for j in jump_intensities[:, end]]...); + color=:darkorange, + label="Jumps", +) + +plot( + p1, p2; plot_title="Marginal State and Jump Intensities", layout=(2, 1), size=(600, 600) +) + +# Plot mean trajectory with standard deviation +mean_trajectory = transpose(hcat(mean(marginal_states; dims=2)...)) +std_trajectory = dropdims(std(stack(marginal_states); dims=3); dims=3) + +ps = [] +for d in 1:2 + p = plot( + ts, + mean_trajectory[:, d]; + ribbon=2 * std_trajectory[:, d]', + color=:darkorange, + label="Mean Trajectory (±2σ)", + fillalpha=0.2, + title="Marginal State Trajectories (X$d)", + ) + plot!(p, ts, X[:, d]; color=:dodgerblue, label="True Trajectory") + push!(ps, p) +end +plot(ps...; layout=(2, 1), size=(600, 600)) diff --git a/examples/particle-gibbs/Project.toml b/examples/particle-gibbs/Project.toml index d8ae6a53..6fa1fbb2 100644 --- a/examples/particle-gibbs/Project.toml +++ b/examples/particle-gibbs/Project.toml @@ -7,4 +7,5 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" diff --git a/examples/particle-gibbs/script.jl b/examples/particle-gibbs/script.jl index baa6ceb1..f7a9d2aa 100644 --- a/examples/particle-gibbs/script.jl +++ b/examples/particle-gibbs/script.jl @@ -5,6 +5,7 @@ using Distributions using Plots using AbstractMCMC using Random123 +using SSMProblems """ plot_update_rate(update_rate, N) @@ -90,22 +91,41 @@ plot(x; label="x", xlabel="t") plot(y; label="y", xlabel="t") # Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition: -mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel +mutable struct NonLinearTimeSeries <: SSMProblems.AbstractStateSpaceModel X::Vector{Float64} + observations::Vector{Float64} θ::Parameters NonLinearTimeSeries(θ::Parameters) = new(Float64[], θ) + NonLinearTimeSeries(y::Vector{Float64}, θ::Parameters) = new(Float64[], y, θ) end # The dynamics of the model is defined through the `AbstractStateSpaceModel` interface: -AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model.θ) -AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model.θ, state, step) -function AdvancedPS.observation(model::NonLinearTimeSeries, state, step) - return logpdf(g(model.θ, state, step), y[step]) +function SSMProblems.transition!!(rng::AbstractRNG, model::NonLinearTimeSeries) + return rand(rng, f₀(model.θ)) end +function SSMProblems.transition!!( + rng::AbstractRNG, model::NonLinearTimeSeries, state::Float64, step::Int +) + return rand(rng, f(model.θ, state, step)) +end + +function SSMProblems.emission_logdensity( + modeL::NonLinearTimeSeries, state::Float64, step::Int +) + return logpdf(g(model.θ, state, step), model.observations[step]) +end +function SSMProblems.transition_logdensity( + model::NonLinearTimeSeries, prev_state, current_state, step +) + return logpdf(f(model.θ, prev_state, step), current_state) +end + +# We need to tell AdvancedPS when to stop the execution of the model +# TODO AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ # Here we use the particle gibbs kernel without adaptive resampling. -model = NonLinearTimeSeries(θ₀) +model = NonLinearTimeSeries(y, θ₀) pg = AdvancedPS.PG(Nₚ, 1.0) chains = sample(rng, model, pg, Nₛ; progress=false); #md nothing #hide diff --git a/src/AdvancedPS.jl b/src/AdvancedPS.jl index 8ce7c2b7..faa673e8 100644 --- a/src/AdvancedPS.jl +++ b/src/AdvancedPS.jl @@ -5,6 +5,7 @@ using Distributions: Distributions using Random: Random using StatsFuns: StatsFuns using Random123: Random123 +using SSMProblems: SSMProblems abstract type AbstractParticleModel <: AbstractMCMC.AbstractModel end diff --git a/src/container.jl b/src/container.jl index c5d693d5..3ac64fef 100644 --- a/src/container.jl +++ b/src/container.jl @@ -269,7 +269,7 @@ function reweight!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing) # Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and # ``θᵢ`` are variables of other samplers. isref = p === ref - score = advance!(p, isref) + score = advance!(p, isref) # SSMProblems.transition!! if score === nothing numdone += 1 diff --git a/src/model.jl b/src/model.jl index fab9a34f..b7b72db3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -7,13 +7,20 @@ mutable struct Trace{F,R} end const Particle = Trace -const SSMTrace{R} = Trace{<:AbstractStateSpaceModel,R} +const SSMTrace{R} = Trace{<:SSMProblems.AbstractStateSpaceModel,R} const GenericTrace{R} = Trace{<:AbstractGenericModel,R} reset_logprob!(::AdvancedPS.Particle) = nothing reset_model(f) = deepcopy(f) delete_retained!(f) = nothing +""" + isdone(model::SSMProblems.AbstractStateSpaceModel, step) + +Returns `true` if we reached the end of the model execution +""" +function isdone end + """ copy(trace::Trace) diff --git a/src/pgas.jl b/src/pgas.jl index 8dfb449e..41bcfea7 100644 --- a/src/pgas.jl +++ b/src/pgas.jl @@ -1,33 +1,3 @@ -""" - initialization(model::AbstractStateSpaceModel) - -Define the distribution of the initial state of the State Space Model -""" -function initialization end - -""" - transition(model::AbstractStateSpaceModel, state, step) - -Define the transition density of the State Space Model -Must return `nothing` if it consumed all the data -""" -function transition end - -""" - observation(model::AbstractStateSpaceModel, state, step) - -Return the log-likelihood of the observed measurement conditional on the current state of the model. -Must return `nothing` if it consumed all the data -""" -function observation end - -""" - isdone(model::AbstractStateSpaceModel, step) - -Return `true` if model reached final state else `false` -""" -function isdone end - """ previous_state(trace::SSMTrace) @@ -54,13 +24,11 @@ current_step(trace::SSMTrace) = trace.rng.count Get the log weight of the transition from previous state of `model` to `x` """ function transition_logweight(particle::SSMTrace, x) - score = Distributions.logpdf( - transition( - particle.model, - particle.model.X[current_step(particle) - 2], - current_step(particle) - 2, - ), + score = SSMProblems.transition_logdensity( + particle.model, + particle.model.X[current_step(particle) - 2], x, + current_step(particle) - 1, ) return score end @@ -93,16 +61,18 @@ function advance!(particle::SSMTrace, isref::Bool=false) if !isref if running_step == 1 - new_state = rand(particle.rng, initialization(model)) # Generate initial state, maybe fallback to 0 if initialization is not defined + new_state = SSMProblems.transition!!(particle.rng, model) else current_state = model.X[running_step - 1] - new_state = rand(particle.rng, transition(model, current_state, running_step)) + new_state = SSMProblems.transition!!( + particle.rng, model, current_state, running_step + ) end else new_state = model.X[running_step] # We need the current state from the reference particle end - score = observation(model, new_state, running_step) + score = SSMProblems.emission_logdensity(model, new_state, running_step) # accept transition !isref && push!(model.X, new_state) diff --git a/src/smc.jl b/src/smc.jl index c45b127a..6d982d2b 100644 --- a/src/smc.jl +++ b/src/smc.jl @@ -26,12 +26,17 @@ struct SMCSample{P,W,L} logevidence::L end -function AbstractMCMC.sample(model::AbstractStateSpaceModel, sampler::SMC; kwargs...) +function AbstractMCMC.sample( + model::SSMProblems.AbstractStateSpaceModel, sampler::SMC; kwargs... +) return AbstractMCMC.sample(Random.GLOBAL_RNG, model, sampler; kwargs...) end function AbstractMCMC.sample( - rng::Random.AbstractRNG, model::AbstractStateSpaceModel, sampler::SMC; kwargs... + rng::Random.AbstractRNG, + model::SSMProblems.AbstractStateSpaceModel, + sampler::SMC; + kwargs..., ) if !isempty(kwargs) @warn "keyword arguments $(keys(kwargs)) are not supported by `SMC`" @@ -95,7 +100,7 @@ PGAS(nparticles::Int) = PGAS(nparticles, ResampleWithESSThreshold(1.0)) function AbstractMCMC.step( rng::Random.AbstractRNG, - model::AbstractStateSpaceModel, + model::SSMProblems.AbstractStateSpaceModel, sampler::Union{PGAS,PG}, state::Union{PGState,Nothing}=nothing; kwargs..., diff --git a/test/Project.toml b/test/Project.toml index 9fcb69de..686826f8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/container.jl b/test/container.jl index 7a312291..ba266499 100644 --- a/test/container.jl +++ b/test/container.jl @@ -1,14 +1,17 @@ @testset "container.jl" begin # Since the extension would hide the low level function call API - mutable struct LogPModel{T} <: AdvancedPS.AbstractStateSpaceModel + mutable struct LogPModel{T} <: SSMProblems.AbstractStateSpaceModel logp::T X::Array{T} end - AdvancedPS.initialization(model::LogPModel) = Uniform() - AdvancedPS.transition(model::LogPModel, state, step) = Uniform() - AdvancedPS.observation(model::LogPModel, state, step) = model.logp + SSMProblems.transition!!(rng::AbstractRNG, model::LogPModel) = rand(rng, Uniform()) + function SSMProblems.transition!!(rng::AbstractRNG, model::LogPModel, state, step) + return rand(rng, Uniform()) + end + SSMProblems.emission_logdensity(model::LogPModel, state, step) = model.logp + AdvancedPS.isdone(model::LogPModel, step) = false @testset "copy particle container" begin diff --git a/test/pgas.jl b/test/pgas.jl index 8c7b3160..3a701db7 100644 --- a/test/pgas.jl +++ b/test/pgas.jl @@ -5,18 +5,25 @@ r::Float64 end - mutable struct BaseModel <: AdvancedPS.AbstractStateSpaceModel + mutable struct BaseModel <: SSMProblems.AbstractStateSpaceModel X::Vector{Float64} θ::Params BaseModel(params::Params) = new(Vector{Float64}(), params) end - AdvancedPS.initialization(model::BaseModel) = Normal(0, model.θ.q) - function AdvancedPS.transition(model::BaseModel, state, step) - return Distributions.Normal(model.θ.a * state, model.θ.q) + function SSMProblems.transition!!(rng::AbstractRNG, model::BaseModel) + return rand(rng, Normal(0, model.θ.q)) end - function AdvancedPS.observation(model::BaseModel, state, step) - return Distributions.logpdf(Distributions.Normal(state, model.θ.r), 0) + function SSMProblems.transition!!(rng::AbstractRNG, model::BaseModel, state, step) + return rand(rng, Normal(model.θ.a * state, model.θ.q)) + end + function SSMProblems.emission_logdensity(model::BaseModel, state, step) + return logpdf(Distributions.Normal(state, model.θ.r), 0) + end + function SSMProblems.transition_logdensity( + model::BaseModel, prev_state, current_state, step + ) + return logpdf(Normal(model.θ.a * prev_state, model.θ.q), current_state) end AdvancedPS.isdone(::BaseModel, step) = step > 3 @@ -83,7 +90,7 @@ seed = 10 rng = Random.MersenneTwister(seed) - for sampler in [AdvancedPS.PGAS(10)] + for sampler in [AdvancedPS.PGAS(10), AdvancedPS.PG(10)] Random.seed!(rng, seed) chain1 = sample(rng, model, sampler, 10) vals1 = hcat([chain.trajectory.model.X for chain in chain1]...) diff --git a/test/runtests.jl b/test/runtests.jl index d7058062..b2c0990c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using Distributions using Libtask using Random using Test +using SSMProblems @testset "AdvancedPS.jl" begin @testset "Resampling tests" begin