From ebcaf95868f3c1240b0f7ff3c5de63df7d6508da Mon Sep 17 00:00:00 2001 From: stecrotti Date: Mon, 20 Nov 2023 14:37:46 +0100 Subject: [PATCH] add test --- src/exact.jl | 17 ++++++----- test/exact_msg.jl | 15 +++++++--- test/periodic.jl | 75 +++++++++++++++++++++++------------------------ 3 files changed, 57 insertions(+), 50 deletions(-) diff --git a/src/exact.jl b/src/exact.jl index 5fead6b9..506c8e25 100644 --- a/src/exact.jl +++ b/src/exact.jl @@ -133,15 +133,15 @@ struct ExactMsg{Periodic,U<:AbstractArray,S,TI<:Integer} end is_periodic(::Type{<:ExactMsg{Periodic}}) where {Periodic} = Periodic -function uniform_exact_msg(states, T) +function uniform_exact_msg(states, T; periodic=false) n = prod(states) ^ (T+1) x = log(1 / n) logm = fill(x, reduce(vcat, (fill(s, T+1) for s in states))...) - return ExactMsg(logm, states, T) + return ExactMsg(logm, states, T; periodic) end -function zero_exact_msg(states, T) +function zero_exact_msg(states, T; periodic=false) logm = fill(-Inf, reduce(vcat, (fill(s, T+1) for s in states))...) - return ExactMsg(logm, states, T) + return ExactMsg(logm, states, T; periodic) end nstates(m::ExactMsg) = prod(m.states) @@ -164,10 +164,11 @@ const MPBPExact = MPBP{<:AbstractIndexedDiGraph, <:Real, <:AbstractVector{<:BPFa function mpbp_exact(g::IndexedBiDiGraph{Int}, w::Vector{<:Vector{<:BPFactor}}, q::AbstractVector{Int}, T::Int; + periodic = false, ϕ = [[ones(q[i]) for t in 0:T] for i in vertices(g)], ψ = [[ones(q[i],q[j]) for t in 0:T] for (i,j) in edges(g)], - μ = [uniform_exact_msg((q[i],q[j]), T) for (i,j) in edges(g)], - b = [uniform_exact_msg((q[i],), T) for i in vertices(g)], + μ = [uniform_exact_msg((q[i],q[j]), T; periodic) for (i,j) in edges(g)], + b = [uniform_exact_msg((q[i],), T; periodic) for i in vertices(g)], f = zeros(nv(g))) return MPBP(g, w, ϕ, ψ, μ, b, f) end @@ -183,7 +184,7 @@ function f_bp(m_in::Vector{M2}, wᵢ::Vector{U}, ϕᵢ::Vector{Vector{F}}, dt = showprogress ? 1.0 : Inf prog = Progress(prod(nstates, m_in; init=1), dt=dt, desc="Computing outgoing message") mⱼᵢ = m_in[j_index] - m_out = zero_exact_msg(reverse(mⱼᵢ.states), mⱼᵢ.T) + m_out = zero_exact_msg(reverse(mⱼᵢ.states), mⱼᵢ.T; periodic) for xᵢ in eachstate(m_out, 1) for xₐ in Iterators.product((eachstate(m, 1) for (k,m) in enumerate(m_in))...) # compute weight @@ -217,7 +218,7 @@ function f_bp_dummy_neighbor(m_in::Vector{M2}, wᵢ::Vector{U}, ϕᵢ::Vector{Ve dt = showprogress ? 1.0 : Inf prog = Progress(prod(nstates, m_in; init=1), dt=dt, desc="Computing outgoing message") - m_out = zero_exact_msg((length(ϕᵢ[1]),), T) + m_out = zero_exact_msg((length(ϕᵢ[1]),), T; periodic) for xᵢ in eachstate(m_out, 1) for xₐ in Iterators.product((eachstate(m, 1) for (k,m) in enumerate(m_in))...) # compute weight diff --git a/test/exact_msg.jl b/test/exact_msg.jl index 70578071..51957d76 100644 --- a/test/exact_msg.jl +++ b/test/exact_msg.jl @@ -2,7 +2,7 @@ using MatrixProductBP, Test import MatrixProductBP: f_bp, eachstate, zero_exact_msg using Random, MatrixProductBP.Models, Graphs, IndexedGraphs -# @testset "Exact messages" begin +@testset "Exact messages" begin rng = MersenneTwister(111) T = 2 @@ -35,6 +35,13 @@ using Random, MatrixProductBP.Models, Graphs, IndexedGraphs @test beliefs(bp) ≈ beliefs(bp_ex) @test bethe_free_energy(bp) ≈ bethe_free_energy(bp) - f(x,i) = 2x-3 - autocorrelations(f, bp_ex) -# end \ No newline at end of file + bp = periodic_mpbp(deepcopy(gl)) + bp_ex = mpbp_exact(bp.g, bp.w, fill(2, nv(bp.g)), T; periodic=true, ϕ = bp.ϕ, ψ = bp.ψ) + @test bp_ex isa MatrixProductBP.MPBPExact + + iterate!(bp_ex; maxiter=20) + svd_trunc = TruncBond(10) + iterate!(bp; maxiter=20, svd_trunc) + @test beliefs(bp) ≈ beliefs(bp_ex) + @test bethe_free_energy(bp) ≈ bethe_free_energy(bp) +end \ No newline at end of file diff --git a/test/periodic.jl b/test/periodic.jl index b3a4962d..d7cc0bc1 100644 --- a/test/periodic.jl +++ b/test/periodic.jl @@ -33,7 +33,7 @@ X, observed = draw_node_observations!(bp, N; rng) svd_trunc = TruncThresh(0.0) - svd_trunc = TruncBondThresh(10) + svd_trunc = TruncBond(10) cb = CB_BP(bp; showprogress=false, info="Glauber") iterate!(bp; maxiter=20, svd_trunc, cb) @@ -67,52 +67,52 @@ @test pb_bp ≈ pb_bp2 end - ########## INFINITE GRAPH - T = 2 - k = 3 - m⁰ = 0.5 + # ########## INFINITE GRAPH + # T = 2 + # k = 3 + # m⁰ = 0.5 - β = 1.0 - J = 1.0 - h = 0.0 + # β = 1.0 + # J = 1.0 + # h = 0.0 - wᵢ = fill(HomogeneousGlauberFactor(J, h, β), T+1) - ϕᵢ = [ t == 0 ? [(1+m⁰)/2, (1-m⁰)/2] : ones(2) for t in 0:T] - ϕᵢ[2] = [0.4, 0.6] - ϕᵢ[end] = [0.95, 0.05] - bp = periodic_mpbp_infinite_graph(k, wᵢ, 2, ϕᵢ) - cb = CB_BP(bp) + # wᵢ = fill(HomogeneousGlauberFactor(J, h, β), T+1) + # ϕᵢ = [ t == 0 ? [(1+m⁰)/2, (1-m⁰)/2] : ones(2) for t in 0:T] + # ϕᵢ[2] = [0.4, 0.6] + # ϕᵢ[end] = [0.95, 0.05] + # bp = periodic_mpbp_infinite_graph(k, wᵢ, 2, ϕᵢ) + # cb = CB_BP(bp) - iters, cb = iterate!(bp; maxiter=150, svd_trunc=TruncBond(10), cb, tol=1e-12, damp=0.2) + # iters, cb = iterate!(bp; maxiter=150, svd_trunc=TruncBond(10), cb, tol=1e-12, damp=0.2) - b_bp = beliefs(bp) - pb_bp = pair_beliefs(bp)[1][1] - p_bp = [[bbb[2] for bbb in bb] for bb in b_bp] + # b_bp = beliefs(bp) + # pb_bp = pair_beliefs(bp)[1][1] + # p_bp = [[bbb[2] for bbb in bb] for bb in b_bp] - f_bethe = bethe_free_energy(bp) - Z_bp = exp(-f_bethe) + # f_bethe = bethe_free_energy(bp) + # Z_bp = exp(-f_bethe) - N = k+1 - g = IndexedBiDiGraph(complete_graph(N)) - bp_exact = periodic_mpbp(g, fill(wᵢ, N), fill(2,N), T) - for i in 1:N; bp_exact.ϕ[i] = ϕᵢ; end + # N = k+1 + # g = IndexedBiDiGraph(complete_graph(N)) + # bp_exact = periodic_mpbp(g, fill(wᵢ, N), fill(2,N), T) + # for i in 1:N; bp_exact.ϕ[i] = ϕᵢ; end - cb = CB_BP(bp_exact) - iterate!(bp_exact; maxiter=150, svd_trunc=TruncBond(10), cb, tol=1e-12, damp=0.2) + # cb = CB_BP(bp_exact) + # iterate!(bp_exact; maxiter=150, svd_trunc=TruncBond(10), cb, tol=1e-12, damp=0.2) - b_exact = beliefs(bp_exact) - p_exact = [[bbb[2] for bbb in bb] for bb in b_exact][1:1] - pb_exact = pair_beliefs(bp_exact)[1][1] + # b_exact = beliefs(bp_exact) + # p_exact = [[bbb[2] for bbb in bb] for bb in b_exact][1:1] + # pb_exact = pair_beliefs(bp_exact)[1][1] - f_bethe_exact = bethe_free_energy(bp_exact) - Z_exact = exp(-1/N*f_bethe_exact) + # f_bethe_exact = bethe_free_energy(bp_exact) + # Z_exact = exp(-1/N*f_bethe_exact) - @testset "Glauber infinite graph - periodic" begin - # @test Z_exact ≈ Z_bp ### NOT WORKING! - @test p_exact ≈ p_bp - @test pb_exact ≈ pb_bp + # @testset "Glauber infinite graph - periodic" begin + # # @test Z_exact ≈ Z_bp ### NOT WORKING! + # @test p_exact ≈ p_bp + # @test pb_exact ≈ pb_bp - end + # end ##### Generic Glauber T = 2 @@ -146,8 +146,7 @@ X, observed = draw_node_observations!(bp, N; rng) - svd_trunc = TruncThresh(0.0) - svd_trunc = TruncBondThresh(10) + svd_trunc = TruncBond(10) cb = CB_BP(bp; showprogress=false, info="Glauber") iterate!(bp; maxiter=20, svd_trunc, cb)