From f42ed21137e78cec7af346dbde2d6027f8293871 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 2 Jan 2025 14:35:13 -0500 Subject: [PATCH 1/5] use faster key storage --- gtsam/discrete/TableFactor.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index bf9662e346..459f36ccbd 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -87,7 +87,7 @@ static Eigen::SparseVector ComputeSparseTable( }); sparseTable.reserve(nrValues); - std::set allKeys(dt.keys().begin(), dt.keys().end()); + KeySet allKeys(dt.keys().begin(), dt.keys().end()); /** * @brief Functor which is called by the DecisionTree for each leaf. @@ -102,13 +102,13 @@ static Eigen::SparseVector ComputeSparseTable( auto op = [&](const Assignment& assignment, double p) { if (p > 0) { // Get all the keys involved in this assignment - std::set assignmentKeys; + KeySet assignmentKeys; for (auto&& [k, _] : assignment) { assignmentKeys.insert(k); } // Find the keys missing in the assignment - std::vector diff; + KeyVector diff; std::set_difference(allKeys.begin(), allKeys.end(), assignmentKeys.begin(), assignmentKeys.end(), std::back_inserter(diff)); From 3fca55acc3f21c42a149a4ea3a29559098e0a266 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 2 Jan 2025 15:17:16 -0500 Subject: [PATCH 2/5] add test exposing issue with reverse ordered keys --- gtsam/discrete/tests/testTableFactor.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index a455faaaa4..d920f978f4 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -173,6 +173,27 @@ TEST(TableFactor, Conversion) { TableFactor tf(dtf.discreteKeys(), dtf); EXPECT(assert_equal(dtf, tf.toDecisionTreeFactor())); + + // Test for correct construction when keys are not in reverse order. + // This is possible in conditionals e.g. P(x1 | x0) + DiscreteKey X(1, 2), Y(0, 2); + DiscreteConditional dtf2( + X, {Y}, std::vector{0.33333333, 0.6, 0.66666667, 0.4}); + + TableFactor tf2(dtf2); + // GTSAM_PRINT(dtf2); + // GTSAM_PRINT(tf2); + // GTSAM_PRINT(tf2.toDecisionTreeFactor()); + + // Check for ADT equality since the order of keys is irrelevant + EXPECT(assert_equal>(dtf2, + tf2.toDecisionTreeFactor())); +} + +/* ************************************************************************* */ +TEST_DISABLED(TableFactor, Empty) { + // TableFactor empty({1, 2}, std::vector()); + // empty.print(); } /* ************************************************************************* */ From ff93c8be292941abd90e09450a9e2b3521ada9fa Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 2 Jan 2025 15:18:22 -0500 Subject: [PATCH 3/5] use denominators to compute the correct index in ComputeSparseTable --- gtsam/discrete/TableFactor.cpp | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 459f36ccbd..742538c875 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -89,6 +89,14 @@ static Eigen::SparseVector ComputeSparseTable( KeySet allKeys(dt.keys().begin(), dt.keys().end()); + // Compute denominators to be used in computing sparse table indices + std::map denominators; + double denom = sparseTable.size(); + for (const DiscreteKey& dkey : dkeys) { + denom /= dkey.second; + denominators.insert(std::pair(dkey.first, denom)); + } + /** * @brief Functor which is called by the DecisionTree for each leaf. * For each leaf value, we use the corresponding assignment to compute a @@ -127,12 +135,10 @@ static Eigen::SparseVector ComputeSparseTable( // Generate index and add to the sparse vector. Eigen::Index idx = 0; - size_t previousCardinality = 1; // We go in reverse since a DecisionTree has the highest label first for (auto&& it = updatedAssignment.rbegin(); it != updatedAssignment.rend(); it++) { - idx += previousCardinality * it->second; - previousCardinality *= dt.cardinality(it->first); + idx += it->second * denominators.at(it->first); } sparseTable.coeffRef(idx) = p; } @@ -252,9 +258,9 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); - std::vector table; - for (auto i = 0; i < sparse_table_.size(); i++) { - table.push_back(sparse_table_.coeff(i)); + std::vector table(sparse_table_.size(), 0.0); + for (SparseIt it(sparse_table_); it; ++it) { + table[it.index()] = it.value(); } AlgebraicDecisionTree tree(dkeys, table); From bd32eb8203d2a1ff1bc1e8fb936e751f45072574 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 2 Jan 2025 15:32:11 -0500 Subject: [PATCH 4/5] unit test to expose another bug --- gtsam/discrete/tests/testTableFactor.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index d920f978f4..e6c71e15ce 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -191,9 +191,18 @@ TEST(TableFactor, Conversion) { } /* ************************************************************************* */ -TEST_DISABLED(TableFactor, Empty) { - // TableFactor empty({1, 2}, std::vector()); - // empty.print(); +TEST(TableFactor, Empty) { + DiscreteKey X(1, 2); + + TableFactor single = *TableFactor({X}, "1 1").sum(1); + // Should not throw a segfault + EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1), + single.toDecisionTreeFactor())); + + TableFactor empty = *TableFactor({X}, "0 0").sum(1); + // Should not throw a segfault + EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1), + empty.toDecisionTreeFactor())); } /* ************************************************************************* */ From f9a7eb0937e58b26ed4f96c46df9d4054c76c0c7 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 2 Jan 2025 15:33:20 -0500 Subject: [PATCH 5/5] add fix --- gtsam/discrete/TableFactor.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 742538c875..a59095d406 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -258,6 +258,16 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); + // If no keys, then return empty DecisionTreeFactor + if (dkeys.size() == 0) { + AlgebraicDecisionTree tree; + // We can have an empty sparse_table_ or one with a single value. + if (sparse_table_.size() != 0) { + tree = AlgebraicDecisionTree(sparse_table_.coeff(0)); + } + return DecisionTreeFactor(dkeys, tree); + } + std::vector table(sparse_table_.size(), 0.0); for (SparseIt it(sparse_table_); it; ++it) { table[it.index()] = it.value();