Skip to content

Fix DecisionTreeFactor division #1962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 6, 2025
15 changes: 14 additions & 1 deletion gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#pragma once

#include <gtsam/discrete/Signature.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
Expand Down Expand Up @@ -116,6 +117,10 @@ namespace gtsam {
DecisionTreeFactor(const DiscreteKey& key, const std::vector<double>& row)
: DecisionTreeFactor(DiscreteKeys{key}, row) {}

/// Construct from Signature
DecisionTreeFactor(const Signature& signature)
: DecisionTreeFactor(signature.discreteKeys(), signature.cpt()) {}

/** Construct from a DiscreteConditional type */
explicit DecisionTreeFactor(const DiscreteConditional& c);

Expand Down Expand Up @@ -156,7 +161,15 @@ namespace gtsam {

/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);
KeyVector diff;
std::set_difference(this->keys().begin(), this->keys().end(),
f.keys().begin(), f.keys().end(),
std::back_inserter(diff));
DiscreteKeys keys;
for (Key key : diff) {
keys.push_back({key, this->cardinality(key)});
}
return DecisionTreeFactor(keys, apply(f, safe_div));
}

/// Convert into a decision tree
Expand Down
31 changes: 17 additions & 14 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ TEST(DecisionTreeFactor, constructors) {
EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, Divide) {
DiscreteKey A(0, 2), S(1, 2);
DecisionTreeFactor pA(A % "99/1"), pS(S % "50/50");
DecisionTreeFactor joint = pA * pS;
DecisionTreeFactor s = joint / pA;
EXPECT(assert_equal(pS, s));
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, Error) {
// Declare a bunch of keys
Expand Down Expand Up @@ -217,27 +226,21 @@ void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
#endif
}

/** Convert Signature into CPT */
DecisionTreeFactor create(const Signature& signature) {
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
return p;
}

/* ************************************************************************* */
// test Asia Joint
TEST(DecisionTreeFactor, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);

gttic_(asiaCPTs);
DecisionTreeFactor pA = create(A % "99/1");
DecisionTreeFactor pS = create(S % "50/50");
DecisionTreeFactor pT = create(T | A = "99/1 95/5");
DecisionTreeFactor pL = create(L | S = "99/1 90/10");
DecisionTreeFactor pB = create(B | S = "70/30 40/60");
DecisionTreeFactor pE = create((E | T, L) = "F T T T");
DecisionTreeFactor pX = create(X | E = "95/5 2/98");
DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
DecisionTreeFactor pA(A % "99/1");
DecisionTreeFactor pS(S % "50/50");
DecisionTreeFactor pT(T | A = "99/1 95/5");
DecisionTreeFactor pL(L | S = "99/1 90/10");
DecisionTreeFactor pB(B | S = "70/30 40/60");
DecisionTreeFactor pE((E | T, L) = "F T T T");
DecisionTreeFactor pX(X | E = "95/5 2/98");
DecisionTreeFactor pD((D | E, B) = "9/1 2/8 3/7 1/9");

// Create joint
gttic_(asiaJoint);
Expand Down
Loading