Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customize discrete elimination in Hybrid #1955

Merged
merged 19 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,13 @@ namespace gtsam {
static DecisionTreeFactor ProductAndNormalize(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
DecisionTreeFactor product = factors.product();
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif

// Max over all the potentials by pretending all keys are frontal:
auto normalizer = product.max(product.size());

#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Normalize the product factor to prevent underflow.
product = product / (*normalizer);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif

return product;
}
Expand Down Expand Up @@ -230,13 +218,7 @@ namespace gtsam {
DecisionTreeFactor product = ProductAndNormalize(factors);

// sum out frontals, this is the factor on the separator
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif

// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
Expand All @@ -246,14 +228,8 @@ namespace gtsam {
sum->keys().end());

// now divide product/sum to get conditional
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional);
#endif
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional);
#endif

return {conditional, sum};
}
Expand Down
72 changes: 69 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,48 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
return std::make_shared<TableFactor>(discreteKeys, potentials);
}

/**
* @brief Multiply all the `factors` and normalize the
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
* product to prevent underflow.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return TableFactor
*/
static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) {
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
TableFactor product;
for (auto &&factor : factors) {
if (factor) {
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
product = product * (*f);
} else if (auto dtf =
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
product = TableFactor(product * (*dtf));
}
}
}
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif

// Max over all the potentials by pretending all keys are frontal:
auto normalizer = product.max(product.size());
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved

#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Normalize the product factor to prevent underflow.
product = product / (*normalizer);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif

return product;
}

/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
Expand Down Expand Up @@ -299,13 +341,37 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete);
#endif
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);
/**** NOTE: This does sum-product. ****/
// Get product factor
TableFactor product = ProductAndNormalize(dfg);

#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
// All the discrete variables should form a single clique,
// so we can sum out on all the variables as frontals.
// This should give an empty separator.
Ordering orderedKeys(product.keys());
TableFactor::shared_ptr sum = product.sum(orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif

#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteFormDiscreteConditional);
#endif
// Finally, get the conditional
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif

#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif

return {std::make_shared<HybridConditional>(result.first), result.second};
return {std::make_shared<HybridConditional>(conditional), sum};
}

/* ************************************************************************ */
Expand Down
Loading