Skip to content

Commit

Permalink
Merge pull request #1830 from martinmodrak/bugfix/1495-neg_binomial_2…
Browse files Browse the repository at this point in the history
…_log_stability

More stable implementation of neg_binomial_2_log_lpmf
  • Loading branch information
bbbales2 authored Apr 11, 2020
2 parents 11742e2 + 766de89 commit d8392cb
Show file tree
Hide file tree
Showing 3 changed files with 556 additions and 443 deletions.
69 changes: 32 additions & 37 deletions stan/math/prim/prob/neg_binomial_2_log_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/binomial_coefficient_log.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/multiply_log.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/prob/poisson_log_lpmf.hpp>
#include <cmath>

namespace stan {
Expand Down Expand Up @@ -52,13 +51,20 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
size_t size_phi = stan::math::size(phi);
size_t size_eta_phi = max_size(eta, phi);
size_t size_n_phi = max_size(n, phi);
size_t max_size_seq_view = max_size(n, eta, phi);
size_t size_all = max_size(n, eta, phi);

VectorBuilder<true, T_partials_return, T_log_location> eta_val(size_eta);
for (size_t i = 0; i < size_eta; ++i) {
eta_val[i] = value_of(eta_vec[i]);
}

VectorBuilder<true, T_partials_return, T_precision> phi_val(size_phi);
VectorBuilder<true, T_partials_return, T_precision> log_phi(size_phi);
for (size_t i = 0; i < size_phi; ++i) {
phi_val[i] = value_of(phi_vec[i]);
log_phi[i] = log(phi_val[i]);
}

VectorBuilder<!is_constant_all<T_log_location, T_precision>::value,
T_partials_return, T_log_location>
exp_eta(size_eta);
Expand All @@ -68,17 +74,19 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
}
}

