diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index f57cda28d9..623b82eea7 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -197,6 +197,30 @@ AlgebraicDecisionTree HybridBayesNet::errorTree( return result; } +/* ************************************************************************* */ +double HybridBayesNet::negLogConstant( + const std::optional &discrete) const { + double negLogNormConst = 0.0; + // Iterate over each conditional. + for (auto &&conditional : *this) { + if (discrete.has_value()) { + if (auto gm = conditional->asHybrid()) { + negLogNormConst += gm->choose(*discrete)->negLogConstant(); + } else if (auto gc = conditional->asGaussian()) { + negLogNormConst += gc->negLogConstant(); + } else if (auto dc = conditional->asDiscrete()) { + negLogNormConst += dc->choose(*discrete)->negLogConstant(); + } else { + throw std::runtime_error( + "Unknown conditional type when computing negLogConstant"); + } + } else { + negLogNormConst += conditional->negLogConstant(); + } + } + return negLogNormConst; +} + /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::discretePosterior( const VectorValues &continuousValues) const { diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index bba301be2f..96afb87d6d 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -237,6 +237,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using BayesNet::logProbability; // expose HybridValues version + /** + * @brief Get the negative log of the normalization constant + * corresponding to the joint density represented by this Bayes net. + * Optionally index by `discrete`. + * + * @param discrete Optional DiscreteValues + * @return double + */ + double negLogConstant(const std::optional &discrete) const; + /** * @brief Compute normalized posterior P(M|X=x) and return as a tree. * diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index ac03bd3a3e..1bec428107 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -322,8 +322,11 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( const GaussianFactorValuePair &pair) -> GaussianFactorValuePair { if (max->evaluate(choices) == 0.0) return {nullptr, std::numeric_limits::infinity()}; - else - return pair; + else { + // Add negLogConstant_ back so that the minimum negLogConstant in the + // HybridGaussianConditional is set correctly. + return {pair.first, pair.second + negLogConstant_}; + } }; FactorValuePairs prunedConditionals = factors().apply(pruner); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index ceabe0871a..9ca7a3938e 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -59,10 +59,11 @@ using OrphanWrapper = BayesTreeOrphanWrapper; /// Result from elimination. struct Result { + // Gaussian conditional resulting from elimination. GaussianConditional::shared_ptr conditional; - double negLogK; - GaussianFactor::shared_ptr factor; - double scalar; + double negLogK; // Negative log of the normalization constant K. + GaussianFactor::shared_ptr factor; // Leftover factor 𝜏. + double scalar; // Scalar value associated with factor 𝜏. bool operator==(const Result &other) const { return conditional == other.conditional && negLogK == other.negLogK && diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 16d0ae1a12..135da5dc73 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -363,10 +363,6 @@ TEST(HybridBayesNet, Pruning) { AlgebraicDecisionTree expected(s.modes, leaves); EXPECT(assert_equal(expected, discretePosterior, 1e-6)); - // Prune and get probabilities - auto prunedBayesNet = posterior->prune(2); - auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous()); - // Verify logProbability computation and check specific logProbability value const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; const HybridValues hybridValues{delta.continuous(), discrete_values}; @@ -381,10 +377,21 @@ TEST(HybridBayesNet, Pruning) { EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues), 1e-9); + double negLogConstant = posterior->negLogConstant(discrete_values); + + // The sum of all the mode densities + double normalizer = + AlgebraicDecisionTree(posterior->errorTree(delta.continuous()), + [](double error) { return exp(-error); }) + .sum(); + // Check agreement with discrete posterior - // double density = exp(logProbability); - // FAILS: EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), - // 1e-6); + double density = exp(logProbability + negLogConstant) / normalizer; + EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6); + + // Prune and get probabilities + auto prunedBayesNet = posterior->prune(2); + auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous()); // Regression test on pruned logProbability tree std::vector pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578}; @@ -392,7 +399,26 @@ TEST(HybridBayesNet, Pruning) { EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); // Regression - // FAILS: EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9); + double pruned_logProbability = 0; + pruned_logProbability += + prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues); + pruned_logProbability += + prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues); + pruned_logProbability += + prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues); + pruned_logProbability += + prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues); + + double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values); + + // The sum of all the mode densities + double pruned_normalizer = + AlgebraicDecisionTree(prunedBayesNet.errorTree(delta.continuous()), + [](double error) { return exp(-error); }) + .sum(); + double pruned_density = + exp(pruned_logProbability + pruned_negLogConstant) / pruned_normalizer; + EXPECT_DOUBLES_EQUAL(pruned_density, prunedTree(discrete_values), 1e-9); } /* ****************************************************************************/ diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index e29c485afd..350bc91848 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -275,6 +275,11 @@ TEST(HybridGaussianConditional, Prune) { // Check that the pruned HybridGaussianConditional has 2 conditionals EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); + + // Check that the minimum negLogConstant is set correctly + EXPECT_DOUBLES_EQUAL( + hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(), + pruned->negLogConstant(), 1e-9); } { const std::vector potentials{0.2, 0, 0.3, 0, // @@ -285,6 +290,9 @@ TEST(HybridGaussianConditional, Prune) { // Check that the pruned HybridGaussianConditional has 3 conditionals EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); + + // Check that the minimum negLogConstant is correct + EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9); } }