diff --git a/src/gamma.jl b/src/gamma.jl index 25aa086..a00fc43 100644 --- a/src/gamma.jl +++ b/src/gamma.jl @@ -3,44 +3,27 @@ gamma(z::Number) = _gamma(float(z)) _gamma(x::Float32) = Float32(_gamma(Float64(x))) function _gamma(x::Float64) + T = Float64 if x < 0 - isinteger(x) && throw(DomainError(x, "NaN result for non-NaN input.")) + (isinteger(x) || x==-Inf) && throw(DomainError(x, "NaN result for non-NaN input.")) xp1 = abs(x) + 1.0 - return π / sinpi(xp1) / _gammax(xp1) - else - return _gammax(x) + return π / (sinpi(xp1) * _gamma(xp1)) end -end -# only have a Float64 implementations -function _gammax(x) + isfinite(x) || return x if x > 11.5 - return large_gamma(x) - elseif x <= 11.5 - return small_gamma(x) - elseif isnan(x) - return x - end -end -function large_gamma(x) - isinf(x) && return x - T = Float64 - w = inv(x) - s = ( - 8.333333333333331800504e-2, 3.472222222230075327854e-3, -2.681327161876304418288e-3, -2.294719747873185405699e-4, - 7.840334842744753003862e-4, 6.989332260623193171870e-5, -5.950237554056330156018e-4, -2.363848809501759061727e-5, - 7.147391378143610789273e-4 - ) - w = w * evalpoly(w, s) + one(T) - # lose precision on following block - y = exp((x)) - # avoid overflow - v = x^(0.5 * x - 0.25) - y = v * (v / y) + w = inv(x) + s = ( + 8.333333333333331800504e-2, 3.472222222230075327854e-3, -2.681327161876304418288e-3, -2.294719747873185405699e-4, + 7.840334842744753003862e-4, 6.989332260623193171870e-5, -5.950237554056330156018e-4, -2.363848809501759061727e-5, + 7.147391378143610789273e-4 + ) + w = muladd(w, evalpoly(w, s), 1.0) + # avoid overflow + v = x ^ muladd(0.5, x, -0.25) + y = v * (v / exp(x)) - return SQ2PI(T) * y * w -end -function small_gamma(x) - T = Float64 + return SQ2PI(T) * y * w + end P = ( 1.000000000000000000009e0, 8.378004301573126728826e-1, 3.629515436640239168939e-1, 1.113062816019361559013e-1, 2.385363243461108252554e-2, 4.092666828394035500949e-3, 4.542931960608009155600e-4, 4.212760487471622013093e-5 @@ -51,18 +34,26 @@ function small_gamma(x) -1.397148517476170440917e-5 ) - z = one(T) + z = 1.0 while x >= 3.0 - x -= one(T) + x -= 1.0 z *= x end while x < 2.0 z /= x - x += one(T) + x += 1.0 end - x -= T(2) + x -= 2.0 p = evalpoly(x, P) q = evalpoly(x, Q) return z * p / q end + + +function gamma(n::Integer) + n < 0 && throw(DomainError(n, "`n` must not be negative.")) + n == 0 && return Inf*float(n) + n > 20 && return gamma(float(n)) + @inbounds return Float64(factorial(n-1)) +end