Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSMProblems integration #97

Merged
merged 6 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "AdvancedPS"
uuid = "576499cb-2369-40b2-a588-c64705576edc"
authors = ["TuringLang"]
version = "0.5.4"
version = "0.6"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[weakdeps]
Expand All @@ -21,10 +22,10 @@ AdvancedPSLibtaskExt = "Libtask"
AbstractMCMC = "2, 3, 4, 5"
Distributions = "0.23, 0.24, 0.25"
Libtask = "0.8"
Random = "1.6"
Random123 = "1.3"
Requires = "1.0"
StatsFuns = "0.9, 1"
Random = "1.6"
julia = "1.6"

[extras]
Expand Down
1 change: 1 addition & 0 deletions examples/gaussian-process/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
35 changes: 23 additions & 12 deletions examples/gaussian-process/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@ using AbstractGPs
using Plots
using Distributions
using Libtask
using SSMProblems

Parameters = @NamedTuple begin
a::Float64
q::Float64
kernel
end

mutable struct GPSSM <: AdvancedPS.AbstractStateSpaceModel
mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel
X::Vector{Float64}
observations::Vector{Float64}
θ::Parameters

GPSSM(params::Parameters) = new(Vector{Float64}(), params)
GPSSM(y::Vector{Float64}, params::Parameters) = new(Vector{Float64}(), y, params)
end

seed = 1
Expand All @@ -29,21 +32,20 @@ q = 0.5

params = Parameters((a, q, SqExponentialKernel()))

f(model::GPSSM, x, t) = Normal(model.θ.a * x, model.θ.q)
h(model::GPSSM) = Normal(0, model.θ.q)
g(model::GPSSM, x, t) = Normal(0, exp(0.5 * x)^2)
f(θ::Parameters, x, t) = Normal(θ.a * x, θ.q)
h(θ::Parameters) = Normal(0, θ.q)
g(θ::Parameters, x, t) = Normal(0, exp(0.5 * x)^2)

rng = Random.MersenneTwister(seed)
ref_model = GPSSM(params)

x = zeros(T)
y = similar(x)
x[1] = rand(rng, h(ref_model))
x[1] = rand(rng, h(params))
for t in 1:T
if t < T
x[t + 1] = rand(rng, f(ref_model, x[t], t))
x[t + 1] = rand(rng, f(params, x[t], t))
end
y[t] = rand(rng, g(ref_model, x[t], t))
y[t] = rand(rng, g(params, x[t], t))
end

function gp_update(model::GPSSM, state, step)
Expand All @@ -54,12 +56,21 @@ function gp_update(model::GPSSM, state, step)
return Normal(μ[1], σ[1])
end

AdvancedPS.initialization(::GPSSM) = h(model)
AdvancedPS.transition(model::GPSSM, state, step) = gp_update(model, state, step)
AdvancedPS.observation(model::GPSSM, state, step) = logpdf(g(model, state, step), y[step])
SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM) = rand(rng, h(model.θ))
function SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM, state, step)
return rand(rng, gp_update(model, state, step))
end

function SSMProblems.emission_logdensity(model::GPSSM, state, step)
return logpdf(g(model.θ, state, step), model.observations[step])
end
function SSMProblems.transition_logdensity(model::GPSSM, prev_state, current_state, step)
return logpdf(gp_update(model, prev_state, step), current_state)
end

AdvancedPS.isdone(::GPSSM, step) = step > T

model = GPSSM(params)
model = GPSSM(y, params)
pg = AdvancedPS.PGAS(Nₚ)
chains = sample(rng, model, pg, Nₛ)

Expand Down
1 change: 1 addition & 0 deletions examples/gaussian-ssm/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
42 changes: 27 additions & 15 deletions examples/gaussian-ssm/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using AdvancedPS
using Random
using Distributions
using Plots
using SSMProblems

