Skip to content

Commit

Permalink
Merge pull request #1954 from borglab/hybrid-with-tablefactor
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Jan 1, 2025
2 parents 02d461e + 71ea8c5 commit 30670ab
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
10 changes: 5 additions & 5 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@

#include <gtsam/base/utilities.h>
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/hybrid/HybridFactor.h>
Expand Down Expand Up @@ -241,18 +241,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
/* ************************************************************************ */
/**
* @brief Take negative log-values, shift them so that the minimum value is 0,
* and then exponentiate to create a DecisionTreeFactor (not normalized yet!).
* and then exponentiate to create a TableFactor (not normalized yet!).
*
* @param errors DecisionTree of (unnormalized) errors.
* @return DecisionTreeFactor::shared_ptr
* @return TableFactor::shared_ptr
*/
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
static TableFactor::shared_ptr DiscreteFactorFromErrors(
const DiscreteKeys &discreteKeys,
const AlgebraicDecisionTree<Key> &errors) {
double min_log = errors.min();
AlgebraicDecisionTree<Key> potentials(
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
return std::make_shared<TableFactor>(discreteKeys, potentials);
}

/* ************************************************************************ */
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
EXPECT(HybridConditional::CheckInvariants(*result.first, values));

// Check that factor is discrete and correct
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
auto factor = std::dynamic_pointer_cast<TableFactor>(result.second);
CHECK(factor);
// regression test
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5));
}

/* ************************************************************************* */
Expand Down Expand Up @@ -329,7 +329,7 @@ TEST(HybridBayesNet, Switching) {

// Check the remaining factor for x1
CHECK(factor_x1);
auto phi_x1 = std::dynamic_pointer_cast<DecisionTreeFactor>(factor_x1);
auto phi_x1 = std::dynamic_pointer_cast<TableFactor>(factor_x1);
CHECK(phi_x1);
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
// We can't really check the error of the decision tree factor phi_x1, because
Expand Down
3 changes: 1 addition & 2 deletions gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,9 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents());

// This is now a discreteFactor
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(factorOnModes);
auto discreteFactor = dynamic_pointer_cast<TableFactor>(factorOnModes);
CHECK(discreteFactor);
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
EXPECT(discreteFactor->root_->isLeaf() == false);
}

/****************************************************************************
Expand Down

0 comments on commit 30670ab

Please sign in to comment.