From 039c9b15426287e3669c5e1ffaaf50a2ce505e29 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 15 Dec 2024 12:29:31 -0500 Subject: [PATCH 1/9] add getter for sparse_table_ --- gtsam/discrete/TableFactor.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index d27c4740cc..5ddb4ab431 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -99,7 +99,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { typedef Eigen::SparseVector::InnerIterator SparseIt; typedef std::vector> AssignValList; - public: /// @name Standard Constructors /// @{ @@ -156,6 +155,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { // /// @name Standard Interface // /// @{ + /// Getter for the underlying sparse vector + Eigen::SparseVector sparseTable() const { return sparse_table_; } + /// Evaluate probability distribution, is just look up in TableFactor. double evaluate(const Assignment& values) const override; From 293c29ebf8f2f706a1dd6857fc5b30593eed5e36 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 15 Dec 2024 12:30:11 -0500 Subject: [PATCH 2/9] update toDecisionTreeFactor to use reverse key format so it's faster --- gtsam/discrete/TableFactor.cpp | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index de1e1f8670..521ca4d470 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -252,12 +252,33 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); - std::vector table; + + DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend()); + std::vector> pair_table; for (auto i = 0; i < sparse_table_.size(); i++) { - table.push_back(sparse_table_.coeff(i)); + std::stringstream ss; + for (auto&& [key, _] : rdkeys) { + ss << keyValueForIndex(key, i); + } + // k will be in reverse key order already + uint64_t k = std::strtoull(ss.str().c_str(), NULL, 10); + pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i))); } - // NOTE(Varun): This constructor is really expensive!! - DecisionTreeFactor f(dkeys, table); + + // Sort based on key so we get values in reverse key order. + std::sort( + pair_table.begin(), pair_table.end(), + [](const std::pair& a, + const std::pair& b) { return a.first <= b.first; }); + + // Create the table vector + std::vector table; + std::for_each(pair_table.begin(), pair_table.end(), + [&table](const std::pair& pair) { + table.push_back(pair.second); + }); + + DecisionTreeFactor f(rdkeys, table); return f; } From 8fefbbf06a751bdd9e637f72571229145be77b13 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 15 Dec 2024 16:03:10 -0500 Subject: [PATCH 3/9] fix toDecisionTreeFactor so the keys are ordered correctly --- gtsam/discrete/TableFactor.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 521ca4d470..63e9e5d6b4 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -153,8 +153,7 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, /* ************************************************************************ */ TableFactor::TableFactor(const DecisionTreeFactor& dtf) - : TableFactor(dtf.discreteKeys(), - ComputeSparseTable(dtf.discreteKeys(), dtf)) {} + : TableFactor(dtf.discreteKeys(), dtf) {} /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c) @@ -278,7 +277,8 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { table.push_back(pair.second); }); - DecisionTreeFactor f(rdkeys, table); + AlgebraicDecisionTree tree(rdkeys, table); + DecisionTreeFactor f(dkeys, tree); return f; } From 37f6de744df609a6cbc6f347d9d82bd690cf6358 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 16 Dec 2024 10:28:55 -0500 Subject: [PATCH 4/9] use c++11 method for string to int --- gtsam/discrete/TableFactor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 63e9e5d6b4..3cfae56eca 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -260,7 +260,7 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { ss << keyValueForIndex(key, i); } // k will be in reverse key order already - uint64_t k = std::strtoull(ss.str().c_str(), NULL, 10); + uint64_t k = std::stoll(ss.str().c_str()); pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i))); } From df0f597ed763d028b170690846f6eddcecf0b46f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 16 Dec 2024 14:28:27 -0500 Subject: [PATCH 5/9] debug conversion --- gtsam/discrete/TableFactor.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 3cfae56eca..cb60a98187 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -260,7 +260,9 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { ss << keyValueForIndex(key, i); } // k will be in reverse key order already - uint64_t k = std::stoll(ss.str().c_str()); + uint64_t k; + ss >> k; + std::cout << "ss: " << ss.str() << ", k=" << k << std::endl; pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i))); } From 17cae8c45336cdb62bf2bfcbc7c6fe6f7cdc1d39 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 16 Dec 2024 15:22:25 -0500 Subject: [PATCH 6/9] print more to debug --- gtsam/discrete/TableFactor.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index cb60a98187..38b4ddd302 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -262,16 +262,21 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { // k will be in reverse key order already uint64_t k; ss >> k; - std::cout << "ss: " << ss.str() << ", k=" << k << std::endl; + std::cout << "ss: " << ss.str() << ", k=" << k + << ", v=" << sparse_table_.coeff(i) << std::endl; pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i))); } - // Sort based on key so we get values in reverse key order. + // Sort based on key assignment so we get values in reverse key order. std::sort( pair_table.begin(), pair_table.end(), [](const std::pair& a, const std::pair& b) { return a.first <= b.first; }); + std::cout << "Sorted pair_table:" << std::endl; + for (auto&& [k, v] : pair_table) { + std::cout << "k=" << k << ", v=" << v << std::endl; + } // Create the table vector std::vector table; std::for_each(pair_table.begin(), pair_table.end(), From 72306efe9830f24ed741db8b614d0c665ccf9452 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 16 Dec 2024 18:20:16 -0500 Subject: [PATCH 7/9] strict less than check --- gtsam/discrete/TableFactor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 38b4ddd302..0a93e7b5df 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -271,7 +271,7 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { std::sort( pair_table.begin(), pair_table.end(), [](const std::pair& a, - const std::pair& b) { return a.first <= b.first; }); + const std::pair& b) { return a.first < b.first; }); std::cout << "Sorted pair_table:" << std::endl; for (auto&& [k, v] : pair_table) { From 70288bc32a573cb2040d2f6f0fc115b1e5a094b2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 16 Dec 2024 22:18:12 -0500 Subject: [PATCH 8/9] remove print statements --- gtsam/discrete/TableFactor.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 0a93e7b5df..d9bfc42c2e 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -262,8 +262,6 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { // k will be in reverse key order already uint64_t k; ss >> k; - std::cout << "ss: " << ss.str() << ", k=" << k - << ", v=" << sparse_table_.coeff(i) << std::endl; pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i))); } @@ -273,10 +271,6 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { [](const std::pair& a, const std::pair& b) { return a.first < b.first; }); - std::cout << "Sorted pair_table:" << std::endl; - for (auto&& [k, v] : pair_table) { - std::cout << "k=" << k << ", v=" << v << std::endl; - } // Create the table vector std::vector table; std::for_each(pair_table.begin(), pair_table.end(), From 3690937159aa79de1570f4cac8db0ac53f8235c0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 18 Dec 2024 08:11:42 -0500 Subject: [PATCH 9/9] add comments --- gtsam/discrete/TableFactor.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index d9bfc42c2e..22548da07f 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -252,6 +252,11 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); + // Record key assignment and value pairs in pair_table. + // The assignments are stored in descending order of keys so that the order of + // the values matches what is expected by a DecisionTree. + // This is why we reverse the keys and then + // query for the key value/assignment. DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend()); std::vector> pair_table; for (auto i = 0; i < sparse_table_.size(); i++) { @@ -265,13 +270,16 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i))); } - // Sort based on key assignment so we get values in reverse key order. + // Sort the pair_table (of assignment-value pairs) based on assignment so we + // get values in reverse key order. std::sort( pair_table.begin(), pair_table.end(), [](const std::pair& a, const std::pair& b) { return a.first < b.first; }); - // Create the table vector + // Create the table vector by extracting the values from pair_table. + // The pair_table has already been sorted in the desired order, + // so the values will be in descending key order. std::vector table; std::for_each(pair_table.begin(), pair_table.end(), [&table](const std::pair& pair) {