Skip to content

Commit

Permalink
Merge pull request #1933 from borglab/tablefactor-conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Dec 23, 2024
2 parents 6f7365e + 3690937 commit c70898a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
44 changes: 37 additions & 7 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,

/* ************************************************************************ */
TableFactor::TableFactor(const DecisionTreeFactor& dtf)
: TableFactor(dtf.discreteKeys(),
ComputeSparseTable(dtf.discreteKeys(), dtf)) {}
: TableFactor(dtf.discreteKeys(), dtf) {}

/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c)
Expand Down Expand Up @@ -252,12 +251,43 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
std::vector<double> table;

// Record key assignment and value pairs in pair_table.
// The assignments are stored in descending order of keys so that the order of
// the values matches what is expected by a DecisionTree.
// This is why we reverse the keys and then
// query for the key value/assignment.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
std::vector<std::pair<uint64_t, double>> pair_table;
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
}
// NOTE(Varun): This constructor is really expensive!!
DecisionTreeFactor f(dkeys, table);
std::stringstream ss;
for (auto&& [key, _] : rdkeys) {
ss << keyValueForIndex(key, i);
}
// k will be in reverse key order already
uint64_t k;
ss >> k;
pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i)));
}

// Sort the pair_table (of assignment-value pairs) based on assignment so we
// get values in reverse key order.
std::sort(
pair_table.begin(), pair_table.end(),
[](const std::pair<uint64_t, double>& a,
const std::pair<uint64_t, double>& b) { return a.first < b.first; });

// Create the table vector by extracting the values from pair_table.
// The pair_table has already been sorted in the desired order,
// so the values will be in descending key order.
std::vector<double> table;
std::for_each(pair_table.begin(), pair_table.end(),
[&table](const std::pair<uint64_t, double>& pair) {
table.push_back(pair.second);
});

AlgebraicDecisionTree<Key> tree(rdkeys, table);
DecisionTreeFactor f(dkeys, tree);
return f;
}

Expand Down
4 changes: 3 additions & 1 deletion gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;

public:
/// @name Standard Constructors
/// @{

Expand Down Expand Up @@ -156,6 +155,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
// /// @name Standard Interface
// /// @{

/// Getter for the underlying sparse vector
Eigen::SparseVector<double> sparseTable() const { return sparse_table_; }

/// Evaluate probability distribution, is just look up in TableFactor.
double evaluate(const Assignment<Key>& values) const override;

Expand Down

0 comments on commit c70898a

Please sign in to comment.