Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved numerical stability of binomial_coefficient_log #1614

Merged
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
8a8ea67
Failing test
martinmodrak Jan 14, 2020
7991f40
Fixes #1592
martinmodrak Jan 14, 2020
77d1e85
Test for derivatives, docs
martinmodrak Jan 14, 2020
8bef133
Merge branch 'bugfix/1611-lbeta-large-arguments' into bugfix/1592-bin…
martinmodrak Jan 14, 2020
118adb2
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 14, 2020
3578ff6
Lint, test sources
martinmodrak Jan 14, 2020
e03205d
Mismatched define
martinmodrak Jan 14, 2020
e44bd90
Merge branch 'bugfix/1611-lbeta-large-arguments' into bugfix/1592-bin…
martinmodrak Jan 14, 2020
035212d
Improved code style
martinmodrak Jan 14, 2020
e7194bf
Merge remote-tracking branch 'origin/bugfix/1611-lbeta-large-argument…
martinmodrak Jan 14, 2020
7922dd5
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 14, 2020
3fd0f3f
Merge remote-tracking branch 'origin/bugfix/1611-lbeta-large-argument…
martinmodrak Jan 15, 2020
6d6ac35
Fixed lint error
martinmodrak Jan 15, 2020
3a320df
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 15, 2020
37760cd
Merge remote-tracking branch 'stan-dev/develop' into bugfix/1592-bino…
martinmodrak Jan 16, 2020
6ebef3e
Removed problematic constexpr
martinmodrak Jan 16, 2020
3df81a2
Merge branch 'bugfix/1611-lbeta-large-arguments' into bugfix/1592-bin…
martinmodrak Jan 16, 2020
d0f7f45
Merge branch 'bugfix/1611-lbeta-large-arguments' into bugfix/1592-bin…
martinmodrak Jan 17, 2020
c01d2be
Merge remote-tracking branch 'stan-dev/develop' into bugfix/1592-bino…
martinmodrak Jan 29, 2020
74d2824
Updated and cleaned up test generating code
martinmodrak Jan 29, 2020
dd59b06
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 29, 2020
5bda4f5
Improved precomputed tests, first pass at formula test
martinmodrak Jan 29, 2020
cd3381a
Work towards formula tests (failing)
martinmodrak Jan 29, 2020
d014464
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Jan 29, 2020
9a5ee11
Merge branch 'develop' into bugfix/1592-binomial_coefficient_log
martinmodrak Mar 7, 2020
51ef491
Passing identity tests, ad test
martinmodrak Mar 7, 2020
215cad0
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Mar 7, 2020
71cbd40
Resolved int and fwd issues
martinmodrak Mar 7, 2020
cfa5b81
Merge remote-tracking branch 'origin/bugfix/1592-binomial_coefficient…
martinmodrak Mar 7, 2020
4fa9c52
Attempt to handle edge cases for derivatives
martinmodrak Mar 9, 2020
eab19c3
Moved to operands_and_partials, fixed edge cases
martinmodrak Mar 9, 2020
75a6f73
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Mar 9, 2020
47c7962
Fixed lint
martinmodrak Mar 9, 2020
cf05d40
Handling review comments
martinmodrak Mar 12, 2020
5b93278
Merge commit 'f2a3c1a6f8de1d5bd380c265f0bc92472dce2d42' into HEAD
yashikno Mar 12, 2020
c84a259
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Mar 12, 2020
6aed316
One more edge case-test+fix
martinmodrak Mar 12, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 96 additions & 28 deletions stan/math/prim/fun/binomial_coefficient_log.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
#ifndef STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
#define STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP

#include <boost/math/constants/constants.hpp>
martinmodrak marked this conversation as resolved.
Show resolved Hide resolved
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/is_any_nan.hpp>
#include <stan/math/prim/fun/log1p.hpp>
#include <stan/math/prim/fun/lbeta.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/multiply_log.hpp>
#include <stan/math/prim/fun/value_of.hpp>

