From 80b1fe569a39b77986d38ccd127c99b0fdc5f395 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 10 Dec 2024 10:41:36 -0500 Subject: [PATCH 1/9] use product method since it has a nullptr check --- gtsam/discrete/DiscreteFactorGraph.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 4ededbb8be..523b992016 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -66,8 +66,9 @@ namespace gtsam { /* ************************************************************************* */ DecisionTreeFactor DiscreteFactorGraph::product() const { DecisionTreeFactor result; - for(const sharedFactor& factor: *this) + for (const sharedFactor& factor : *this) { if (factor) result = (*factor) * result; + } return result; } @@ -75,8 +76,9 @@ namespace gtsam { double DiscreteFactorGraph::operator()( const DiscreteValues &values) const { double product = 1.0; - for( const sharedFactor& factor: factors_ ) + for (const sharedFactor& factor : factors_) { product *= (*factor)(values); + } return product; } @@ -117,8 +119,7 @@ namespace gtsam { const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic(product); - DecisionTreeFactor product; - for (auto&& factor : factors) product = (*factor) * product; + DecisionTreeFactor product = factors.product(); gttoc(product); // Max over all the potentials by pretending all keys are frontal: @@ -206,8 +207,7 @@ namespace gtsam { const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic(product); - DecisionTreeFactor product; - for (auto&& factor : factors) product = (*factor) * product; + DecisionTreeFactor product = factors.product(); gttoc(product); // Max over all the potentials by pretending all keys are frontal: From 92e0a55e7817d7507ebc098df85e1e8d7cbfca5f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 10 Dec 2024 10:41:57 -0500 Subject: [PATCH 2/9] generalize discreteKeys method --- gtsam/discrete/DiscreteFactorGraph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 523b992016..96fdfc3383 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -54,7 +54,7 @@ namespace gtsam { DiscreteKeys DiscreteFactorGraph::discreteKeys() const { DiscreteKeys result; for (auto&& factor : *this) { - if (auto p = std::dynamic_pointer_cast(factor)) { + if (auto p = std::dynamic_pointer_cast(factor)) { DiscreteKeys factor_keys = p->discreteKeys(); result.insert(result.end(), factor_keys.begin(), factor_keys.end()); } From dea9c7f7655576d7629903e38fc357e81898ec82 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 10 Dec 2024 10:49:08 -0500 Subject: [PATCH 3/9] common function for product and normalization --- gtsam/discrete/DiscreteFactorGraph.cpp | 32 ++++++++++++++------------ 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 96fdfc3383..436d784f7a 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -112,11 +112,14 @@ namespace gtsam { // } // } - /* ************************************************************************ */ - // Alternate eliminate function for MPE - std::pair // - EliminateForMPE(const DiscreteFactorGraph& factors, - const Ordering& frontalKeys) { + /** + * @brief Multiply all the `factors` and normalize the + * product to prevent underflow. + * + * @param factors The factors to multiply as a DiscreteFactorGraph. + * @return DecisionTreeFactor + */ + static DecisionTreeFactor ProductAndNormalize(const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors gttic(product); DecisionTreeFactor product = factors.product(); @@ -127,6 +130,14 @@ namespace gtsam { // Normalize the product factor to prevent underflow. product = product / (*normalization); + } + + /* ************************************************************************ */ + // Alternate eliminate function for MPE + std::pair // + EliminateForMPE(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys) { + DecisionTreeFactor product = ProductAndNormalize(factors); // max out frontals, this is the factor on the separator gttic(max); @@ -205,16 +216,7 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - // PRODUCT: multiply all factors - gttic(product); - DecisionTreeFactor product = factors.product(); - gttoc(product); - - // Max over all the potentials by pretending all keys are frontal: - auto normalization = product.max(product.size()); - - // Normalize the product factor to prevent underflow. - product = product / (*normalization); + DecisionTreeFactor product = ProductAndNormalize(factors); // sum out frontals, this is the factor on the separator gttic(sum); From 590293bb92f53291a0a77837dfc51314e34c641a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 10 Dec 2024 10:52:15 -0500 Subject: [PATCH 4/9] return tau factor as DiscreteFactor for discrete elimination --- gtsam/discrete/DiscreteFactorGraph.cpp | 5 +++-- gtsam/discrete/DiscreteFactorGraph.h | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 436d784f7a..508a5f0f8a 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -14,6 +14,7 @@ * @date Feb 14, 2011 * @author Duy-Nguyen Ta * @author Frank Dellaert + * @author Varun Agrawal */ #include @@ -134,7 +135,7 @@ namespace gtsam { /* ************************************************************************ */ // Alternate eliminate function for MPE - std::pair // + std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { DecisionTreeFactor product = ProductAndNormalize(factors); @@ -213,7 +214,7 @@ namespace gtsam { } /* ************************************************************************ */ - std::pair // + std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { DecisionTreeFactor product = ProductAndNormalize(factors); diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index d0dc282b41..c57d2258c2 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -14,6 +14,7 @@ * @date Feb 14, 2011 * @author Duy-Nguyen Ta * @author Frank Dellaert + * @author Varun Agrawal */ #pragma once @@ -48,7 +49,7 @@ class DiscreteJunctionTree; * @ingroup discrete */ GTSAM_EXPORT -std::pair +std::pair EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys); @@ -61,7 +62,7 @@ EliminateDiscrete(const DiscreteFactorGraph& factors, * @ingroup discrete */ GTSAM_EXPORT -std::pair +std::pair EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys); From 162f61061cbbe02298250c759ff84387dd065039 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 10 Dec 2024 10:52:36 -0500 Subject: [PATCH 5/9] use BaseFactor methods to reduce code in DiscreteConditional --- gtsam/discrete/DiscreteConditional.h | 10 +++------- gtsam/discrete/DiscreteDistribution.h | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 8586233016..8cba6dbe7f 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -168,13 +168,9 @@ class GTSAM_EXPORT DiscreteConditional static_cast(this)->print(s, formatter); } - /// Evaluate, just look up in AlgebraicDecisionTree - virtual double evaluate(const Assignment& values) const override { - return ADT::operator()(values); - } - - using DecisionTreeFactor::error; ///< DiscreteValues version - using DiscreteFactor::operator(); ///< DiscreteValues version + using BaseFactor::error; ///< DiscreteValues version + using BaseFactor::evaluate; ///< DiscreteValues version + using BaseFactor::operator(); ///< DiscreteValues version /** * @brief restrict to given *parent* values. diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h index 4b690da156..09ea50332e 100644 --- a/gtsam/discrete/DiscreteDistribution.h +++ b/gtsam/discrete/DiscreteDistribution.h @@ -40,7 +40,7 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { /// Default constructor needed for serialization. DiscreteDistribution() {} - /// Constructor from factor. + /// Constructor from DecisionTreeFactor. explicit DiscreteDistribution(const DecisionTreeFactor& f) : Base(f.size(), f) {} From 372b133366626d03d5319843fad08e77909156db Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 10 Dec 2024 11:05:50 -0500 Subject: [PATCH 6/9] formatting --- gtsam/discrete/DiscreteFactorGraph.cpp | 27 +++++++++++++------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 508a5f0f8a..ebba93382e 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -36,13 +36,12 @@ namespace gtsam { template class FactorGraph; template class EliminateableFactorGraph; - /* ************************************************************************* */ - bool DiscreteFactorGraph::equals(const This& fg, double tol) const - { + /* ************************************************************************ */ + bool DiscreteFactorGraph::equals(const This& fg, double tol) const { return Base::equals(fg, tol); } - /* ************************************************************************* */ + /* ************************************************************************ */ KeySet DiscreteFactorGraph::keys() const { KeySet keys; for (const sharedFactor& factor : *this) { @@ -51,7 +50,7 @@ namespace gtsam { return keys; } - /* ************************************************************************* */ + /* ************************************************************************ */ DiscreteKeys DiscreteFactorGraph::discreteKeys() const { DiscreteKeys result; for (auto&& factor : *this) { @@ -64,7 +63,7 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { DecisionTreeFactor result; for (const sharedFactor& factor : *this) { @@ -73,9 +72,8 @@ namespace gtsam { return result; } - /* ************************************************************************* */ - double DiscreteFactorGraph::operator()( - const DiscreteValues &values) const { + /* ************************************************************************ */ + double DiscreteFactorGraph::operator()(const DiscreteValues& values) const { double product = 1.0; for (const sharedFactor& factor : factors_) { product *= (*factor)(values); @@ -83,9 +81,9 @@ namespace gtsam { return product; } - /* ************************************************************************* */ + /* ************************************************************************ */ void DiscreteFactorGraph::print(const string& s, - const KeyFormatter& formatter) const { + const KeyFormatter& formatter) const { std::cout << s << std::endl; std::cout << "size: " << size() << std::endl; for (size_t i = 0; i < factors_.size(); i++) { @@ -120,7 +118,8 @@ namespace gtsam { * @param factors The factors to multiply as a DiscreteFactorGraph. * @return DecisionTreeFactor */ - static DecisionTreeFactor ProductAndNormalize(const DiscreteFactorGraph& factors) { + static DecisionTreeFactor ProductAndNormalize( + const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors gttic(product); DecisionTreeFactor product = factors.product(); @@ -155,8 +154,8 @@ namespace gtsam { // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - auto lookup = std::make_shared(nrFrontals, - orderedKeys, product); + auto lookup = + std::make_shared(nrFrontals, orderedKeys, product); gttoc(lookup); return {std::dynamic_pointer_cast(lookup), max}; From 62008fc99543c8c87e6a38bea11ff1592b1aac16 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 10 Dec 2024 11:10:25 -0500 Subject: [PATCH 7/9] add return --- gtsam/discrete/DiscreteFactorGraph.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index ebba93382e..ec6dac2fce 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -130,6 +130,8 @@ namespace gtsam { // Normalize the product factor to prevent underflow. product = product / (*normalization); + + return product; } /* ************************************************************************ */ From 2fd60a47a231b6cbebc6b916d213e64d379b39fc Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 10 Dec 2024 11:54:05 -0500 Subject: [PATCH 8/9] fix test --- gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 341eb63e38..0d71c12bad 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -113,7 +113,8 @@ TEST(DiscreteFactorGraph, test) { const Ordering frontalKeys{0}; const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys); - DecisionTreeFactor newFactor = *newFactorPtr; + DecisionTreeFactor newFactor = + *std::dynamic_pointer_cast(newFactorPtr); // Normalize newFactor by max for comparison with expected auto normalization = newFactor.max(newFactor.size()); From 588533751b07d69a1940241ba6a144eb30ae4938 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 10 Dec 2024 14:10:33 -0500 Subject: [PATCH 9/9] add another pointer check --- gtsam/discrete/DiscreteFactorGraph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index ec6dac2fce..d0bf210477 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -76,7 +76,7 @@ namespace gtsam { double DiscreteFactorGraph::operator()(const DiscreteValues& values) const { double product = 1.0; for (const sharedFactor& factor : factors_) { - product *= (*factor)(values); + if (factor) product *= (*factor)(values); } return product; }