Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customize discrete elimination in Hybrid #1955

Merged
merged 19 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
#include <gtsam/hybrid/HybridValues.h>

#include <algorithm>
#include <cassert>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
#include <cassert>

using namespace std;
using std::pair;
Expand All @@ -47,6 +47,15 @@ DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}

/* ************************************************************************** */
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
const DecisionTreeFactor& f,
const Ordering& orderedKeys)
: BaseFactor(f), BaseConditional(nrFrontals) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
}

/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
const DiscreteKeys& keys,
Expand Down
11 changes: 11 additions & 0 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ class GTSAM_EXPORT DiscreteConditional
/// Construct from factor, taking the first `nFrontals` keys as frontals.
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);

/**
* @brief Construct from DecisionTreeFactor,
* taking the first `nrFrontals` from `orderedKeys`.
*
* @param nrFrontals The number of frontal variables.
* @param f The DecisionTreeFactor to construct from.
* @param orderedKeys Ordered list of keys involved in the conditional.
*/
DiscreteConditional(size_t nrFrontals, const DecisionTreeFactor& f,
const Ordering& orderedKeys);

/**
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
* `nFrontals` keys as frontals, in the order given.
Expand Down
24 changes: 0 additions & 24 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,13 @@ namespace gtsam {
static DecisionTreeFactor ProductAndNormalize(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
DecisionTreeFactor product = factors.product();
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif

// Max over all the potentials by pretending all keys are frontal:
auto normalizer = product.max(product.size());

#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Normalize the product factor to prevent underflow.
product = product / (*normalizer);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif

return product;
}
Expand Down Expand Up @@ -230,13 +218,7 @@ namespace gtsam {
DecisionTreeFactor product = ProductAndNormalize(factors);

// sum out frontals, this is the factor on the separator
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif

// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
Expand All @@ -246,14 +228,8 @@ namespace gtsam {
sum->keys().end());

// now divide product/sum to get conditional
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional);
#endif
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional);
#endif

return {conditional, sum};
}
Expand Down
9 changes: 9 additions & 0 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,15 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();

// If no keys, then return empty DecisionTreeFactor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in separate PR, with unit test you already have

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in #1960

if (dkeys.size() == 0) {
AlgebraicDecisionTree<Key> tree;
if (sparse_table_.size() != 0) {
tree = AlgebraicDecisionTree<Key>(sparse_table_.coeff(0));
}
return DecisionTreeFactor(dkeys, tree);
}

std::vector<double> table;
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
Expand Down
76 changes: 73 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,48 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
return std::make_shared<TableFactor>(discreteKeys, potentials);
}

/**
* @brief Multiply all the `factors` and normalize the
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
* product to prevent underflow.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return TableFactor
*/
static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) {
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
TableFactor product;
for (auto &&factor : factors) {
if (factor) {
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
product = product * (*f);
} else if (auto dtf =
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
product = TableFactor(product * (*dtf));
}
}
}
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif

// Max over all the potentials by pretending all keys are frontal:
auto normalizer = product.max(product.size());
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved

#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Normalize the product factor to prevent underflow.
product = product / (*normalizer);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif

return product;
}

/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
Expand Down Expand Up @@ -299,13 +341,41 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete);
#endif
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);
/**** NOTE: This does sum-product. ****/
// Get product factor
TableFactor product = ProductAndNormalize(dfg);

#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
// All the discrete variables should form a single clique,
// so we can sum out on all the variables as frontals.
// This should give an empty separator.
TableFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif

// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());

#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteFormDiscreteConditional);
#endif
auto c = product / (*sum);
auto conditional = std::make_shared<DiscreteConditional>(
frontalKeys.size(), c.toDecisionTreeFactor(), orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif

#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif

return {std::make_shared<HybridConditional>(result.first), result.second};
return {std::make_shared<HybridConditional>(conditional), sum};
}

/* ************************************************************************ */
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/tests/testGaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
}
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
Loading