Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code to compute pairwise marginals over neighboring variables #138

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
- '1.8'
- '1.9'
- '1.10'
- '1.11'
os:
- ubuntu-latest
arch:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ TensorTrains = "0.10"
Tullio = "0.3"
UnPack = "1"
Unzip = "0.2"
julia = "1.8, 1.9, 1.10"
julia = "1.8"
6 changes: 5 additions & 1 deletion src/MatrixProductBP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,18 @@ export
BPFactor, nstates, MPBP, mpbp, reset_messages!, reset_beliefs!, reset_observations!,
reset!, is_free_dynamics, onebpiter!, CB_BP, iterate!,
pair_beliefs, pair_beliefs_as_mpem, beliefs_tu, autocorrelations,
autocovariances, means, beliefs, bethe_free_energy,
autocovariances, means,
pair_correlations, alternate_marginals, alternate_correlations,
beliefs, bethe_free_energy,
periodic_mpbp, is_periodic,
mpbp_infinite_graph, InfiniteRegularGraph, periodic_mpbp_infinite_graph,
InfiniteBipartiteRegularGraph, mpbp_infinite_bipartite_graph,
logprob, expectation, pair_observations_directed,
pair_observations_nondirected, pair_obs_undirected_to_directed,
exact_prob, exact_marginals, site_marginals, exact_autocorrelations,
exact_autocovariances, exact_marginal_expectations,
exact_pair_marginals, exact_pair_marginal_expectations,
exact_alternate_marginals, exact_alternate_marginal_expectations,
SoftMarginSampler, onesample!, onesample, sample!, sample, marginals, pair_marginals,
continuous_sis_sampler, simulate_queue_sis!,
draw_node_observations!, AtomicVector,
Expand Down
76 changes: 76 additions & 0 deletions src/exact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,82 @@

exact_marginal_expectations(bp; m_exact = exact_marginals(bp)) = exact_marginal_expectations((x,i)->x, bp; m_exact)

function pair_marginals(bp::MPBP{G,F}; p = exact_prob(bp)[1]) where {G,F}
T = getT(bp); N = nv(bp.g)
qs = Tuple(nstates(bp,i) for t=1:T+1, i=1:N)
m = [zeros(vcat(fill(nstates(bp,i),T+1), fill(nstates(bp,j),T+1))...) for (i,j) in edges(bp.g)]
prog = Progress(prod(qs), desc="Computing exact pair marginals")
X = zeros(Int, T+1, N)
for x in 1:prod(qs)
X .= _int_to_matrix(x, qs, (T+1,N))
for (i,j,id) in edges(bp.g)
m[id][X[:,i]...,X[:,j]...] += p[x]
end
next!(prog)
end
@assert all(sum(pᵢ) ≈ 1 for pᵢ in m)
return m
end

function exact_pair_marginals(bp::MPBP{G,F}; p_exact = exact_prob(bp)[1]) where {G,F}
m = pair_marginals(bp; p = p_exact)
T = getT(bp)
pp = [[zeros(nstates(bp,i),nstates(bp,j)) for t in 1:T+1] for (i,j) in edges(bp.g)]
for (i,j,id) in edges(bp.g)
for t in 1:T+1
for xᵢᵗ in 1:nstates(bp,i)
for xⱼᵗ⁺¹ in 1:nstates(bp,j)
indices_i = [s==t ? xᵢᵗ : Colon() for s in 1:T+1]
indices_j = [s==t ? xⱼᵗ⁺¹ : Colon() for s in 1:T+1]
pp[id][t][xᵢᵗ,xⱼᵗ⁺¹] += sum(m[id][indices_i...,indices_j...])
end
end
@debug @assert sum(pp[id][t]) ≈ 1
end
end
return pp
end

function exact_pair_marginal_expectations(f, bp::MPBP{G,F};
m_exact = exact_pair_marginals(bp)) where {G,F}
map(eachindex(m_exact)) do i
expectation.(x->f(x,i), m_exact[i])
end
end

function exact_pair_marginal_expectations(bp; m_exact = exact_alternate_marginals(bp))
return exact_pair_marginal_expectations((x,i)->x, bp; m_exact)

