Skip to content

Commit d834ace

Browse files
seabbsSamuelBrand1
andauthored
Issue 340: Remove nan handling (#415)
* remove nan handling * get rid of all NB padding and clamping * remove nan handling * get rid of all NB padding and clamping * remove overflow test * Add `rand` safe version of Poisson and Negative binomial distributions (#418) * SafePoisson with safety for large means * better selection for conversion to Int or BigInt * add SafeNegativeBinomial * add unit tests to doctests * reformat * Add type promotion so AD works with distribution constructor * Add logpdf grad call unit tests for Safe discrete dists * reformat * change neg bin param to (r, p) * Update utils.jl * reformat * change empirical var test to more principled approach * add default nadapts rather than just 50% of target sampling * Update NUTSampler.jl * set dist check_args = false * Set nadapts to Turing Default * reformat --------- Co-authored-by: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com>
1 parent bc99b2e commit d834ace

File tree

9 files changed

+606
-32
lines changed

9 files changed

+606
-32
lines changed

EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ using ..EpiAwareBase
88
using DataFramesMeta: DataFrame, @rename!
99
using DynamicPPL: Model, fix, condition, @submodel, @model
1010
using MCMCChains: Chains
11-
using Random: AbstractRNG
11+
using Random: AbstractRNG, randexp
1212
using Tables: rowtable
1313

1414
using Distributions, DocStringExtensions, QuadGK, Statistics, Turing
1515

1616
#Export Structures
17-
export HalfNormal, DirectSample
17+
export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial
1818

1919
#Export functions
2020
export scan, spread_draws, censored_pmf, get_param_array, prefix_submodel
@@ -32,5 +32,7 @@ include("turing-methods.jl")
3232
include("DirectSample.jl")
3333
include("post-inference.jl")
3434
include("get_param_array.jl")
35+
include("SafePoisson.jl")
36+
include("SafeNegativeBinomial.jl")
3537

3638
end
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
@doc raw"
2+
Create a Negative binomial distribution with the specified mean that avoids `InExactError`
3+
when the mean is too large.
4+
5+
6+
# Parameterisation:
7+
We are using a mean and cluster factorization of the negative binomial distribution such
8+
that the variance to mean relationship is:
9+
10+
```math
11+
\sigma^2 = \mu + \alpha^2 \mu^2
12+
```
13+
14+
The reason for this parameterisation is that at sufficiently large mean values (i.e. `r > 1 / p`) `p` is approximately equal to the
15+
standard fluctuation of the distribution, e.g. if `p = 0.05` we expect typical fluctuations of samples from the negative binomial to be
16+
about 5% of the mean when the mean is notably larger than 20. Otherwise, we expect approximately Poisson noise. In our opinion, this
17+
parameterisation is useful for specifying the distribution in a way that is easier to reason on priors for `p`.
18+
19+
# Arguments:
20+
21+
- `r`: The number of successes, although this can be extended to a continous number.
22+
- `p`: Success rate.
23+
24+
# Returns:
25+
26+
- A `SafeNegativeBinomial` distribution with the specified mean.
27+
28+
# Examples:
29+
30+
```jldoctest SafeNegativeBinomial
31+
using EpiAware, Distributions
32+
33+
bigμ = exp(48.0) #Large value of μ
34+
σ² = bigμ + 0.05 * bigμ^2 #Large variance
35+
36+
# We can calculate the success rate from the mean to variance relationship
37+
p = bigμ / σ²
38+
r = bigμ * p / (1 - p)
39+
d = SafeNegativeBinomial(r, p)
40+
# output
41+
EpiAware.EpiAwareUtils.SafeNegativeBinomial{Float64}(r=20.0, p=2.85032816548187e-20)
42+
```
43+
44+
```jldoctest SafeNegativeBinomial
45+
cdf(d, 100)
46+
# output
47+
0.0
48+
```
49+
50+
```jldoctest SafeNegativeBinomial
51+
logpdf(d, 100)
52+
# output
53+
-850.1397180331871
54+
```
55+
56+
```jldoctest SafeNegativeBinomial
57+
mean(d)
58+
# output
59+
7.016735912097631e20
60+
```
61+
62+
```jldoctest SafeNegativeBinomial
63+
var(d)
64+
# output
65+
2.4617291430060293e40
66+
```
67+
"
68+
struct SafeNegativeBinomial{T <: Real} <: DiscreteUnivariateDistribution
69+
r::T
70+
p::T
71+
72+
function SafeNegativeBinomial{T}(r::T, p::T) where {T <: Real}
73+
return new{T}(r, p)
74+
end
75+
end
76+
77+
#Outer constructors make AD work
78+
function SafeNegativeBinomial(r::T, p::T) where {T <: Real}
79+
return SafeNegativeBinomial{T}(r, p)
80+
end
81+
82+
SafeNegativeBinomial(r::Real, p::Real) = SafeNegativeBinomial(promote(r, p)...)
83+
84+
# helper function
85+
_negbin(d::SafeNegativeBinomial) = NegativeBinomial(d.r, d.p; check_args = false)
86+
87+
### Support
88+
89+
Base.minimum(d::SafeNegativeBinomial) = 0
90+
Base.maximum(d::SafeNegativeBinomial) = Inf
91+
Distributions.insupport(d::SafeNegativeBinomial, x::Integer) = x >= 0
92+
93+
#### Parameters
94+
95+
Distributions.params(d::SafeNegativeBinomial) = _negbin(d) |> params
96+
Distributions.partype(::SafeNegativeBinomial{T}) where {T} = T
97+
98+
Distributions.succprob(d::SafeNegativeBinomial) = _negbin(d).p
99+
Distributions.failprob(d::SafeNegativeBinomial{T}) where {T} = one(T) - _negbin(d).p
100+
101+
#### Statistics
102+
103+
Distributions.mean(d::SafeNegativeBinomial) = _negbin(d) |> mean
104+
Distributions.var(d::SafeNegativeBinomial) = _negbin(d) |> var
105+
Distributions.std(d::SafeNegativeBinomial) = _negbin(d) |> std
106+
Distributions.skewness(d::SafeNegativeBinomial) = _negbin(d) |> skewness
107+
Distributions.kurtosis(d::SafeNegativeBinomial) = _negbin(d) |> kurtosis
108+
Distributions.mode(d::SafeNegativeBinomial) = _negbin(d) |> mode
109+
function Distributions.kldivergence(p::SafeNegativeBinomial, q::SafeNegativeBinomial)
110+
kldivergence(_negbin(p), _negbin(q))
111+
end
112+
113+
#### Evaluation & Sampling
114+
115+
Distributions.logpdf(d::SafeNegativeBinomial, k::Real) = logpdf(_negbin(d), k)
116+
117+
Distributions.cdf(d::SafeNegativeBinomial, x::Real) = cdf(_negbin(d), x)
118+
Distributions.ccdf(d::SafeNegativeBinomial, x::Real) = ccdf(_negbin(d), x)
119+
Distributions.logcdf(d::SafeNegativeBinomial, x::Real) = logcdf(_negbin(d), x)
120+
Distributions.logccdf(d::SafeNegativeBinomial, x::Real) = logccdf(_negbin(d), x)
121+
Distributions.quantile(d::SafeNegativeBinomial, q::Real) = quantile(_negbin(d), q)
122+
Distributions.cquantile(d::SafeNegativeBinomial, q::Real) = cquantile(_negbin(d), q)
123+
Distributions.invlogcdf(d::SafeNegativeBinomial, lq::Real) = invlogcdf(_negbin(d), lq)
124+
Distributions.invlogccdf(d::SafeNegativeBinomial, lq::Real) = invlogccdf(_negbin(d), lq)
125+
126+
## sampling
127+
function Base.rand(rng::AbstractRNG, d::SafeNegativeBinomial)
128+
if isone(d.p)
129+
return 0
130+
else
131+
return rand(rng, SafePoisson(rand(rng, Gamma(d.r, (1 - d.p) / d.p))))
132+
end
133+
end
134+
135+
Distributions.mgf(d::SafeNegativeBinomial, t::Real) = mgf(_negbin(d), t)
136+
Distributions.cgf(d::SafeNegativeBinomial, t) = cgf(_negbin(d), t)
137+
Distributions.cf(d::SafeNegativeBinomial, t::Real) = cf(_negbin(d), t)

0 commit comments

Comments
 (0)