Skip to content

TableFactor and TableDistribution #1953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 94 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
34fba68
use TableFactor instead of DecisionTreeFactor in discrete elimination
varunagrawal Dec 27, 2024
de652ea
initial DiscreteTableConditional
varunagrawal Dec 30, 2024
b57e448
DiscreteConditional evaluate method for conditionals
varunagrawal Dec 31, 2024
d18f23c
setData method
varunagrawal Dec 31, 2024
4ff7014
use a TableFactor as the underlying data representation for DiscreteT…
varunagrawal Dec 31, 2024
b39b200
fix return type
varunagrawal Dec 31, 2024
d9faa82
add evaluate and getter
varunagrawal Dec 31, 2024
60945c8
add override methods to DiscreteTableConditional
varunagrawal Dec 31, 2024
e46e9d6
use DiscreteTableConditional in EliminateDiscrete
varunagrawal Dec 31, 2024
b7b2734
small cleanup
varunagrawal Dec 31, 2024
214043d
use DiscreteConditional shared_ptr for dynamic dispatch
varunagrawal Dec 31, 2024
dfec840
use TableFactor for discrete elimination
varunagrawal Dec 31, 2024
5019153
small cleanup
varunagrawal Dec 31, 2024
623bd63
fix hybrid tests
varunagrawal Dec 31, 2024
9f85d4c
fix equals
varunagrawal Dec 31, 2024
9cacb98
undo changes to DiscreteFactorGraph
varunagrawal Dec 31, 2024
c6e9bfc
remove unused methods
varunagrawal Dec 31, 2024
34eb0fc
Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
varunagrawal Dec 31, 2024
462a5b8
return DiscreteTableConditional from hybrid elimination
varunagrawal Dec 31, 2024
5e1931e
update testGaussianMixture
varunagrawal Dec 31, 2024
3119d13
remove evaluate method
varunagrawal Dec 31, 2024
9e1c0d7
fix constructor and equals
varunagrawal Dec 31, 2024
782f39a
Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
varunagrawal Jan 1, 2025
e854d15
evaluate needed for correct test results
varunagrawal Jan 1, 2025
ec5d87e
custom discreteMaxProduct
varunagrawal Jan 1, 2025
2a5833b
custom ProductAndNormalize for TableFactor
varunagrawal Jan 1, 2025
6f19ffd
fixed maxProduct
varunagrawal Jan 1, 2025
cc237a2
Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
varunagrawal Jan 2, 2025
e56fac2
fix TableProduct name
varunagrawal Jan 2, 2025
26e1f08
fix testGaussianMixture
varunagrawal Jan 2, 2025
c7c42af
undo HybridBayesNet changes
varunagrawal Jan 2, 2025
27e3a04
fix testHybridGaussianFactorGraph
varunagrawal Jan 2, 2025
cafac63
fix to use DiscreteTableConditional
varunagrawal Jan 2, 2025
35502f3
custom max-product for HybridBayesTree
varunagrawal Jan 2, 2025
62a6558
fix discreteMaxProduct declaration
varunagrawal Jan 2, 2025
5d2d879
make asDiscrete a template
varunagrawal Jan 2, 2025
4c5b842
add checks
varunagrawal Jan 2, 2025
e620729
fix testHybridEstimation
varunagrawal Jan 2, 2025
d18569b
fix testGaussianMixture
varunagrawal Jan 2, 2025
769e2c7
fix testHybridMotionModel
varunagrawal Jan 2, 2025
da22055
formatting
varunagrawal Jan 2, 2025
fcc56f5
fix pruning test in testHybridBayesNet
varunagrawal Jan 2, 2025
f80a3a1
fix testHybridGaussianFactorGraph
varunagrawal Jan 2, 2025
b343a80
more helper methods in DiscreteTableConditional
varunagrawal Jan 2, 2025
e6db6d1
cleaner API
varunagrawal Jan 2, 2025
fd2820e
fix testHybridNonlinearFactorGraph
varunagrawal Jan 2, 2025
446263c
Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
varunagrawal Jan 2, 2025
b9293b4
fix testHybridGaussianISAM
varunagrawal Jan 3, 2025
8e36361
fix testHybridNonlinearISAM
varunagrawal Jan 3, 2025
62a35c0
serialize table inside TableDistribution
varunagrawal Jan 3, 2025
0098112
Merge branch 'hybrid-timing' into discrete-table-conditional
varunagrawal Jan 3, 2025
9b1918c
rename from DiscreteTableConditional to TableDistribution
varunagrawal Jan 3, 2025
e1628e3
rename source files
varunagrawal Jan 3, 2025
83bb404
export TableDistribution for serialization
varunagrawal Jan 3, 2025
2e06954
improved docstring
varunagrawal Jan 3, 2025
35e1e61
kill operator* method
varunagrawal Jan 3, 2025
bc449c1
formatting
varunagrawal Jan 3, 2025
bd30bef
remove constructors that need parents
varunagrawal Jan 4, 2025
f9e3280
add helpful constructors
varunagrawal Jan 4, 2025
3abff90
fix tests
varunagrawal Jan 4, 2025
11a740e
use template
varunagrawal Jan 4, 2025
b7bddde
fix TableDistribution constructor call
varunagrawal Jan 4, 2025
d6bc1e1
pass DiscreteConditional& for pruning instead of shared_ptr
varunagrawal Jan 4, 2025
9a40be6
normalize values in sparse_table so it forms a proper distribution
varunagrawal Jan 4, 2025
7cb8181
fix TableDistribution constructors in tests
varunagrawal Jan 4, 2025
d39641d
get rid of setData and make prune() imperative for non-factors
varunagrawal Jan 4, 2025
d378015
update pruning in BayesNet and BayesTree
varunagrawal Jan 4, 2025
14f3254
update test
varunagrawal Jan 4, 2025
5a8a942
add argmax method to TableDistribution
varunagrawal Jan 4, 2025
2410d4f
use TableDistribution::argmax in discreteMaxProduct
varunagrawal Jan 4, 2025
5e4cf89
max returns DiscreteFactor
varunagrawal Jan 4, 2025
ffc20f8
wrap TableDistribution
varunagrawal Jan 4, 2025
e9abd5c
wrap TableFactor
varunagrawal Jan 4, 2025
9a356f1
typo fix
varunagrawal Jan 4, 2025
aba691d
fix python test
varunagrawal Jan 4, 2025
69b5e7d
return DiscreteValues directly
varunagrawal Jan 4, 2025
07a6829
code cleanup
varunagrawal Jan 4, 2025
bcc52be
emplace then prune
varunagrawal Jan 4, 2025
77f3874
remove deleted constructors
varunagrawal Jan 4, 2025
8658f25
Merge branch 'hybrid-timing' into discrete-table-conditional
varunagrawal Jan 7, 2025
5913fd1
updates to get things working
varunagrawal Jan 7, 2025
90825b9
remove hybrid timing flag from DiscreteFactorGraph
varunagrawal Jan 7, 2025
82dba63
new scaledProduct method instead of DiscreteProduct
varunagrawal Jan 7, 2025
9960f2d
kill TableProduct in favor of DiscreteFactorGraph::scaledProduct
varunagrawal Jan 7, 2025
96a136b
override sum and max in TableDistribution
varunagrawal Jan 7, 2025
3fb6f39
override operator/ in TableDistribution
varunagrawal Jan 7, 2025
3d2dd7c
update scaledProduct docs
varunagrawal Jan 7, 2025
9228f0f
fix headers
varunagrawal Jan 7, 2025
b81ab86
make ADT with nullptr in TableDistribution
varunagrawal Jan 7, 2025
3629c33
override sample in TableDistribution
varunagrawal Jan 7, 2025
9dfdf55
add hack to multiply DiscreteConditional with TableDistribution
varunagrawal Jan 7, 2025
9c2ecc3
simplify multiplication
varunagrawal Jan 7, 2025
4fc2387
fix relinearization in HybridNonlinearISAM
varunagrawal Jan 7, 2025
3ecc232
fix tests
varunagrawal Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ namespace gtsam {

AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}

