Skip to content

Optimization #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,6 @@
"end"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2023-08-25T13:29:06.554Z"
}
},
"outputs": [],
"source": [
"push!(matrix_sizes, 25)\n",
"reset!(bp)\n",
"it, _ = iterate!(bp; maxiter=50, svd_trunc=TruncBond(25), cb, tol)\n",
"push!(iters, it)\n",
"push!(m, only(means(spin, bp)));"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
49 changes: 49 additions & 0 deletions notebooks/glauber_periodic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# import Pkg; Pkg.develop(path="../../.julia/dev/BeliefPropagation/")
using MatrixProductBP, MatrixProductBP.Models
using Statistics
using Graphs, IndexedGraphs, Random

### LOOPY RANDOM GRAPH
T = 15
J = 0.8
β = 1.0
h = 0.2
rng = MersenneTwister(0)
N = 4
k = 3
g = random_regular_graph(N, k) |> IndexedGraph
ising = Ising(J * adjacency_matrix(g), h*ones(nv(g)), β)
gl = Glauber(ising, T)
M = 15
bp = periodic_mpbp(gl; d=M)
cb = CB_BP(bp)

svd_trunc=TruncBond(M)
iters, cb = iterate!(bp; maxiter=2, svd_trunc, cb, tol=1e-5, damp=0.0)


# import BeliefPropagation
# bp_static = BeliefPropagation.BP(BeliefPropagation.Models.Ising(J * adjacency_matrix(g), h*ones(nv(g)), β))
# BeliefPropagation.iterate!(bp_static; maxiter=100)
# m_static = reduce.(-, BeliefPropagation.beliefs(bp_static)) |> mean

# include("../../telegram/notifications.jl")
# @telegram "glauber periodic"

# using Plots

# unicodeplots()
# pl_conv = plot(cb.Δs, ylabel="convergence error", xlabel="iters", yaxis=:log10, legend=:outertopright,
# size=(300,200))
# display(pl_conv)

# spin(x, i) = 3-2x
# spin(x) = spin(x, 0)
# m = mean(means(spin, bp))

# pl = scatter(m, label="MPBP")
# plot!(pl, 1:T+1, fill(m_static, T+1), label="equilibrium", ylims=m_static .+ 1e-1 .* (-1, 1))
# display(pl)

import TensorTrains
TensorTrains.bond_dims.(bp.μ) |> display
616 changes: 616 additions & 0 deletions notebooks/glauber_periodic_tree.ipynb

Large diffs are not rendered by default.

13 changes: 9 additions & 4 deletions src/MatrixProductBP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,30 @@ import .Threads: SpinLock, lock, unlock
import Lazy: @forward
import CavityTools: cavity
import LogarithmicNumbers: ULogarithmic
import LinearAlgebra: I, tr

import TensorTrains:
getindex, iterate, firstindex, lastindex, setindex!, length, eachindex, +, -, isapprox,
SVDTrunc, TruncBond, TruncThresh, TruncBondMax, TruncBondThresh, summary_compact,
TensorTrain, normalize_eachmatrix!, check_bond_dims, evaluate,
bond_dims, uniform_tt, rand_tt, orthogonalize_right!, orthogonalize_left!, compress!,
AbstractTensorTrain, PeriodicTensorTrain, TensorTrain, normalize_eachmatrix!,
check_bond_dims, evaluate,
bond_dims, uniform_tt, rand_tt, uniform_periodic_tt, rand_periodic_tt,
orthogonalize_right!, orthogonalize_left!, compress!,
marginals, twovar_marginals, normalization, normalize!,
svd, _compose, accumulate_L, accumulate_R


export
SVDTrunc, TruncBond, TruncThresh, TruncBondMax, TruncBondThresh,
PeriodicMPEM2, PeriodicMPEM3, PeriodicMPEM1,
MPEM1, MPEM2, MPEM3, mpem2, normalization, normalize!, orthogonalize_right!,
orthogonalize_left!, compress!, twovar_marginals, evaluate,
BPFactor, nstates, MPBP, mpbp, reset_messages!, reset_beliefs!, reset_observations!,
reset!, is_free_dynamics, onebpiter!, CB_BP, iterate!,
pair_beliefs, pair_beliefs_as_mpem, pair_beliefs_tu, beliefs_tu, autocorrelations,
pair_beliefs, pair_beliefs_as_mpem, beliefs_tu, autocorrelations,
autocovariances, means, beliefs, bethe_free_energy,
mpbp_infinite_graph, InfiniteRegularGraph,
periodic_mpbp, is_periodic,
mpbp_infinite_graph, InfiniteRegularGraph, periodic_mpbp_infinite_graph,
logprob, expectation, pair_observations_directed,
pair_observations_nondirected, pair_obs_undirected_to_directed,
exact_prob, exact_marginals, site_marginals, exact_autocorrelations,
Expand Down
2 changes: 1 addition & 1 deletion src/Models/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import MatrixProductBP: exact_prob, getT, nstates, mpbp, compress!,
beliefs, beliefs_tu, marginals, pair_belief, pair_beliefs,
marginalize, cavity, onebpiter!, check_ψs, _compose,
RecursiveBPFactor, nstates, prob_y, prob_xy, prob_yy, prob_y_partial,
prob_y_dummy
prob_y_dummy, periodic_mpbp
using MatrixProductBP

import IndexedGraphs: IndexedGraph, IndexedBiDiGraph, AbstractIndexedDiGraph, ne, nv,
Expand Down
2 changes: 1 addition & 1 deletion src/Models/epidemics/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function auc(guess_zp::Vector{Int}, true_zp::Vector{Int})
Z = maximum(x) * maximum(y)
Z == 0 && return 1.0
a = 0
for i in 2:length(y)
for i in Iterators.drop(eachindex(y), 1)
if y[i] == y[i-1]
a += y[i]
end
Expand Down
7 changes: 7 additions & 0 deletions src/Models/epidemics/sis_bp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ function mpbp(sis::SIS{T,N,F}; kw...) where {T,N,F}
return mpbp(g, w, fill(2, nv(g)), T, ϕ=sis_.ϕ, ψ=sis_.ψ; kw...)
end

function periodic_mpbp(sis::SIS{T,N,F}; kw...) where {T,N,F}
sis_ = deepcopy(sis)
g = IndexedBiDiGraph(sis_.g.A)
w = sis_factors(sis_)
return periodic_mpbp(g, w, fill(2, nv(g)), T, ϕ=sis_.ϕ, ψ=sis_.ψ; kw...)
end

# neighbor j is susceptible -> does nothing
function prob_y(wᵢ::SISFactor, xᵢᵗ⁺¹, xᵢᵗ, yᵗ, d)
@unpack λ, ρ = wᵢ
Expand Down
8 changes: 8 additions & 0 deletions src/Models/glauber/glauber_bp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ function mpbp(gl::Glauber{T,N,F}; kw...) where {T,N,F<:AbstractFloat}
return mpbp(g, w, fill(2, nv(g)), T; ϕ, ψ, kw...)
end

function periodic_mpbp(gl::Glauber{T,N,F}; kw...) where {T,N,F<:AbstractFloat}
g = IndexedBiDiGraph(gl.ising.g.A)
w = glauber_factors(gl.ising, T)
ϕ = gl.ϕ
ψ = pair_obs_undirected_to_directed(gl.ψ, gl.ising.g)
return periodic_mpbp(g, w, fill(2, nv(g)), T; ϕ, ψ, kw...)
end


# construct an array of GlauberFactors corresponding to gl
# seems to be type stable
Expand Down
64 changes: 43 additions & 21 deletions src/bp_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ Base.broadcastable(b::BPFactor) = Ref(b)
# compute outgoing message as a function of the incoming ones
# A is a vector with all incoming messages. At index j_index there is m(j → i)
# ψᵢⱼ are the ones living on the outedges of node i
function f_bp(A::Vector{MPEM2{F}}, wᵢ::Vector{U},
ϕᵢ::Vector{Vector{F}}, ψₙᵢ::Vector{Vector{Matrix{F}}}, j_index::Integer;
showprogress=false, svd_trunc::SVDTrunc=TruncThresh(0.0)) where {F,U<:BPFactor}
function f_bp(A::Vector{M2}, wᵢ::Vector{U}, ϕᵢ::Vector{Vector{F}},
ψₙᵢ::Vector{Vector{Matrix{F}}}, j_index::Integer; showprogress=false,
svd_trunc::SVDTrunc=TruncThresh(0.0), periodic=false) where {F,U<:BPFactor,M2<:AbstractMPEM2}
T = length(A[1]) - 1
@assert all(length(a) == T + 1 for a in A)
@assert length(wᵢ) == T + 1
Expand Down Expand Up @@ -55,19 +55,23 @@ function f_bp(A::Vector{MPEM2{F}}, wᵢ::Vector{U},
any(isnan, Bᵗ) && @error "NaN in tensor at time $t"
next!(prog, showvalues=[(:t, "$t/$T")])
end

return MPEM3(B), 0.0
# apologies to the gods of type stability
if periodic
return PeriodicMPEM3(B), 0.0
else
return MPEM3(B), 0.0
end
end

# compute outgoing message to dummy neighbor to get the belief
function f_bp_dummy_neighbor(A::Vector{MPEM2{F}},
function f_bp_dummy_neighbor(A::Vector{<:AbstractMPEM2},
wᵢ::Vector{U}, ϕᵢ::Vector{Vector{F}}, ψₙᵢ::Vector{Vector{Matrix{F}}};
showprogress=false, svd_trunc::SVDTrunc=TruncThresh(0.0)) where {F,U<:BPFactor}
T = length(A[1]) - 1
@assert all(length(a) == T + 1 for a in A)
showprogress=false, svd_trunc::SVDTrunc=TruncThresh(0.0), periodic=false) where {F,U<:BPFactor}

q = length(ϕᵢ[1])
T = length(ϕᵢ) - 1
@assert all(length(a) == T + 1 for a in A)
@assert length(wᵢ) == T + 1
@assert length(ϕᵢ) == T + 1
z = length(A) # z = |∂i|
xₙᵢ = Iterators.product((1:size(ψₙᵢ[k][1],2) for k=1:z)...) .|> collect

Expand All @@ -79,12 +83,27 @@ function f_bp_dummy_neighbor(A::Vector{MPEM2{F}},
nrows = size(Aᵗ, 1)
ncols = size(Aᵗ, 2)
Bᵗ = zeros(nrows, ncols, q, 1, q)
# for xᵢᵗ in 1:q
# for xᵢᵗ⁺¹ in 1:q
# tmp = Matrix(I*(t == T + 1 ? 1.0 : ϕᵢ[t][xᵢᵗ]), 1, 1)
# for xₙᵢᵗ in xₙᵢ
# tmp = (t == T + 1 ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, xₙᵢᵗ, xᵢᵗ)) .*
# Aᵗ[:, :, xᵢᵗ, xₙᵢᵗ...] .* tmp .*
# prod(ψₙᵢ[k][t][xᵢᵗ, xₖᵗ] for (k, xₖᵗ) in enumerate(xₙᵢᵗ))
# end
# Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+= tmp
# end
# end
for xᵢᵗ in 1:q
for xᵢᵗ⁺¹ in 1:q
for xₙᵢᵗ in xₙᵢ
Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+= (t == T + 1 ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, xₙᵢᵗ, xᵢᵗ)) .*
Aᵗ[:, :, xᵢᵗ, xₙᵢᵗ...] .* ϕᵢ[t][xᵢᵗ] .*
prod(ψₙᵢ[k][t][xᵢᵗ, xₖᵗ] for (k, xₖᵗ) in enumerate(xₙᵢᵗ))
if isempty(A)
Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+= (t == T + 1 ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, Int[], xᵢᵗ)) * ϕᵢ[t][xᵢᵗ]
else
for xₙᵢᵗ in xₙᵢ
Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+= (t == T + 1 ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, xₙᵢᵗ, xᵢᵗ)) .*
Aᵗ[:, :, xᵢᵗ, xₙᵢᵗ...] .* ϕᵢ[t][xᵢᵗ] .*
prod(ψₙᵢ[k][t][xᵢᵗ, xₖᵗ] for (k, xₖᵗ) in enumerate(xₙᵢᵗ))
end
end
end
end
Expand All @@ -93,22 +112,25 @@ function f_bp_dummy_neighbor(A::Vector{MPEM2{F}},
next!(prog, showvalues=[(:t, "$t/$T")])
end

return MPEM3(B), 0.0
if periodic
return PeriodicMPEM3(B), 0.0
else
return MPEM3(B), 0.0
end
end

function pair_belief_as_mpem(Aᵢⱼ::MPEM2, Aⱼᵢ::MPEM2, ψᵢⱼ)
function pair_belief_as_mpem(Aᵢⱼ::M2, Aⱼᵢ::M2, ψᵢⱼ) where {M2<:AbstractMPEM2}
A = map(zip(Aᵢⱼ, Aⱼᵢ, ψᵢⱼ)) do (Aᵢⱼᵗ, Aⱼᵢᵗ, ψᵢⱼᵗ)
@cast Aᵗ[(aᵗ,bᵗ),(aᵗ⁺¹,bᵗ⁺¹),xᵢᵗ,xⱼᵗ] := Aᵢⱼᵗ[aᵗ,aᵗ⁺¹,xᵢᵗ, xⱼᵗ] *
Aⱼᵢᵗ[bᵗ,bᵗ⁺¹,xⱼᵗ,xᵢᵗ] * ψᵢⱼᵗ[xᵢᵗ, xⱼᵗ]
end
return MPEM2(A)
return M2(A)
end

# compute bᵢⱼᵗ(xᵢᵗ,xⱼᵗ) from μᵢⱼ, μⱼᵢ, ψᵢⱼ
# also return normalization zᵢⱼ
function pair_belief(Aᵢⱼ::MPEM2, Aⱼᵢ::MPEM2, ψᵢⱼ)
function pair_belief(Aᵢⱼ::AbstractMPEM2, Aⱼᵢ::AbstractMPEM2, ψᵢⱼ)
A = pair_belief_as_mpem(Aᵢⱼ, Aⱼᵢ, ψᵢⱼ)
l = accumulate_L(A); r = accumulate_R(A)
z = only(l[end])
marginals(A; l, r), z
l = accumulate_L(A)
marginals(A; l), normalization(A; l)
end
3 changes: 3 additions & 0 deletions src/exact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ function exact_prob(bp::MPBP{G,F}) where {G,F}
logp[x] += log( w[i][t](X[t+1,i], X[t,∂i], X[t,i]) )
logp[x] += log( ϕ[i][t+1][X[t+1,i]] )
end
if is_periodic(bp)
logp[x] += log( w[i][end](X[1,i], X[end,∂i], X[end,i]) )
end
end
for (i, j, ij) in edges(g)
for t in 1:T+1
Expand Down
15 changes: 15 additions & 0 deletions src/infinite_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,19 @@ function _pair_beliefs!(b, f, bp::MPBP{G,F}) where {G<:InfiniteRegularGraph,F}
logz = [(1/(bp.g.k-1)- 1/2) * log(zᵢⱼ)]
b[1] = bᵢⱼ
b, logz
end

function periodic_mpbp_infinite_graph(k::Integer, wᵢ::Vector{U}, qi::Int,
ϕᵢ = fill(ones(qi), length(wᵢ));
ψₖᵢ = fill(ones(qi, qi), length(wᵢ)),
d::Int=1, bondsizes=fill(d, length(wᵢ))) where {U<:BPFactor}

T = length(wᵢ) - 1
@assert length(ϕᵢ) == T + 1
@assert length(ψₖᵢ) == T + 1

g = InfiniteRegularGraph(k)
μ = rand_periodic_mpem2(qi, qi, T; d, bondsizes)
b = rand_periodic_mpem1(qi, T; d, bondsizes)
MPBP(g, [wᵢ], [ϕᵢ], [ψₖᵢ], [μ], [b], [0.0])
end
Loading