20
20
21
21
#pragma once
22
22
23
- #include < gtsam/discrete/DiscreteFactorGraph.h>
24
- #include < gtsam/discrete/DiscreteBayesTree.h>
25
23
#include < gtsam/base/Vector.h>
24
+ #include < gtsam/discrete/DiscreteBayesTree.h>
25
+ #include < gtsam/discrete/DiscreteFactorGraph.h>
26
26
27
27
namespace gtsam {
28
28
29
- /* *
30
- * A class for computing marginals of variables in a DiscreteFactorGraph
31
- * @ingroup discrete
32
- */
29
+ /* *
30
+ * A class for computing marginals of variables in a DiscreteFactorGraph
31
+ * @ingroup discrete
32
+ */
33
33
class DiscreteMarginals {
34
+ protected:
35
+ DiscreteBayesTree::shared_ptr bayesTree_;
34
36
35
- protected:
36
-
37
- DiscreteBayesTree::shared_ptr bayesTree_;
38
-
39
- public:
40
-
37
+ public:
41
38
DiscreteMarginals () {}
42
39
43
40
/* * Construct a marginals class.
44
- * @param graph The factor graph defining the full joint distribution on all variables.
41
+ * @param graph The factor graph defining the full joint
42
+ * distribution on all variables.
45
43
*/
46
44
DiscreteMarginals (const DiscreteFactorGraph& graph) {
47
45
bayesTree_ = graph.eliminateMultifrontal ();
@@ -50,8 +48,8 @@ class DiscreteMarginals {
50
48
/* * Compute the marginal of a single variable */
51
49
DiscreteFactor::shared_ptr operator ()(Key variable) const {
52
50
// Compute marginal
53
- DiscreteFactor::shared_ptr marginalFactor;
54
- marginalFactor = bayesTree_->marginalFactor (variable, &EliminateDiscrete);
51
+ DiscreteFactor::shared_ptr marginalFactor =
52
+ bayesTree_->marginalFactor (variable, &EliminateDiscrete);
55
53
return marginalFactor;
56
54
}
57
55
@@ -61,19 +59,17 @@ class DiscreteMarginals {
61
59
*/
62
60
Vector marginalProbabilities (const DiscreteKey& key) const {
63
61
// Compute marginal
64
- DiscreteFactor::shared_ptr marginalFactor;
65
- marginalFactor = bayesTree_->marginalFactor (key.first , &EliminateDiscrete);
62
+ DiscreteFactor::shared_ptr marginalFactor = this ->operator ()(key.first );
66
63
67
- // Create result
64
+ // Create result
68
65
Vector vResult (key.second );
69
- for (size_t state = 0 ; state < key.second ; ++ state) {
66
+ for (size_t state = 0 ; state < key.second ; ++state) {
70
67
DiscreteValues values;
71
68
values[key.first ] = state;
72
69
vResult (state) = (*marginalFactor)(values);
73
70
}
74
71
return vResult;
75
72
}
76
-
77
- };
73
+ };
78
74
79
75
} /* namespace gtsam */
0 commit comments