Skip to content

Commit a8e24ef

Browse files
committed
update ComputeLeafOrdering to give a correct vector of values
1 parent b91c470 commit a8e24ef

File tree

1 file changed

+57
-15
lines changed

1 file changed

+57
-15
lines changed

gtsam/discrete/TableFactor.cpp

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,27 +64,69 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
6464
/**
6565
* @brief Compute the correct ordering of the leaves in the decision tree.
6666
*
67-
* This is done by first taking all the values which have modulo 0 value with
68-
* the cardinality of the innermost key `n`, and we go up to modulo n.
69-
*
7067
* @param dt The DecisionTree
71-
* @return std::vector<double>
68+
* @return Eigen::SparseVector<double>
7269
*/
73-
std::vector<double> ComputeLeafOrdering(const DiscreteKeys& dkeys,
74-
const DecisionTreeFactor& dt) {
75-
std::vector<double> probs = dt.probabilities();
76-
std::vector<double> ordered;
70+
static Eigen::SparseVector<double> ComputeLeafOrdering(
71+
const DiscreteKeys& dkeys, const DecisionTreeFactor& dt) {
72+
// SparseVector needs to know the maximum possible index,
73+
// so we compute the product of cardinalities.
74+
size_t prod_cardinality = 1;
75+
for (auto&& [_, c] : dt.cardinalities()) {
76+
prod_cardinality *= c;
77+
}
78+
Eigen::SparseVector<double> sparse_table(prod_cardinality);
79+
size_t nrValues = 0;
80+
dt.visit([&nrValues](double x) {
81+
if (x > 0) nrValues += 1;
82+
});
83+
sparse_table.reserve(nrValues);
84+
85+
std::set<Key> allKeys(dt.keys().begin(), dt.keys().end());
86+
87+
auto op = [&](const Assignment<Key>& assignment, double p) {
88+
if (p > 0) {
89+
// Get all the keys involved in this assignment
90+
std::set<Key> assignment_keys;
91+
for (auto&& [k, _] : assignment) {
92+
assignment_keys.insert(k);
93+
}
7794

78-
size_t n = dkeys[0].second;
95+
// Find the keys missing in the assignment
96+
std::vector<Key> diff;
97+
std::set_difference(allKeys.begin(), allKeys.end(),
98+
assignment_keys.begin(), assignment_keys.end(),
99+
std::back_inserter(diff));
79100

80-
for (size_t k = 0; k < n; ++k) {
81-
for (size_t idx = 0; idx < probs.size(); ++idx) {
82-
if (idx % n == k) {
83-
ordered.push_back(probs[idx]);
101+
// Generate all assignments using the missing keys
102+
DiscreteKeys extras;
103+
for (auto&& key : diff) {
104+
extras.push_back({key, dt.cardinality(key)});
105+
}
106+
auto&& extra_assignments = DiscreteValues::CartesianProduct(extras);
107+
108+
for (auto&& extra : extra_assignments) {
109+
// Create new assignment using the extra assignment
110+
DiscreteValues updated_assignment(assignment);
111+
updated_assignment.insert(extra);
112+
113+
// Generate index and add to the sparse vector.
114+
Eigen::Index idx = 0;
115+
size_t prev_cardinality = 1;
116+
// We go in reverse since a DecisionTree has the highest label first
117+
for (auto&& it = updated_assignment.rbegin();
118+
it != updated_assignment.rend(); it++) {
119+
idx += prev_cardinality * it->second;
120+
prev_cardinality *= dt.cardinality(it->first);
121+
}
122+
sparse_table.coeffRef(idx) = p;
84123
}
85124
}
86-
}
87-
return ordered;
125+
};
126+
127+
dt.visitWith(op);
128+
129+
return sparse_table;
88130
}
89131

90132
/* ************************************************************************ */

0 commit comments

Comments
 (0)