@@ -255,15 +255,55 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
255
255
return std::make_shared<TableFactor>(discreteKeys, potentials);
256
256
}
257
257
258
+ /* *
259
+ * @brief Multiply all the `factors` using the machinery of the TableFactor.
260
+ *
261
+ * @param factors The factors to multiply as a DiscreteFactorGraph.
262
+ * @return TableFactor
263
+ */
264
+ static TableFactor TableProduct (const DiscreteFactorGraph &factors) {
265
+ // PRODUCT: multiply all factors
266
+ #if GTSAM_HYBRID_TIMING
267
+ gttic_ (DiscreteProduct);
268
+ #endif
269
+ TableFactor product;
270
+ for (auto &&factor : factors) {
271
+ if (factor) {
272
+ if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
273
+ product = product * (*f);
274
+ } else if (auto dtf =
275
+ std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
276
+ product = product * TableFactor (*dtf);
277
+ }
278
+ }
279
+ }
280
+ #if GTSAM_HYBRID_TIMING
281
+ gttoc_ (DiscreteProduct);
282
+ #endif
283
+
284
+ #if GTSAM_HYBRID_TIMING
285
+ gttic_ (DiscreteNormalize);
286
+ #endif
287
+ // Max over all the potentials by pretending all keys are frontal:
288
+ auto denominator = product.max (product.size ());
289
+ // Normalize the product factor to prevent underflow.
290
+ product = product / (*denominator);
291
+ #if GTSAM_HYBRID_TIMING
292
+ gttoc_ (DiscreteNormalize);
293
+ #endif
294
+
295
+ return product;
296
+ }
297
+
258
298
/* ************************************************************************ */
259
- static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
260
- discreteElimination (const HybridGaussianFactorGraph &factors,
261
- const Ordering &frontalKeys) {
299
+ static DiscreteFactorGraph CollectDiscreteFactors (
300
+ const HybridGaussianFactorGraph &factors) {
262
301
DiscreteFactorGraph dfg;
263
302
264
303
for (auto &f : factors) {
265
304
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
266
305
dfg.push_back (df);
306
+
267
307
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
268
308
// Case where we have a HybridGaussianFactor with no continuous keys.
269
309
// In this case, compute a discrete factor from the remaining error.
@@ -296,16 +336,48 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
296
336
}
297
337
}
298
338
339
+ return dfg;
340
+ }
341
+
342
+ /* ************************************************************************ */
343
+ static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
344
+ discreteElimination (const HybridGaussianFactorGraph &factors,
345
+ const Ordering &frontalKeys) {
346
+ DiscreteFactorGraph dfg = CollectDiscreteFactors (factors);
347
+
299
348
#if GTSAM_HYBRID_TIMING
300
349
gttic_ (EliminateDiscrete);
301
350
#endif
302
- // NOTE: This does sum-product. For max-product, use EliminateForMPE.
303
- auto result = EliminateDiscrete (dfg, frontalKeys);
351
+ // Check if separator is empty.
352
+ // This is the same as checking if the number of frontal variables
353
+ // is the same as the number of variables in the DiscreteFactorGraph.
354
+ // If the separator is empty, we have a clique of all the discrete variables
355
+ // so we can use the TableFactor for efficiency.
356
+ if (frontalKeys.size () == dfg.keys ().size ()) {
357
+ // Get product factor
358
+ TableFactor product = TableProduct (dfg);
359
+
360
+ #if GTSAM_HYBRID_TIMING
361
+ gttic_ (EliminateDiscreteFormDiscreteConditional);
362
+ #endif
363
+ auto conditional = std::make_shared<DiscreteConditional>(
364
+ frontalKeys.size (), product.toDecisionTreeFactor ());
304
365
#if GTSAM_HYBRID_TIMING
305
- gttoc_ (EliminateDiscrete );
366
+ gttoc_ (EliminateDiscreteFormDiscreteConditional );
306
367
#endif
307
368
308
- return {std::make_shared<HybridConditional>(result.first ), result.second };
369
+ TableFactor::shared_ptr sum = product.sum (frontalKeys);
370
+ #if GTSAM_HYBRID_TIMING
371
+ gttoc_ (EliminateDiscrete);
372
+ #endif
373
+
374
+ return {std::make_shared<HybridConditional>(conditional), sum};
375
+
376
+ } else {
377
+ // Perform sum-product.
378
+ auto result = EliminateDiscrete (dfg, frontalKeys);
379
+ return {std::make_shared<HybridConditional>(result.first ), result.second };
380
+ }
309
381
}
310
382
311
383
/* ************************************************************************ */
0 commit comments