Skip to content

Commit

Permalink
fix bug in DampedFactor
Browse files Browse the repository at this point in the history
  • Loading branch information
abraunst committed Dec 19, 2023
1 parent 73926f8 commit 6b9a0c0
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/Models/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import MatrixProductBP: exact_prob, getT, nstates, mpbp, compress!,
kron2, f_bp, f_bp_dummy_neighbor, onebpiter_dummy_neighbor,
beliefs, beliefs_tu, marginals, pair_belief, pair_beliefs,
marginalize, cavity, onebpiter!, check_ψs, _compose,
RecursiveBPFactor, nstates, prob_y, prob_xy, prob_yy, prob_y_partial,
RecursiveBPFactor, nstates, prob_y, prob_xy, prob_yy, prob_y0, prob_y_partial,
prob_y_dummy, periodic_mpbp
using MatrixProductBP

Expand Down
11 changes: 10 additions & 1 deletion src/Models/glauber/glauber_bp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,13 @@ end
"P(yₖᵗ| xₖᵗ, xᵢᵗ)"
prob_xy(wᵢ::IntegerGlauberFactor, yₖ, xₖ, xᵢ, k) = (yₖ == potts2spin(xₖ)*wᵢ.J[k] + wᵢ.K)
prob_yy(wᵢ::IntegerGlauberFactor, y, y1, y2, xᵢ) = (y + wᵢ.K == y1 + y2)
prob_y0(wᵢ::IntegerGlauberFactor, y, xᵢ) = y == wᵢ.K
prob_y0(wᵢ::IntegerGlauberFactor, y, xᵢ) = y == wᵢ.K

function (wᵢ::IntegerGlauberFactor)(xᵢᵗ⁺¹::Integer, xₙᵢᵗ::AbstractVector{<:Integer},
xᵢᵗ::Integer)
@unpack J, h, β, K = wᵢ
hᵗ = sum(Jk*potts2spin(xk) for (Jk,xk) in zip(J, xₙᵢᵗ); init=0.0)
βhⱼᵢ = β*(hᵗ + h)
E = - potts2spin(xᵢᵗ⁺¹) * βhⱼᵢ
return 1 / (1 + exp(2E))
end
4 changes: 3 additions & 1 deletion src/recursive_bp_factor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,6 @@ end

function prob_y(wᵢ::DampedFactor, xᵢᵗ⁺¹, xᵢᵗ, yᵗ, d)
return (1-wᵢ.p)*(prob_y(wᵢ.w, xᵢᵗ⁺¹, xᵢᵗ, yᵗ, d)) + wᵢ.p*(xᵢᵗ⁺¹ == xᵢᵗ)
end
end

prob_y0(wᵢ::DampedFactor, y, xᵢᵗ) = prob_y0(wᵢ.w, y, xᵢᵗ)
4 changes: 2 additions & 2 deletions test/glauber_small_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ end
end

bp = mpbp(deepcopy(gl))
#X, observed = draw_node_observations!(bp, N; rng)
X, observed = draw_node_observations!(bp, N; rng)

svd_trunc = TruncThresh(0.0)
svd_trunc = TruncBondThresh(15)
Expand Down Expand Up @@ -219,7 +219,7 @@ end
logl_bp = - f_bethe
logp = logprob(bp, X)

@testset "Glauber small tree - observe everything" begin
@testset "Glauber small tree integer - observe everything" begin
@test logl_bp logp
end

Expand Down

0 comments on commit 6b9a0c0

Please sign in to comment.