VectorBuilder<true, T_partials_return, T_precision> phi_val(size_phi);
VectorBuilder<true, T_partials_return, T_precision> log_phi(size_phi);
for (size_t i = 0; i < size_phi; ++i) {
phi_val[i] = value_of(phi_vec[i]);
log_phi[i] = log(phi_val[i]);
VectorBuilder<!is_constant_all<T_log_location, T_precision>::value,
T_partials_return, T_log_location, T_precision>
exp_eta_over_exp_eta_phi(size_eta_phi);
if (!is_constant_all<T_log_location, T_precision>::value) {
for (size_t i = 0; i < size_eta_phi; ++i) {
exp_eta_over_exp_eta_phi[i] = inv(phi_val[i] / exp_eta[i] + 1);
}
}

VectorBuilder<true, T_partials_return, T_log_location, T_precision>
logsumexp_eta_logphi(size_eta_phi);
log1p_exp_eta_m_logphi(size_eta_phi);
for (size_t i = 0; i < size_eta_phi; ++i) {
logsumexp_eta_logphi[i] = log_sum_exp(eta_val[i], log_phi[i]);
log1p_exp_eta_m_logphi[i] = log1p_exp(eta_val[i] - log_phi[i]);
}

VectorBuilder<true, T_partials_return, T_n, T_precision> n_plus_phi(
Expand All @@ -87,38 +95,25 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
n_plus_phi[i] = n_vec[i] + phi_val[i];
}

for (size_t i = 0; i < max_size_seq_view; i++) {
if (phi_val[i] > 1e5) {
// TODO(martinmodrak) This is wrong (doesn't pass propto information),
// and inaccurate for n = 0, but shouldn't break most models.
// Also the 1e5 cutoff is way too low.
// Will be addressed better once PR #1497 is merged
logp += poisson_log_lpmf(n_vec[i], eta_val[i]);
} else {
if (include_summand<propto>::value) {
logp -= lgamma(n_vec[i] + 1.0);
}
if (include_summand<propto, T_precision>::value) {
logp += multiply_log(phi_val[i], phi_val[i]) - lgamma(phi_val[i]);
}
if (include_summand<propto, T_log_location>::value) {
logp += n_vec[i] * eta_val[i];
}
if (include_summand<propto, T_precision>::value) {
logp += lgamma(n_plus_phi[i]);
}
logp -= (n_plus_phi[i]) * logsumexp_eta_logphi[i];
for (size_t i = 0; i < size_all; i++) {
if (include_summand<propto, T_precision>::value) {
logp += binomial_coefficient_log(n_plus_phi[i] - 1, n_vec[i]);
}
if (include_summand<propto, T_log_location>::value) {
logp += n_vec[i] * eta_val[i];
}
logp += -phi_val[i] * log1p_exp_eta_m_logphi[i]
- n_vec[i] * (log_phi[i] + log1p_exp_eta_m_logphi[i]);

if (!is_constant_all<T_log_location>::value) {
ops_partials.edge1_.partials_[i]
+= n_vec[i] - n_plus_phi[i] / (phi_val[i] / exp_eta[i] + 1);
+= n_vec[i] - n_plus_phi[i] * exp_eta_over_exp_eta_phi[i];
}
if (!is_constant_all<T_precision>::value) {
ops_partials.edge2_.partials_[i]
+= 1.0 - n_plus_phi[i] / (exp_eta[i] + phi_val[i]) + log_phi[i]
- logsumexp_eta_logphi[i] - digamma(phi_val[i])
+ digamma(n_plus_phi[i]);
+= exp_eta_over_exp_eta_phi[i] - n_vec[i] / (exp_eta[i] + phi_val[i])
- log1p_exp_eta_m_logphi[i]
- (digamma(phi_val[i]) - digamma(n_plus_phi[i]));
}
}
return ops_partials.build(logp);
Expand Down
17 changes: 5 additions & 12 deletions test/unit/math/prim/prob/neg_binomial_2_log_test.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <stan/math/prim.hpp>
#include <test/unit/math/prim/prob/vector_rng_test_helper.hpp>
#include <test/unit/math/prim/prob/NegativeBinomial2LogTestRig.hpp>
#include <test/unit/math/expect_near_rel.hpp>
#include <gtest/gtest.h>
#include <boost/random/mersenne_twister.hpp>
#include <boost/math/distributions.hpp>
Expand Down Expand Up @@ -212,29 +213,21 @@ TEST(ProbNegBinomial2, log_matches_lpmf) {
TEST(ProbDistributionsNegBinomial2Log, neg_binomial_2_log_grid_test) {
std::vector<double> mu_log_to_test
= {-101, -27, -3, -1, -0.132, 0, 4, 10, 87};
// TODO(martinmodrak) Reducing the span of the test, should be fixed
// along with #1495
// std::vector<double> phi_to_test = {2e-5, 0.36, 1, 10, 2.3e5, 1.8e10, 6e16};
std::vector<double> phi_to_test = {0.36, 1, 10};
std::vector<double> phi_to_test = {2e-5, 0.36, 1, 10, 2.3e5, 1.8e10, 6e16};
std::vector<int> n_to_test = {0, 1, 10, 39, 101, 3048, 150054};

// TODO(martinmdorak) Only weak tolerance for this quick fix
auto tolerance = [](double x) { return std::max(fabs(x * 1e-8), 1e-8); };

for (double mu_log : mu_log_to_test) {
for (double phi : phi_to_test) {
for (int n : n_to_test) {
double val_log = stan::math::neg_binomial_2_log_lpmf(n, mu_log, phi);
EXPECT_LE(val_log, 0)
<< "neg_binomial_2_log_lpmf yields " << val_log
<< " which si greater than 0 for n = " << n
<< ", mu_log = " << mu_log << ", phi = " << phi << ".";
std::stringstream msg;
double val_orig
= stan::math::neg_binomial_2_lpmf(n, std::exp(mu_log), phi);
EXPECT_NEAR(val_log, val_orig, tolerance(val_orig))
msg << std::setprecision(22)
<< "neg_binomial_2_log_lpmf yields different result (" << val_log
<< ") than neg_binomial_2_lpmf (" << val_orig << ") for n = " << n
<< ", mu_log = " << mu_log << ", phi = " << phi << ".";
stan::test::expect_near_rel(msg.str(), val_log, val_orig);
}
}
}
Expand Down
Loading

0 comments on commit d8392cb

Please sign in to comment.