diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 80ee10a7b2..eb6d9eaa24 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -154,7 +154,15 @@ namespace gtsam { static double safe_div(const double& a, const double& b); - /// divide by factor f (safely) + /** + * @brief Divide by factor f (safely). + * Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$ + * results in a function which involves all keys + * \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$ + * + * @param f The DecisinTreeFactor to divide by. + * @return DecisionTreeFactor + */ DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { return apply(f, safe_div); } diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 756a0cebe8..ba8714783f 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -30,6 +30,12 @@ using namespace std; using namespace gtsam; +/** Convert Signature into CPT */ +DecisionTreeFactor create(const Signature& signature) { + DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); + return p; +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, ConstructorsMatch) { // Declare two keys @@ -105,6 +111,27 @@ TEST(DecisionTreeFactor, multiplication) { CHECK(assert_equal(expected2, actual)); } +/* ************************************************************************* */ +TEST(DecisionTreeFactor, Divide) { + DiscreteKey A(0, 2), S(1, 2); + DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50"); + DecisionTreeFactor joint = pA * pS; + + DecisionTreeFactor s = joint / pA; + + // Factors are not equal due to difference in keys + EXPECT(assert_inequal(pS, s)); + + // The underlying data should be the same + using ADT = AlgebraicDecisionTree; + EXPECT(assert_equal(ADT(pS), ADT(s))); + + KeySet keys(joint.keys()); + keys.insert(pA.keys().begin(), pA.keys().end()); + EXPECT(assert_inequal(KeySet(pS.keys()), keys)); + +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, sum_max) { DiscreteKey v0(0, 3), v1(1, 2); @@ -217,12 +244,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) { #endif } -/** Convert Signature into CPT */ -DecisionTreeFactor create(const Signature& signature) { - DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); - return p; -} - /* ************************************************************************* */ // test Asia Joint TEST(DecisionTreeFactor, joint) {