From 834288f9748992b24bc4d4f4cffc77c7d8461d8c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:18:24 -0500 Subject: [PATCH 01/10] additional Signature based constructor for DecisionTreeFactor --- gtsam/discrete/DecisionTreeFactor.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 80ee10a7b2..24a699d426 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -18,6 +18,7 @@ #pragma once +#include #include #include #include @@ -116,6 +117,10 @@ namespace gtsam { DecisionTreeFactor(const DiscreteKey& key, const std::vector& row) : DecisionTreeFactor(DiscreteKeys{key}, row) {} + /// Construct from Signature + DecisionTreeFactor(const Signature& signature) + : DecisionTreeFactor(signature.discreteKeys(), signature.cpt()) {} + /** Construct from a DiscreteConditional type */ explicit DecisionTreeFactor(const DiscreteConditional& c); From e6567457b511a6ff993efcf2710c98f72c71bdad Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:21:19 -0500 Subject: [PATCH 02/10] update tests --- .../discrete/tests/testDecisionTreeFactor.cpp | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 756a0cebe8..1828db5253 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -217,12 +217,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) { #endif } -/** Convert Signature into CPT */ -DecisionTreeFactor create(const Signature& signature) { - DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); - return p; -} - /* ************************************************************************* */ // test Asia Joint TEST(DecisionTreeFactor, joint) { @@ -230,14 +224,14 @@ TEST(DecisionTreeFactor, joint) { D(7, 2); gttic_(asiaCPTs); - DecisionTreeFactor pA = create(A % "99/1"); - DecisionTreeFactor pS = create(S % "50/50"); - DecisionTreeFactor pT = create(T | A = "99/1 95/5"); - DecisionTreeFactor pL = create(L | S = "99/1 90/10"); - DecisionTreeFactor pB = create(B | S = "70/30 40/60"); - DecisionTreeFactor pE = create((E | T, L) = "F T T T"); - DecisionTreeFactor pX = create(X | E = "95/5 2/98"); - DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); + DecisionTreeFactor pA(A % "99/1"); + DecisionTreeFactor pS(S % "50/50"); + DecisionTreeFactor pT(T | A = "99/1 95/5"); + DecisionTreeFactor pL(L | S = "99/1 90/10"); + DecisionTreeFactor pB(B | S = "70/30 40/60"); + DecisionTreeFactor pE((E | T, L) = "F T T T"); + DecisionTreeFactor pX(X | E = "95/5 2/98"); + DecisionTreeFactor pD((D | E, B) = "9/1 2/8 3/7 1/9"); // Create joint gttic_(asiaJoint); From cb9cec30e39895b4745a4727aa896718dcccc467 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:28:21 -0500 Subject: [PATCH 03/10] unit test exposing division bug --- gtsam/discrete/tests/testDecisionTreeFactor.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 1828db5253..73420c860b 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -69,6 +69,15 @@ TEST(DecisionTreeFactor, constructors) { EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9); } +/* ************************************************************************* */ +TEST(DecisionTreeFactor, Divide) { + DiscreteKey A(0, 2), S(1, 2); + DecisionTreeFactor pA(A % "99/1"), pS(S % "50/50"); + DecisionTreeFactor joint = pA * pS; + DecisionTreeFactor s = joint / pA; + EXPECT(assert_equal(pS, s)); +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, Error) { // Declare a bunch of keys From c6c451bee102f624cf25660ec6a820c9d5c1c49c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:28:40 -0500 Subject: [PATCH 04/10] compute correct subset of keys for division --- gtsam/discrete/DecisionTreeFactor.h | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 24a699d426..0b94140da9 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -161,7 +161,15 @@ namespace gtsam { /// divide by factor f (safely) DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { - return apply(f, safe_div); + KeyVector diff; + std::set_difference(this->keys().begin(), this->keys().end(), + f.keys().begin(), f.keys().end(), + std::back_inserter(diff)); + DiscreteKeys keys; + for (Key key : diff) { + keys.push_back({key, this->cardinality(key)}); + } + return DecisionTreeFactor(keys, apply(f, safe_div)); } /// Convert into a decision tree From ffe14d39aae1609c4d85b863a3ffd5ebca0089ac Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:42:48 -0500 Subject: [PATCH 05/10] Revert "update tests" This reverts commit e6567457b511a6ff993efcf2710c98f72c71bdad. --- .../discrete/tests/testDecisionTreeFactor.cpp | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 73420c860b..61ce9038d3 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -226,6 +226,12 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) { #endif } +/** Convert Signature into CPT */ +DecisionTreeFactor create(const Signature& signature) { + DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); + return p; +} + /* ************************************************************************* */ // test Asia Joint TEST(DecisionTreeFactor, joint) { @@ -233,14 +239,14 @@ TEST(DecisionTreeFactor, joint) { D(7, 2); gttic_(asiaCPTs); - DecisionTreeFactor pA(A % "99/1"); - DecisionTreeFactor pS(S % "50/50"); - DecisionTreeFactor pT(T | A = "99/1 95/5"); - DecisionTreeFactor pL(L | S = "99/1 90/10"); - DecisionTreeFactor pB(B | S = "70/30 40/60"); - DecisionTreeFactor pE((E | T, L) = "F T T T"); - DecisionTreeFactor pX(X | E = "95/5 2/98"); - DecisionTreeFactor pD((D | E, B) = "9/1 2/8 3/7 1/9"); + DecisionTreeFactor pA = create(A % "99/1"); + DecisionTreeFactor pS = create(S % "50/50"); + DecisionTreeFactor pT = create(T | A = "99/1 95/5"); + DecisionTreeFactor pL = create(L | S = "99/1 90/10"); + DecisionTreeFactor pB = create(B | S = "70/30 40/60"); + DecisionTreeFactor pE = create((E | T, L) = "F T T T"); + DecisionTreeFactor pX = create(X | E = "95/5 2/98"); + DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); // Create joint gttic_(asiaJoint); From be7be376a9ef95b12c08b83d72bb602cc93eb852 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 13:42:56 -0500 Subject: [PATCH 06/10] Revert "additional Signature based constructor for DecisionTreeFactor" This reverts commit 834288f9748992b24bc4d4f4cffc77c7d8461d8c. --- gtsam/discrete/DecisionTreeFactor.h | 5 ----- 1 file changed, 5 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 0b94140da9..804b956fe2 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -18,7 +18,6 @@ #pragma once -#include #include #include #include @@ -117,10 +116,6 @@ namespace gtsam { DecisionTreeFactor(const DiscreteKey& key, const std::vector& row) : DecisionTreeFactor(DiscreteKeys{key}, row) {} - /// Construct from Signature - DecisionTreeFactor(const Signature& signature) - : DecisionTreeFactor(signature.discreteKeys(), signature.cpt()) {} - /** Construct from a DiscreteConditional type */ explicit DecisionTreeFactor(const DiscreteConditional& c); From 6b6283c1512467819918c90191aa8372e96a00dd Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jan 2025 15:21:49 -0500 Subject: [PATCH 07/10] fix factor construction --- gtsam/discrete/tests/testDecisionTreeFactor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 61ce9038d3..dc18e0ab2f 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -72,7 +72,7 @@ TEST(DecisionTreeFactor, constructors) { /* ************************************************************************* */ TEST(DecisionTreeFactor, Divide) { DiscreteKey A(0, 2), S(1, 2); - DecisionTreeFactor pA(A % "99/1"), pS(S % "50/50"); + DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50"); DecisionTreeFactor joint = pA * pS; DecisionTreeFactor s = joint / pA; EXPECT(assert_equal(pS, s)); From a142556c52064b5adb2926e76e4eb7e238ed0cb5 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 08:45:11 -0500 Subject: [PATCH 08/10] move create to the top --- .../discrete/tests/testDecisionTreeFactor.cpp | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index dc18e0ab2f..7210622d8d 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -30,6 +30,12 @@ using namespace std; using namespace gtsam; +/** Convert Signature into CPT */ +DecisionTreeFactor create(const Signature& signature) { + DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); + return p; +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, ConstructorsMatch) { // Declare two keys @@ -69,15 +75,6 @@ TEST(DecisionTreeFactor, constructors) { EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9); } -/* ************************************************************************* */ -TEST(DecisionTreeFactor, Divide) { - DiscreteKey A(0, 2), S(1, 2); - DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50"); - DecisionTreeFactor joint = pA * pS; - DecisionTreeFactor s = joint / pA; - EXPECT(assert_equal(pS, s)); -} - /* ************************************************************************* */ TEST(DecisionTreeFactor, Error) { // Declare a bunch of keys @@ -114,6 +111,15 @@ TEST(DecisionTreeFactor, multiplication) { CHECK(assert_equal(expected2, actual)); } +/* ************************************************************************* */ +TEST(DecisionTreeFactor, Divide) { + DiscreteKey A(0, 2), S(1, 2); + DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50"); + DecisionTreeFactor joint = pA * pS; + DecisionTreeFactor s = joint / pA; + EXPECT(assert_equal(pS, s)); +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, sum_max) { DiscreteKey v0(0, 3), v1(1, 2); @@ -226,12 +232,6 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) { #endif } -/** Convert Signature into CPT */ -DecisionTreeFactor create(const Signature& signature) { - DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); - return p; -} - /* ************************************************************************* */ // test Asia Joint TEST(DecisionTreeFactor, joint) { From e309bf370bd195d5f4e2171cd451ca58588e9fb1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 11:54:09 -0500 Subject: [PATCH 09/10] improve operator/ documentation and also showcase understanding in test --- gtsam/discrete/DecisionTreeFactor.h | 24 +++++++++---------- .../discrete/tests/testDecisionTreeFactor.cpp | 14 ++++++++++- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 804b956fe2..a5b82f2772 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -154,17 +154,17 @@ namespace gtsam { static double safe_div(const double& a, const double& b); - /// divide by factor f (safely) + /** + * @brief Divide by factor f (safely). + * Division of a factor \f$f(x, y)\f$ by another factor \f$g(y, z)\f$ + * results in a function which involves all keys + * \f$(\frac{f}{g})(x, y, z) = f(x, y) / g(y, z)\f$ + * + * @param f The DecisinTreeFactor to divide by. + * @return DecisionTreeFactor + */ DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { - KeyVector diff; - std::set_difference(this->keys().begin(), this->keys().end(), - f.keys().begin(), f.keys().end(), - std::back_inserter(diff)); - DiscreteKeys keys; - for (Key key : diff) { - keys.push_back({key, this->cardinality(key)}); - } - return DecisionTreeFactor(keys, apply(f, safe_div)); + return apply(f, safe_div); } /// Convert into a decision tree @@ -181,12 +181,12 @@ namespace gtsam { } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { + DiscreteFactor::shared_ptr max(size_t nrFrontals) const { return combine(nrFrontals, Ring::max); } /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { + DiscreteFactor::shared_ptr max(const Ordering& keys) const { return combine(keys, Ring::max); } diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 7210622d8d..ba8714783f 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -116,8 +116,20 @@ TEST(DecisionTreeFactor, Divide) { DiscreteKey A(0, 2), S(1, 2); DecisionTreeFactor pA = create(A % "99/1"), pS = create(S % "50/50"); DecisionTreeFactor joint = pA * pS; + DecisionTreeFactor s = joint / pA; - EXPECT(assert_equal(pS, s)); + + // Factors are not equal due to difference in keys + EXPECT(assert_inequal(pS, s)); + + // The underlying data should be the same + using ADT = AlgebraicDecisionTree; + EXPECT(assert_equal(ADT(pS), ADT(s))); + + KeySet keys(joint.keys()); + keys.insert(pA.keys().begin(), pA.keys().end()); + EXPECT(assert_inequal(KeySet(pS.keys()), keys)); + } /* ************************************************************************* */ From b83aadb20487f69c6ab932245f8524c8ec92fdde Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 5 Jan 2025 15:37:37 -0500 Subject: [PATCH 10/10] remove accidental type change --- gtsam/discrete/DecisionTreeFactor.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index a5b82f2772..eb6d9eaa24 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -181,12 +181,12 @@ namespace gtsam { } /// Create new factor by maximizing over all values with the same separator. - DiscreteFactor::shared_ptr max(size_t nrFrontals) const { + shared_ptr max(size_t nrFrontals) const { return combine(nrFrontals, Ring::max); } /// Create new factor by maximizing over all values with the same separator. - DiscreteFactor::shared_ptr max(const Ordering& keys) const { + shared_ptr max(const Ordering& keys) const { return combine(keys, Ring::max); }