@@ -64,27 +64,69 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
64
64
/* *
65
65
* @brief Compute the correct ordering of the leaves in the decision tree.
66
66
*
67
- * This is done by first taking all the values which have modulo 0 value with
68
- * the cardinality of the innermost key `n`, and we go up to modulo n.
69
- *
70
67
* @param dt The DecisionTree
71
- * @return std::vector <double>
68
+ * @return Eigen::SparseVector <double>
72
69
*/
73
- std::vector<double > ComputeLeafOrdering (const DiscreteKeys& dkeys,
74
- const DecisionTreeFactor& dt) {
75
- std::vector<double > probs = dt.probabilities ();
76
- std::vector<double > ordered;
70
+ static Eigen::SparseVector<double > ComputeLeafOrdering (
71
+ const DiscreteKeys& dkeys, const DecisionTreeFactor& dt) {
72
+ // SparseVector needs to know the maximum possible index,
73
+ // so we compute the product of cardinalities.
74
+ size_t prod_cardinality = 1 ;
75
+ for (auto && [_, c] : dt.cardinalities ()) {
76
+ prod_cardinality *= c;
77
+ }
78
+ Eigen::SparseVector<double > sparse_table (prod_cardinality);
79
+ size_t nrValues = 0 ;
80
+ dt.visit ([&nrValues](double x) {
81
+ if (x > 0 ) nrValues += 1 ;
82
+ });
83
+ sparse_table.reserve (nrValues);
84
+
85
+ std::set<Key> allKeys (dt.keys ().begin (), dt.keys ().end ());
86
+
87
+ auto op = [&](const Assignment<Key>& assignment, double p) {
88
+ if (p > 0 ) {
89
+ // Get all the keys involved in this assignment
90
+ std::set<Key> assignment_keys;
91
+ for (auto && [k, _] : assignment) {
92
+ assignment_keys.insert (k);
93
+ }
77
94
78
- size_t n = dkeys[0 ].second ;
95
+ // Find the keys missing in the assignment
96
+ std::vector<Key> diff;
97
+ std::set_difference (allKeys.begin (), allKeys.end (),
98
+ assignment_keys.begin (), assignment_keys.end (),
99
+ std::back_inserter (diff));
79
100
80
- for (size_t k = 0 ; k < n; ++k) {
81
- for (size_t idx = 0 ; idx < probs.size (); ++idx) {
82
- if (idx % n == k) {
83
- ordered.push_back (probs[idx]);
101
+ // Generate all assignments using the missing keys
102
+ DiscreteKeys extras;
103
+ for (auto && key : diff) {
104
+ extras.push_back ({key, dt.cardinality (key)});
105
+ }
106
+ auto && extra_assignments = DiscreteValues::CartesianProduct (extras);
107
+
108
+ for (auto && extra : extra_assignments) {
109
+ // Create new assignment using the extra assignment
110
+ DiscreteValues updated_assignment (assignment);
111
+ updated_assignment.insert (extra);
112
+
113
+ // Generate index and add to the sparse vector.
114
+ Eigen::Index idx = 0 ;
115
+ size_t prev_cardinality = 1 ;
116
+ // We go in reverse since a DecisionTree has the highest label first
117
+ for (auto && it = updated_assignment.rbegin ();
118
+ it != updated_assignment.rend (); it++) {
119
+ idx += prev_cardinality * it->second ;
120
+ prev_cardinality *= dt.cardinality (it->first );
121
+ }
122
+ sparse_table.coeffRef (idx) = p;
84
123
}
85
124
}
86
- }
87
- return ordered;
125
+ };
126
+
127
+ dt.visitWith (op);
128
+
129
+ return sparse_table;
88
130
}
89
131
90
132
/* ************************************************************************ */
0 commit comments