Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various Discrete Improvements #1927

Merged
merged 9 commits into from
Dec 10, 2024
Merged
10 changes: 3 additions & 7 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,9 @@ class GTSAM_EXPORT DiscreteConditional
static_cast<const BaseConditional*>(this)->print(s, formatter);
}

/// Evaluate, just look up in AlgebraicDecisionTree
virtual double evaluate(const Assignment<Key>& values) const override {
return ADT::operator()(values);
}

using DecisionTreeFactor::error; ///< DiscreteValues version
using DiscreteFactor::operator(); ///< DiscreteValues version
using BaseFactor::error; ///< DiscreteValues version
using BaseFactor::evaluate; ///< DiscreteValues version
using BaseFactor::operator(); ///< DiscreteValues version

/**
* @brief restrict to given *parent* values.
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DiscreteDistribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
/// Default constructor needed for serialization.
DiscreteDistribution() {}

/// Constructor from factor.
/// Constructor from DecisionTreeFactor.
explicit DiscreteDistribution(const DecisionTreeFactor& f)
: Base(f.size(), f) {}

Expand Down
76 changes: 40 additions & 36 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* @date Feb 14, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
* @author Varun Agrawal
*/

#include <gtsam/discrete/DiscreteBayesTree.h>
Expand All @@ -35,13 +36,12 @@ namespace gtsam {
template class FactorGraph<DiscreteFactor>;
template class EliminateableFactorGraph<DiscreteFactorGraph>;

/* ************************************************************************* */
bool DiscreteFactorGraph::equals(const This& fg, double tol) const
{
/* ************************************************************************ */
bool DiscreteFactorGraph::equals(const This& fg, double tol) const {
return Base::equals(fg, tol);
}

/* ************************************************************************* */
/* ************************************************************************ */
KeySet DiscreteFactorGraph::keys() const {
KeySet keys;
for (const sharedFactor& factor : *this) {
Expand All @@ -50,11 +50,11 @@ namespace gtsam {
return keys;
}

/* ************************************************************************* */
/* ************************************************************************ */
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result;
for (auto&& factor : *this) {
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
DiscreteKeys factor_keys = p->discreteKeys();
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
}
Expand All @@ -63,26 +63,27 @@ namespace gtsam {
return result;
}

/* ************************************************************************* */
/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
for(const sharedFactor& factor: *this)
for (const sharedFactor& factor : *this) {
if (factor) result = (*factor) * result;
}
return result;
}

/* ************************************************************************* */
double DiscreteFactorGraph::operator()(
const DiscreteValues &values) const {
/* ************************************************************************ */
double DiscreteFactorGraph::operator()(const DiscreteValues& values) const {
double product = 1.0;
for( const sharedFactor& factor: factors_ )
product *= (*factor)(values);
for (const sharedFactor& factor : factors_) {
if (factor) product *= (*factor)(values);
}
return product;
}

/* ************************************************************************* */
/* ************************************************************************ */
void DiscreteFactorGraph::print(const string& s,
const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const {
std::cout << s << std::endl;
std::cout << "size: " << size() << std::endl;
for (size_t i = 0; i < factors_.size(); i++) {
Expand Down Expand Up @@ -110,15 +111,18 @@ namespace gtsam {
// }
// }

/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
/**
* @brief Multiply all the `factors` and normalize the
* product to prevent underflow.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor
*/
static DecisionTreeFactor ProductAndNormalize(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
DecisionTreeFactor product = factors.product();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the biggest fix in this PR since now we also check for nullptrs.

gttoc(product);

// Max over all the potentials by pretending all keys are frontal:
Expand All @@ -127,6 +131,16 @@ namespace gtsam {
// Normalize the product factor to prevent underflow.
product = product / (*normalization);

return product;
}

/* ************************************************************************ */
// Alternate eliminate function for MPE
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors);

// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
Expand All @@ -142,8 +156,8 @@ namespace gtsam {
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup = std::make_shared<DiscreteLookupTable>(nrFrontals,
orderedKeys, product);
auto lookup =
std::make_shared<DiscreteLookupTable>(nrFrontals, orderedKeys, product);
gttoc(lookup);

return {std::dynamic_pointer_cast<DiscreteConditional>(lookup), max};
Expand Down Expand Up @@ -201,20 +215,10 @@ namespace gtsam {
}

/* ************************************************************************ */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product);

// Max over all the potentials by pretending all keys are frontal:
auto normalization = product.max(product.size());

// Normalize the product factor to prevent underflow.
product = product / (*normalization);
DecisionTreeFactor product = ProductAndNormalize(factors);

// sum out frontals, this is the factor on the separator
gttic(sum);
Expand Down
5 changes: 3 additions & 2 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* @date Feb 14, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
* @author Varun Agrawal
*/

#pragma once
Expand Down Expand Up @@ -48,7 +49,7 @@ class DiscreteJunctionTree;
* @ingroup discrete
*/
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);

Expand All @@ -61,7 +62,7 @@ EliminateDiscrete(const DiscreteFactorGraph& factors,
* @ingroup discrete
*/
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr>
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);

Expand Down
3 changes: 2 additions & 1 deletion gtsam/discrete/tests/testDiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ TEST(DiscreteFactorGraph, test) {
const Ordering frontalKeys{0};
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);

DecisionTreeFactor newFactor = *newFactorPtr;
DecisionTreeFactor newFactor =
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);

// Normalize newFactor by max for comparison with expected
auto normalization = newFactor.max(newFactor.size());
Expand Down
Loading