diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 8586233016..8cba6dbe7f 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -168,13 +168,9 @@ class GTSAM_EXPORT DiscreteConditional static_cast(this)->print(s, formatter); } - /// Evaluate, just look up in AlgebraicDecisionTree - virtual double evaluate(const Assignment& values) const override { - return ADT::operator()(values); - } - - using DecisionTreeFactor::error; ///< DiscreteValues version - using DiscreteFactor::operator(); ///< DiscreteValues version + using BaseFactor::error; ///< DiscreteValues version + using BaseFactor::evaluate; ///< DiscreteValues version + using BaseFactor::operator(); ///< DiscreteValues version /** * @brief restrict to given *parent* values. diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h index 4b690da156..09ea50332e 100644 --- a/gtsam/discrete/DiscreteDistribution.h +++ b/gtsam/discrete/DiscreteDistribution.h @@ -40,7 +40,7 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { /// Default constructor needed for serialization. DiscreteDistribution() {} - /// Constructor from factor. + /// Constructor from DecisionTreeFactor. explicit DiscreteDistribution(const DecisionTreeFactor& f) : Base(f.size(), f) {} diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 4ededbb8be..d0bf210477 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -14,6 +14,7 @@ * @date Feb 14, 2011 * @author Duy-Nguyen Ta * @author Frank Dellaert + * @author Varun Agrawal */ #include @@ -35,13 +36,12 @@ namespace gtsam { template class FactorGraph; template class EliminateableFactorGraph; - /* ************************************************************************* */ - bool DiscreteFactorGraph::equals(const This& fg, double tol) const - { + /* ************************************************************************ */ + bool DiscreteFactorGraph::equals(const This& fg, double tol) const { return Base::equals(fg, tol); } - /* ************************************************************************* */ + /* ************************************************************************ */ KeySet DiscreteFactorGraph::keys() const { KeySet keys; for (const sharedFactor& factor : *this) { @@ -50,11 +50,11 @@ namespace gtsam { return keys; } - /* ************************************************************************* */ + /* ************************************************************************ */ DiscreteKeys DiscreteFactorGraph::discreteKeys() const { DiscreteKeys result; for (auto&& factor : *this) { - if (auto p = std::dynamic_pointer_cast(factor)) { + if (auto p = std::dynamic_pointer_cast(factor)) { DiscreteKeys factor_keys = p->discreteKeys(); result.insert(result.end(), factor_keys.begin(), factor_keys.end()); } @@ -63,26 +63,27 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { DecisionTreeFactor result; - for(const sharedFactor& factor: *this) + for (const sharedFactor& factor : *this) { if (factor) result = (*factor) * result; + } return result; } - /* ************************************************************************* */ - double DiscreteFactorGraph::operator()( - const DiscreteValues &values) const { + /* ************************************************************************ */ + double DiscreteFactorGraph::operator()(const DiscreteValues& values) const { double product = 1.0; - for( const sharedFactor& factor: factors_ ) - product *= (*factor)(values); + for (const sharedFactor& factor : factors_) { + if (factor) product *= (*factor)(values); + } return product; } - /* ************************************************************************* */ + /* ************************************************************************ */ void DiscreteFactorGraph::print(const string& s, - const KeyFormatter& formatter) const { + const KeyFormatter& formatter) const { std::cout << s << std::endl; std::cout << "size: " << size() << std::endl; for (size_t i = 0; i < factors_.size(); i++) { @@ -110,15 +111,18 @@ namespace gtsam { // } // } - /* ************************************************************************ */ - // Alternate eliminate function for MPE - std::pair // - EliminateForMPE(const DiscreteFactorGraph& factors, - const Ordering& frontalKeys) { + /** + * @brief Multiply all the `factors` and normalize the + * product to prevent underflow. + * + * @param factors The factors to multiply as a DiscreteFactorGraph. + * @return DecisionTreeFactor + */ + static DecisionTreeFactor ProductAndNormalize( + const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors gttic(product); - DecisionTreeFactor product; - for (auto&& factor : factors) product = (*factor) * product; + DecisionTreeFactor product = factors.product(); gttoc(product); // Max over all the potentials by pretending all keys are frontal: @@ -127,6 +131,16 @@ namespace gtsam { // Normalize the product factor to prevent underflow. product = product / (*normalization); + return product; + } + + /* ************************************************************************ */ + // Alternate eliminate function for MPE + std::pair // + EliminateForMPE(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys) { + DecisionTreeFactor product = ProductAndNormalize(factors); + // max out frontals, this is the factor on the separator gttic(max); DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); @@ -142,8 +156,8 @@ namespace gtsam { // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - auto lookup = std::make_shared(nrFrontals, - orderedKeys, product); + auto lookup = + std::make_shared(nrFrontals, orderedKeys, product); gttoc(lookup); return {std::dynamic_pointer_cast(lookup), max}; @@ -201,20 +215,10 @@ namespace gtsam { } /* ************************************************************************ */ - std::pair // + std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - // PRODUCT: multiply all factors - gttic(product); - DecisionTreeFactor product; - for (auto&& factor : factors) product = (*factor) * product; - gttoc(product); - - // Max over all the potentials by pretending all keys are frontal: - auto normalization = product.max(product.size()); - - // Normalize the product factor to prevent underflow. - product = product / (*normalization); + DecisionTreeFactor product = ProductAndNormalize(factors); // sum out frontals, this is the factor on the separator gttic(sum); diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index d0dc282b41..c57d2258c2 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -14,6 +14,7 @@ * @date Feb 14, 2011 * @author Duy-Nguyen Ta * @author Frank Dellaert + * @author Varun Agrawal */ #pragma once @@ -48,7 +49,7 @@ class DiscreteJunctionTree; * @ingroup discrete */ GTSAM_EXPORT -std::pair +std::pair EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys); @@ -61,7 +62,7 @@ EliminateDiscrete(const DiscreteFactorGraph& factors, * @ingroup discrete */ GTSAM_EXPORT -std::pair +std::pair EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys); diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 341eb63e38..0d71c12bad 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -113,7 +113,8 @@ TEST(DiscreteFactorGraph, test) { const Ordering frontalKeys{0}; const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys); - DecisionTreeFactor newFactor = *newFactorPtr; + DecisionTreeFactor newFactor = + *std::dynamic_pointer_cast(newFactorPtr); // Normalize newFactor by max for comparison with expected auto normalization = newFactor.max(newFactor.size());