Skip to content

Commit

Permalink
fix bug in DampedFactor
Browse files Browse the repository at this point in the history
  • Loading branch information
abraunst committed Feb 21, 2024
1 parent 289e76d commit 1c91eb5
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/Models/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import MatrixProductBP: exact_prob, getT, nstates, mpbp, compress!,
f_bp, f_bp_dummy_neighbor, onebpiter_dummy_neighbor,
beliefs, beliefs_tu, marginals, pair_belief, pair_beliefs,
marginalize, cavity, onebpiter!, check_ψs, _compose,
RecursiveBPFactor, nstates, prob_y, prob_xy, prob_yy, prob_y_partial,
RecursiveBPFactor, nstates, prob_y, prob_xy, prob_yy, prob_y0, prob_y_partial,
prob_y_dummy, periodic_mpbp
using MatrixProductBP

Expand Down
11 changes: 10 additions & 1 deletion src/Models/glauber/glauber_bp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,13 @@ end
"P(yₖᵗ| xₖᵗ, xᵢᵗ)"
prob_xy(wᵢ::IntegerGlauberFactor, yₖ, xₖ, xᵢ, k) = (yₖ == potts2spin(xₖ)*wᵢ.J[k] + wᵢ.K)
prob_yy(wᵢ::IntegerGlauberFactor, y, y1, y2, xᵢ) = (y + wᵢ.K == y1 + y2)
prob_y0(wᵢ::IntegerGlauberFactor, y, xᵢ) = y == wᵢ.K
prob_y0(wᵢ::IntegerGlauberFactor, y, xᵢ) = y == wᵢ.K

function (wᵢ::IntegerGlauberFactor)(xᵢᵗ⁺¹::Integer, xₙᵢᵗ::AbstractVector{<:Integer},
xᵢᵗ::Integer)
@unpack J, h, β, K = wᵢ
hᵗ = sum(Jk*potts2spin(xk) for (Jk,xk) in zip(J, xₙᵢᵗ); init=0.0)
βhⱼᵢ = β*(hᵗ + h)
E = - potts2spin(xᵢᵗ⁺¹) * βhⱼᵢ
return 1 / (1 + exp(2E))
end
12 changes: 7 additions & 5 deletions src/recursive_bp_factor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ prob_y0(wᵢ::RecursiveBPFactor, y, xᵢᵗ) = y == 1
function (wᵢ::RecursiveBPFactor)(xᵢᵗ⁺¹::Integer, xₙᵢᵗ::AbstractVector{<:Integer},
xᵢᵗ::Integer)
d = length(xₙᵢᵗ)
Pyy = [prob_y0(wᵢ, y, xᵢᵗ) for y in 1:nstates(wᵢ,0)]
Pyy = [float(prob_y0(wᵢ, y, xᵢᵗ)) for y in 1:nstates(wᵢ,0)]
for k in 1:d
Pyy = [sum(prob_yy(wᵢ, y, y1, y2, xᵢᵗ, 1, k-1) *
prob_xy(wᵢ, y1, xₙᵢᵗ[k], xᵢᵗ, k) *
Expand Down Expand Up @@ -124,10 +124,10 @@ function compute_prob_ys(wᵢ::Vector{U}, qi::Int, μin::Vector{M2}, ψout, T, s
B, lz + lz1 + lz2, d1 + d2
end

Minit = [[float(prob_y0(wᵢ[t], y, xᵢ)) for _ in 1:1,
_ in 1:1,
Minit = [[float(prob_y0(wᵢ[t], y, xᵢ)) for _ in 1:1,
_ in 1:1,
y in 1:nstates(wᵢ[t],0),
xᵢ in 1:qi]
xᵢ in 1:qi]
for t=1:T+1]
init = (M2(Minit), 0.0, 0)
dest, (full, logzᵢ,) = cavity(B, op, init)
Expand Down Expand Up @@ -193,4 +193,6 @@ end

function prob_y(wᵢ::DampedFactor, xᵢᵗ⁺¹, xᵢᵗ, yᵗ, d)
return (1-wᵢ.p)*(prob_y(wᵢ.w, xᵢᵗ⁺¹, xᵢᵗ, yᵗ, d)) + wᵢ.p*(xᵢᵗ⁺¹ == xᵢᵗ)
end
end

prob_y0(wᵢ::DampedFactor, y, xᵢᵗ) = prob_y0(wᵢ.w, y, xᵢᵗ)
4 changes: 2 additions & 2 deletions test/glauber_small_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ end
end

bp = mpbp(deepcopy(gl))
#X, observed = draw_node_observations!(bp, N; rng)
X, observed = draw_node_observations!(bp, N; rng)

svd_trunc = TruncThresh(0.0)
svd_trunc = TruncBondThresh(15)
Expand Down Expand Up @@ -219,7 +219,7 @@ end
logl_bp = - f_bethe
logp = logprob(bp, X)

@testset "Glauber small tree - observe everything" begin
@testset "Glauber small tree integer - observe everything" begin
@test logl_bp logp
end

Expand Down

0 comments on commit 1c91eb5

Please sign in to comment.