Skip to content

Commit

Permalink
generic BP adapted to work with periodic time, with test (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
stecrotti authored Nov 13, 2023
1 parent d793589 commit 88f059d
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/bp_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function f_bp(A::Vector{M2}, wᵢ::Vector{U}, ϕᵢ::Vector{Vector{F}},
for xₙᵢᵗ in x_neigs
xⱼᵗ = xₙᵢᵗ[j_index]
xₙᵢ₋ⱼᵗ = xₙᵢᵗ[Not(j_index)]
Bᵗ[:, :, xᵢᵗ, xⱼᵗ, xᵢᵗ⁺¹] .+= (t == T + 1 ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, xₙᵢᵗ, xᵢᵗ)) *
Bᵗ[:, :, xᵢᵗ, xⱼᵗ, xᵢᵗ⁺¹] .+= ((t == T + 1) && !periodic ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, xₙᵢᵗ, xᵢᵗ)) *
Aᵗ[:, :, xᵢᵗ, xₙᵢ₋ⱼᵗ...] *
prod(ψₙᵢ[k][t][xᵢᵗ, xₖᵗ] for (k, xₖᵗ) in enumerate(xₙᵢᵗ) if k != j_index; init=1.0)
end
Expand Down Expand Up @@ -97,10 +97,10 @@ function f_bp_dummy_neighbor(A::Vector{<:AbstractMPEM2},
for xᵢᵗ in 1:q
for xᵢᵗ⁺¹ in 1:q
if isempty(A)
Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+= (t == T + 1 ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, Int[], xᵢᵗ)) * ϕᵢ[t][xᵢᵗ]
Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+= ((t == T + 1) && !periodic ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, Int[], xᵢᵗ)) * ϕᵢ[t][xᵢᵗ]
else
for xₙᵢᵗ in xₙᵢ
Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+= (t == T + 1 ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, xₙᵢᵗ, xᵢᵗ)) .*
Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+= ((t == T + 1) && !periodic ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, xₙᵢᵗ, xᵢᵗ)) .*
Aᵗ[:, :, xᵢᵗ, xₙᵢᵗ...] .* ϕᵢ[t][xᵢᵗ] .*
prod(ψₙᵢ[k][t][xᵢᵗ, xₖᵗ] for (k, xₖᵗ) in enumerate(xₙᵢᵗ))
end
Expand Down
3 changes: 0 additions & 3 deletions src/mpbp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,6 @@ function iterate!(bp::MPBP; maxiter::Integer=5,
svd_trunc::SVDTrunc=TruncThresh(1e-6),
showprogress=true, cb=CB_BP(bp; showprogress), tol=1e-10,
nodes = collect(vertices(bp.g)), shuffle_nodes::Bool=true, damp=0.0)
# if is_periodic(bp) && !isa(eltype(eltype(bp.w)), RecursiveBPFactor)
# @warn "MPBP with generic factors + PBCs is not guaranteed to give correct results"
# end
for it in 1:maxiter
Threads.@threads for i in nodes
onebpiter!(bp, i, eltype(bp.w[i]); svd_trunc, damp)
Expand Down
2 changes: 2 additions & 0 deletions src/recursive_bp_factor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ function compute_prob_ys(wᵢ::Vector{U}, qi::Int, μin::Vector{M2}, ψout, T, s
@cast _[(m1,m2),(n1,n2),y,xᵢ] := B3[m1,m2,n1,n2,y,xᵢ]
end |> M2
lz = normalize!(B)
any(any(isnan, b) for b in B) && @error "NaN in tensor train"
compress!(B; svd_trunc)
any(any(isnan, b) for b in B) && @error "NaN in tensor train"
B, lz + lz1 + lz2, d1 + d2
end

Expand Down
65 changes: 65 additions & 0 deletions test/periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,69 @@

end

##### Generic Glauber
T = 2
J = [0 1 0 0 0;
1 0 1 1 0;
0 1 0 0 0;
0 1 0 0 0;
0 0 0 0 0] .|> float
J .* rand.(rng)

N = size(J, 1)
h = randn(rng, N)

β = 1.0
ising = Ising(J, h, β)

O = [ (1, 2, 1, [0.1 0.9; 0.3 0.4]),
(2, 4, 2, [0.4 0.6; 0.5 0.9]),
(2, 3, T, rand(2,2)) ]

ψ = pair_observations_nondirected(O, ising.g, T, 2)

gl = Glauber(ising, T; ψ)

for i in 1:N
r = 0.75
gl.ϕ[i][1] .*= [r, 1-r]
end

bp = periodic_mpbp(deepcopy(gl))

X, observed = draw_node_observations!(bp, N; rng)

svd_trunc = TruncThresh(0.0)
svd_trunc = TruncBondThresh(10)
cb = CB_BP(bp; showprogress=false, info="Glauber")
iterate!(bp; maxiter=20, svd_trunc, cb)

b_bp = beliefs(bp)
p_bp = [[bbb[2] for bbb in bb] for bb in b_bp]

p_exact, Z_exact = exact_prob(bp)
b_exact = exact_marginals(bp; p_exact)
p_ex = [[bbb[2] for bbb in bb] for bb in b_exact]

f_bethe = bethe_free_energy(bp)
Z_bp = exp(-f_bethe)

r_bp = autocorrelations(f, bp)
r_exact = exact_autocorrelations(f, bp; p_exact)

c_bp = autocovariances(f, bp)
c_exact = exact_autocovariances(f, bp; r = r_exact)

pb_bp = pair_beliefs(bp)[1]
p_bp = [[bbb[2] for bbb in bb] for bb in b_bp]
pb_bp2 = marginals.(pair_beliefs_as_mpem(bp)[1])

@testset "Glauber small tree - periodic + pair observations" begin
@test Z_exact Z_bp
@test p_ex p_bp
@test r_bp r_exact
@test c_bp c_exact
@test pb_bp pb_bp2
end

end

0 comments on commit 88f059d

Please sign in to comment.