Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
stecrotti committed Feb 21, 2024
1 parent ebb6782 commit ebcaf95
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 50 deletions.
17 changes: 9 additions & 8 deletions src/exact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ struct ExactMsg{Periodic,U<:AbstractArray,S,TI<:Integer}
end
is_periodic(::Type{<:ExactMsg{Periodic}}) where {Periodic} = Periodic

function uniform_exact_msg(states, T)
function uniform_exact_msg(states, T; periodic=false)
n = prod(states) ^ (T+1)
x = log(1 / n)
logm = fill(x, reduce(vcat, (fill(s, T+1) for s in states))...)
return ExactMsg(logm, states, T)
return ExactMsg(logm, states, T; periodic)
end
function zero_exact_msg(states, T)
function zero_exact_msg(states, T; periodic=false)
logm = fill(-Inf, reduce(vcat, (fill(s, T+1) for s in states))...)
return ExactMsg(logm, states, T)
return ExactMsg(logm, states, T; periodic)
end

nstates(m::ExactMsg) = prod(m.states)
Expand All @@ -164,10 +164,11 @@ const MPBPExact = MPBP{<:AbstractIndexedDiGraph, <:Real, <:AbstractVector{<:BPFa

function mpbp_exact(g::IndexedBiDiGraph{Int}, w::Vector{<:Vector{<:BPFactor}},
q::AbstractVector{Int}, T::Int;
periodic = false,
ϕ = [[ones(q[i]) for t in 0:T] for i in vertices(g)],
ψ = [[ones(q[i],q[j]) for t in 0:T] for (i,j) in edges(g)],
μ = [uniform_exact_msg((q[i],q[j]), T) for (i,j) in edges(g)],
b = [uniform_exact_msg((q[i],), T) for i in vertices(g)],
μ = [uniform_exact_msg((q[i],q[j]), T; periodic) for (i,j) in edges(g)],
b = [uniform_exact_msg((q[i],), T; periodic) for i in vertices(g)],
f = zeros(nv(g)))
return MPBP(g, w, ϕ, ψ, μ, b, f)
end
Expand All @@ -183,7 +184,7 @@ function f_bp(m_in::Vector{M2}, wᵢ::Vector{U}, ϕᵢ::Vector{Vector{F}},
dt = showprogress ? 1.0 : Inf
prog = Progress(prod(nstates, m_in; init=1), dt=dt, desc="Computing outgoing message")
mⱼᵢ = m_in[j_index]
m_out = zero_exact_msg(reverse(mⱼᵢ.states), mⱼᵢ.T)
m_out = zero_exact_msg(reverse(mⱼᵢ.states), mⱼᵢ.T; periodic)
for xᵢ in eachstate(m_out, 1)
for xₐ in Iterators.product((eachstate(m, 1) for (k,m) in enumerate(m_in))...)
# compute weight
Expand Down Expand Up @@ -217,7 +218,7 @@ function f_bp_dummy_neighbor(m_in::Vector{M2}, wᵢ::Vector{U}, ϕᵢ::Vector{Ve

dt = showprogress ? 1.0 : Inf
prog = Progress(prod(nstates, m_in; init=1), dt=dt, desc="Computing outgoing message")
m_out = zero_exact_msg((length(ϕᵢ[1]),), T)
m_out = zero_exact_msg((length(ϕᵢ[1]),), T; periodic)
for xᵢ in eachstate(m_out, 1)
for xₐ in Iterators.product((eachstate(m, 1) for (k,m) in enumerate(m_in))...)
# compute weight
Expand Down
15 changes: 11 additions & 4 deletions test/exact_msg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using MatrixProductBP, Test
import MatrixProductBP: f_bp, eachstate, zero_exact_msg
using Random, MatrixProductBP.Models, Graphs, IndexedGraphs

# @testset "Exact messages" begin
@testset "Exact messages" begin
rng = MersenneTwister(111)

T = 2
Expand Down Expand Up @@ -35,6 +35,13 @@ using Random, MatrixProductBP.Models, Graphs, IndexedGraphs
@test beliefs(bp) beliefs(bp_ex)
@test bethe_free_energy(bp) bethe_free_energy(bp)

f(x,i) = 2x-3
autocorrelations(f, bp_ex)
# end
bp = periodic_mpbp(deepcopy(gl))
bp_ex = mpbp_exact(bp.g, bp.w, fill(2, nv(bp.g)), T; periodic=true, ϕ = bp.ϕ, ψ = bp.ψ)
@test bp_ex isa MatrixProductBP.MPBPExact

iterate!(bp_ex; maxiter=20)
svd_trunc = TruncBond(10)
iterate!(bp; maxiter=20, svd_trunc)
@test beliefs(bp) beliefs(bp_ex)
@test bethe_free_energy(bp) bethe_free_energy(bp)
end
75 changes: 37 additions & 38 deletions test/periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
X, observed = draw_node_observations!(bp, N; rng)

svd_trunc = TruncThresh(0.0)
svd_trunc = TruncBondThresh(10)
svd_trunc = TruncBond(10)
cb = CB_BP(bp; showprogress=false, info="Glauber")
iterate!(bp; maxiter=20, svd_trunc, cb)

Expand Down Expand Up @@ -67,52 +67,52 @@
@test pb_bp pb_bp2
end

########## INFINITE GRAPH
T = 2
k = 3
m⁰ = 0.5
# ########## INFINITE GRAPH
# T = 2
# k = 3
# m⁰ = 0.5

β = 1.0
J = 1.0
h = 0.0
# β = 1.0
# J = 1.0
# h = 0.0

wᵢ = fill(HomogeneousGlauberFactor(J, h, β), T+1)
ϕᵢ = [ t == 0 ? [(1+m⁰)/2, (1-m⁰)/2] : ones(2) for t in 0:T]
ϕᵢ[2] = [0.4, 0.6]
ϕᵢ[end] = [0.95, 0.05]
bp = periodic_mpbp_infinite_graph(k, wᵢ, 2, ϕᵢ)
cb = CB_BP(bp)
# wᵢ = fill(HomogeneousGlauberFactor(J, h, β), T+1)
# ϕᵢ = [ t == 0 ? [(1+m⁰)/2, (1-m⁰)/2] : ones(2) for t in 0:T]
# ϕᵢ[2] = [0.4, 0.6]
# ϕᵢ[end] = [0.95, 0.05]
# bp = periodic_mpbp_infinite_graph(k, wᵢ, 2, ϕᵢ)
# cb = CB_BP(bp)

iters, cb = iterate!(bp; maxiter=150, svd_trunc=TruncBond(10), cb, tol=1e-12, damp=0.2)
# iters, cb = iterate!(bp; maxiter=150, svd_trunc=TruncBond(10), cb, tol=1e-12, damp=0.2)

b_bp = beliefs(bp)
pb_bp = pair_beliefs(bp)[1][1]
p_bp = [[bbb[2] for bbb in bb] for bb in b_bp]
# b_bp = beliefs(bp)
# pb_bp = pair_beliefs(bp)[1][1]
# p_bp = [[bbb[2] for bbb in bb] for bb in b_bp]

f_bethe = bethe_free_energy(bp)
Z_bp = exp(-f_bethe)
# f_bethe = bethe_free_energy(bp)
# Z_bp = exp(-f_bethe)

N = k+1
g = IndexedBiDiGraph(complete_graph(N))
bp_exact = periodic_mpbp(g, fill(wᵢ, N), fill(2,N), T)
for i in 1:N; bp_exact.ϕ[i] = ϕᵢ; end
# N = k+1
# g = IndexedBiDiGraph(complete_graph(N))
# bp_exact = periodic_mpbp(g, fill(wᵢ, N), fill(2,N), T)
# for i in 1:N; bp_exact.ϕ[i] = ϕᵢ; end

cb = CB_BP(bp_exact)
iterate!(bp_exact; maxiter=150, svd_trunc=TruncBond(10), cb, tol=1e-12, damp=0.2)
# cb = CB_BP(bp_exact)
# iterate!(bp_exact; maxiter=150, svd_trunc=TruncBond(10), cb, tol=1e-12, damp=0.2)

b_exact = beliefs(bp_exact)
p_exact = [[bbb[2] for bbb in bb] for bb in b_exact][1:1]
pb_exact = pair_beliefs(bp_exact)[1][1]
# b_exact = beliefs(bp_exact)
# p_exact = [[bbb[2] for bbb in bb] for bb in b_exact][1:1]
# pb_exact = pair_beliefs(bp_exact)[1][1]

f_bethe_exact = bethe_free_energy(bp_exact)
Z_exact = exp(-1/N*f_bethe_exact)
# f_bethe_exact = bethe_free_energy(bp_exact)
# Z_exact = exp(-1/N*f_bethe_exact)

@testset "Glauber infinite graph - periodic" begin
# @test Z_exact ≈ Z_bp ### NOT WORKING!
@test p_exact p_bp
@test pb_exact pb_bp
# @testset "Glauber infinite graph - periodic" begin
# # @test Z_exact ≈ Z_bp ### NOT WORKING!
# @test p_exact ≈ p_bp
# @test pb_exact ≈ pb_bp

end
# end

##### Generic Glauber
T = 2
Expand Down Expand Up @@ -146,8 +146,7 @@

X, observed = draw_node_observations!(bp, N; rng)

svd_trunc = TruncThresh(0.0)
svd_trunc = TruncBondThresh(10)
svd_trunc = TruncBond(10)
cb = CB_BP(bp; showprogress=false, info="Glauber")
iterate!(bp; maxiter=20, svd_trunc, cb)

Expand Down

0 comments on commit ebcaf95

Please sign in to comment.