Skip to content

Commit 137a503

Browse files
authored
Merge pull request #1928 from borglab/fix-table-factor
2 parents a648bd8 + ed742d3 commit 137a503

File tree

6 files changed

+49
-18
lines changed

6 files changed

+49
-18
lines changed

gtsam/discrete/DecisionTreeFactor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ namespace gtsam {
4848
return false;
4949
} else {
5050
const auto& f(static_cast<const DecisionTreeFactor&>(other));
51-
return ADT::equals(f, tol);
51+
return Base::equals(other, tol) && ADT::equals(f, tol);
5252
}
5353
}
5454

gtsam/discrete/DiscreteFactor.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ using namespace std;
2828

2929
namespace gtsam {
3030

31+
/* ************************************************************************* */
32+
bool DiscreteFactor::equals(const DiscreteFactor& lf, double tol) const {
33+
return Base::equals(lf, tol) && cardinalities_ == lf.cardinalities_;
34+
}
35+
3136
/* ************************************************************************ */
3237
DiscreteKeys DiscreteFactor::discreteKeys() const {
3338
DiscreteKeys result;

gtsam/discrete/DiscreteFactor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
7777
/// @{
7878

7979
/// equals
80-
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0;
80+
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const;
8181

8282
/// print
8383
void print(

gtsam/discrete/TableFactor.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,28 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
9292
const DecisionTreeFactor& dtf)
9393
: TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {}
9494

95+
/* ************************************************************************ */
96+
TableFactor::TableFactor(const DecisionTreeFactor& dtf)
97+
: TableFactor(dtf.discreteKeys(),
98+
ComputeLeafOrdering(dtf.discreteKeys(), dtf)) {}
99+
95100
/* ************************************************************************ */
96101
TableFactor::TableFactor(const DiscreteConditional& c)
97102
: TableFactor(c.discreteKeys(), c) {}
98103

99104
/* ************************************************************************ */
100105
Eigen::SparseVector<double> TableFactor::Convert(
101-
const std::vector<double>& table) {
106+
const DiscreteKeys& keys, const std::vector<double>& table) {
107+
size_t max_size = 1;
108+
for (auto&& [_, cardinality] : keys.cardinalities()) {
109+
max_size *= cardinality;
110+
}
111+
if (table.size() != max_size) {
112+
throw std::runtime_error(
113+
"The cardinalities of the keys don't match the number of values in the "
114+
"input.");
115+
}
116+
102117
Eigen::SparseVector<double> sparse_table(table.size());
103118
// Count number of nonzero elements in table and reserve the space.
104119
const uint64_t nnz = std::count_if(table.begin(), table.end(),
@@ -113,13 +128,14 @@ Eigen::SparseVector<double> TableFactor::Convert(
113128
}
114129

115130
/* ************************************************************************ */
116-
Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) {
131+
Eigen::SparseVector<double> TableFactor::Convert(const DiscreteKeys& keys,
132+
const std::string& table) {
117133
// Convert string to doubles.
118134
std::vector<double> ys;
119135
std::istringstream iss(table);
120136
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
121137
std::back_inserter(ys));
122-
return Convert(ys);
138+
return Convert(keys, ys);
123139
}
124140

125141
/* ************************************************************************ */
@@ -128,7 +144,8 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
128144
return false;
129145
} else {
130146
const auto& f(static_cast<const TableFactor&>(other));
131-
return sparse_table_.isApprox(f.sparse_table_, tol);
147+
return Base::equals(other, tol) &&
148+
sparse_table_.isApprox(f.sparse_table_, tol);
132149
}
133150
}
134151

@@ -250,7 +267,8 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
250267
for (auto&& kv : assignment) {
251268
cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
252269
}
253-
cout << " | " << it.value() << " | " << it.index() << endl;
270+
cout << " | " << std::setw(10) << std::left << it.value() << " | "
271+
<< it.index() << endl;
254272
}
255273
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
256274
}

gtsam/discrete/TableFactor.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,16 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
8080
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
8181
}
8282

83-
/// Convert probability table given as doubles to SparseVector.
84-
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
85-
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
83+
/**
84+
* Convert probability table given as doubles to SparseVector.
85+
* Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
86+
*/
87+
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
88+
const std::vector<double>& table);
8689

8790
/// Convert probability table given as string to SparseVector.
88-
static Eigen::SparseVector<double> Convert(const std::string& table);
91+
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
92+
const std::string& table);
8993

9094
public:
9195
// typedefs needed to play nice with gtsam
@@ -111,11 +115,11 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
111115

112116
/** Constructor from doubles */
113117
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
114-
: TableFactor(keys, Convert(table)) {}
118+
: TableFactor(keys, Convert(keys, table)) {}
115119

116120
/** Constructor from string */
117121
TableFactor(const DiscreteKeys& keys, const std::string& table)
118-
: TableFactor(keys, Convert(table)) {}
122+
: TableFactor(keys, Convert(keys, table)) {}
119123

120124
/// Single-key specialization
121125
template <class SOURCE>
@@ -128,6 +132,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
128132

129133
/// Constructor from DecisionTreeFactor
130134
TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf);
135+
TableFactor(const DecisionTreeFactor& dtf);
131136

132137
/// Constructor from DecisionTree<Key, double>/AlgebraicDecisionTree
133138
TableFactor(const DiscreteKeys& keys, const DecisionTree<Key, double>& dtree);

gtsam/discrete/tests/testTableFactor.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,17 @@ TEST(TableFactor, constructors) {
134134
EXPECT(assert_equal(expected, f4));
135135

136136
// Test for 9=3x3 values.
137-
DiscreteKey V(0, 3), W(1, 3);
137+
DiscreteKey V(0, 3), W(1, 3), O(100, 3);
138138
DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11");
139139
TableFactor f5(conditional5);
140-
// GTSAM_PRINT(f5);
141-
TableFactor expected_f5(
142-
X & Y,
143-
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667");
140+
141+
std::string expected_values =
142+
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667";
143+
TableFactor expected_f5(V & W, expected_values);
144144
EXPECT(assert_equal(expected_f5, f5, 1e-6));
145+
146+
TableFactor f5_with_wrong_keys(V & O, expected_values);
147+
EXPECT(assert_inequal(f5_with_wrong_keys, f5, 1e-9));
145148
}
146149

147150
/* ************************************************************************* */

0 commit comments

Comments
 (0)