Skip to content

Discrete Improvements #1947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,23 @@ namespace gtsam {
// }

/**
* @brief Multiply all the `factors` and normalize the
* product to prevent underflow.
* @brief Multiply all the `factors`.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor
*/
static DecisionTreeFactor ProductAndNormalize(
static DecisionTreeFactor DiscreteProduct(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product = factors.product();
gttoc(product);

// Max over all the potentials by pretending all keys are frontal:
auto normalization = product.max(product.size());
auto denominator = product.max(product.size());

// Normalize the product factor to prevent underflow.
product = product / (*normalization);
product = product / (*denominator);

return product;
}
Expand All @@ -139,7 +138,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors);
DecisionTreeFactor product = DiscreteProduct(factors);

// max out frontals, this is the factor on the separator
gttic(max);
Expand Down Expand Up @@ -207,8 +206,7 @@ namespace gtsam {
return dag.argmax();
}

DiscreteValues DiscreteFactorGraph::optimize(
const Ordering& ordering) const {
DiscreteValues DiscreteFactorGraph::optimize(const Ordering& ordering) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(ordering);
return dag.argmax();
Expand All @@ -218,7 +216,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors);
DecisionTreeFactor product = DiscreteProduct(factors);

// sum out frontals, this is the factor on the separator
gttic(sum);
Expand Down
37 changes: 4 additions & 33 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,41 +252,12 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();

// Record key assignment and value pairs in pair_table.
// The assignments are stored in descending order of keys so that the order of
// the values matches what is expected by a DecisionTree.
// This is why we reverse the keys and then
// query for the key value/assignment.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
std::vector<std::pair<uint64_t, double>> pair_table;
for (auto i = 0; i < sparse_table_.size(); i++) {
std::stringstream ss;
for (auto&& [key, _] : rdkeys) {
ss << keyValueForIndex(key, i);
}
// k will be in reverse key order already
uint64_t k;
ss >> k;
pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i)));
}

// Sort the pair_table (of assignment-value pairs) based on assignment so we
// get values in reverse key order.
std::sort(
pair_table.begin(), pair_table.end(),
[](const std::pair<uint64_t, double>& a,
const std::pair<uint64_t, double>& b) { return a.first < b.first; });

// Create the table vector by extracting the values from pair_table.
// The pair_table has already been sorted in the desired order,
// so the values will be in descending key order.
std::vector<double> table;
std::for_each(pair_table.begin(), pair_table.end(),
[&table](const std::pair<uint64_t, double>& pair) {
table.push_back(pair.second);
});
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
}

AlgebraicDecisionTree<Key> tree(rdkeys, table);
AlgebraicDecisionTree<Key> tree(dkeys, table);
DecisionTreeFactor f(dkeys, tree);
return f;
}
Expand Down
10 changes: 5 additions & 5 deletions gtsam/discrete/tests/testDiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);

// Normalize newFactor by max for comparison with expected
auto normalization = newFactor.max(newFactor.size());
auto normalizer = newFactor.max(newFactor.size());

newFactor = newFactor / *normalization;
newFactor = newFactor / *normalizer;

// Check Conditional
CHECK(conditional);
Expand All @@ -131,9 +131,9 @@ TEST(DiscreteFactorGraph, test) {
CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
// Normalize by max.
normalization = expectedFactor.max(expectedFactor.size());
// Ensure normalization is correct.
expectedFactor = expectedFactor / *normalization;
normalizer = expectedFactor.max(expectedFactor.size());
// Ensure normalizer is correct.
expectedFactor = expectedFactor / *normalizer;
EXPECT(assert_equal(expectedFactor, newFactor));

// Test using elimination tree
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
auto dc = hc->asDiscrete();
if (!dc) throwRuntimeError("discreteElimination", dc);
dfg.push_back(hc->asDiscrete());
dfg.push_back(dc);
} else {
throwRuntimeError("discreteElimination", f);
}
Expand Down
Loading