namespace stan {
namespace math {
Expand All @@ -13,22 +19,24 @@ namespace math {
* Return the log of the binomial coefficient for the specified
* arguments.
*
* The binomial coefficient, \f${N \choose n}\f$, read "N choose n", is
* defined for \f$0 \leq n \leq N\f$ by
* The binomial coefficient, \f${n \choose k}\f$, read "n choose k", is
* defined for \f$0 \leq k \leq n\f$ by
*
* \f${N \choose n} = \frac{N!}{n! (N-n)!}\f$.
* \f${n \choose k} = \frac{n!}{k! (n-k)!}\f$.
*
* This function uses Gamma functions to define the log
* and generalize the arguments to continuous N and n.
* and generalize the arguments to continuous n and k.
*
* \f$ \log {n \choose k}
* = \log \ \Gamma(n+1) - \log \Gamma(k+1) - \log \Gamma(n-k+1)\f$.
*
* \f$ \log {N \choose n}
* = \log \ \Gamma(N+1) - \log \Gamma(n+1) - \log \Gamma(N-n+1)\f$.
*
\f[
\mbox{binomial\_coefficient\_log}(x, y) =
\begin{cases}
\textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
\ln\Gamma(x+1) & \mbox{if } 0\leq y \leq x \\
\textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
< -1\\
\ln\Gamma(x+1) & \mbox{if } -1 < y < x + 1 \\
\quad -\ln\Gamma(y+1)& \\
\quad -\ln\Gamma(x-y+1)& \\[6pt]
\textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
Expand All @@ -38,7 +46,8 @@ namespace math {
\f[
\frac{\partial\, \mbox{binomial\_coefficient\_log}(x, y)}{\partial x} =
\begin{cases}
\textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
\textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
< -1\\
\Psi(x+1) & \mbox{if } 0\leq y \leq x \\
\quad -\Psi(x-y+1)& \\[6pt]
\textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
Expand All @@ -48,32 +57,91 @@ namespace math {
\f[
\frac{\partial\, \mbox{binomial\_coefficient\_log}(x, y)}{\partial y} =
\begin{cases}
\textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\
\textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x
< -1\\
-\Psi(y+1) & \mbox{if } 0\leq y \leq x \\
\quad +\Psi(x-y+1)& \\[6pt]
\textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
\end{cases}
\f]
*
* @tparam T_N type of the first argument
* @tparam T_n type of the second argument
* @param N total number of objects.
* @param n number of objects chosen.
* @return log (N choose n).
* This function is numerically more stable than naive evaluation via lgamma.
*
* @tparam T_n type of the first argument
* @tparam T_k type of the second argument
*
* @param n total number of objects.
* @param k number of objects chosen.
* @return log (n choose k).
*/
template <typename T_N, typename T_n>
inline return_type_t<T_N, T_n> binomial_coefficient_log(const T_N N,
const T_n n) {
const double CUTOFF = 1000;
if (N - n < CUTOFF) {
const T_N N_plus_1 = N + 1;
return lgamma(N_plus_1) - lgamma(n + 1) - lgamma(N_plus_1 - n);

template <typename T_n, typename T_k>
inline return_type_t<T_n, T_k> binomial_coefficient_log(const T_n n,
const T_k k) {
if (is_any_nan(n, k)) {
return stan::math::NOT_A_NUMBER;
martinmodrak marked this conversation as resolved.
Show resolved Hide resolved
}

// Choosing the more stable of the symmetric branches
if (n > 0 && k > value_of_rec(n) / 2.0 + 1e-8) {
return binomial_coefficient_log(n, n - k);
}

using T_partials_return = partials_return_t<T_n, T_k>;
martinmodrak marked this conversation as resolved.
Show resolved Hide resolved

const T_partials_return n_ = value_of(n);
const T_partials_return k_ = value_of(k);
martinmodrak marked this conversation as resolved.
Show resolved Hide resolved
const T_partials_return n_plus_1 = n_ + 1;
const T_partials_return n_plus_1_mk = n_plus_1 - k_;

static const char* function = "binomial_coefficient_log";
check_greater_or_equal(function, "first argument", n, -1);
check_greater_or_equal(function, "second argument", k, -1);
check_greater_or_equal(function, "(first argument - second argument + 1)",
n_plus_1_mk, 0.0);

operands_and_partials<T_n, T_k> ops_partials(n, k);

T_partials_return value;
if (k_ == 0) {
value = 0;
} else if (n_plus_1 < lgamma_stirling_diff_useful) {
value = lgamma(n_plus_1) - lgamma(k_ + 1) - lgamma(n_plus_1_mk);
} else {
return_type_t<T_N, T_n> N_minus_n = N - n;
const double one_twelfth = inv(12);
return multiply_log(n, N_minus_n) + multiply_log((N + 0.5), N / N_minus_n)
+ one_twelfth / N - n - one_twelfth / N_minus_n - lgamma(n + 1);
value = -lbeta(n_plus_1_mk, k_ + 1) - log1p(n_);
}

if (!is_constant_all<T_n, T_k>::value) {
// Branching on all the edge cases.
// In direct computation many of those would be NaN
// But one-sided limits from within the domain exist.
T_partials_return digamma_n_plus_1_mk = digamma(n_plus_1_mk);

if (!is_constant_all<T_n>::value) {
if (n_ == -1.0) {
if (k_ == 0) {
ops_partials.edge1_.partials_[0] = 0;
} else {
ops_partials.edge1_.partials_[0] = stan::math::NEGATIVE_INFTY;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be ops_partials.edge1_.partials_[0] = k_dbl == 0 ? 0 : NEGATIVE_INFTY.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I just thought the ternary operator is discouraged. But if that is the Stan style, I can happily use it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's discouraged, I've been suggested to use it in similar cases. It's a matter of preference, if you find your version clearer, don't change it. :) For me, given that there are a series of nested ifs, it helped reducing the number of things I had to keep in mind while reading.

}
} else {
ops_partials.edge1_.partials_[0]
= (digamma(n_plus_1) - digamma_n_plus_1_mk);
}
}
if (!is_constant_all<T_k>::value) {
if (k_ == 0 && n_ == -1.0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the gradient logic for k here?

I think I kinda followed the ones for n but these ones confused me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what to add. The only fact I use is that lim x-> 0 digamma(x) from above is negative infinity. Mentioned that in the comments, if that is still confusing, let me know (it is also possible I am missing some of the edge cases).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah, that makes sense.

ops_partials.edge2_.partials_[0] = stan::math::NEGATIVE_INFTY;
} else if (k_ == -1) {
ops_partials.edge2_.partials_[0] = stan::math::INFTY;
} else {
ops_partials.edge2_.partials_[0]
= (digamma_n_plus_1_mk - digamma(k_ + 1));
}
}
}

return ops_partials.build(value);
}

} // namespace math
Expand Down
4 changes: 4 additions & 0 deletions test/unit/math/mix/fun/binomial_coefficient_log_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,9 @@ TEST(mathMixScalFun, binomialCoefficientLog) {
};
stan::test::expect_ad(f, 3, 2);
stan::test::expect_ad(f, 24.0, 12.0);
stan::test::expect_ad(f, 1.0, 0.0);
stan::test::expect_ad(f, 0.0, 1.0);
stan::test::expect_ad(f, -0.3, 0.5);

stan::test::expect_common_nonzero_binary(f);
}
32 changes: 29 additions & 3 deletions test/unit/math/prim/fun/binomial_coefficient_log_test.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#include <stan/math/prim.hpp>
#include <test/unit/math/expect_near_rel.hpp>
#include <gtest/gtest.h>
#include <cmath>
#include <limits>

template <typename T_N, typename T_n>
void test_binom_coefficient(const T_N& N, const T_n& n) {
using stan::math::binomial_coefficient_log;
EXPECT_FLOAT_EQ(lgamma(N + 1) - lgamma(n + 1) - lgamma(N - n + 1),
binomial_coefficient_log(N, n));
binomial_coefficient_log(N, n))
<< "N = " << N << ", n = " << n;
}

TEST(MathFunctions, binomial_coefficient_log) {
Expand All @@ -19,6 +20,13 @@ TEST(MathFunctions, binomial_coefficient_log) {

EXPECT_FLOAT_EQ(29979.16, binomial_coefficient_log(100000, 91116));

EXPECT_EQ(binomial_coefficient_log(-1, 0), 0); // Needed for neg_binomial_2
EXPECT_EQ(binomial_coefficient_log(50, 0), 0);
EXPECT_EQ(binomial_coefficient_log(10000, 0), 0);

EXPECT_EQ(binomial_coefficient_log(10, 11), stan::math::NEGATIVE_INFTY);
EXPECT_EQ(binomial_coefficient_log(10, -1), stan::math::NEGATIVE_INFTY);

for (int n = 0; n < 1010; ++n) {
test_binom_coefficient(1010, n);
test_binom_coefficient(1010.0, n);
Expand All @@ -32,9 +40,27 @@ TEST(MathFunctions, binomial_coefficient_log) {
}

TEST(MathFunctions, binomial_coefficient_log_nan) {
double nan = std::numeric_limits<double>::quiet_NaN();
double nan = stan::math::NOT_A_NUMBER;

EXPECT_TRUE(std::isnan(stan::math::binomial_coefficient_log(2.0, nan)));
EXPECT_TRUE(std::isnan(stan::math::binomial_coefficient_log(nan, 2.0)));
EXPECT_TRUE(std::isnan(stan::math::binomial_coefficient_log(nan, nan)));
}

TEST(MathFunctions, binomial_coefficient_log_errors_edge_cases) {
using stan::math::INFTY;
using stan::math::binomial_coefficient_log;

EXPECT_NO_THROW(binomial_coefficient_log(10, 11));
EXPECT_THROW(binomial_coefficient_log(10, 11.01), std::domain_error);
EXPECT_THROW(binomial_coefficient_log(10, -1.1), std::domain_error);
EXPECT_THROW(binomial_coefficient_log(-1, 0.3), std::domain_error);
EXPECT_NO_THROW(binomial_coefficient_log(-0.5, 0.49));
EXPECT_NO_THROW(binomial_coefficient_log(10, -0.9));

EXPECT_FLOAT_EQ(binomial_coefficient_log(0, -1), -INFTY);
EXPECT_FLOAT_EQ(binomial_coefficient_log(-1, 0), 0);
EXPECT_FLOAT_EQ(binomial_coefficient_log(-1, -0.3), INFTY);
EXPECT_FLOAT_EQ(binomial_coefficient_log(0.3, -1), -INFTY);
EXPECT_FLOAT_EQ(binomial_coefficient_log(5.0, 6.0), -INFTY);
}
Loading