/// Constructor which accepts root pointer
AlgebraicDecisionTree(const typename Base::NodePtr root) : Base(root) {}

// Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {}

Expand Down
20 changes: 20 additions & 0 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ DiscreteConditional::DiscreteConditional(const Signature& signature)
/* ************************************************************************** */
DiscreteConditional DiscreteConditional::operator*(
const DiscreteConditional& other) const {
// If the root is a nullptr, we have a TableDistribution
// TODO(Varun) Revisit this hack after RSS2025 submission
if (!other.root_) {
DiscreteConditional dc(other.nrFrontals(), other.toDecisionTreeFactor());
return dc * (*this);
}

// Take union of frontal keys
std::set<Key> newFrontals;
for (auto&& key : this->frontals()) newFrontals.insert(key);
Expand Down Expand Up @@ -479,6 +486,19 @@ double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->operator()(x.discrete());
}

/* ************************************************************************* */
DiscreteFactor::shared_ptr DiscreteConditional::max(
const Ordering& keys) const {
return BaseFactor::max(keys);
}

/* ************************************************************************* */
void DiscreteConditional::prune(size_t maxNrAssignments) {
// Get as DiscreteConditional so the probabilities are normalized
DiscreteConditional pruned(nrFrontals(), BaseFactor::prune(maxNrAssignments));
this->root_ = pruned.root_;
}

