From 6b9a0c0e2024ea26f0e17c323e0d6011f0fe24f9 Mon Sep 17 00:00:00 2001 From: Alfredo Braunstein Date: Tue, 19 Dec 2023 16:11:35 +0100 Subject: [PATCH] fix bug in DampedFactor --- src/Models/Models.jl | 2 +- src/Models/glauber/glauber_bp.jl | 11 ++++++++++- src/recursive_bp_factor.jl | 4 +++- test/glauber_small_tree.jl | 4 ++-- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/Models/Models.jl b/src/Models/Models.jl index c6bc53e1..c31759d6 100644 --- a/src/Models/Models.jl +++ b/src/Models/Models.jl @@ -4,7 +4,7 @@ import MatrixProductBP: exact_prob, getT, nstates, mpbp, compress!, kron2, f_bp, f_bp_dummy_neighbor, onebpiter_dummy_neighbor, beliefs, beliefs_tu, marginals, pair_belief, pair_beliefs, marginalize, cavity, onebpiter!, check_ψs, _compose, - RecursiveBPFactor, nstates, prob_y, prob_xy, prob_yy, prob_y_partial, + RecursiveBPFactor, nstates, prob_y, prob_xy, prob_yy, prob_y0, prob_y_partial, prob_y_dummy, periodic_mpbp using MatrixProductBP diff --git a/src/Models/glauber/glauber_bp.jl b/src/Models/glauber/glauber_bp.jl index f9e656b8..6a5281ec 100644 --- a/src/Models/glauber/glauber_bp.jl +++ b/src/Models/glauber/glauber_bp.jl @@ -159,4 +159,13 @@ end "P(yₖᵗ| xₖᵗ, xᵢᵗ)" prob_xy(wᵢ::IntegerGlauberFactor, yₖ, xₖ, xᵢ, k) = (yₖ == potts2spin(xₖ)*wᵢ.J[k] + wᵢ.K) prob_yy(wᵢ::IntegerGlauberFactor, y, y1, y2, xᵢ) = (y + wᵢ.K == y1 + y2) -prob_y0(wᵢ::IntegerGlauberFactor, y, xᵢ) = y == wᵢ.K \ No newline at end of file +prob_y0(wᵢ::IntegerGlauberFactor, y, xᵢ) = y == wᵢ.K + +function (wᵢ::IntegerGlauberFactor)(xᵢᵗ⁺¹::Integer, xₙᵢᵗ::AbstractVector{<:Integer}, + xᵢᵗ::Integer) + @unpack J, h, β, K = wᵢ + hᵗ = sum(Jk*potts2spin(xk) for (Jk,xk) in zip(J, xₙᵢᵗ); init=0.0) + βhⱼᵢ = β*(hᵗ + h) + E = - potts2spin(xᵢᵗ⁺¹) * βhⱼᵢ + return 1 / (1 + exp(2E)) +end \ No newline at end of file diff --git a/src/recursive_bp_factor.jl b/src/recursive_bp_factor.jl index 0e6b278e..3324aa2f 100644 --- a/src/recursive_bp_factor.jl +++ b/src/recursive_bp_factor.jl @@ -193,4 +193,6 @@ end function prob_y(wᵢ::DampedFactor, xᵢᵗ⁺¹, xᵢᵗ, yᵗ, d) return (1-wᵢ.p)*(prob_y(wᵢ.w, xᵢᵗ⁺¹, xᵢᵗ, yᵗ, d)) + wᵢ.p*(xᵢᵗ⁺¹ == xᵢᵗ) -end \ No newline at end of file +end + +prob_y0(wᵢ::DampedFactor, y, xᵢᵗ) = prob_y0(wᵢ.w, y, xᵢᵗ) \ No newline at end of file diff --git a/test/glauber_small_tree.jl b/test/glauber_small_tree.jl index c8b1ea57..3c561960 100644 --- a/test/glauber_small_tree.jl +++ b/test/glauber_small_tree.jl @@ -173,7 +173,7 @@ end end bp = mpbp(deepcopy(gl)) - #X, observed = draw_node_observations!(bp, N; rng) + X, observed = draw_node_observations!(bp, N; rng) svd_trunc = TruncThresh(0.0) svd_trunc = TruncBondThresh(15) @@ -219,7 +219,7 @@ end logl_bp = - f_bethe logp = logprob(bp, X) - @testset "Glauber small tree - observe everything" begin + @testset "Glauber small tree integer - observe everything" begin @test logl_bp ≈ logp end