Skip to content

Commit

Permalink
Merge pull request #1962 from borglab/fix-dtf-division
Browse files Browse the repository at this point in the history
Fix DecisionTreeFactor division
  • Loading branch information
varunagrawal authored Jan 6, 2025
2 parents e9e52ad + b83aadb commit ffd04fd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
10 changes: 9 additions & 1 deletion gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,15 @@ namespace gtsam {

static double safe_div(const double& a, const double& b);

/// divide by factor f (safely)
/**
* @brief Divide by factor f (safely).
* Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$
* results in a function which involves all keys
* \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$
*
* @param f The DecisinTreeFactor to divide by.
* @return DecisionTreeFactor
*/
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);
}
Expand Down
33 changes: 27 additions & 6 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
using namespace std;
using namespace gtsam;

/** Convert Signature into CPT */
DecisionTreeFactor create(const Signature& signature) {
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
return p;
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, ConstructorsMatch) {
// Declare two keys
Expand Down Expand Up @@ -105,6 +111,27 @@ TEST(DecisionTreeFactor, multiplication) {
CHECK(assert_equal(expected2, actual));
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, Divide) {
DiscreteKey A(0, 2), S(1, 2);
DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50");
DecisionTreeFactor joint = pA * pS;

DecisionTreeFactor s = joint / pA;

// Factors are not equal due to difference in keys
EXPECT(assert_inequal(pS, s));

// The underlying data should be the same
using ADT = AlgebraicDecisionTree<Key>;
EXPECT(assert_equal(ADT(pS), ADT(s)));

KeySet keys(joint.keys());
keys.insert(pA.keys().begin(), pA.keys().end());
EXPECT(assert_inequal(KeySet(pS.keys()), keys));

}

/* ************************************************************************* */
TEST(DecisionTreeFactor, sum_max) {
DiscreteKey v0(0, 3), v1(1, 2);
Expand Down Expand Up @@ -217,12 +244,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
#endif
}

/** Convert Signature into CPT */
DecisionTreeFactor create(const Signature& signature) {
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
return p;
}

/* ************************************************************************* */
// test Asia Joint
TEST(DecisionTreeFactor, joint) {
Expand Down

0 comments on commit ffd04fd

Please sign in to comment.