-
Notifications
You must be signed in to change notification settings - Fork 782
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1890 from NewThinker-Jeffrey/jeffrey/isam2_margin…
…alization
- Loading branch information
Showing
4 changed files
with
471 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
358 changes: 358 additions & 0 deletions
358
gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,358 @@ | ||
/* ---------------------------------------------------------------------------- | ||
* GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||
* Atlanta, Georgia 30332-0415 | ||
* All Rights Reserved | ||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||
* See LICENSE for the license information | ||
* -------------------------------------------------------------------------- */ | ||
|
||
/** | ||
* @file BayesTreeMarginalizationHelper.h | ||
* @brief Helper functions for marginalizing variables from a Bayes Tree. | ||
* | ||
* @author Jeffrey (Zhiwei Wang) | ||
* @date Oct 28, 2024 | ||
*/ | ||
|
||
// \callgraph | ||
#pragma once | ||
|
||
#include <unordered_map> | ||
#include <unordered_set> | ||
#include <deque> | ||
#include <gtsam/inference/BayesTree.h> | ||
#include <gtsam/inference/BayesTreeCliqueBase.h> | ||
#include <gtsam/base/debug.h> | ||
#include "gtsam_unstable/dllexport.h" | ||
|
||
namespace gtsam { | ||
|
||
/** | ||
* This class provides helper functions for marginalizing variables from a Bayes Tree. | ||
*/ | ||
template <typename BayesTree> | ||
class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper { | ||
|
||
public: | ||
using Clique = typename BayesTree::Clique; | ||
using sharedClique = typename BayesTree::sharedClique; | ||
|
||
/** | ||
* This function identifies variables that need to be re-eliminated before | ||
* performing marginalization. | ||
* | ||
* Re-elimination is necessary for a clique containing marginalizable | ||
* variables if: | ||
* | ||
* 1. Some non-marginalizable variables appear before marginalizable ones | ||
* in that clique; | ||
* 2. Or it has a child node depending on a marginalizable variable AND the | ||
* subtree rooted at that child contains non-marginalizables. | ||
* | ||
* In addition, for any descendant node depending on a marginalizable | ||
* variable, if the subtree rooted at that descendant contains | ||
* non-marginalizable variables (i.e., it lies on a path from one of the | ||
* aforementioned cliques that require re-elimination to a node containing | ||
* non-marginalizable variables at the leaf side), then it also needs to | ||
* be re-eliminated. | ||
* | ||
* @param[in] bayesTree The Bayes tree | ||
* @param[in] marginalizableKeys Keys to be marginalized | ||
* @return Set of additional keys that need to be re-eliminated | ||
*/ | ||
static std::unordered_set<Key> | ||
gatherAdditionalKeysToReEliminate( | ||
const BayesTree& bayesTree, | ||
const KeyVector& marginalizableKeys) { | ||
const bool debug = ISDEBUG("BayesTreeMarginalizationHelper"); | ||
|
||
std::unordered_set<const Clique*> additionalCliques = | ||
gatherAdditionalCliquesToReEliminate(bayesTree, marginalizableKeys); | ||
|
||
std::unordered_set<Key> additionalKeys; | ||
for (const Clique* clique : additionalCliques) { | ||
addCliqueToKeySet(clique, &additionalKeys); | ||
} | ||
|
||
if (debug) { | ||
std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; | ||
for (const Key& key : additionalKeys) { | ||
std::cout << DefaultKeyFormatter(key) << " "; | ||
} | ||
std::cout << std::endl; | ||
} | ||
|
||
return additionalKeys; | ||
} | ||
|
||
protected: | ||
/** | ||
* This function identifies cliques that need to be re-eliminated before | ||
* performing marginalization. | ||
* See the docstring of @ref gatherAdditionalKeysToReEliminate(). | ||
*/ | ||
static std::unordered_set<const Clique*> | ||
gatherAdditionalCliquesToReEliminate( | ||
const BayesTree& bayesTree, | ||
const KeyVector& marginalizableKeys) { | ||
std::unordered_set<const Clique*> additionalCliques; | ||
std::unordered_set<Key> marginalizableKeySet( | ||
marginalizableKeys.begin(), marginalizableKeys.end()); | ||
CachedSearch cachedSearch; | ||
|
||
// Check each clique that contains a marginalizable key | ||
for (const Clique* clique : | ||
getCliquesContainingKeys(bayesTree, marginalizableKeySet)) { | ||
if (additionalCliques.count(clique)) { | ||
// The clique has already been visited. This can happen when an | ||
// ancestor of the current clique also contain some marginalizable | ||
// varaibles and it's processed beore the current. | ||
continue; | ||
} | ||
|
||
if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) { | ||
// Add the current clique | ||
additionalCliques.insert(clique); | ||
|
||
// Then add the dependent cliques | ||
gatherDependentCliques(clique, marginalizableKeySet, &additionalCliques, | ||
&cachedSearch); | ||
} | ||
} | ||
return additionalCliques; | ||
} | ||
|
||
/** | ||
* Gather the cliques containing any of the given keys. | ||
* | ||
* @param[in] bayesTree The Bayes tree | ||
* @param[in] keysOfInterest Set of keys of interest | ||
* @return Set of cliques that contain any of the given keys | ||
*/ | ||
static std::unordered_set<const Clique*> getCliquesContainingKeys( | ||
const BayesTree& bayesTree, | ||
const std::unordered_set<Key>& keysOfInterest) { | ||
std::unordered_set<const Clique*> cliques; | ||
for (const Key& key : keysOfInterest) { | ||
cliques.insert(bayesTree[key].get()); | ||
} | ||
return cliques; | ||
} | ||
|
||
/** | ||
* A struct to cache the results of the below two functions. | ||
*/ | ||
struct CachedSearch { | ||
std::unordered_map<const Clique*, bool> wholeMarginalizableCliques; | ||
std::unordered_map<const Clique*, bool> wholeMarginalizableSubtrees; | ||
}; | ||
|
||
/** | ||
* Check if all variables in the clique are marginalizable. | ||
* | ||
* Note we use a cache map to avoid repeated searches. | ||
*/ | ||
static bool isWholeCliqueMarginalizable( | ||
const Clique* clique, | ||
const std::unordered_set<Key>& marginalizableKeys, | ||
CachedSearch* cache) { | ||
auto it = cache->wholeMarginalizableCliques.find(clique); | ||
if (it != cache->wholeMarginalizableCliques.end()) { | ||
return it->second; | ||
} else { | ||
bool ret = true; | ||
for (Key key : clique->conditional()->frontals()) { | ||
if (!marginalizableKeys.count(key)) { | ||
ret = false; | ||
break; | ||
} | ||
} | ||
cache->wholeMarginalizableCliques.insert({clique, ret}); | ||
return ret; | ||
} | ||
} | ||
|
||
/** | ||
* Check if all variables in the subtree are marginalizable. | ||
* | ||
* Note we use a cache map to avoid repeated searches. | ||
*/ | ||
static bool isWholeSubtreeMarginalizable( | ||
const Clique* subtree, | ||
const std::unordered_set<Key>& marginalizableKeys, | ||
CachedSearch* cache) { | ||
auto it = cache->wholeMarginalizableSubtrees.find(subtree); | ||
if (it != cache->wholeMarginalizableSubtrees.end()) { | ||
return it->second; | ||
} else { | ||
bool ret = true; | ||
if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) { | ||
for (const sharedClique& child : subtree->children) { | ||
if (!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { | ||
ret = false; | ||
break; | ||
} | ||
} | ||
} else { | ||
ret = false; | ||
} | ||
cache->wholeMarginalizableSubtrees.insert({subtree, ret}); | ||
return ret; | ||
} | ||
} | ||
|
||
/** | ||
* Check if a clique contains variables that need reelimination due to | ||
* elimination ordering conflicts. | ||
* | ||
* @param[in] clique The clique to check | ||
* @param[in] marginalizableKeys Set of keys to be marginalized | ||
* @return true if any variables in the clique need re-elimination | ||
*/ | ||
static bool needsReelimination( | ||
const Clique* clique, | ||
const std::unordered_set<Key>& marginalizableKeys, | ||
CachedSearch* cache) { | ||
bool hasNonMarginalizableAhead = false; | ||
|
||
// Check each frontal variable in order | ||
for (Key key : clique->conditional()->frontals()) { | ||
if (marginalizableKeys.count(key)) { | ||
// If we've seen non-marginalizable variables before this one, | ||
// we need to reeliminate | ||
if (hasNonMarginalizableAhead) { | ||
return true; | ||
} | ||
|
||
// Check if any child depends on this marginalizable key and the | ||
// subtree rooted at that child contains non-marginalizables. | ||
for (const sharedClique& child : clique->children) { | ||
if (hasDependency(child.get(), key) && | ||
!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { | ||
return true; | ||
} | ||
} | ||
} else { | ||
hasNonMarginalizableAhead = true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
/** | ||
* Gather all dependent nodes that lie on a path from the root clique | ||
* to a clique containing a non-marginalizable variable at the leaf side. | ||
* | ||
* @param[in] rootClique The root clique | ||
* @param[in] marginalizableKeys Set of keys to be marginalized | ||
*/ | ||
static void gatherDependentCliques( | ||
const Clique* rootClique, | ||
const std::unordered_set<Key>& marginalizableKeys, | ||
std::unordered_set<const Clique*>* additionalCliques, | ||
CachedSearch* cache) { | ||
std::vector<const Clique*> dependentChildren; | ||
dependentChildren.reserve(rootClique->children.size()); | ||
for (const sharedClique& child : rootClique->children) { | ||
if (additionalCliques->count(child.get())) { | ||
// This child has already been visited. This can happen if the | ||
// child itself contains a marginalizable variable and it's | ||
// processed before the current rootClique. | ||
continue; | ||
} | ||
if (hasDependency(child.get(), marginalizableKeys)) { | ||
dependentChildren.push_back(child.get()); | ||
} | ||
} | ||
gatherDependentCliquesFromChildren( | ||
dependentChildren, marginalizableKeys, additionalCliques, cache); | ||
} | ||
|
||
/** | ||
* A helper function for the above gatherDependentCliques(). | ||
*/ | ||
static void gatherDependentCliquesFromChildren( | ||
const std::vector<const Clique*>& dependentChildren, | ||
const std::unordered_set<Key>& marginalizableKeys, | ||
std::unordered_set<const Clique*>* additionalCliques, | ||
CachedSearch* cache) { | ||
std::deque<const Clique*> descendants( | ||
dependentChildren.begin(), dependentChildren.end()); | ||
while (!descendants.empty()) { | ||
const Clique* descendant = descendants.front(); | ||
descendants.pop_front(); | ||
|
||
// If the subtree rooted at this descendant contains non-marginalizables, | ||
// it must lie on a path from the root clique to a clique containing | ||
// non-marginalizables at the leaf side. | ||
if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) { | ||
additionalCliques->insert(descendant); | ||
|
||
// Add children of the current descendant to the set descendants. | ||
for (const sharedClique& child : descendant->children) { | ||
if (additionalCliques->count(child.get())) { | ||
// This child has already been visited. | ||
continue; | ||
} else { | ||
descendants.push_back(child.get()); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Add all frontal variables from a clique to a key set. | ||
* | ||
* @param[in] clique Clique to add keys from | ||
* @param[out] additionalKeys Pointer to the output key set | ||
*/ | ||
static void addCliqueToKeySet( | ||
const Clique* clique, | ||
std::unordered_set<Key>* additionalKeys) { | ||
for (Key key : clique->conditional()->frontals()) { | ||
additionalKeys->insert(key); | ||
} | ||
} | ||
|
||
/** | ||
* Check if the clique depends on the given key. | ||
* | ||
* @param[in] clique Clique to check | ||
* @param[in] key Key to check for dependencies | ||
* @return true if clique depends on the key | ||
*/ | ||
static bool hasDependency( | ||
const Clique* clique, Key key) { | ||
auto& conditional = clique->conditional(); | ||
if (std::find(conditional->beginParents(), | ||
conditional->endParents(), key) | ||
!= conditional->endParents()) { | ||
return true; | ||
} else { | ||
return false; | ||
} | ||
} | ||
|
||
/** | ||
* Check if the clique depends on any of the given keys. | ||
*/ | ||
static bool hasDependency( | ||
const Clique* clique, const std::unordered_set<Key>& keys) { | ||
auto& conditional = clique->conditional(); | ||
for (auto it = conditional->beginParents(); | ||
it != conditional->endParents(); ++it) { | ||
if (keys.count(*it)) { | ||
return true; | ||
} | ||
} | ||
|
||
return false; | ||
} | ||
}; | ||
// BayesTreeMarginalizationHelper | ||
|
||
}/// namespace gtsam |
Oops, something went wrong.