Skip to content

Commit 06dac43

Browse files
Some refinement in BayesTreeMarginalizationHelper:
1. Skip subtrees that have already been visited when searching for dependent cliques; 2. Avoid copying shared_ptrs (which needs extra expensive atomic operations) in the searching. Use const Clique* instead of sharedClique whenever possible; 3. Use std::unordered_set instead of std::set to improve average searching speed.
1 parent 0d9c3a9 commit 06dac43

File tree

2 files changed

+105
-66
lines changed

2 files changed

+105
-66
lines changed

gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h

Lines changed: 104 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#pragma once
2222

2323
#include <unordered_map>
24+
#include <unordered_set>
2425
#include <deque>
2526
#include <gtsam/inference/BayesTree.h>
2627
#include <gtsam/inference/BayesTreeCliqueBase.h>
@@ -62,30 +63,18 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
6263
* @param[in] marginalizableKeys Keys to be marginalized
6364
* @return Set of additional keys that need to be re-eliminated
6465
*/
65-
static std::set<Key> gatherAdditionalKeysToReEliminate(
66+
static std::unordered_set<Key>
67+
gatherAdditionalKeysToReEliminate(
6668
const BayesTree& bayesTree,
6769
const KeyVector& marginalizableKeys) {
6870
const bool debug = ISDEBUG("BayesTreeMarginalizationHelper");
6971

70-
std::set<Key> additionalKeys;
71-
std::set<Key> marginalizableKeySet(
72-
marginalizableKeys.begin(), marginalizableKeys.end());
73-
CachedSearch cachedSearch;
74-
75-
// Check each clique that contains a marginalizable key
76-
for (const sharedClique& clique :
77-
getCliquesContainingKeys(bayesTree, marginalizableKeySet)) {
72+
std::unordered_set<const Clique*> additionalCliques =
73+
gatherAdditionalCliquesToReEliminate(bayesTree, marginalizableKeys);
7874

79-
if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) {
80-
// Add frontal variables from current clique
81-
addCliqueToKeySet(clique, &additionalKeys);
82-
83-
// Then add the dependent cliques
84-
for (const sharedClique& dependent :
85-
gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) {
86-
addCliqueToKeySet(dependent, &additionalKeys);
87-
}
88-
}
75+
std::unordered_set<Key> additionalKeys;
76+
for (const Clique* clique : additionalCliques) {
77+
addCliqueToKeySet(clique, &additionalKeys);
8978
}
9079

9180
if (debug) {
@@ -100,6 +89,41 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
10089
}
10190

10291
protected:
92+
/**
93+
* This function identifies cliques that need to be re-eliminated before
94+
* performing marginalization.
95+
* See the docstring of @ref gatherAdditionalKeysToReEliminate().
96+
*/
97+
static std::unordered_set<const Clique*>
98+
gatherAdditionalCliquesToReEliminate(
99+
const BayesTree& bayesTree,
100+
const KeyVector& marginalizableKeys) {
101+
std::unordered_set<const Clique*> additionalCliques;
102+
std::unordered_set<Key> marginalizableKeySet(
103+
marginalizableKeys.begin(), marginalizableKeys.end());
104+
CachedSearch cachedSearch;
105+
106+
// Check each clique that contains a marginalizable key
107+
for (const Clique* clique :
108+
getCliquesContainingKeys(bayesTree, marginalizableKeySet)) {
109+
if (additionalCliques.count(clique)) {
110+
// The clique has already been visited. This can happen when an
111+
// ancestor of the current clique also contain some marginalizable
112+
// varaibles and it's processed beore the current.
113+
continue;
114+
}
115+
116+
if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) {
117+
// Add the current clique
118+
additionalCliques.insert(clique);
119+
120+
// Then add the dependent cliques
121+
gatherDependentCliques(clique, marginalizableKeySet, &additionalCliques,
122+
&cachedSearch);
123+
}
124+
}
125+
return additionalCliques;
126+
}
103127

104128
/**
105129
* Gather the cliques containing any of the given keys.
@@ -108,12 +132,12 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
108132
* @param[in] keysOfInterest Set of keys of interest
109133
* @return Set of cliques that contain any of the given keys
110134
*/
111-
static std::set<sharedClique> getCliquesContainingKeys(
135+
static std::unordered_set<const Clique*> getCliquesContainingKeys(
112136
const BayesTree& bayesTree,
113-
const std::set<Key>& keysOfInterest) {
114-
std::set<sharedClique> cliques;
137+
const std::unordered_set<Key>& keysOfInterest) {
138+
std::unordered_set<const Clique*> cliques;
115139
for (const Key& key : keysOfInterest) {
116-
cliques.insert(bayesTree[key]);
140+
cliques.insert(bayesTree[key].get());
117141
}
118142
return cliques;
119143
}
@@ -122,8 +146,8 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
122146
* A struct to cache the results of the below two functions.
123147
*/
124148
struct CachedSearch {
125-
std::unordered_map<Clique*, bool> wholeMarginalizableCliques;
126-
std::unordered_map<Clique*, bool> wholeMarginalizableSubtrees;
149+
std::unordered_map<const Clique*, bool> wholeMarginalizableCliques;
150+
std::unordered_map<const Clique*, bool> wholeMarginalizableSubtrees;
127151
};
128152

129153
/**
@@ -132,10 +156,10 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
132156
* Note we use a cache map to avoid repeated searches.
133157
*/
134158
static bool isWholeCliqueMarginalizable(
135-
const sharedClique& clique,
136-
const std::set<Key>& marginalizableKeys,
159+
const Clique* clique,
160+
const std::unordered_set<Key>& marginalizableKeys,
137161
CachedSearch* cache) {
138-
auto it = cache->wholeMarginalizableCliques.find(clique.get());
162+
auto it = cache->wholeMarginalizableCliques.find(clique);
139163
if (it != cache->wholeMarginalizableCliques.end()) {
140164
return it->second;
141165
} else {
@@ -146,7 +170,7 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
146170
break;
147171
}
148172
}
149-
cache->wholeMarginalizableCliques.insert({clique.get(), ret});
173+
cache->wholeMarginalizableCliques.insert({clique, ret});
150174
return ret;
151175
}
152176
}
@@ -157,25 +181,25 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
157181
* Note we use a cache map to avoid repeated searches.
158182
*/
159183
static bool isWholeSubtreeMarginalizable(
160-
const sharedClique& subtree,
161-
const std::set<Key>& marginalizableKeys,
184+
const Clique* subtree,
185+
const std::unordered_set<Key>& marginalizableKeys,
162186
CachedSearch* cache) {
163-
auto it = cache->wholeMarginalizableSubtrees.find(subtree.get());
187+
auto it = cache->wholeMarginalizableSubtrees.find(subtree);
164188
if (it != cache->wholeMarginalizableSubtrees.end()) {
165189
return it->second;
166190
} else {
167191
bool ret = true;
168192
if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) {
169193
for (const sharedClique& child : subtree->children) {
170-
if (!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) {
194+
if (!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) {
171195
ret = false;
172196
break;
173197
}
174198
}
175199
} else {
176200
ret = false;
177201
}
178-
cache->wholeMarginalizableSubtrees.insert({subtree.get(), ret});
202+
cache->wholeMarginalizableSubtrees.insert({subtree, ret});
179203
return ret;
180204
}
181205
}
@@ -189,8 +213,8 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
189213
* @return true if any variables in the clique need re-elimination
190214
*/
191215
static bool needsReelimination(
192-
const sharedClique& clique,
193-
const std::set<Key>& marginalizableKeys,
216+
const Clique* clique,
217+
const std::unordered_set<Key>& marginalizableKeys,
194218
CachedSearch* cache) {
195219
bool hasNonMarginalizableAhead = false;
196220

@@ -206,8 +230,8 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
206230
// Check if any child depends on this marginalizable key and the
207231
// subtree rooted at that child contains non-marginalizables.
208232
for (const sharedClique& child : clique->children) {
209-
if (hasDependency(child, key) &&
210-
!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) {
233+
if (hasDependency(child.get(), key) &&
234+
!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) {
211235
return true;
212236
}
213237
}
@@ -225,47 +249,59 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
225249
* @param[in] rootClique The root clique
226250
* @param[in] marginalizableKeys Set of keys to be marginalized
227251
*/
228-
static std::set<sharedClique> gatherDependentCliques(
229-
const sharedClique& rootClique,
230-
const std::set<Key>& marginalizableKeys,
252+
static void gatherDependentCliques(
253+
const Clique* rootClique,
254+
const std::unordered_set<Key>& marginalizableKeys,
255+
std::unordered_set<const Clique*>* additionalCliques,
231256
CachedSearch* cache) {
232-
std::vector<sharedClique> dependentChildren;
257+
std::vector<const Clique*> dependentChildren;
233258
dependentChildren.reserve(rootClique->children.size());
234259
for (const sharedClique& child : rootClique->children) {
235-
if (hasDependency(child, marginalizableKeys)) {
236-
dependentChildren.push_back(child);
260+
if (additionalCliques->count(child.get())) {
261+
// This child has already been visited. This can happen if the
262+
// child itself contains a marginalizable variable and it's
263+
// processed before the current rootClique.
264+
continue;
265+
}
266+
if (hasDependency(child.get(), marginalizableKeys)) {
267+
dependentChildren.push_back(child.get());
237268
}
238269
}
239-
return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache);
270+
gatherDependentCliquesFromChildren(
271+
dependentChildren, marginalizableKeys, additionalCliques, cache);
240272
}
241273

242274
/**
243275
* A helper function for the above gatherDependentCliques().
244276
*/
245-
static std::set<sharedClique> gatherDependentCliquesFromChildren(
246-
const std::vector<sharedClique>& dependentChildren,
247-
const std::set<Key>& marginalizableKeys,
277+
static void gatherDependentCliquesFromChildren(
278+
const std::vector<const Clique*>& dependentChildren,
279+
const std::unordered_set<Key>& marginalizableKeys,
280+
std::unordered_set<const Clique*>* additionalCliques,
248281
CachedSearch* cache) {
249-
std::deque<sharedClique> descendants(
282+
std::deque<const Clique*> descendants(
250283
dependentChildren.begin(), dependentChildren.end());
251-
std::set<sharedClique> dependentCliques;
252284
while (!descendants.empty()) {
253-
sharedClique descendant = descendants.front();
285+
const Clique* descendant = descendants.front();
254286
descendants.pop_front();
255287

256288
// If the subtree rooted at this descendant contains non-marginalizables,
257289
// it must lie on a path from the root clique to a clique containing
258290
// non-marginalizables at the leaf side.
259291
if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) {
260-
dependentCliques.insert(descendant);
261-
}
262-
263-
// Add all children of the current descendant to the set descendants.
264-
for (const sharedClique& child : descendant->children) {
265-
descendants.push_back(child);
292+
additionalCliques->insert(descendant);
293+
294+
// Add children of the current descendant to the set descendants.
295+
for (const sharedClique& child : descendant->children) {
296+
if (additionalCliques->count(child.get())) {
297+
// This child has already been visited.
298+
continue;
299+
} else {
300+
descendants.push_back(child.get());
301+
}
302+
}
266303
}
267304
}
268-
return dependentCliques;
269305
}
270306

271307
/**
@@ -275,8 +311,8 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
275311
* @param[out] additionalKeys Pointer to the output key set
276312
*/
277313
static void addCliqueToKeySet(
278-
const sharedClique& clique,
279-
std::set<Key>* additionalKeys) {
314+
const Clique* clique,
315+
std::unordered_set<Key>* additionalKeys) {
280316
for (Key key : clique->conditional()->frontals()) {
281317
additionalKeys->insert(key);
282318
}
@@ -290,8 +326,8 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
290326
* @return true if clique depends on the key
291327
*/
292328
static bool hasDependency(
293-
const sharedClique& clique, Key key) {
294-
auto conditional = clique->conditional();
329+
const Clique* clique, Key key) {
330+
auto& conditional = clique->conditional();
295331
if (std::find(conditional->beginParents(),
296332
conditional->endParents(), key)
297333
!= conditional->endParents()) {
@@ -305,12 +341,15 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
305341
* Check if the clique depends on any of the given keys.
306342
*/
307343
static bool hasDependency(
308-
const sharedClique& clique, const std::set<Key>& keys) {
309-
for (Key key : keys) {
310-
if (hasDependency(clique, key)) {
344+
const Clique* clique, const std::unordered_set<Key>& keys) {
345+
auto& conditional = clique->conditional();
346+
for (auto it = conditional->beginParents();
347+
it != conditional->endParents(); ++it) {
348+
if (keys.count(*it)) {
311349
return true;
312350
}
313351
}
352+
314353
return false;
315354
}
316355
};

gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update(
9393
std::cout << std::endl;
9494
}
9595

96-
std::set<Key> additionalKeys =
96+
std::unordered_set<Key> additionalKeys =
9797
BayesTreeMarginalizationHelper<ISAM2>::gatherAdditionalKeysToReEliminate(
9898
isam_, marginalizableKeys);
9999
KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end());

0 commit comments

Comments
 (0)