From f95ae52aff0385b947a878beb3c3eb6258ddc2dd Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 14:16:00 -0500 Subject: [PATCH 1/2] Use TableFactor everywhere in hybrid elimination --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 703684c788..4fcd420b19 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -20,12 +20,12 @@ #include #include -#include #include #include #include #include #include +#include #include #include #include @@ -241,18 +241,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors, /* ************************************************************************ */ /** * @brief Take negative log-values, shift them so that the minimum value is 0, - * and then exponentiate to create a DecisionTreeFactor (not normalized yet!). + * and then exponentiate to create a TableFactor (not normalized yet!). * * @param errors DecisionTree of (unnormalized) errors. - * @return DecisionTreeFactor::shared_ptr + * @return TableFactor::shared_ptr */ -static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors( +static TableFactor::shared_ptr DiscreteFactorFromErrors( const DiscreteKeys &discreteKeys, const AlgebraicDecisionTree &errors) { double min_log = errors.min(); AlgebraicDecisionTree potentials( errors, [&min_log](const double x) { return exp(-(x - min_log)); }); - return std::make_shared(discreteKeys, potentials); + return std::make_shared(discreteKeys, potentials); } /* ************************************************************************ */ From 71ea8c5d4c22289c328c5bdad054037a9793b679 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 14:55:56 -0500 Subject: [PATCH 2/2] fix tests --- gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp | 6 +++--- gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 31e36101b2..1942e92347 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -114,10 +114,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) { EXPECT(HybridConditional::CheckInvariants(*result.first, values)); // Check that factor is discrete and correct - auto factor = std::dynamic_pointer_cast(result.second); + auto factor = std::dynamic_pointer_cast(result.second); CHECK(factor); // regression test - EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5)); + EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5)); } /* ************************************************************************* */ @@ -329,7 +329,7 @@ TEST(HybridBayesNet, Switching) { // Check the remaining factor for x1 CHECK(factor_x1); - auto phi_x1 = std::dynamic_pointer_cast(factor_x1); + auto phi_x1 = std::dynamic_pointer_cast(factor_x1); CHECK(phi_x1); EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0 // We can't really check the error of the decision tree factor phi_x1, because diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 07f70e95c4..6e844dbcbf 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -368,10 +368,9 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) { EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents()); // This is now a discreteFactor - auto discreteFactor = dynamic_pointer_cast(factorOnModes); + auto discreteFactor = dynamic_pointer_cast(factorOnModes); CHECK(discreteFactor); EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size()); - EXPECT(discreteFactor->root_->isLeaf() == false); } /****************************************************************************