-
Notifications
You must be signed in to change notification settings - Fork 846
TableFactor and TableDistribution #1953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
94 commits
Select commit
Hold shift + click to select a range
34fba68
use TableFactor instead of DecisionTreeFactor in discrete elimination
varunagrawal de652ea
initial DiscreteTableConditional
varunagrawal b57e448
DiscreteConditional evaluate method for conditionals
varunagrawal d18f23c
setData method
varunagrawal 4ff7014
use a TableFactor as the underlying data representation for DiscreteT…
varunagrawal b39b200
fix return type
varunagrawal d9faa82
add evaluate and getter
varunagrawal 60945c8
add override methods to DiscreteTableConditional
varunagrawal e46e9d6
use DiscreteTableConditional in EliminateDiscrete
varunagrawal b7b2734
small cleanup
varunagrawal 214043d
use DiscreteConditional shared_ptr for dynamic dispatch
varunagrawal dfec840
use TableFactor for discrete elimination
varunagrawal 5019153
small cleanup
varunagrawal 623bd63
fix hybrid tests
varunagrawal 9f85d4c
fix equals
varunagrawal 9cacb98
undo changes to DiscreteFactorGraph
varunagrawal c6e9bfc
remove unused methods
varunagrawal 34eb0fc
Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
varunagrawal 462a5b8
return DiscreteTableConditional from hybrid elimination
varunagrawal 5e1931e
update testGaussianMixture
varunagrawal 3119d13
remove evaluate method
varunagrawal 9e1c0d7
fix constructor and equals
varunagrawal 782f39a
Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
varunagrawal e854d15
evaluate needed for correct test results
varunagrawal ec5d87e
custom discreteMaxProduct
varunagrawal 2a5833b
custom ProductAndNormalize for TableFactor
varunagrawal 6f19ffd
fixed maxProduct
varunagrawal cc237a2
Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
varunagrawal e56fac2
fix TableProduct name
varunagrawal 26e1f08
fix testGaussianMixture
varunagrawal c7c42af
undo HybridBayesNet changes
varunagrawal 27e3a04
fix testHybridGaussianFactorGraph
varunagrawal cafac63
fix to use DiscreteTableConditional
varunagrawal 35502f3
custom max-product for HybridBayesTree
varunagrawal 62a6558
fix discreteMaxProduct declaration
varunagrawal 5d2d879
make asDiscrete a template
varunagrawal 4c5b842
add checks
varunagrawal e620729
fix testHybridEstimation
varunagrawal d18569b
fix testGaussianMixture
varunagrawal 769e2c7
fix testHybridMotionModel
varunagrawal da22055
formatting
varunagrawal fcc56f5
fix pruning test in testHybridBayesNet
varunagrawal f80a3a1
fix testHybridGaussianFactorGraph
varunagrawal b343a80
more helper methods in DiscreteTableConditional
varunagrawal e6db6d1
cleaner API
varunagrawal fd2820e
fix testHybridNonlinearFactorGraph
varunagrawal 446263c
Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
varunagrawal b9293b4
fix testHybridGaussianISAM
varunagrawal 8e36361
fix testHybridNonlinearISAM
varunagrawal 62a35c0
serialize table inside TableDistribution
varunagrawal 0098112
Merge branch 'hybrid-timing' into discrete-table-conditional
varunagrawal 9b1918c
rename from DiscreteTableConditional to TableDistribution
varunagrawal e1628e3
rename source files
varunagrawal 83bb404
export TableDistribution for serialization
varunagrawal 2e06954
improved docstring
varunagrawal 35e1e61
kill operator* method
varunagrawal bc449c1
formatting
varunagrawal bd30bef
remove constructors that need parents
varunagrawal f9e3280
add helpful constructors
varunagrawal 3abff90
fix tests
varunagrawal 11a740e
use template
varunagrawal b7bddde
fix TableDistribution constructor call
varunagrawal d6bc1e1
pass DiscreteConditional& for pruning instead of shared_ptr
varunagrawal 9a40be6
normalize values in sparse_table so it forms a proper distribution
varunagrawal 7cb8181
fix TableDistribution constructors in tests
varunagrawal d39641d
get rid of setData and make prune() imperative for non-factors
varunagrawal d378015
update pruning in BayesNet and BayesTree
varunagrawal 14f3254
update test
varunagrawal 5a8a942
add argmax method to TableDistribution
varunagrawal 2410d4f
use TableDistribution::argmax in discreteMaxProduct
varunagrawal 5e4cf89
max returns DiscreteFactor
varunagrawal ffc20f8
wrap TableDistribution
varunagrawal e9abd5c
wrap TableFactor
varunagrawal 9a356f1
typo fix
varunagrawal aba691d
fix python test
varunagrawal 69b5e7d
return DiscreteValues directly
varunagrawal 07a6829
code cleanup
varunagrawal bcc52be
emplace then prune
varunagrawal 77f3874
remove deleted constructors
varunagrawal 8658f25
Merge branch 'hybrid-timing' into discrete-table-conditional
varunagrawal 5913fd1
updates to get things working
varunagrawal 90825b9
remove hybrid timing flag from DiscreteFactorGraph
varunagrawal 82dba63
new scaledProduct method instead of DiscreteProduct
varunagrawal 9960f2d
kill TableProduct in favor of DiscreteFactorGraph::scaledProduct
varunagrawal 96a136b
override sum and max in TableDistribution
varunagrawal 3fb6f39
override operator/ in TableDistribution
varunagrawal 3d2dd7c
update scaledProduct docs
varunagrawal 9228f0f
fix headers
varunagrawal b81ab86
make ADT with nullptr in TableDistribution
varunagrawal 3629c33
override sample in TableDistribution
varunagrawal 9dfdf55
add hack to multiply DiscreteConditional with TableDistribution
varunagrawal 9c2ecc3
simplify multiplication
varunagrawal 4fc2387
fix relinearization in HybridNonlinearISAM
varunagrawal 3ecc232
fix tests
varunagrawal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
/* ---------------------------------------------------------------------------- | ||
|
||
* GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||
* Atlanta, Georgia 30332-0415 | ||
* All Rights Reserved | ||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||
|
||
* See LICENSE for the license information | ||
|
||
* -------------------------------------------------------------------------- */ | ||
|
||
/** | ||
* @file TableDistribution.cpp | ||
* @date Dec 22, 2024 | ||
* @author Varun Agrawal | ||
*/ | ||
|
||
#include <gtsam/base/Testable.h> | ||
#include <gtsam/base/debug.h> | ||
#include <gtsam/discrete/Ring.h> | ||
#include <gtsam/discrete/Signature.h> | ||
#include <gtsam/discrete/TableDistribution.h> | ||
#include <gtsam/hybrid/HybridValues.h> | ||
|
||
#include <algorithm> | ||
#include <cassert> | ||
#include <random> | ||
#include <set> | ||
#include <stdexcept> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
using namespace std; | ||
using std::pair; | ||
using std::stringstream; | ||
using std::vector; | ||
namespace gtsam { | ||
|
||
/// Normalize sparse_table | ||
static Eigen::SparseVector<double> normalizeSparseTable( | ||
const Eigen::SparseVector<double>& sparse_table) { | ||
return sparse_table / sparse_table.sum(); | ||
} | ||
|
||
/* ************************************************************************** */ | ||
TableDistribution::TableDistribution(const TableFactor& f) | ||
: BaseConditional(f.keys().size(), f.discreteKeys(), ADT(nullptr)), | ||
table_(f / (*std::dynamic_pointer_cast<TableFactor>( | ||
f.sum(f.keys().size())))) {} | ||
|
||
/* ************************************************************************** */ | ||
TableDistribution::TableDistribution(const DiscreteKeys& keys, | ||
const std::vector<double>& potentials) | ||
: BaseConditional(keys.size(), keys, ADT(nullptr)), | ||
table_(TableFactor( | ||
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) { | ||
} | ||
|
||
/* ************************************************************************** */ | ||
TableDistribution::TableDistribution(const DiscreteKeys& keys, | ||
const std::string& potentials) | ||
: BaseConditional(keys.size(), keys, ADT(nullptr)), | ||
table_(TableFactor( | ||
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) { | ||
} | ||
|
||
/* ************************************************************************** */ | ||
void TableDistribution::print(const string& s, | ||
const KeyFormatter& formatter) const { | ||
cout << s << " P( "; | ||
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { | ||
cout << formatter(*it) << " "; | ||
} | ||
cout << "):\n"; | ||
table_.print("", formatter); | ||
cout << endl; | ||
} | ||
|
||
/* ************************************************************************** */ | ||
bool TableDistribution::equals(const DiscreteFactor& other, double tol) const { | ||
auto dtc = dynamic_cast<const TableDistribution*>(&other); | ||
if (!dtc) { | ||
return false; | ||
} else { | ||
const DiscreteConditional& f( | ||
static_cast<const DiscreteConditional&>(other)); | ||
return table_.equals(dtc->table_, tol) && | ||
DiscreteConditional::BaseConditional::equals(f, tol); | ||
} | ||
} | ||
|
||
/* ****************************************************************************/ | ||
DiscreteFactor::shared_ptr TableDistribution::sum(size_t nrFrontals) const { | ||
return table_.sum(nrFrontals); | ||
} | ||
|
||
/* ****************************************************************************/ | ||
DiscreteFactor::shared_ptr TableDistribution::sum(const Ordering& keys) const { | ||
return table_.sum(keys); | ||
} | ||
|
||
/* ****************************************************************************/ | ||
DiscreteFactor::shared_ptr TableDistribution::max(size_t nrFrontals) const { | ||
return table_.max(nrFrontals); | ||
} | ||
|
||
/* ****************************************************************************/ | ||
DiscreteFactor::shared_ptr TableDistribution::max(const Ordering& keys) const { | ||
return table_.max(keys); | ||
} | ||
|
||
/* ****************************************************************************/ | ||
DiscreteFactor::shared_ptr TableDistribution::operator/( | ||
const DiscreteFactor::shared_ptr& f) const { | ||
return table_ / f; | ||
} | ||
|
||
/* ************************************************************************ */ | ||
DiscreteValues TableDistribution::argmax() const { | ||
uint64_t maxIdx = 0; | ||
double maxValue = 0.0; | ||
|
||
Eigen::SparseVector<double> sparseTable = table_.sparseTable(); | ||
|
||
for (SparseIt it(sparseTable); it; ++it) { | ||
if (it.value() > maxValue) { | ||
maxIdx = it.index(); | ||
maxValue = it.value(); | ||
} | ||
} | ||
|
||
return table_.findAssignments(maxIdx); | ||
} | ||
|
||
/* ****************************************************************************/ | ||
void TableDistribution::prune(size_t maxNrAssignments) { | ||
table_ = table_.prune(maxNrAssignments); | ||
} | ||
|
||
/* ****************************************************************************/ | ||
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const { | ||
static mt19937 rng(2); // random number generator | ||
|
||
DiscreteKeys parentsKeys; | ||
for (auto&& [key, _] : parentsValues) { | ||
parentsKeys.push_back({key, table_.cardinality(key)}); | ||
} | ||
|
||
// Get the correct conditional distribution: P(F|S=parentsValues) | ||
TableFactor pFS = table_.choose(parentsValues, parentsKeys); | ||
|
||
// TODO(Duy): only works for one key now, seems horribly slow this way | ||
if (nrFrontals() != 1) { | ||
throw std::invalid_argument( | ||
"TableDistribution::sample can only be called on single variable " | ||
"conditionals"); | ||
} | ||
Key key = firstFrontalKey(); | ||
size_t nj = cardinality(key); | ||
vector<double> p(nj); | ||
DiscreteValues frontals; | ||
for (size_t value = 0; value < nj; value++) { | ||
frontals[key] = value; | ||
p[value] = pFS(frontals); // P(F=value|S=parentsValues) | ||
if (p[value] == 1.0) { | ||
return value; // shortcut exit | ||
} | ||
} | ||
std::discrete_distribution<size_t> distribution(p.begin(), p.end()); | ||
return distribution(rng); | ||
} | ||
|
||
} // namespace gtsam |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.