Check warning on line 129 in src/exact.jl

View check run for this annotation

Codecov / codecov/patch

src/exact.jl#L128-L129

Added lines #L128 - L129 were not covered by tests
end

function exact_alternate_marginals(bp::MPBP{G,F}; p_exact = exact_prob(bp)[1]) where {G,F}
m = pair_marginals(bp; p = p_exact)
T = getT(bp)
pp = [[zeros(nstates(bp,i),nstates(bp,j)) for t in 1:T] for (i,j) in edges(bp.g)]
for (i,j,id) in edges(bp.g)
for t in 1:T
for xᵢᵗ in 1:nstates(bp,i)
for xⱼᵗ⁺¹ in 1:nstates(bp,j)
indices_i = [s==t ? xᵢᵗ : Colon() for s in 1:T+1]
indices_j = [s==t+1 ? xⱼᵗ⁺¹ : Colon() for s in 1:T+1]
pp[id][t][xᵢᵗ,xⱼᵗ⁺¹] = sum(m[id][indices_i...,indices_j...])
end
end
end
end
return pp
end

function exact_alternate_marginal_expectations(f, bp::MPBP{G,F};
m_exact = exact_alternate_marginals(bp)) where {G,F}
map(eachindex(m_exact)) do i
expectation.(x->f(x,i), m_exact[i])
end
end
function exact_alternate_marginal_expectations(bp; m_exact = exact_alternate_marginals(bp))
return exact_alternate_marginal_expectations((x,i)->x, bp; m_exact)

Check warning on line 157 in src/exact.jl

View check run for this annotation

Codecov / codecov/patch

src/exact.jl#L156-L157

Added lines #L156 - L157 were not covered by tests
end


function exact_autocorrelations(f, bp::MPBP{G,F};
p_exact = exact_prob(bp)[1]) where {G,F}
m = site_marginals(bp; p = p_exact)
Expand Down
29 changes: 27 additions & 2 deletions src/mpbp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ function _pair_beliefs!(b, f, bp::MPBP{G,F}) where {G,F}
for j in 1:N
dⱼ = length(nzrange(X, j))
for k in nzrange(X, j)
ji = k # idx of message i→j
ij = vals[k] # idx of message j→i
ij = k # idx of message i→j
ji = vals[k] # idx of message j→i
μᵢⱼ = bp.μ[ij]; μⱼᵢ = bp.μ[ji]
bᵢⱼ, zᵢⱼ = f(μᵢⱼ, μⱼᵢ, bp.ψ[ij])
logz[j] += (1/dⱼ- 1/2) * log(zᵢⱼ)
Expand Down Expand Up @@ -260,6 +260,31 @@ function means(f, bp::MPBP; sites=vertices(bp.g))
end
end

# return <f(xᵢᵗ)f(xⱼᵗ)> per each directed edge i->j
function pair_correlations(f, bp::MPBP{G,F,V,M2}) where {G,F,V,M2}
am = pair_beliefs(bp)[1]
return [expectation.(f, amij) for amij in am]
end

# return p(xᵢᵗ,xⱼᵗ⁺¹) per each directed edge i->j
function alternate_marginals(bp::MPBP{G,F,V,M2}) where {G,F,V,M2}
pbs = pair_beliefs_as_mpem(bp)[1]
tvs = twovar_marginals.(pbs)

return map(tvs) do tv
map(1:size(tv,1)-1) do t
tvt = tv[t,t+1]
dropdims(sum(tvt; dims=(2,3)); dims=(2,3))
end
end
end

# return <f(xᵢᵗ)f(xⱼᵗ⁺¹)> per each directed edge i->j
function alternate_correlations(f, bp::MPBP{G,F,V,M2}) where {G,F,V,M2}
am = alternate_marginals(bp)
return [expectation.(f, amij) for amij in am]
end

covariance(r::Matrix{<:Real}, μ::Vector{<:Real}) = r .- μ*μ'

function autocovariances(f, bp::MPBP; sites=vertices(bp.g), kw...)
Expand Down
22 changes: 21 additions & 1 deletion test/glauber_small_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
f_bethe = bethe_free_energy(bp)
Z_bp = exp(-f_bethe)