/* ************************************************************************* */
double DiscreteConditional::negLogConstant() const { return 0.0; }

Expand Down
14 changes: 13 additions & 1 deletion gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class GTSAM_EXPORT DiscreteConditional
* @param parentsValues Known values of the parents
* @return sample from conditional
*/
size_t sample(const DiscreteValues& parentsValues) const;
virtual size_t sample(const DiscreteValues& parentsValues) const;

/// Single parent version.
size_t sample(size_t parent_value) const;
Expand All @@ -214,6 +214,15 @@ class GTSAM_EXPORT DiscreteConditional
*/
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;

/**
* @brief Create new factor by maximizing over all
* values with the same separator.
*
* @param keys The keys to sum over.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const override;

/// @}
/// @name Advanced Interface
/// @{
Expand Down Expand Up @@ -267,6 +276,9 @@ class GTSAM_EXPORT DiscreteConditional
*/
double negLogConstant() const override;

/// Prune the conditional
virtual void prune(size_t maxNrAssignments);

/// @}

protected:
Expand Down
22 changes: 5 additions & 17 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,30 +118,18 @@ namespace gtsam {
// }
// }

/**
* @brief Multiply all the `factors`.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DiscreteFactor::shared_ptr
*/
static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) {
/* ************************************************************************ */
DiscreteFactor::shared_ptr DiscreteFactorGraph::scaledProduct() const {
// PRODUCT: multiply all factors
gttic(product);
DiscreteFactor::shared_ptr product = factors.product();
DiscreteFactor::shared_ptr product = this->product();
gttoc(product);

#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto denominator = product->max(product->size());

// Normalize the product factor to prevent underflow.
product = product->operator/(denominator);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif

return product;
}
Expand All @@ -151,7 +139,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = factors.scaledProduct();

// max out frontals, this is the factor on the separator
gttic(max);
Expand Down Expand Up @@ -229,7 +217,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = factors.scaledProduct();

// sum out frontals, this is the factor on the separator
gttic(sum);
Expand Down
9 changes: 9 additions & 0 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** return product of all factors as a single factor */
DiscreteFactor::shared_ptr product() const;

/**
* @brief Return product of all `factors` as a single factor,
* which is scaled by the max value to prevent underflow
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DiscreteFactor::shared_ptr
*/
DiscreteFactor::shared_ptr scaledProduct() const;

/**
* Evaluates the factor graph given values, returns the joint probability of
* the factor graph given specific instantiation of values
Expand Down
174 changes: 174 additions & 0 deletions gtsam/discrete/TableDistribution.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/* ----------------------------------------------------------------------------

* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)

* See LICENSE for the license information

* -------------------------------------------------------------------------- */

/**
* @file TableDistribution.cpp
* @date Dec 22, 2024
* @author Varun Agrawal
*/

#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/Ring.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridValues.h>

#include <algorithm>
#include <cassert>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>

using namespace std;
using std::pair;
using std::stringstream;
using std::vector;
namespace gtsam {

/// Normalize sparse_table
static Eigen::SparseVector<double> normalizeSparseTable(
const Eigen::SparseVector<double>& sparse_table) {
return sparse_table / sparse_table.sum();
}

/* ************************************************************************** */
TableDistribution::TableDistribution(const TableFactor& f)
: BaseConditional(f.keys().size(), f.discreteKeys(), ADT(nullptr)),
table_(f / (*std::dynamic_pointer_cast<TableFactor>(
f.sum(f.keys().size())))) {}

/* ************************************************************************** */
TableDistribution::TableDistribution(const DiscreteKeys& keys,
const std::vector<double>& potentials)
: BaseConditional(keys.size(), keys, ADT(nullptr)),
table_(TableFactor(
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
}

/* ************************************************************************** */
TableDistribution::TableDistribution(const DiscreteKeys& keys,
const std::string& potentials)
: BaseConditional(keys.size(), keys, ADT(nullptr)),
table_(TableFactor(
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
}

/* ************************************************************************** */
void TableDistribution::print(const string& s,
const KeyFormatter& formatter) const {
cout << s << " P( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << " ";
}
cout << "):\n";
table_.print("", formatter);
cout << endl;
}

/* ************************************************************************** */
bool TableDistribution::equals(const DiscreteFactor& other, double tol) const {
auto dtc = dynamic_cast<const TableDistribution*>(&other);
if (!dtc) {
return false;
} else {
const DiscreteConditional& f(
static_cast<const DiscreteConditional&>(other));
return table_.equals(dtc->table_, tol) &&
DiscreteConditional::BaseConditional::equals(f, tol);
}
}

/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::sum(size_t nrFrontals) const {
return table_.sum(nrFrontals);
}

/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::sum(const Ordering& keys) const {
return table_.sum(keys);
}

/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::max(size_t nrFrontals) const {
return table_.max(nrFrontals);
}

/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::max(const Ordering& keys) const {
return table_.max(keys);
}

/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::operator/(
const DiscreteFactor::shared_ptr& f) const {
return table_ / f;
}

/* ************************************************************************ */
DiscreteValues TableDistribution::argmax() const {
uint64_t maxIdx = 0;
double maxValue = 0.0;

Eigen::SparseVector<double> sparseTable = table_.sparseTable();

for (SparseIt it(sparseTable); it; ++it) {
if (it.value() > maxValue) {
maxIdx = it.index();
maxValue = it.value();
}
}

return table_.findAssignments(maxIdx);
}

/* ****************************************************************************/
void TableDistribution::prune(size_t maxNrAssignments) {
table_ = table_.prune(maxNrAssignments);
}

/* ****************************************************************************/
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator

DiscreteKeys parentsKeys;
for (auto&& [key, _] : parentsValues) {
parentsKeys.push_back({key, table_.cardinality(key)});
}

// Get the correct conditional distribution: P(F|S=parentsValues)
TableFactor pFS = table_.choose(parentsValues, parentsKeys);

// TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) {
throw std::invalid_argument(
"TableDistribution::sample can only be called on single variable "
"conditionals");
}
Key key = firstFrontalKey();
size_t nj = cardinality(key);
vector<double> p(nj);
DiscreteValues frontals;
for (size_t value = 0; value < nj; value++) {
frontals[key] = value;
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
if (p[value] == 1.0) {
return value; // shortcut exit
}
}
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(rng);
}

} // namespace gtsam
Loading
Loading