diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 338453404c..2037dd9514 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -19,7 +19,6 @@ #include #include -#include #include #include #include @@ -65,17 +64,10 @@ namespace gtsam { } /* ************************************************************************ */ - TableFactor DiscreteFactorGraph::product() const { - TableFactor result; + DecisionTreeFactor DiscreteFactorGraph::product() const { + DecisionTreeFactor result; for (const sharedFactor& factor : *this) { - if (factor) { - if (auto f = std::dynamic_pointer_cast(factor)) { - result = result * (*f); - } else if (auto dtf = - std::dynamic_pointer_cast(factor)) { - result = TableFactor(result * (*dtf)); - } - } + if (factor) result = (*factor) * result; } return result; } @@ -124,14 +116,15 @@ namespace gtsam { * product to prevent underflow. * * @param factors The factors to multiply as a DiscreteFactorGraph. - * @return TableFactor + * @return DecisionTreeFactor */ - static TableFactor ProductAndNormalize(const DiscreteFactorGraph& factors) { + static DecisionTreeFactor ProductAndNormalize( + const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors #if GTSAM_HYBRID_TIMING gttic_(DiscreteProduct); #endif - TableFactor product = factors.product(); + DecisionTreeFactor product = factors.product(); #if GTSAM_HYBRID_TIMING gttoc_(DiscreteProduct); #endif @@ -156,11 +149,11 @@ namespace gtsam { std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - TableFactor product = ProductAndNormalize(factors); + DecisionTreeFactor product = ProductAndNormalize(factors); // max out frontals, this is the factor on the separator gttic(max); - TableFactor::shared_ptr max = product.max(frontalKeys); + DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); gttoc(max); // Ordering keys for the conditional so that frontalKeys are really in front @@ -173,8 +166,8 @@ namespace gtsam { // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - auto lookup = std::make_shared( - nrFrontals, orderedKeys, product.toDecisionTreeFactor()); + auto lookup = + std::make_shared(nrFrontals, orderedKeys, product); gttoc(lookup); return {std::dynamic_pointer_cast(lookup), max}; @@ -234,13 +227,13 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - TableFactor product = ProductAndNormalize(factors); + DecisionTreeFactor product = ProductAndNormalize(factors); // sum out frontals, this is the factor on the separator #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteSum); #endif - TableFactor::shared_ptr sum = product.sum(frontalKeys); + DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteSum); #endif @@ -257,7 +250,7 @@ namespace gtsam { gttic_(EliminateDiscreteToDiscreteConditional); #endif auto conditional = - std::make_shared(product, *sum, orderedKeys); + std::make_shared(product, *sum, orderedKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteToDiscreteConditional); #endif diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index f1575cd7e1..c57d2258c2 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -25,7 +25,6 @@ #include #include #include -#include #include #include @@ -148,7 +147,7 @@ class GTSAM_EXPORT DiscreteFactorGraph DiscreteKeys discreteKeys() const; /** return product of all factors as a single factor */ - TableFactor product() const; + DecisionTreeFactor product() const; /** * Evaluates the factor graph given values, returns the joint probability of