Skip to content

Commit

Permalink
multithreaded, lighter sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
stecrotti committed Feb 28, 2024
1 parent 1c91eb5 commit f15d9fd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/MatrixProductBP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import Statistics: mean, std
import Unzip: unzip
import StatsBase: weights, proportions
import LogExpFunctions: logistic, logsumexp
import .Threads: SpinLock, lock, unlock
import .Threads: SpinLock, lock, unlock, @threads
import Lazy: @forward
import CavityTools: cavity
import LogarithmicNumbers: ULogarithmic
Expand Down
24 changes: 13 additions & 11 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@ import CavityTools: ExponentialQueue

# as in https://doi.org/10.1103/PhysRevLett.114.248701
# draw samples from the prior and weight them with their likelihood
struct SoftMarginSampler{B<:MPBP, F<:Real}
struct SoftMarginSampler{B<:MPBP,F<:Real,TI<:Integer}
bp :: B
X :: Vector{Matrix{Int}}
X :: Vector{Matrix{TI}}
w :: Vector{F}

function SoftMarginSampler(bp::TBP,
X::Vector{Matrix{Int}}, w::Vector{F}) where{TBP<:MPBP, F}
X::Vector{Matrix{TI}}, w::Vector{F}) where{TBP<:MPBP,F,TI<:Integer}

N = nv(bp.g); T = getT(bp)
@assert length(X) == length(w)
@assert all((0), w)
@assert all(x -> size(x) == (N, T+1), X)

new{TBP, F}(bp, X, w)
new{TBP,F,TI}(bp, X, w)
end
end

function SoftMarginSampler(bp::MPBP)
X = Matrix{Int}[]
X = Matrix{UInt8}[]
w = zeros(ULogarithmic, 0)
SoftMarginSampler(bp, X, w)
end
Expand Down Expand Up @@ -72,7 +72,7 @@ function sample!(sms::SoftMarginSampler, nsamples::Integer;
X = [zeros(Int, N, T+1) for _ in 1:nsamples]
p⁰ = [ϕᵢ[1] ./ sum(ϕᵢ[1]) for ϕᵢ in sms.bp.ϕ]
w = zeros(ULogarithmic, nsamples)
for n in 1:nsamples
@threads for n in 1:nsamples
_, w[n] = onesample!(X[n], sms.bp; p⁰, rng)
next!(prog)
end
Expand All @@ -98,9 +98,11 @@ function marginals(sms::SoftMarginSampler; showprogress::Bool=true, sites=vertic
prog = Progress(N, desc="Marginals from Soft Margin"; dt=showprogress ? 0.1 : Inf)
is_free = is_free_dynamics(sms.bp)

for (a,i) in pairs(sites)
@threads for a in eachindex(sites)
i = sites[a]
x = x = zeros(Int, length(X))
for t in 1:T+1
x = [xx[i, t] for xx in X]
x .= [xx[i, t] for xx in X]
mit_avg = is_free ? proportions(x, nstates(bp,i)) : proportions(x, nstates(bp,i), wv)
mit_var = mit_avg .* (1 .- mit_avg) ./ nsamples
marg[a][t] .= mit_avg sqrt.( mit_var )
Expand All @@ -126,11 +128,11 @@ function pair_marginals(sms::SoftMarginSampler; showprogress::Bool=true)
@assert all(>=(0), w)
wv = weights(w)
nsamples = length(X)
prog = Progress(E, desc="Marginals from Soft Margin"; dt=showprogress ? 0.1 : Inf)
x = zeros(Int, length(X))
prog = Progress(E, desc="Pair marginals from Soft Margin"; dt=showprogress ? 0.1 : Inf)

for (i,j,id) in edges(g)
@threads for (i,j,id) in collect(edges(g))
linear = LinearIndices((1:nstates(bp,i), 1:nstates(bp,j)))
x = zeros(Int, length(X))
for t in 1:T+1
x .= [linear[xx[i, t],xx[j,t]] for xx in X]
mijt_avg_linear = proportions(x, nstates(bp,i)*nstates(bp,j), wv)
Expand Down

0 comments on commit f15d9fd

Please sign in to comment.