Skip to content

Commit

Permalink
undo changes to DiscreteFactorGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Dec 31, 2024
1 parent 9f85d4c commit 9cacb98
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 23 deletions.
35 changes: 14 additions & 21 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteTableConditional.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
Expand Down Expand Up @@ -65,17 +64,10 @@ namespace gtsam {
}

/* ************************************************************************ */
TableFactor DiscreteFactorGraph::product() const {
TableFactor result;
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
for (const sharedFactor& factor : *this) {
if (factor) {
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
result = result * (*f);
} else if (auto dtf =
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
result = TableFactor(result * (*dtf));
}
}
if (factor) result = (*factor) * result;
}
return result;
}
Expand Down Expand Up @@ -124,14 +116,15 @@ namespace gtsam {
* product to prevent underflow.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return TableFactor
* @return DecisionTreeFactor
*/
static TableFactor ProductAndNormalize(const DiscreteFactorGraph& factors) {
static DecisionTreeFactor ProductAndNormalize(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
TableFactor product = factors.product();
DecisionTreeFactor product = factors.product();
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif
Expand All @@ -156,11 +149,11 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
TableFactor product = ProductAndNormalize(factors);
DecisionTreeFactor product = ProductAndNormalize(factors);

// max out frontals, this is the factor on the separator
gttic(max);
TableFactor::shared_ptr max = product.max(frontalKeys);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
gttoc(max);

// Ordering keys for the conditional so that frontalKeys are really in front
Expand All @@ -173,8 +166,8 @@ namespace gtsam {
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup = std::make_shared<DiscreteLookupTable>(
nrFrontals, orderedKeys, product.toDecisionTreeFactor());
auto lookup =
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
gttoc(lookup);

return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
Expand Down Expand Up @@ -234,13 +227,13 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
TableFactor product = ProductAndNormalize(factors);
DecisionTreeFactor product = ProductAndNormalize(factors);

// sum out frontals, this is the factor on the separator
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
TableFactor::shared_ptr sum = product.sum(frontalKeys);
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif
Expand All @@ -257,7 +250,7 @@ namespace gtsam {
gttic_(EliminateDiscreteToDiscreteConditional);
#endif
auto conditional =
std::make_shared<DiscreteTableConditional>(product, *sum, orderedKeys);
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional);
#endif
Expand Down
3 changes: 1 addition & 2 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/TableFactor.h>

#include <string>
#include <utility>
Expand Down Expand Up @@ -148,7 +147,7 @@ class GTSAM_EXPORT DiscreteFactorGraph
DiscreteKeys discreteKeys() const;

/** return product of all factors as a single factor */
TableFactor product() const;
DecisionTreeFactor product() const;

/**
* Evaluates the factor graph given values, returns the joint probability of
Expand Down

0 comments on commit 9cacb98

Please sign in to comment.