Skip to content

Commit

Permalink
avoid most allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
abraunst committed Dec 15, 2023
1 parent c0ff68e commit 9cde212
Showing 1 changed file with 17 additions and 35 deletions.
52 changes: 17 additions & 35 deletions src/bp_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,54 +27,49 @@ function f_bp(A::Vector{M2}, wᵢ::Vector{U}, ϕᵢ::Vector{Vector{F}},
@assert all(length(ϕᵢᵗ) == q for ϕᵢᵗ in ϕᵢ)
@assert j_index in eachindex(A)
z = length(A) # z = |∂i|
x_neigs = Iterators.product((1:size(ψₙᵢ[k][1],2) for k=1:z)...) .|> collect
x_neigs = (collect(x) for x in Iterators.product((1:size(ψₙᵢ[k][1],2) for k=1:z)...))

B = Vector{Array{F,5}}(undef, T + 1)

dt = showprogress ? 1.0 : Inf
prog = Progress(T - 1, dt=dt, desc="Computing outgoing message")
for t in 1:T+1
ϕᵢᵗ = ϕᵢ[t]
# select incoming A's but not the j-th one
Aᵗ = kron2([A[k][t] for k in eachindex(A)[Not(j_index)]]...)
nrows, ncols = size(Aᵗ, 1), size(Aᵗ, 2)
Bᵗ = zeros(nrows, ncols, q, qj, q)

for xᵢᵗ in 1:q
@inbounds for xᵢᵗ in 1:q
for xᵢᵗ⁺¹ in 1:q
for xₙᵢᵗ in x_neigs
xⱼᵗ = xₙᵢᵗ[j_index]
xₙᵢ₋ⱼᵗ = xₙᵢᵗ[Not(j_index)]
Bᵗ[:, :, xᵢᵗ, xⱼᵗ, xᵢᵗ⁺¹] .+= ((t == T + 1) && !periodic ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, xₙᵢᵗ, xᵢᵗ)) *
Aᵗ[:, :, xᵢᵗ, xₙᵢ₋ⱼᵗ...] *
@views Bᵗ[:, :, xᵢᵗ, xⱼᵗ, xᵢᵗ⁺¹] .+= (t == T + 1 && !periodic ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, xₙᵢᵗ, xᵢᵗ)) .*
Aᵗ[:, :, xᵢᵗ, xₙᵢ₋ⱼᵗ...] .*
ϕᵢᵗ[xᵢᵗ] .*
prod(ψₙᵢ[k][t][xᵢᵗ, xₖᵗ] for (k, xₖᵗ) in enumerate(xₙᵢᵗ) if k != j_index; init=1.0)
end
end
Bᵗ[:, :, xᵢᵗ, :, :] *= ϕᵢ[t][xᵢᵗ]
end
B[t] = Bᵗ
any(isnan, Bᵗ) && @error "NaN in tensor at time $t"
next!(prog, showvalues=[(:t, "$t/$T")])
end
# apologies to the gods of type stability
if periodic
return PeriodicMPEM3(B), 0.0
else
return MPEM3(B), 0.0
end

mpem3from2(eltype(A))(B), 0.0
end

# compute outgoing message to dummy neighbor to get the belief
function f_bp_dummy_neighbor(A::Vector{<:AbstractMPEM2},
wᵢ::Vector{U}, ϕᵢ::Vector{Vector{F}}, ψₙᵢ::Vector{Vector{Matrix{F}}};
showprogress=false, svd_trunc::SVDTrunc=TruncThresh(0.0), periodic=false) where {F,U<:BPFactor}

q = length(ϕᵢ[1])
T = length(ϕᵢ) - 1
@assert all(length(a) == T + 1 for a in A)
@assert length(wᵢ) == T + 1
z = length(A) # z = |∂i|
xₙᵢ = Iterators.product((1:size(ψₙᵢ[k][1],2) for k=1:z)...) .|> collect

xₙᵢ = (collect(x) for x in Iterators.product((1:size(ψₙᵢ[k][1],2) for k=1:z)...))
B = Vector{Array{F,5}}(undef, T + 1)
dt = showprogress ? 1.0 : Inf
prog = Progress(T - 1, dt=dt, desc="Computing outgoing message")
Expand All @@ -83,26 +78,17 @@ function f_bp_dummy_neighbor(A::Vector{<:AbstractMPEM2},
nrows = size(Aᵗ, 1)
ncols = size(Aᵗ, 2)
Bᵗ = zeros(nrows, ncols, q, 1, q)
# for xᵢᵗ in 1:q
# for xᵢᵗ⁺¹ in 1:q
# tmp = Matrix(I*(t == T + 1 ? 1.0 : ϕᵢ[t][xᵢᵗ]), 1, 1)
# for xₙᵢᵗ in xₙᵢ
# tmp = (t == T + 1 ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, xₙᵢᵗ, xᵢᵗ)) .*
# Aᵗ[:, :, xᵢᵗ, xₙᵢᵗ...] .* tmp .*
# prod(ψₙᵢ[k][t][xᵢᵗ, xₖᵗ] for (k, xₖᵗ) in enumerate(xₙᵢᵗ))
# end
# Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+= tmp
# end
# end
for xᵢᵗ in 1:q
@inbounds for xᵢᵗ in 1:q
for xᵢᵗ⁺¹ in 1:q
if isempty(A)
Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+= ((t == T + 1) && !periodic ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, Int[], xᵢᵗ)) * ϕᵢ[t][xᵢᵗ]
Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+= (t == T + 1 && !periodic ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, Int[], xᵢᵗ)) * ϕᵢ[t][xᵢᵗ]
else
for xₙᵢᵗ in xₙᵢ
Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+= ((t == T + 1) && !periodic ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, xₙᵢᵗ, xᵢᵗ)) .*
Aᵗ[:, :, xᵢᵗ, xₙᵢᵗ...] .* ϕᵢ[t][xᵢᵗ] .*
prod(ψₙᵢ[k][t][xᵢᵗ, xₖᵗ] for (k, xₖᵗ) in enumerate(xₙᵢᵗ))
@views Bᵗ[:, :, xᵢᵗ, 1, xᵢᵗ⁺¹] .+=
(t == T + 1 && !periodic ? 1.0 : wᵢ[t](xᵢᵗ⁺¹, xₙᵢᵗ, xᵢᵗ)) .*
Aᵗ[:, :, xᵢᵗ, xₙᵢᵗ...] .*
ϕᵢ[t][xᵢᵗ] .*
prod(ψₙᵢ[k][t][xᵢᵗ, xₖᵗ] for (k, xₖᵗ) in enumerate(xₙᵢᵗ))
end
end
end
Expand All @@ -112,11 +98,7 @@ function f_bp_dummy_neighbor(A::Vector{<:AbstractMPEM2},
next!(prog, showvalues=[(:t, "$t/$T")])
end

if periodic
return PeriodicMPEM3(B), 0.0
else
return MPEM3(B), 0.0
end
return mpem3from2(eltype(A))(B), 0.0
end

function pair_belief_as_mpem(Aᵢⱼ::M2, Aⱼᵢ::M2, ψᵢⱼ) where {M2<:AbstractMPEM2}
Expand Down

0 comments on commit 9cde212

Please sign in to comment.