Skip to content

Commit 73f98d8

Browse files
authored
Merge pull request #1955 from borglab/hybrid-custom-discrete
2 parents 49b74af + 7440c19 commit 73f98d8

File tree

4 files changed

+81
-26
lines changed

4 files changed

+81
-26
lines changed

gtsam/discrete/DiscreteConditional.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
#include <gtsam/hybrid/HybridValues.h>
2525

2626
#include <algorithm>
27+
#include <cassert>
2728
#include <random>
2829
#include <set>
2930
#include <stdexcept>
3031
#include <string>
3132
#include <utility>
3233
#include <vector>
33-
#include <cassert>
3434

3535
using namespace std;
3636
using std::pair;

gtsam/discrete/DiscreteFactorGraph.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,7 @@ namespace gtsam {
120120
static DecisionTreeFactor DiscreteProduct(
121121
const DiscreteFactorGraph& factors) {
122122
// PRODUCT: multiply all factors
123-
#if GTSAM_HYBRID_TIMING
124-
gttic_(DiscreteProduct);
125-
#endif
126123
DecisionTreeFactor product = factors.product();
127-
#if GTSAM_HYBRID_TIMING
128-
gttoc_(DiscreteProduct);
129-
#endif
130124

131125
#if GTSAM_HYBRID_TIMING
132126
gttic_(DiscreteNormalize);
@@ -229,13 +223,7 @@ namespace gtsam {
229223
DecisionTreeFactor product = DiscreteProduct(factors);
230224

231225
// sum out frontals, this is the factor on the separator
232-
#if GTSAM_HYBRID_TIMING
233-
gttic_(EliminateDiscreteSum);
234-
#endif
235226
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
236-
#if GTSAM_HYBRID_TIMING
237-
gttoc_(EliminateDiscreteSum);
238-
#endif
239227

240228
// Ordering keys for the conditional so that frontalKeys are really in front
241229
Ordering orderedKeys;
@@ -245,14 +233,8 @@ namespace gtsam {
245233
sum->keys().end());
246234

247235
// now divide product/sum to get conditional
248-
#if GTSAM_HYBRID_TIMING
249-
gttic_(EliminateDiscreteToDiscreteConditional);
250-
#endif
251236
auto conditional =
252237
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
253-
#if GTSAM_HYBRID_TIMING
254-
gttoc_(EliminateDiscreteToDiscreteConditional);
255-
#endif
256238

257239
return {conditional, sum};
258240
}

gtsam/hybrid/HybridGaussianFactorGraph.cpp

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,55 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
255255
return std::make_shared<TableFactor>(discreteKeys, potentials);
256256
}
257257

258+
/**
259+
* @brief Multiply all the `factors` using the machinery of the TableFactor.
260+
*
261+
* @param factors The factors to multiply as a DiscreteFactorGraph.
262+
* @return TableFactor
263+
*/
264+
static TableFactor TableProduct(const DiscreteFactorGraph &factors) {
265+
// PRODUCT: multiply all factors
266+
#if GTSAM_HYBRID_TIMING
267+
gttic_(DiscreteProduct);
268+
#endif
269+
TableFactor product;
270+
for (auto &&factor : factors) {
271+
if (factor) {
272+
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
273+
product = product * (*f);
274+
} else if (auto dtf =
275+
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
276+
product = product * TableFactor(*dtf);
277+
}
278+
}
279+
}
280+
#if GTSAM_HYBRID_TIMING
281+
gttoc_(DiscreteProduct);
282+
#endif
283+
284+
#if GTSAM_HYBRID_TIMING
285+
gttic_(DiscreteNormalize);
286+
#endif
287+
// Max over all the potentials by pretending all keys are frontal:
288+
auto denominator = product.max(product.size());
289+
// Normalize the product factor to prevent underflow.
290+
product = product / (*denominator);
291+
#if GTSAM_HYBRID_TIMING
292+
gttoc_(DiscreteNormalize);
293+
#endif
294+
295+
return product;
296+
}
297+
258298
/* ************************************************************************ */
259-
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
260-
discreteElimination(const HybridGaussianFactorGraph &factors,
261-
const Ordering &frontalKeys) {
299+
static DiscreteFactorGraph CollectDiscreteFactors(
300+
const HybridGaussianFactorGraph &factors) {
262301
DiscreteFactorGraph dfg;
263302

264303
for (auto &f : factors) {
265304
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
266305
dfg.push_back(df);
306+
267307
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
268308
// Case where we have a HybridGaussianFactor with no continuous keys.
269309
// In this case, compute a discrete factor from the remaining error.
@@ -296,16 +336,48 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
296336
}
297337
}
298338

339+
return dfg;
340+
}
341+
342+
/* ************************************************************************ */
343+
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
344+
discreteElimination(const HybridGaussianFactorGraph &factors,
345+
const Ordering &frontalKeys) {
346+
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);
347+
299348
#if GTSAM_HYBRID_TIMING
300349
gttic_(EliminateDiscrete);
301350
#endif
302-
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
303-
auto result = EliminateDiscrete(dfg, frontalKeys);
351+
// Check if separator is empty.
352+
// This is the same as checking if the number of frontal variables
353+
// is the same as the number of variables in the DiscreteFactorGraph.
354+
// If the separator is empty, we have a clique of all the discrete variables
355+
// so we can use the TableFactor for efficiency.
356+
if (frontalKeys.size() == dfg.keys().size()) {
357+
// Get product factor
358+
TableFactor product = TableProduct(dfg);
359+
360+
#if GTSAM_HYBRID_TIMING
361+
gttic_(EliminateDiscreteFormDiscreteConditional);
362+
#endif
363+
auto conditional = std::make_shared<DiscreteConditional>(
364+
frontalKeys.size(), product.toDecisionTreeFactor());
304365
#if GTSAM_HYBRID_TIMING
305-
gttoc_(EliminateDiscrete);
366+
gttoc_(EliminateDiscreteFormDiscreteConditional);
306367
#endif
307368

308-
return {std::make_shared<HybridConditional>(result.first), result.second};
369+
TableFactor::shared_ptr sum = product.sum(frontalKeys);
370+
#if GTSAM_HYBRID_TIMING
371+
gttoc_(EliminateDiscrete);
372+
#endif
373+
374+
return {std::make_shared<HybridConditional>(conditional), sum};
375+
376+
} else {
377+
// Perform sum-product.
378+
auto result = EliminateDiscrete(dfg, frontalKeys);
379+
return {std::make_shared<HybridConditional>(result.first), result.second};
380+
}
309381
}
310382

311383
/* ************************************************************************ */

gtsam/hybrid/tests/testGaussianMixture.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
162162
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
163163
}
164164
}
165+
165166
/* ************************************************************************* */
166167
int main() {
167168
TestResult tr;

0 commit comments

Comments
 (0)