From 1c91eb574fd9eef37c41b3c04c333ffb7c388cbe Mon Sep 17 00:00:00 2001 From: Alfredo Braunstein Date: Tue, 19 Dec 2023 16:11:35 +0100 Subject: [PATCH] fix bug in DampedFactor --- src/Models/Models.jl | 2 +- src/Models/glauber/glauber_bp.jl | 11 ++++++++++- src/recursive_bp_factor.jl | 12 +++++++----- test/glauber_small_tree.jl | 4 ++-- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/Models/Models.jl b/src/Models/Models.jl index e19de52b..a65ba24c 100644 --- a/src/Models/Models.jl +++ b/src/Models/Models.jl @@ -4,7 +4,7 @@ import MatrixProductBP: exact_prob, getT, nstates, mpbp, compress!, f_bp, f_bp_dummy_neighbor, onebpiter_dummy_neighbor, beliefs, beliefs_tu, marginals, pair_belief, pair_beliefs, marginalize, cavity, onebpiter!, check_ψs, _compose, - RecursiveBPFactor, nstates, prob_y, prob_xy, prob_yy, prob_y_partial, + RecursiveBPFactor, nstates, prob_y, prob_xy, prob_yy, prob_y0, prob_y_partial, prob_y_dummy, periodic_mpbp using MatrixProductBP diff --git a/src/Models/glauber/glauber_bp.jl b/src/Models/glauber/glauber_bp.jl index f9e656b8..706159f2 100644 --- a/src/Models/glauber/glauber_bp.jl +++ b/src/Models/glauber/glauber_bp.jl @@ -159,4 +159,13 @@ end "P(yₖᵗ| xₖᵗ, xᵢᵗ)" prob_xy(wᵢ::IntegerGlauberFactor, yₖ, xₖ, xᵢ, k) = (yₖ == potts2spin(xₖ)*wᵢ.J[k] + wᵢ.K) prob_yy(wᵢ::IntegerGlauberFactor, y, y1, y2, xᵢ) = (y + wᵢ.K == y1 + y2) -prob_y0(wᵢ::IntegerGlauberFactor, y, xᵢ) = y == wᵢ.K \ No newline at end of file +prob_y0(wᵢ::IntegerGlauberFactor, y, xᵢ) = y == wᵢ.K + +function (wᵢ::IntegerGlauberFactor)(xᵢᵗ⁺¹::Integer, xₙᵢᵗ::AbstractVector{<:Integer}, + xᵢᵗ::Integer) + @unpack J, h, β, K = wᵢ + hᵗ = sum(Jk*potts2spin(xk) for (Jk,xk) in zip(J, xₙᵢᵗ); init=0.0) + βhⱼᵢ = β*(hᵗ + h) + E = - potts2spin(xᵢᵗ⁺¹) * βhⱼᵢ + return 1 / (1 + exp(2E)) +end diff --git a/src/recursive_bp_factor.jl b/src/recursive_bp_factor.jl index 0e6b278e..da143036 100644 --- a/src/recursive_bp_factor.jl +++ b/src/recursive_bp_factor.jl @@ -32,7 +32,7 @@ prob_y0(wᵢ::RecursiveBPFactor, y, xᵢᵗ) = y == 1 function (wᵢ::RecursiveBPFactor)(xᵢᵗ⁺¹::Integer, xₙᵢᵗ::AbstractVector{<:Integer}, xᵢᵗ::Integer) d = length(xₙᵢᵗ) - Pyy = [prob_y0(wᵢ, y, xᵢᵗ) for y in 1:nstates(wᵢ,0)] + Pyy = [float(prob_y0(wᵢ, y, xᵢᵗ)) for y in 1:nstates(wᵢ,0)] for k in 1:d Pyy = [sum(prob_yy(wᵢ, y, y1, y2, xᵢᵗ, 1, k-1) * prob_xy(wᵢ, y1, xₙᵢᵗ[k], xᵢᵗ, k) * @@ -124,10 +124,10 @@ function compute_prob_ys(wᵢ::Vector{U}, qi::Int, μin::Vector{M2}, ψout, T, s B, lz + lz1 + lz2, d1 + d2 end - Minit = [[float(prob_y0(wᵢ[t], y, xᵢ)) for _ in 1:1, - _ in 1:1, + Minit = [[float(prob_y0(wᵢ[t], y, xᵢ)) for _ in 1:1, + _ in 1:1, y in 1:nstates(wᵢ[t],0), - xᵢ in 1:qi] + xᵢ in 1:qi] for t=1:T+1] init = (M2(Minit), 0.0, 0) dest, (full, logzᵢ,) = cavity(B, op, init) @@ -193,4 +193,6 @@ end function prob_y(wᵢ::DampedFactor, xᵢᵗ⁺¹, xᵢᵗ, yᵗ, d) return (1-wᵢ.p)*(prob_y(wᵢ.w, xᵢᵗ⁺¹, xᵢᵗ, yᵗ, d)) + wᵢ.p*(xᵢᵗ⁺¹ == xᵢᵗ) -end \ No newline at end of file +end + +prob_y0(wᵢ::DampedFactor, y, xᵢᵗ) = prob_y0(wᵢ.w, y, xᵢᵗ) diff --git a/test/glauber_small_tree.jl b/test/glauber_small_tree.jl index c8b1ea57..3c561960 100644 --- a/test/glauber_small_tree.jl +++ b/test/glauber_small_tree.jl @@ -173,7 +173,7 @@ end end bp = mpbp(deepcopy(gl)) - #X, observed = draw_node_observations!(bp, N; rng) + X, observed = draw_node_observations!(bp, N; rng) svd_trunc = TruncThresh(0.0) svd_trunc = TruncBondThresh(15) @@ -219,7 +219,7 @@ end logl_bp = - f_bethe logp = logprob(bp, X) - @testset "Glauber small tree - observe everything" begin + @testset "Glauber small tree integer - observe everything" begin @test logl_bp ≈ logp end