Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quick fix for notable problems with neg_binomial_2 and neg_binomial_2_log #1622

Merged
33 changes: 21 additions & 12 deletions stan/math/prim/prob/neg_binomial_2_log_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/prob/poisson_log_lpmf.hpp>
#include <cmath>

namespace stan {
Expand Down Expand Up @@ -84,19 +85,27 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
}

for (size_t i = 0; i < max_size_seq_view; i++) {
if (include_summand<propto>::value) {
logp -= lgamma(n_vec[i] + 1.0);
if (phi__[i] > 1e5) {
// TODO(martinmodrak) This is wrong (doesn't pass propto information),
// and inaccurate for n = 0, but shouldn't break most models.
// Also the 1e5 cutoff is way too low.
// Will be adressed better once PR #1497 is merged
logp += poisson_log_lpmf(n_vec[i], eta__[i]);
} else {
if (include_summand<propto>::value) {
logp -= lgamma(n_vec[i] + 1.0);
}
if (include_summand<propto, T_precision>::value) {
logp += multiply_log(phi__[i], phi__[i]) - lgamma(phi__[i]);
}
if (include_summand<propto, T_log_location>::value) {
logp += n_vec[i] * eta__[i];
}
if (include_summand<propto, T_precision>::value) {
logp += lgamma(n_plus_phi[i]);
}
logp -= (n_plus_phi[i]) * logsumexp_eta_logphi[i];
}
if (include_summand<propto, T_precision>::value) {
logp += multiply_log(phi__[i], phi__[i]) - lgamma(phi__[i]);
}
if (include_summand<propto, T_log_location>::value) {
logp += n_vec[i] * eta__[i];
}
if (include_summand<propto, T_precision>::value) {
logp += lgamma(n_plus_phi[i]);
}
logp -= (n_plus_phi[i]) * logsumexp_eta_logphi[i];

if (!is_constant_all<T_log_location>::value) {
ops_partials.edge1_.partials_[i]
Expand Down
34 changes: 19 additions & 15 deletions stan/math/prim/prob/neg_binomial_2_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,27 @@ return_type_t<T_location, T_precision> neg_binomial_2_lpmf(
}

for (size_t i = 0; i < max_size_seq_view; i++) {
if (include_summand<propto>::value) {
logp -= lgamma(n_vec[i] + 1.0);
}
if (include_summand<propto, T_precision>::value) {
logp += multiply_log(phi__[i], phi__[i]) - lgamma(phi__[i]);
}
if (include_summand<propto, T_location>::value) {
logp += multiply_log(n_vec[i], mu__[i]);
}
if (include_summand<propto, T_precision>::value) {
logp += lgamma(n_plus_phi[i]);
}
logp -= (n_plus_phi[i]) * log_mu_plus_phi[i];

// if phi is large we probably overflow, defer to Poisson:
if (phi__[i] > 1e5) {
logp = poisson_lpmf(n_vec[i], mu__[i]);
// TODO(martinmodrak) This is wrong (doesn't pass propto information),
// and inaccurate for n = 0, but shouldn't break most models.
// Also the 1e5 cutoff is too small.
// Will be adressed better in PR #1497
logp += poisson_lpmf(n_vec[i], mu__[i]);
} else {
if (include_summand<propto>::value) {
logp -= lgamma(n_vec[i] + 1.0);
}
if (include_summand<propto, T_precision>::value) {
logp += multiply_log(phi__[i], phi__[i]) - lgamma(phi__[i]);
}
if (include_summand<propto, T_location>::value) {
logp += multiply_log(n_vec[i], mu__[i]);
}
if (include_summand<propto, T_precision>::value) {
logp += lgamma(n_plus_phi[i]);
}
logp -= (n_plus_phi[i]) * log_mu_plus_phi[i];
}

if (!is_constant_all<T_location>::value) {
Expand Down
29 changes: 29 additions & 0 deletions test/unit/math/prim/prob/neg_binomial_2_log_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <gtest/gtest.h>
#include <boost/random/mersenne_twister.hpp>
#include <boost/math/distributions.hpp>
#include <algorithm>
#include <limits>
#include <vector>
#include <string>
Expand Down Expand Up @@ -207,3 +208,31 @@ TEST(ProbNegBinomial2, log_matches_lpmf) {
(stan::math::neg_binomial_2_lpmf<double, double, double>(y, mu, phi)),
(stan::math::neg_binomial_2_log<double, double, double>(y, mu, phi)));
}

TEST(ProbDistributionsNegBinomial2Log, neg_binomial_2_log_grid_test) {
std::vector<double> mu_log_to_test
= {-101, -27, -3, -1, -0.132, 0, 4, 10, 87};
std::vector<double> phi_to_test = {2e-5, 0.36, 1, 2.3e5, 1.8e10, 6e16};
std::vector<int> n_to_test = {0, 1, 10, 39, 101, 3048, 150054};

// TODO(martinmdorak) Only weak tolerance for this quick fix
auto tolerance = [](double x) { return std::max(fabs(x * 1e-8), 1e-8); };

for (double mu_log : mu_log_to_test) {
for (double phi : phi_to_test) {
for (int n : n_to_test) {
double val_log = stan::math::neg_binomial_2_log_lpmf(n, mu_log, phi);
EXPECT_LE(val_log, 0)
<< "neg_binomial_2_log_lpmf yields " << val_log
<< " which si greater than 0 for n = " << n
<< ", mu_log = " << mu_log << ", phi = " << phi << ".";
double val_orig
= stan::math::neg_binomial_2_lpmf(n, std::exp(mu_log), phi);
EXPECT_NEAR(val_log, val_orig, tolerance(val_orig))
<< "neg_binomial_2_log_lpmf yields different result (" << val_log
<< ") than neg_binomial_2_lpmf (" << val_orig << ") for n = " << n
<< ", mu_log = " << mu_log << ", phi = " << phi << ".";
}
}
}
}
13 changes: 13 additions & 0 deletions test/unit/math/prim/prob/neg_binomial_2_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,19 @@ TEST(ProbDistributionsNegBinomial, extreme_values) {
}
}

TEST(ProbDistributionsNegBinomial2, vectorAroundCutoff) {
int y = 10;
double mu = 9.36;
std::vector<double> phi;
phi.push_back(1);
phi.push_back(1e15);
double vector_value = stan::math::neg_binomial_2_lpmf(y, mu, phi);
double scalar_value = stan::math::neg_binomial_2_lpmf(y, mu, phi[0])
+ stan::math::neg_binomial_2_lpmf(y, mu, phi[1]);

EXPECT_FLOAT_EQ(vector_value, scalar_value);
}

TEST(ProbDistributionsNegativeBinomial2Log, distributionCheck) {
check_counts_real_real(NegativeBinomial2LogTestRig());
}
Loading