# We consider the following linear state-space model with Gaussian innovations. The latent state is a simple gaussian random walk
# and the observation is linear in the latent states, namely:
Expand Down Expand Up @@ -33,32 +34,45 @@ Parameters = @NamedTuple begin
r::Float64
end

mutable struct LinearSSM <: AdvancedPS.AbstractStateSpaceModel
mutable struct LinearSSM <: SSMProblems.AbstractStateSpaceModel
X::Vector{Float64}
observations::Vector{Float64}
θ::Parameters
LinearSSM(θ::Parameters) = new(Vector{Float64}(), θ)
LinearSSM(y::Vector, θ::Parameters) = new(Vector{Float64}(), y, θ)
end

# and the densities defined above.
f(m::LinearSSM, state, t) = Normal(m.θ.a * state, m.θ.q) # Transition density
g(m::LinearSSM, state, t) = Normal(state, m.θ.r) # Observation density
f₀(m::LinearSSM) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state density
f(θ::Parameters, state, t) = Normal(θ.a * state, θ.q) # Transition density
g(θ::Parameters, state, t) = Normal(state, θ.r) # Observation density
f₀(θ::Parameters) = Normal(0, θ.q^2 / (1 - θ.a^2)) # Initial state density
#md nothing #hide

# We also need to specify the dynamics of the system through the transition equations:
# - `AdvancedPS.initialization`: the initial state density
# - `AdvancedPS.transition`: the state transition density
# - `AdvancedPS.observation`: the observation score given the observed data
# - `AdvancedPS.isdone`: signals the end of the execution for the model
AdvancedPS.initialization(model::LinearSSM) = f₀(model)
AdvancedPS.transition(model::LinearSSM, state, step) = f(model, state, step)
function AdvancedPS.observation(model::LinearSSM, state, step)
return logpdf(g(model, state, step), y[step])
SSMProblems.transition!!(rng::AbstractRNG, model::LinearSSM) = rand(rng, f₀(model.θ))
function SSMProblems.transition!!(
rng::AbstractRNG, model::LinearSSM, state::Float64, step::Int
)
return rand(rng, f(model.θ, state, step))
end

function SSMProblems.emission_logdensity(modeL::LinearSSM, state::Float64, step::Int)
return logpdf(g(model.θ, state, step), model.observations[step])
end
function SSMProblems.transition_logdensity(
model::LinearSSM, prev_state, current_state, step
)
return logpdf(f(model.θ, prev_state, step), current_state)
end

# We need to think seriously about how the data is handled
AdvancedPS.isdone(::LinearSSM, step) = step > Tₘ

# Everything is now ready to simulate some data.

a = 0.9 # Scale
q = 0.32 # State variance
r = 1 # Observation variance
Expand All @@ -72,14 +86,12 @@ rng = Random.MersenneTwister(seed)

x = zeros(Tₘ)
y = zeros(Tₘ)

reference = LinearSSM(θ₀)
x[1] = rand(rng, f₀(reference))
x[1] = rand(rng, f₀(θ₀))
for t in 1:Tₘ
if t < Tₘ
x[t + 1] = rand(rng, f(reference, x[t], t))
x[t + 1] = rand(rng, f(θ₀, x[t], t))
end
y[t] = rand(rng, g(reference, x[t], t))
y[t] = rand(rng, g(θ₀, x[t], t))
end

# Here are the latent and obseravation timeseries
Expand All @@ -88,7 +100,7 @@ plot!(y; seriestype=:scatter, label="y", xlabel="t", mc=:red, ms=2, ma=0.5)

# `AdvancedPS` subscribes to the `AbstractMCMC` API. To sample we just need to define a Particle Gibbs kernel
# and a model interface.
model = LinearSSM(θ₀)
model = LinearSSM(y, θ₀)
pgas = AdvancedPS.PGAS(Nₚ)
chains = sample(rng, model, pgas, Nₛ; progress=false);
#md nothing #hide
Expand Down
7 changes: 7 additions & 0 deletions examples/levy-ssm/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
Loading
Loading