Skip to content

Commit

Permalink
Merge branch 'develop' into discrete-elimination-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Dec 11, 2024
2 parents cc4e9cb + 2c9e315 commit 0b3f058
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 44 deletions.
71 changes: 28 additions & 43 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* @date Feb 14, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
* @author Varun Agrawal
*/

#include <gtsam/discrete/DiscreteBayesTree.h>
Expand All @@ -35,13 +36,12 @@ namespace gtsam {
template class FactorGraph<DiscreteFactor>;
template class EliminateableFactorGraph<DiscreteFactorGraph>;

/* ************************************************************************* */
bool DiscreteFactorGraph::equals(const This& fg, double tol) const
{
/* ************************************************************************ */
bool DiscreteFactorGraph::equals(const This& fg, double tol) const {
return Base::equals(fg, tol);
}

/* ************************************************************************* */
/* ************************************************************************ */
KeySet DiscreteFactorGraph::keys() const {
KeySet keys;
for (const sharedFactor& factor : *this) {
Expand All @@ -50,7 +50,7 @@ namespace gtsam {
return keys;
}

/* ************************************************************************* */
/* ************************************************************************ */
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result;
for (auto&& factor : *this) {
Expand All @@ -63,7 +63,7 @@ namespace gtsam {
return result;
}

/* ************************************************************************* */
/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
for (const sharedFactor& factor : *this) {
Expand All @@ -72,18 +72,18 @@ namespace gtsam {
return result;
}

/* ************************************************************************* */
double DiscreteFactorGraph::operator()(
const DiscreteValues &values) const {
/* ************************************************************************ */
double DiscreteFactorGraph::operator()(const DiscreteValues& values) const {
double product = 1.0;
for( const sharedFactor& factor: factors_ )
product *= (*factor)(values);
for (const sharedFactor& factor : factors_) {
if (factor) product *= (*factor)(values);
}
return product;
}

/* ************************************************************************* */
/* ************************************************************************ */
void DiscreteFactorGraph::print(const string& s,
const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const {
std::cout << s << std::endl;
std::cout << "size: " << size() << std::endl;
for (size_t i = 0; i < factors_.size(); i++) {
Expand Down Expand Up @@ -112,43 +112,36 @@ namespace gtsam {
// }

/**
* @brief Helper method to normalize the product factor by
* the max value to prevent underflow
* @brief Multiply all the `factors` and normalize the
* product to prevent underflow.
*
* @param product The product discrete factor.
* @return DiscreteFactor::shared_ptr
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor
*/
static DecisionTreeFactor Normalize(const DecisionTreeFactor& product) {
static DecisionTreeFactor ProductAndNormalize(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product = factors.product();
gttoc(product);

// Max over all the potentials by pretending all keys are frontal:
gttic_(DiscreteFindMax);
auto normalization = product.max(product.size());
gttoc_(DiscreteFindMax);

gttic_(DiscreteNormalization);
// Normalize the product factor to prevent underflow.
auto normalized_product =
product /
(*std::dynamic_pointer_cast<DecisionTreeFactor>(normalization));
gttoc_(DiscreteNormalization);

return normalized_product;
}

/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic_(MPEProduct);
DecisionTreeFactor product = factors.product();
gttoc_(MPEProduct);

gttic_(Normalize);

// Normalize the product
product = Normalize(product);
gttoc_(Normalize);
DecisionTreeFactor product = ProductAndNormalize(factors);

// max out frontals, this is the factor on the separator
gttic(max);
Expand Down Expand Up @@ -225,18 +218,10 @@ namespace gtsam {
}

/* ************************************************************************ */
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic_(product);
DecisionTreeFactor product = factors.product();
gttoc_(product);

gttic_(Normalize);
// Normalize the product
product = Normalize(product);
gttoc_(Normalize);
DecisionTreeFactor product = ProductAndNormalize(factors);

// sum out frontals, this is the factor on the separator
gttic_(sum);
Expand Down
1 change: 1 addition & 0 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* @date Feb 14, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
* @author Varun Agrawal
*/

#pragma once
Expand Down
3 changes: 2 additions & 1 deletion gtsam/discrete/tests/testDiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ TEST(DiscreteFactorGraph, test) {
const Ordering frontalKeys{0};
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);

auto newFactor = *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
DecisionTreeFactor newFactor =
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);

// Normalize newFactor by max for comparison with expected
auto normalization = newFactor.max(newFactor.size());
Expand Down

0 comments on commit 0b3f058

Please sign in to comment.