diff --git a/src/MatrixProductBP.jl b/src/MatrixProductBP.jl index 1f93ec41..88a8ecbf 100644 --- a/src/MatrixProductBP.jl +++ b/src/MatrixProductBP.jl @@ -16,7 +16,7 @@ import Statistics: mean, std import Unzip: unzip import StatsBase: weights, proportions import LogExpFunctions: logistic, logsumexp -import .Threads: SpinLock, lock, unlock +import .Threads: SpinLock, lock, unlock, @threads import Lazy: @forward import CavityTools: cavity import LogarithmicNumbers: ULogarithmic diff --git a/src/sampling.jl b/src/sampling.jl index c1ac1a5a..fd6bc043 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -3,25 +3,25 @@ import CavityTools: ExponentialQueue # as in https://doi.org/10.1103/PhysRevLett.114.248701 # draw samples from the prior and weight them with their likelihood -struct SoftMarginSampler{B<:MPBP, F<:Real} +struct SoftMarginSampler{B<:MPBP,F<:Real,TI<:Integer} bp :: B - X :: Vector{Matrix{Int}} + X :: Vector{Matrix{TI}} w :: Vector{F} function SoftMarginSampler(bp::TBP, - X::Vector{Matrix{Int}}, w::Vector{F}) where{TBP<:MPBP, F} + X::Vector{Matrix{TI}}, w::Vector{F}) where{TBP<:MPBP,F,TI<:Integer} N = nv(bp.g); T = getT(bp) @assert length(X) == length(w) @assert all(≥(0), w) @assert all(x -> size(x) == (N, T+1), X) - new{TBP, F}(bp, X, w) + new{TBP,F,TI}(bp, X, w) end end function SoftMarginSampler(bp::MPBP) - X = Matrix{Int}[] + X = Matrix{UInt8}[] w = zeros(ULogarithmic, 0) SoftMarginSampler(bp, X, w) end @@ -72,7 +72,7 @@ function sample!(sms::SoftMarginSampler, nsamples::Integer; X = [zeros(Int, N, T+1) for _ in 1:nsamples] p⁰ = [ϕᵢ[1] ./ sum(ϕᵢ[1]) for ϕᵢ in sms.bp.ϕ] w = zeros(ULogarithmic, nsamples) - for n in 1:nsamples + @threads for n in 1:nsamples _, w[n] = onesample!(X[n], sms.bp; p⁰, rng) next!(prog) end @@ -98,9 +98,11 @@ function marginals(sms::SoftMarginSampler; showprogress::Bool=true, sites=vertic prog = Progress(N, desc="Marginals from Soft Margin"; dt=showprogress ? 0.1 : Inf) is_free = is_free_dynamics(sms.bp) - for (a,i) in pairs(sites) + @threads for a in eachindex(sites) + i = sites[a] + x = x = zeros(Int, length(X)) for t in 1:T+1 - x = [xx[i, t] for xx in X] + x .= [xx[i, t] for xx in X] mit_avg = is_free ? proportions(x, nstates(bp,i)) : proportions(x, nstates(bp,i), wv) mit_var = mit_avg .* (1 .- mit_avg) ./ nsamples marg[a][t] .= mit_avg .± sqrt.( mit_var ) @@ -126,11 +128,11 @@ function pair_marginals(sms::SoftMarginSampler; showprogress::Bool=true) @assert all(>=(0), w) wv = weights(w) nsamples = length(X) - prog = Progress(E, desc="Marginals from Soft Margin"; dt=showprogress ? 0.1 : Inf) - x = zeros(Int, length(X)) + prog = Progress(E, desc="Pair marginals from Soft Margin"; dt=showprogress ? 0.1 : Inf) - for (i,j,id) in edges(g) + @threads for (i,j,id) in collect(edges(g)) linear = LinearIndices((1:nstates(bp,i), 1:nstates(bp,j))) + x = zeros(Int, length(X)) for t in 1:T+1 x .= [linear[xx[i, t],xx[j,t]] for xx in X] mijt_avg_linear = proportions(x, nstates(bp,i)*nstates(bp,j), wv)