diff --git a/examples/levy-ssm/gamma_process.jl b/examples/levy-ssm/gamma_process.jl index 6d9d356..11bd33c 100644 --- a/examples/levy-ssm/gamma_process.jl +++ b/examples/levy-ssm/gamma_process.jl @@ -98,7 +98,8 @@ Ns = 10 f(dt, θ) = exp(θ * dt) function Base.exp(dyn::LangevinDynamics, dt::Real) let θ = dyn.θ - return [1.0 (f(dt, θ) - 1)/θ; 0 f(dt, θ)] + f_val = f(dt, θ) + return [1.0 (f_val - 1)/θ; 0 f_val] end end @@ -125,7 +126,7 @@ for (i, t) in enumerate(ts) dt = t - s path = simulate(process, dt, s, t, ϵ) μ, Σ = meancov(t, dyn, path, nvm) - X[i, :] = rand(MultivariateNormal(exp(dyn, dt) * X[i - 1, :] + μ, Σ)) + X[i, :] .= rand(MultivariateNormal(exp(dyn, dt) * X[i - 1, :] + μ, Σ)) end let H = dyn.H, σe = dyn.σe @@ -141,36 +142,113 @@ Parameters = @NamedTuple begin times::Vector{Float64} # Ugly, but avoids global de-ref end +struct MixedState{T} + x::Vector{T} + path::GammaPath{T} +end + mutable struct LevyLangevin <: AdvancedPS.AbstractStateSpaceModel - X::Vector{Vector{Float64}} + X::Vector{MixedState{Float64}} θ::Parameters - LevyLangevin(θ::Parameters) = new(Vector{Array{2,Float64}}(), θ) + LevyLangevin(θ::Parameters) = new(Vector{MixedState{Float64}}(), θ) +end + +struct InitialDistribution end +function Base.rand(rng::Random.AbstractRNG, ::InitialDistribution) + return MixedState(rand(MultivariateNormal([0, 0], I)), GammaPath(Float64[], Float64[])) +end + +struct TransitionDistribution{T} + current_state::MixedState{T} + model::LevyLangevin + s::T + t::T +end + +function Base.rand(rng::Random.AbstractRNG, td::TransitionDistribution) + let model = td.model, s = td.s, t = td.t, state = td.current_state + dt = t - s + path = simulate(model.θ.process, dt, s, t, ϵ) + μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm) + Σ += 1e-6 * I + return MixedState(rand(MultivariateNormal(exp(dyn, dt) * state.x + μ, Σ)), path) + end +end + +# Required for ancestor sampling in PGAS +function Distributions.logpdf(td::TransitionDistribution, state::MixedState) + let model = td.model, s = td.s, t = td.t, state = td.current_state + dt = t - s + path = simulate(model.θ.process, dt, s, t, ϵ) + μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm) + Σ += 1e-6 * I + return logpdf(MultivariateNormal(exp(dyn, dt) * state.x + μ, Σ), state.x) + end end θ₀ = Parameters((dyn, process, nvm, ts)) -AdvancedPS.initialization(model::LevyLangevin) = MultivariateNormal([0, 0], I) -function AdvancedPS.transition(model::LevyLangevin, state, step) +function AdvancedPS.initialization(::LevyLangevin) + return InitialDistribution() +end +function AdvancedPS.transition(model::LevyLangevin, state::MixedState, step) times = model.θ.times s = times[step - 1] t = times[step] - dt = t - s - path = simulate(model.θ.process, dt, s, t, ϵ) - μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm) - return MultivariateNormal(exp(dyn, dt) * state + μ, Σ) + return TransitionDistribution(state, model, s, t) end -function AdvancedPS.observation(model::LevyLangevin, state, step) - return logpdf(Normal(transpose(H) * state, σe), Y[step]) +function AdvancedPS.observation(model::LevyLangevin, state::MixedState, step) + return logpdf(Normal(transpose(H) * state.x, σe), Y[step]) end AdvancedPS.isdone(::LevyLangevin, step) = step > length(ts) model = LevyLangevin(θ₀) -pg = AdvancedPS.PG(Np, 1.0) +pg = AdvancedPS.PG(Np) chains = sample(rng, model, pg, Ns; progress=true); particles = hcat([chain.trajectory.model.X for chain in chains]...) # Concat all sampled states -mean_trajectory = transpose(hcat(mean(particles; dims=2)...)) +marginal_states = map(s -> s.x, particles); +jump_times = map(s -> s.path.times, particles); +jump_intensities = map(s -> s.path.jumps, particles); + +# Plot marginal state and jump intensities for one trajectory +p1 = plot( + ts, + [state[1] for state in marginal_states[:, end]]; + color=:darkorange, + label="Marginal State (x1)", +) +plot!( + ts, + [state[2] for state in marginal_states[:, end]]; + color=:dodgerblue, + label="Marginal State (x2)", +) -plot(X; color=:darkorange, label="Original Trajectory") -plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9) +p2 = scatter( + vcat([t for t in jump_times[:, end]]...), + vcat([j for j in jump_intensities[:, end]]...); + color=:darkorange, + label="Jumps", +) + +plot( + p1, p2; plot_title="Marginal State and Jump Intensities", layout=(2, 1), size=(600, 600) +) + +# Plot mean trajectory with standard deviation +mean_trajectory = transpose(hcat(mean(marginal_states; dims=2)...)) +std_trajectory = dropdims(std(stack(marginal_states); dims=3); dims=3) + +plot( + mean_trajectory; + ribbon=std_trajectory', + color=:darkorange, + label="Mean trajectory", + opacity=0.3, + title="Inference Quality", +) +plot!( + mean_trajectory; color=:dodgerblue, label="Original Trajectory", title="Path Degeneracy" +)