Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DecisionTreeFactor division #1962

Merged
merged 10 commits into from
Jan 6, 2025
10 changes: 9 additions & 1 deletion gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,15 @@ namespace gtsam {

static double safe_div(const double& a, const double& b);

/// divide by factor f (safely)
/**
* @brief Divide by factor f (safely).
* Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$
* results in a function which involves all keys
* \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$
*
* @param f The DecisinTreeFactor to divide by.
* @return DecisionTreeFactor
*/
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);
}
Expand Down
33 changes: 27 additions & 6 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
using namespace std;
using namespace gtsam;

/** Convert Signature into CPT */
DecisionTreeFactor create(const Signature& signature) {
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
return p;
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, ConstructorsMatch) {
// Declare two keys
Expand Down Expand Up @@ -105,6 +111,27 @@ TEST(DecisionTreeFactor, multiplication) {
CHECK(assert_equal(expected2, actual));
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, Divide) {
DiscreteKey A(0, 2), S(1, 2);
DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50");
DecisionTreeFactor joint = pA * pS;

DecisionTreeFactor s = joint / pA;

// Factors are not equal due to difference in keys
EXPECT(assert_inequal(pS, s));

// The underlying data should be the same
using ADT = AlgebraicDecisionTree<Key>;
EXPECT(assert_equal(ADT(pS), ADT(s)));

KeySet keys(joint.keys());
keys.insert(pA.keys().begin(), pA.keys().end());
EXPECT(assert_inequal(KeySet(pS.keys()), keys));

}

/* ************************************************************************* */
TEST(DecisionTreeFactor, sum_max) {
DiscreteKey v0(0, 3), v1(1, 2);
Expand Down Expand Up @@ -217,12 +244,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
#endif
}

/** Convert Signature into CPT */
DecisionTreeFactor create(const Signature& signature) {
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
return p;
}

/* ************************************************************************* */
// test Asia Joint
TEST(DecisionTreeFactor, joint) {
Expand Down
Loading