21
21
#pragma once
22
22
23
23
#include < unordered_map>
24
+ #include < unordered_set>
24
25
#include < deque>
25
26
#include < gtsam/inference/BayesTree.h>
26
27
#include < gtsam/inference/BayesTreeCliqueBase.h>
@@ -62,30 +63,18 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
62
63
* @param[in] marginalizableKeys Keys to be marginalized
63
64
* @return Set of additional keys that need to be re-eliminated
64
65
*/
65
- static std::set<Key> gatherAdditionalKeysToReEliminate (
66
+ static std::unordered_set<Key>
67
+ gatherAdditionalKeysToReEliminate (
66
68
const BayesTree& bayesTree,
67
69
const KeyVector& marginalizableKeys) {
68
70
const bool debug = ISDEBUG (" BayesTreeMarginalizationHelper" );
69
71
70
- std::set<Key> additionalKeys;
71
- std::set<Key> marginalizableKeySet (
72
- marginalizableKeys.begin (), marginalizableKeys.end ());
73
- CachedSearch cachedSearch;
74
-
75
- // Check each clique that contains a marginalizable key
76
- for (const sharedClique& clique :
77
- getCliquesContainingKeys (bayesTree, marginalizableKeySet)) {
72
+ std::unordered_set<const Clique*> additionalCliques =
73
+ gatherAdditionalCliquesToReEliminate (bayesTree, marginalizableKeys);
78
74
79
- if (needsReelimination (clique, marginalizableKeySet, &cachedSearch)) {
80
- // Add frontal variables from current clique
81
- addCliqueToKeySet (clique, &additionalKeys);
82
-
83
- // Then add the dependent cliques
84
- for (const sharedClique& dependent :
85
- gatherDependentCliques (clique, marginalizableKeySet, &cachedSearch)) {
86
- addCliqueToKeySet (dependent, &additionalKeys);
87
- }
88
- }
75
+ std::unordered_set<Key> additionalKeys;
76
+ for (const Clique* clique : additionalCliques) {
77
+ addCliqueToKeySet (clique, &additionalKeys);
89
78
}
90
79
91
80
if (debug) {
@@ -100,6 +89,41 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
100
89
}
101
90
102
91
protected:
92
+ /* *
93
+ * This function identifies cliques that need to be re-eliminated before
94
+ * performing marginalization.
95
+ * See the docstring of @ref gatherAdditionalKeysToReEliminate().
96
+ */
97
+ static std::unordered_set<const Clique*>
98
+ gatherAdditionalCliquesToReEliminate (
99
+ const BayesTree& bayesTree,
100
+ const KeyVector& marginalizableKeys) {
101
+ std::unordered_set<const Clique*> additionalCliques;
102
+ std::unordered_set<Key> marginalizableKeySet (
103
+ marginalizableKeys.begin (), marginalizableKeys.end ());
104
+ CachedSearch cachedSearch;
105
+
106
+ // Check each clique that contains a marginalizable key
107
+ for (const Clique* clique :
108
+ getCliquesContainingKeys (bayesTree, marginalizableKeySet)) {
109
+ if (additionalCliques.count (clique)) {
110
+ // The clique has already been visited. This can happen when an
111
+ // ancestor of the current clique also contain some marginalizable
112
+ // varaibles and it's processed beore the current.
113
+ continue ;
114
+ }
115
+
116
+ if (needsReelimination (clique, marginalizableKeySet, &cachedSearch)) {
117
+ // Add the current clique
118
+ additionalCliques.insert (clique);
119
+
120
+ // Then add the dependent cliques
121
+ gatherDependentCliques (clique, marginalizableKeySet, &additionalCliques,
122
+ &cachedSearch);
123
+ }
124
+ }
125
+ return additionalCliques;
126
+ }
103
127
104
128
/* *
105
129
* Gather the cliques containing any of the given keys.
@@ -108,12 +132,12 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
108
132
* @param[in] keysOfInterest Set of keys of interest
109
133
* @return Set of cliques that contain any of the given keys
110
134
*/
111
- static std::set<sharedClique > getCliquesContainingKeys (
135
+ static std::unordered_set< const Clique* > getCliquesContainingKeys (
112
136
const BayesTree& bayesTree,
113
- const std::set <Key>& keysOfInterest) {
114
- std::set<sharedClique > cliques;
137
+ const std::unordered_set <Key>& keysOfInterest) {
138
+ std::unordered_set< const Clique* > cliques;
115
139
for (const Key& key : keysOfInterest) {
116
- cliques.insert (bayesTree[key]);
140
+ cliques.insert (bayesTree[key]. get () );
117
141
}
118
142
return cliques;
119
143
}
@@ -122,8 +146,8 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
122
146
* A struct to cache the results of the below two functions.
123
147
*/
124
148
struct CachedSearch {
125
- std::unordered_map<Clique*, bool > wholeMarginalizableCliques;
126
- std::unordered_map<Clique*, bool > wholeMarginalizableSubtrees;
149
+ std::unordered_map<const Clique*, bool > wholeMarginalizableCliques;
150
+ std::unordered_map<const Clique*, bool > wholeMarginalizableSubtrees;
127
151
};
128
152
129
153
/* *
@@ -132,10 +156,10 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
132
156
* Note we use a cache map to avoid repeated searches.
133
157
*/
134
158
static bool isWholeCliqueMarginalizable (
135
- const sharedClique& clique,
136
- const std::set <Key>& marginalizableKeys,
159
+ const Clique* clique,
160
+ const std::unordered_set <Key>& marginalizableKeys,
137
161
CachedSearch* cache) {
138
- auto it = cache->wholeMarginalizableCliques .find (clique. get () );
162
+ auto it = cache->wholeMarginalizableCliques .find (clique);
139
163
if (it != cache->wholeMarginalizableCliques .end ()) {
140
164
return it->second ;
141
165
} else {
@@ -146,7 +170,7 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
146
170
break ;
147
171
}
148
172
}
149
- cache->wholeMarginalizableCliques .insert ({clique. get () , ret});
173
+ cache->wholeMarginalizableCliques .insert ({clique, ret});
150
174
return ret;
151
175
}
152
176
}
@@ -157,25 +181,25 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
157
181
* Note we use a cache map to avoid repeated searches.
158
182
*/
159
183
static bool isWholeSubtreeMarginalizable (
160
- const sharedClique& subtree,
161
- const std::set <Key>& marginalizableKeys,
184
+ const Clique* subtree,
185
+ const std::unordered_set <Key>& marginalizableKeys,
162
186
CachedSearch* cache) {
163
- auto it = cache->wholeMarginalizableSubtrees .find (subtree. get () );
187
+ auto it = cache->wholeMarginalizableSubtrees .find (subtree);
164
188
if (it != cache->wholeMarginalizableSubtrees .end ()) {
165
189
return it->second ;
166
190
} else {
167
191
bool ret = true ;
168
192
if (isWholeCliqueMarginalizable (subtree, marginalizableKeys, cache)) {
169
193
for (const sharedClique& child : subtree->children ) {
170
- if (!isWholeSubtreeMarginalizable (child, marginalizableKeys, cache)) {
194
+ if (!isWholeSubtreeMarginalizable (child. get () , marginalizableKeys, cache)) {
171
195
ret = false ;
172
196
break ;
173
197
}
174
198
}
175
199
} else {
176
200
ret = false ;
177
201
}
178
- cache->wholeMarginalizableSubtrees .insert ({subtree. get () , ret});
202
+ cache->wholeMarginalizableSubtrees .insert ({subtree, ret});
179
203
return ret;
180
204
}
181
205
}
@@ -189,8 +213,8 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
189
213
* @return true if any variables in the clique need re-elimination
190
214
*/
191
215
static bool needsReelimination (
192
- const sharedClique& clique,
193
- const std::set <Key>& marginalizableKeys,
216
+ const Clique* clique,
217
+ const std::unordered_set <Key>& marginalizableKeys,
194
218
CachedSearch* cache) {
195
219
bool hasNonMarginalizableAhead = false ;
196
220
@@ -206,8 +230,8 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
206
230
// Check if any child depends on this marginalizable key and the
207
231
// subtree rooted at that child contains non-marginalizables.
208
232
for (const sharedClique& child : clique->children ) {
209
- if (hasDependency (child, key) &&
210
- !isWholeSubtreeMarginalizable (child, marginalizableKeys, cache)) {
233
+ if (hasDependency (child. get () , key) &&
234
+ !isWholeSubtreeMarginalizable (child. get () , marginalizableKeys, cache)) {
211
235
return true ;
212
236
}
213
237
}
@@ -225,47 +249,59 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
225
249
* @param[in] rootClique The root clique
226
250
* @param[in] marginalizableKeys Set of keys to be marginalized
227
251
*/
228
- static std::set<sharedClique> gatherDependentCliques (
229
- const sharedClique& rootClique,
230
- const std::set<Key>& marginalizableKeys,
252
+ static void gatherDependentCliques (
253
+ const Clique* rootClique,
254
+ const std::unordered_set<Key>& marginalizableKeys,
255
+ std::unordered_set<const Clique*>* additionalCliques,
231
256
CachedSearch* cache) {
232
- std::vector<sharedClique > dependentChildren;
257
+ std::vector<const Clique* > dependentChildren;
233
258
dependentChildren.reserve (rootClique->children .size ());
234
259
for (const sharedClique& child : rootClique->children ) {
235
- if (hasDependency (child, marginalizableKeys)) {
236
- dependentChildren.push_back (child);
260
+ if (additionalCliques->count (child.get ())) {
261
+ // This child has already been visited. This can happen if the
262
+ // child itself contains a marginalizable variable and it's
263
+ // processed before the current rootClique.
264
+ continue ;
265
+ }
266
+ if (hasDependency (child.get (), marginalizableKeys)) {
267
+ dependentChildren.push_back (child.get ());
237
268
}
238
269
}
239
- return gatherDependentCliquesFromChildren (dependentChildren, marginalizableKeys, cache);
270
+ gatherDependentCliquesFromChildren (
271
+ dependentChildren, marginalizableKeys, additionalCliques, cache);
240
272
}
241
273
242
274
/* *
243
275
* A helper function for the above gatherDependentCliques().
244
276
*/
245
- static std::set<sharedClique> gatherDependentCliquesFromChildren (
246
- const std::vector<sharedClique>& dependentChildren,
247
- const std::set<Key>& marginalizableKeys,
277
+ static void gatherDependentCliquesFromChildren (
278
+ const std::vector<const Clique*>& dependentChildren,
279
+ const std::unordered_set<Key>& marginalizableKeys,
280
+ std::unordered_set<const Clique*>* additionalCliques,
248
281
CachedSearch* cache) {
249
- std::deque<sharedClique > descendants (
282
+ std::deque<const Clique* > descendants (
250
283
dependentChildren.begin (), dependentChildren.end ());
251
- std::set<sharedClique> dependentCliques;
252
284
while (!descendants.empty ()) {
253
- sharedClique descendant = descendants.front ();
285
+ const Clique* descendant = descendants.front ();
254
286
descendants.pop_front ();
255
287
256
288
// If the subtree rooted at this descendant contains non-marginalizables,
257
289
// it must lie on a path from the root clique to a clique containing
258
290
// non-marginalizables at the leaf side.
259
291
if (!isWholeSubtreeMarginalizable (descendant, marginalizableKeys, cache)) {
260
- dependentCliques.insert (descendant);
261
- }
262
-
263
- // Add all children of the current descendant to the set descendants.
264
- for (const sharedClique& child : descendant->children ) {
265
- descendants.push_back (child);
292
+ additionalCliques->insert (descendant);
293
+
294
+ // Add children of the current descendant to the set descendants.
295
+ for (const sharedClique& child : descendant->children ) {
296
+ if (additionalCliques->count (child.get ())) {
297
+ // This child has already been visited.
298
+ continue ;
299
+ } else {
300
+ descendants.push_back (child.get ());
301
+ }
302
+ }
266
303
}
267
304
}
268
- return dependentCliques;
269
305
}
270
306
271
307
/* *
@@ -275,8 +311,8 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
275
311
* @param[out] additionalKeys Pointer to the output key set
276
312
*/
277
313
static void addCliqueToKeySet (
278
- const sharedClique& clique,
279
- std::set <Key>* additionalKeys) {
314
+ const Clique* clique,
315
+ std::unordered_set <Key>* additionalKeys) {
280
316
for (Key key : clique->conditional ()->frontals ()) {
281
317
additionalKeys->insert (key);
282
318
}
@@ -290,8 +326,8 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
290
326
* @return true if clique depends on the key
291
327
*/
292
328
static bool hasDependency (
293
- const sharedClique& clique, Key key) {
294
- auto conditional = clique->conditional ();
329
+ const Clique* clique, Key key) {
330
+ auto & conditional = clique->conditional ();
295
331
if (std::find (conditional->beginParents (),
296
332
conditional->endParents (), key)
297
333
!= conditional->endParents ()) {
@@ -305,12 +341,15 @@ class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper {
305
341
* Check if the clique depends on any of the given keys.
306
342
*/
307
343
static bool hasDependency (
308
- const sharedClique& clique, const std::set<Key>& keys) {
309
- for (Key key : keys) {
310
- if (hasDependency (clique, key)) {
344
+ const Clique* clique, const std::unordered_set<Key>& keys) {
345
+ auto & conditional = clique->conditional ();
346
+ for (auto it = conditional->beginParents ();
347
+ it != conditional->endParents (); ++it) {
348
+ if (keys.count (*it)) {
311
349
return true ;
312
350
}
313
351
}
352
+
314
353
return false ;
315
354
}
316
355
};
0 commit comments