f(x,i) = 2x-3
f(x, args...) = 2x-3

r_bp = autocorrelations(f, bp)
r_exact = exact_autocorrelations(f, bp; p_exact)
Expand All @@ -52,11 +52,22 @@
pb_bp = pair_beliefs(bp)[1]
pb_bp2 = marginals.(pair_beliefs_as_mpem(bp)[1])

pb_exact = exact_pair_marginals(bp)

pc_bp = pair_correlations(f, bp)
pc_exact = exact_pair_marginal_expectations(f, bp)

a_bp = alternate_correlations(f, bp)
a_exact = exact_alternate_marginal_expectations(f, bp)

@testset "Observables" begin
@test Z_exact ≈ Z_bp
@test p_ex ≈ p_bp
@test a_bp ≈ a_exact
@test r_bp ≈ r_exact
@test c_bp ≈ c_exact
@test pc_bp ≈ pc_exact
@test pb_bp ≈ pb_exact
@test pb_bp ≈ pb_bp2
end

Expand Down Expand Up @@ -103,11 +114,20 @@
c_bp = autocovariances(f, bp2)
c_exact = exact_autocovariances(f, bp2; r = r_exact)

pb_bp = pair_beliefs(bp2)[1]
pb_exact = exact_pair_marginals(bp2)

a_bp = alternate_marginals(bp2)
a_exact = exact_alternate_marginals(bp2)


@testset "Glauber small tree - DampedFactor" begin
@test Z_exact ≈ Z_bp
@test p_ex ≈ p_bp
@test r_bp ≈ r_exact
@test c_bp ≈ c_exact
@test pb_bp ≈ pb_exact
@test a_bp ≈ a_exact
end

## Generic Factor
Expand Down
8 changes: 2 additions & 6 deletions test/sirs_small_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@

@testset "SIRS small tree" begin

@testset "logprob" begin
@test logprob(bp, X) ≈ -4.017112724421366
end

draw_node_observations!(bp.ϕ, X, N, last_time=true; rng)

svd_trunc = TruncThresh(0.0)
Expand Down Expand Up @@ -80,9 +76,9 @@
reset!(bp; observations=true)
draw_node_observations!(bp.ϕ, X, N*(T+1), last_time=false)
svd_trunc = TruncThresh(0.0)
iterate!(bp, maxiter=10; svd_trunc, showprogress=false)
iters, cb = iterate!(bp, maxiter=10, tol=0.0; svd_trunc, showprogress=false)
f_bethe = bethe_free_energy(bp)
logl_bp = - f_bethe
logl_bp = -f_bethe
logp = logprob(bp, X)
@test logl_bp ≈ logp
end
Expand Down
4 changes: 0 additions & 4 deletions test/sis_heterogeneous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
bp = mpbp(sis)
X, _ = onesample(bp; rng)

@testset "logprob" begin
@test logprob(bp, X) ≈ -5.813130622330121
end

draw_node_observations!(bp.ϕ, X, N, last_time=true; rng)

@testset "SIS small tree" begin
Expand Down
9 changes: 4 additions & 5 deletions test/sis_small_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
bp = mpbp(sis)
rng = MersenneTwister(111)
X, _ = onesample(bp; rng)

@testset "logprob" begin
@test logprob(bp, X) ≈ -10.900027128953564
end


draw_node_observations!(bp.ϕ, X, N, last_time=true; rng)

@testset "SIS small tree" begin
Expand All @@ -44,11 +40,14 @@
c_bp = autocovariances(f, bp)
c_exact = exact_autocovariances(f, bp; r = r_exact)

pb_exact = exact_pair_marginals(bp)
pb_bp = pair_beliefs(bp)[1]

@test Z_exact ≈ Z_bp
@test p_ex ≈ p_bp
@test r_bp ≈ r_exact
@test c_bp ≈ c_exact
@test pb_bp ≈ pb_exact
end

@testset "RestrictedRecursiveBPFactor - RecursiveBPFactor generic methods" begin
Expand Down
Loading