Skip to content

Use native samplers for Poisson distribution #1021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 6, 2019

Conversation

devmotion
Copy link
Member

I noticed that sampler is not implemented for Poisson distributions although two native samplers PoissonCountSampler and PoissonADSampler exist. Moreover, currently R is used for sampling from Poisson distributions, which is quite slow. For now I used the same cut-off value of 6 for both samplers, as chosen in PoissonRandom.jl.

I repeated the benchmarks

using Distributions, PoissonRandom, StatsFuns
using Plots

function n_count(rng, λ, n)
  tmp = 0
  for i in 1:n
    tmp += PoissonRandom.count_rand(rng,λ)
  end
  tmp
end

function n_pois(rng,λ,n)
  tmp = 0
  for i in 1:n
    tmp += pois_rand(rng,λ)
  end
  tmp
end

function n_ad(rng, λ, n)
  tmp = 0
  for i in 1:n
    tmp += PoissonRandom.ad_rand(rng, λ)
  end
  tmp
end

function n_dist(λ,n)
  tmp = 0
  for i in 1:n
    tmp += rand(Poisson(λ))
  end
  tmp
end

function n_rfunctions(λ, n)
  tmp = 0
  for i in 1:n
    tmp += convert(Int, StatsFuns.RFunctions.poisrand(λ))
  end
  tmp
end

function n_countsampler(rng, λ::Float64, n)
  tmp = 0
  for i in 1:n
    tmp += rand(rng, Distributions.PoissonCountSampler(λ))
  end
  tmp
end

function n_adsampler(rng, λ::Float64, n)
  tmp = 0
  for i in 1:n
    tmp += rand(rng, Distributions.PoissonADSampler(λ))
  end
  tmp
end

function time_λ!(rng, times, λ::Float64, n)
  times[1] = @elapsed n_count(rng, λ, n)
  times[2] = @elapsed n_ad(rng, λ, n)
  times[3] = @elapsed n_pois(rng, λ, n)
  times[4] = @elapsed n_dist(rng, λ, n)
  times[5] = @elapsed n_rfunctions(λ, n)
  times[6] = @elapsed n_countsampler(rng, λ, n)
  times[7] = @elapsed n_adsampler(rng, λ, n)

  nothing
end

function plot_benchmark(rng)
    times = Matrix{Float64}(undef, 7, 20)

    # Compile
    time_λ!(rng, view(times, :, 1), 5, 5_000_000)

    # Run with a bunch of λ
    for λ in 1:20
        time_λ!(rng, view(times, :, λ), float(λ), 5_000_000)
    end

    plot(times',
         labels = ["count_rand" "ad_rand" "pois_rand" "Distributions" "RFunctions" "PoissonCountSampler" "PoissonADSampler"],
         lw = 3)
end

from SciML/PoissonRandom.jl#6 with this PR. I get

using Random
Random.seed!(1234)
plot_benchmark(Random.GLOBAL_RNG)
savefig("global_rng.png")

global_rng
and

using RandomNumbers
plot_benchmark(Xorshifts.Xoroshiro128Plus(1234))
savefig("xoroshiro128plus.png")

xoroshiro128plus

Using the native samplers leads to a significant speed-up and a performance which is on par with PoissonRandom.jl.

@codecov-io
Copy link

codecov-io commented Nov 28, 2019

Codecov Report

Merging #1021 into master will increase coverage by 0.25%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1021      +/-   ##
==========================================
+ Coverage   77.92%   78.17%   +0.25%     
==========================================
  Files         112      112              
  Lines        5391     5325      -66     
==========================================
- Hits         4201     4163      -38     
+ Misses       1190     1162      -28
Impacted Files Coverage Δ
src/samplers/poisson.jl 92.18% <100%> (+1.27%) ⬆️
src/univariate/discrete/poisson.jl 66.66% <100%> (+4.56%) ⬆️
src/multivariate/mvnormalcanon.jl 78.72% <0%> (-2.13%) ⬇️
src/multivariate/mvnormal.jl 71.59% <0%> (ø) ⬆️
src/univariate/discrete/discretenonparametric.jl 98.03% <0%> (+0.01%) ⬆️
src/utils.jl 80% <0%> (+6.53%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 802a65b...ef80351. Read the comment docs.

@matbesancon
Copy link
Member

Samplers are not my area, but do we have a way to test for correctness of the sampler?

@devmotion
Copy link
Member Author

It seems they are tested in

## Poisson samplers
for (S, paramlst) in [
(PoissonCountSampler, [0.2, 0.5, 1.0, 2.0, 5.0, 10.0, 15.0, 20.0, 30.0]),
(PoissonADSampler, [5.0, 10.0, 15.0, 20.0, 30.0])]
local S
println(" testing $S")
for μ in paramlst
test_samples(S(μ), Poisson(μ), n_tsamples)
test_samples(S(μ), Poisson(μ), n_tsamples, rng=rng)
end
end

@matbesancon
Copy link
Member

LGTM, I'll wait for another review before merging though, to get more educated opinions


if G >= 0.0
K = floor(Int,G)
if G >= zero(G)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just compare with 0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that could be used here as well. I think it makes only a different if one works with non-standard number types such as unitful numbers - e.g., u"1.0m" > 0 errors whereas u"1.0m" > zero(u"1.0m") works as expected.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point, let's be defensive then.

else # Case B
# Ahrens & Dieter use a sequential method for tabulating and looking up quantiles.
# TODO: check which is more efficient.
return quantile(d,rand(rng))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the sub approach we will not use anymore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, instead of the quantile-based generation this PR uses the PoissonCountSampler for small rates which just counts exponentially distributed random numbers.

px = -μ
py = μ^K/factorial(K) # replace with loopup?
function sampler(d::Poisson)
if rate(d) < 6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this magic threshold appears twice, make it a constant?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's even better to not define rand at all and remove the code duplication.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump on this? Maybe naming the constant is still a nice idea

@matbesancon
Copy link
Member

@mschauer @devmotion good to merge?

@devmotion
Copy link
Member Author

Yes 👍 I guess, I adjusted the PR according to your suggestions and comments

@matbesancon matbesancon merged commit 3151573 into JuliaStats:master Dec 6, 2019
@matbesancon
Copy link
Member

thanks for the PR :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants