From 35e7acbf16548a286d3677717f8dfc1b9398d1f8 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 13:05:20 -0500 Subject: [PATCH 01/12] Print factors --- gtsam/discrete/DiscreteJunctionTree.cpp | 45 ++++++++--- gtsam/discrete/DiscreteJunctionTree.h | 99 +++++++++++++++---------- 2 files changed, 92 insertions(+), 52 deletions(-) diff --git a/gtsam/discrete/DiscreteJunctionTree.cpp b/gtsam/discrete/DiscreteJunctionTree.cpp index dc24860ebc..b4657ec8dc 100644 --- a/gtsam/discrete/DiscreteJunctionTree.cpp +++ b/gtsam/discrete/DiscreteJunctionTree.cpp @@ -16,19 +16,42 @@ * @author Richard Roberts */ -#include -#include #include +#include +#include namespace gtsam { - // Instantiate base classes - template class EliminatableClusterTree; - template class JunctionTree; - - /* ************************************************************************* */ - DiscreteJunctionTree::DiscreteJunctionTree( - const DiscreteEliminationTree& eliminationTree) : - Base(eliminationTree) {} - +// Instantiate base classes +template class EliminatableClusterTree; +template class JunctionTree; + +/* ************************************************************************* */ +DiscreteJunctionTree::DiscreteJunctionTree( + const DiscreteEliminationTree& eliminationTree) + : Base(eliminationTree) {} +/* ************************************************************************* */ +namespace { +struct PrintForestVisitorPre { + const KeyFormatter& formatter; + PrintForestVisitorPre(const KeyFormatter& formatter) : formatter(formatter) {} + std::string operator()( + const std::shared_ptr& node, + const std::string& parentString) { + // Print the current node + node->print(parentString + "-", formatter); + node->factors.print(parentString + "-", formatter); + std::cout << std::endl; + // Increment the indentation + return parentString + "| "; + } +}; +} // namespace + +void DiscreteJunctionTree::print(const std::string& s, + const KeyFormatter& keyFormatter) const { + PrintForestVisitorPre visitor(keyFormatter); + treeTraversal::DepthFirstForest(*this, s, visitor); } + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteJunctionTree.h b/gtsam/discrete/DiscreteJunctionTree.h index f6171c6727..4b92410362 100644 --- a/gtsam/discrete/DiscreteJunctionTree.h +++ b/gtsam/discrete/DiscreteJunctionTree.h @@ -18,54 +18,71 @@ #pragma once -#include #include +#include #include namespace gtsam { - // Forward declarations - class DiscreteEliminationTree; +// Forward declarations +class DiscreteEliminationTree; + +/** + * An EliminatableClusterTree, i.e., a set of variable clusters with factors, + * arranged in a tree, with the additional property that it represents the + * clique tree associated with a Bayes net. + * + * In GTSAM a junction tree is an intermediate data structure in multifrontal + * variable elimination. Each node is a cluster of factors, along with a + * clique of variables that are eliminated all at once. In detail, every node k + * represents a clique (maximal fully connected subset) of an associated chordal + * graph, such as a chordal Bayes net resulting from elimination. + * + * The difference with the BayesTree is that a JunctionTree stores factors, + * whereas a BayesTree stores conditionals, that are the product of eliminating + * the factors in the corresponding JunctionTree cliques. + * + * The tree structure and elimination method are exactly analogous to the + * EliminationTree, except that in the JunctionTree, at each node multiple + * variables are eliminated at a time. + * + * \ingroup Multifrontal + * @ingroup discrete + * \nosubgrouping + */ +class GTSAM_EXPORT DiscreteJunctionTree + : public JunctionTree { + public: + typedef JunctionTree + Base; ///< Base class + typedef DiscreteJunctionTree This; ///< This class + typedef std::shared_ptr shared_ptr; ///< Shared pointer to this class + + /// @name Constructors + /// @{ /** - * An EliminatableClusterTree, i.e., a set of variable clusters with factors, arranged in a tree, - * with the additional property that it represents the clique tree associated with a Bayes net. - * - * In GTSAM a junction tree is an intermediate data structure in multifrontal - * variable elimination. Each node is a cluster of factors, along with a - * clique of variables that are eliminated all at once. In detail, every node k represents - * a clique (maximal fully connected subset) of an associated chordal graph, such as a - * chordal Bayes net resulting from elimination. - * - * The difference with the BayesTree is that a JunctionTree stores factors, whereas a - * BayesTree stores conditionals, that are the product of eliminating the factors in the - * corresponding JunctionTree cliques. - * - * The tree structure and elimination method are exactly analogous to the EliminationTree, - * except that in the JunctionTree, at each node multiple variables are eliminated at a time. - * - * \ingroup Multifrontal - * @ingroup discrete - * \nosubgrouping + * Build the elimination tree of a factor graph using precomputed column + * structure. + * @param factorGraph The factor graph for which to build the elimination tree + * @param structure The set of factors involving each variable. If this is + * not precomputed, you can call the Create(const FactorGraph&) + * named constructor instead. + * @return The elimination tree */ - class GTSAM_EXPORT DiscreteJunctionTree : - public JunctionTree { - public: - typedef JunctionTree Base; ///< Base class - typedef DiscreteJunctionTree This; ///< This class - typedef std::shared_ptr shared_ptr; ///< Shared pointer to this class + DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); + + /// @} + /// @name Testable + /// @{ + + /** Print the tree to cout */ + void print(const std::string& name = "DiscreteJunctionTree: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const; - /** - * Build the elimination tree of a factor graph using precomputed column structure. - * @param factorGraph The factor graph for which to build the elimination tree - * @param structure The set of factors involving each variable. If this is not - * precomputed, you can call the Create(const FactorGraph&) - * named constructor instead. - * @return The elimination tree - */ - DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); - }; + /// @} +}; - /// typedef for wrapper: - using DiscreteCluster = DiscreteJunctionTree::Cluster; -} +/// typedef for wrapper: +using DiscreteCluster = DiscreteJunctionTree::Cluster; +} // namespace gtsam From d8ed60aeada546b5145ee64ec067ba04c2c0c436 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 14:01:20 -0500 Subject: [PATCH 02/12] Refactor to slots --- gtsam/discrete/DiscreteSearch.cpp | 79 ++++++++++++++++++------------- gtsam/discrete/DiscreteSearch.h | 43 ++++++++++++----- 2 files changed, 77 insertions(+), 45 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index c5941862df..e3722c13ec 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -20,6 +20,7 @@ namespace gtsam { +using Slot = DiscreteSearch::Slot; using Solution = DiscreteSearch::Solution; /** @@ -59,12 +60,12 @@ struct SearchNode { /** * @brief Expands the node by assigning the next variable. * - * @param conditional The discrete conditional representing the next variable + * @param factor The discrete factor associated with the next variable * to be assigned. * @param fa The frontal assignment for the next variable. * @return A new SearchNode representing the expanded state. */ - SearchNode expand(const DiscreteConditional& conditional, + SearchNode expand(const DiscreteFactor& factor, const DiscreteValues& fa) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; @@ -72,7 +73,7 @@ struct SearchNode { newAssignment[key] = value; } - return {newAssignment, error + conditional.error(newAssignment), 0.0, + return {newAssignment, error + factor.error(newAssignment), 0.0, nextConditional - 1}; } @@ -150,10 +151,20 @@ class Solutions { } }; +DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph) { + slots_.reserve(factorGraph.size()); + for (auto& factor : factorGraph) { + slots_.emplace_back(factor, std::vector{}, 0.0); + } + computeHeuristic(); +} + DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { - std::vector conditionals; - for (auto& factor : bayesNet) conditionals_.push_back(factor); - costToGo_ = computeCostToGo(conditionals_); + slots_.reserve(bayesNet.size()); + for (auto& conditional : bayesNet) { + slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0); + } + computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { @@ -161,22 +172,21 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { collectConditionals = [&](const auto& clique) { if (!clique) return; for (const auto& child : clique->children) collectConditionals(child); - conditionals_.push_back(clique->conditional()); + auto conditional = clique->conditional(); + slots_.emplace_back(conditional, conditional->frontalAssignments(), + 0.0); }; + + slots_.reserve(bayesTree.size()); for (const auto& root : bayesTree.roots()) collectConditionals(root); - costToGo_ = computeCostToGo(conditionals_); + computeHeuristic(); } struct SearchNodeQueue : public std::priority_queue, SearchNode::Compare> { - void expandNextNode( - const std::vector& conditionals, - const std::vector& costToGo, Solutions* solutions) { - // Pop the partial assignment with the smallest bound - SearchNode current = top(); - pop(); - + void expandNextNode(const SearchNode& current, const Slot& slot, + Solutions* solutions) { // If we already have K solutions, prune if we cannot beat the worst one. if (solutions->prune(current.bound)) { return; @@ -188,13 +198,11 @@ struct SearchNodeQueue return; } - // Expand on the next factor - const auto& conditional = conditionals[current.nextConditional]; - - for (auto& fa : conditional->frontalAssignments()) { - auto childNode = current.expand(*conditional, fa); + for (auto& fa : slot.assignments) { + auto childNode = current.expand(*slot.factor, fa); if (childNode.nextConditional >= 0) - childNode.bound = childNode.error + costToGo[childNode.nextConditional]; + // TODO(frank): this might be wrong ! + childNode.bound = childNode.error + slot.heuristic; // Again, prune if we cannot beat the worst solution if (!solutions->prune(childNode.bound)) { @@ -207,8 +215,7 @@ struct SearchNodeQueue std::vector DiscreteSearch::run(size_t K) const { Solutions solutions(K); SearchNodeQueue expansions; - expansions.push(SearchNode::Root(conditionals_.size(), - costToGo_.empty() ? 0.0 : costToGo_.back())); + expansions.push(SearchNode::Root(slots_.size(), slots_.back().heuristic)); #ifdef DISCRETE_SEARCH_DEBUG size_t numExpansions = 0; @@ -216,7 +223,13 @@ std::vector DiscreteSearch::run(size_t K) const { // Perform the search while (!expansions.empty()) { - expansions.expandNextNode(conditionals_, costToGo_, &solutions); + // Pop the partial assignment with the smallest bound + SearchNode current = expansions.top(); + expansions.pop(); + + // Get the next slot to expand + const auto& slot = slots_[current.nextConditional]; + expansions.expandNextNode(current, slot, &solutions); #ifdef DISCRETE_SEARCH_DEBUG ++numExpansions; #endif @@ -230,17 +243,19 @@ std::vector DiscreteSearch::run(size_t K) const { return solutions.extractSolutions(); } -std::vector DiscreteSearch::computeCostToGo( - const std::vector& conditionals) { - std::vector costToGo; +// We have a number of factors, each with a max value, and we want to compute +// the a lower-bound on the cost-to-go for each slot. For the first slot, this +// -log(max(factor[0])), as we only have one factor to resolve. For the second +// slot, we need to add -log(max(factor[1])) to it, etc... +void DiscreteSearch::computeHeuristic() { double error = 0.0; - for (const auto& conditional : conditionals) { - Ordering ordering(conditional->begin(), conditional->end()); - auto maxx = conditional->max(ordering); + for (size_t i = 0; i < slots_.size(); ++i) { + const auto& factor = slots_[i].factor; + Ordering ordering(factor->begin(), factor->end()); + auto maxx = factor->max(ordering); error -= std::log(maxx->evaluate({})); - costToGo.push_back(error); + slots_[i].heuristic = error; } - return costToGo; } } // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 6202880b26..ddeaf38f18 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -28,9 +28,25 @@ namespace gtsam { */ class GTSAM_EXPORT DiscreteSearch { public: - /** - * @brief A solution to a discrete search problem. - */ + /// We structure the search as a set of slots, each with a factor and + /// a set of variable assignments that need to be chosen. In addition, each + /// slot has a heuristic associated with it. + struct Slot { + /// The factors in the search problem, + /// e.g., [P(B|A),P(A)] + DiscreteFactor::shared_ptr factor; + + /// The assignments for each factor, + /// e.g., [[B0,B1] [A0,A1]] + std::vector assignments; + + /// A lower bound on the cost-to-go for each slot, e.g., + /// [-log(max_B P(B|A)), -log(max_A P(A))] + double heuristic; + }; + + /// A solution is then a set of assignments, covering all the slots. + /// as well as an associated error = -log(probability) struct Solution { double error; DiscreteValues assignment; @@ -42,13 +58,19 @@ class GTSAM_EXPORT DiscreteSearch { } }; + public: /** - * Construct from a DiscreteBayesNet and K. + * Construct from a DiscreteFactorGraph. + */ + DiscreteSearch(const DiscreteFactorGraph& bayesNet); + + /** + * Construct from a DiscreteBayesNet. */ DiscreteSearch(const DiscreteBayesNet& bayesNet); /** - * Construct from a DiscreteBayesTree and K. + * Construct from a DiscreteBayesTree. */ DiscreteSearch(const DiscreteBayesTree& bayesTree); @@ -65,14 +87,9 @@ class GTSAM_EXPORT DiscreteSearch { std::vector run(size_t K = 1) const; private: - /// Compute the cumulative cost-to-go for each conditional slot. - static std::vector computeCostToGo( - const std::vector& conditionals); - - /// Expand the next node in the search tree. - void expandNextNode() const; + /// Compute the cumulative lower-bound cost-to-go for each slot. + void computeHeuristic(); - std::vector conditionals_; - std::vector costToGo_; + std::vector slots_; }; } // namespace gtsam From 9800e110aab52433ce90494d398c1bbbf9f37523 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 14:01:31 -0500 Subject: [PATCH 03/12] Build etree and jtree --- gtsam/discrete/tests/testDiscreteSearch.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index b537dd2f08..7e715ca62f 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -18,6 +18,8 @@ #include #include +#include +#include #include #include "AsiaExample.h" @@ -28,13 +30,29 @@ using namespace gtsam; namespace asia { using namespace asia_example; static const DiscreteBayesNet bayesNet = createAsiaExample(); + +// Create factor graph and optimize with max-product for MPE static const DiscreteFactorGraph factorGraph(bayesNet); static const DiscreteValues mpe = factorGraph.optimize(); + +// Create junction tree static const Ordering ordering{D, X, B, E, L, T, S, A}; + +static const DiscreteEliminationTree etree(factorGraph, ordering); +static const DiscreteJunctionTree junctionTree(etree); + +// Create Bayes tree static const DiscreteBayesTree bayesTree = *factorGraph.eliminateMultifrontal(ordering); } // namespace asia +/* ************************************************************************* */ +TEST(DiscreteBayesNet, AsiaFactorGraphKBest) { + GTSAM_PRINT(asia::etree); + GTSAM_PRINT(asia::junctionTree); + DiscreteSearch search(asia::factorGraph); +} + /* ************************************************************************* */ TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors From c4870cc840e32a444754cd2b61ef5d0a344099c0 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 14:49:54 -0500 Subject: [PATCH 04/12] Fix heuristic --- gtsam/discrete/DiscreteSearch.cpp | 55 ++++++++++++--------- gtsam/discrete/DiscreteSearch.h | 21 ++++++-- gtsam/discrete/tests/testDiscreteSearch.cpp | 11 +---- 3 files changed, 51 insertions(+), 36 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index e3722c13ec..439331fa1b 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -16,6 +16,8 @@ * @author Frank Dellaert */ +#include +#include #include namespace gtsam { @@ -39,9 +41,8 @@ struct SearchNode { /** * @brief Construct the root node for the search. */ - static SearchNode Root(size_t numConditionals, double bound) { - return {DiscreteValues(), 0.0, bound, - static_cast(numConditionals) - 1}; + static SearchNode Root(size_t numSlots, double bound) { + return {DiscreteValues(), 0.0, bound, static_cast(numSlots) - 1}; } struct Compare { @@ -60,20 +61,18 @@ struct SearchNode { /** * @brief Expands the node by assigning the next variable. * - * @param factor The discrete factor associated with the next variable - * to be assigned. + * @param slot The slot to be filled. * @param fa The frontal assignment for the next variable. * @return A new SearchNode representing the expanded state. */ - SearchNode expand(const DiscreteFactor& factor, - const DiscreteValues& fa) const { + SearchNode expand(const Slot& slot, const DiscreteValues& fa) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; for (auto& [key, value] : fa) { newAssignment[key] = value; } - - return {newAssignment, error + factor.error(newAssignment), 0.0, + double errorSoFar = error + slot.factor->error(newAssignment); + return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextConditional - 1}; } @@ -151,12 +150,19 @@ class Solutions { } }; -DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph) { +DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph, + const Ordering& ordering, + bool buildJunctionTree) { + const DiscreteEliminationTree etree(factorGraph, ordering); + const DiscreteJunctionTree junctionTree(etree); + + // GTSAM_PRINT(asia::etree); + // GTSAM_PRINT(asia::junctionTree); slots_.reserve(factorGraph.size()); for (auto& factor : factorGraph) { slots_.emplace_back(factor, std::vector{}, 0.0); } - computeHeuristic(); + lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { @@ -164,7 +170,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { for (auto& conditional : bayesNet) { slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0); } - computeHeuristic(); + lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { @@ -179,7 +185,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { slots_.reserve(bayesTree.size()); for (const auto& root : bayesTree.roots()) collectConditionals(root); - computeHeuristic(); + lowerBound_ = computeHeuristic(); } struct SearchNodeQueue @@ -199,10 +205,7 @@ struct SearchNodeQueue } for (auto& fa : slot.assignments) { - auto childNode = current.expand(*slot.factor, fa); - if (childNode.nextConditional >= 0) - // TODO(frank): this might be wrong ! - childNode.bound = childNode.error + slot.heuristic; + auto childNode = current.expand(slot, fa); // Again, prune if we cannot beat the worst solution if (!solutions->prune(childNode.bound)) { @@ -212,10 +215,12 @@ struct SearchNodeQueue } }; +#define DISCRETE_SEARCH_DEBUG + std::vector DiscreteSearch::run(size_t K) const { Solutions solutions(K); SearchNodeQueue expansions; - expansions.push(SearchNode::Root(slots_.size(), slots_.back().heuristic)); + expansions.push(SearchNode::Root(slots_.size(), lowerBound_)); #ifdef DISCRETE_SEARCH_DEBUG size_t numExpansions = 0; @@ -244,18 +249,22 @@ std::vector DiscreteSearch::run(size_t K) const { } // We have a number of factors, each with a max value, and we want to compute -// the a lower-bound on the cost-to-go for each slot. For the first slot, this -// -log(max(factor[0])), as we only have one factor to resolve. For the second -// slot, we need to add -log(max(factor[1])) to it, etc... -void DiscreteSearch::computeHeuristic() { +// a lower-bound on the cost-to-go for each slot, *not* including this factor. +// For the first slot, this is 0.0, as this is the last slot to be filled, so +// the cost after that is zero. For the second slot, it is h0 = +// -log(max(factor[0])), because after we assign slot[1] we still need to assign +// slot[0], which will cost *at least* h0. +// We return the estimated lower bound of the cost for *all* slots. +double DiscreteSearch::computeHeuristic() { double error = 0.0; for (size_t i = 0; i < slots_.size(); ++i) { + slots_[i].heuristic = error; const auto& factor = slots_[i].factor; Ordering ordering(factor->begin(), factor->end()); auto maxx = factor->max(ordering); error -= std::log(maxx->evaluate({})); - slots_[i].heuristic = error; } + return error; } } // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index ddeaf38f18..3d5ee1d50c 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -61,8 +61,19 @@ class GTSAM_EXPORT DiscreteSearch { public: /** * Construct from a DiscreteFactorGraph. + * + * Internally creates either an elimination tree or a junction tree. The + * latter incurs more up-front computation but the search itself might be + * faster. Then again, for the elimination tree, the heuristic will be more + * fine-grained (more slots). + * + * @param factorGraph The factor graph to search over. + * @param ordering The ordering of the variables to search over. + * @param buildJunctionTree Whether to build a junction tree for the factor + * graph. */ - DiscreteSearch(const DiscreteFactorGraph& bayesNet); + DiscreteSearch(const DiscreteFactorGraph& factorGraph, + const Ordering& ordering, bool buildJunctionTree = false); /** * Construct from a DiscreteBayesNet. @@ -87,9 +98,11 @@ class GTSAM_EXPORT DiscreteSearch { std::vector run(size_t K = 1) const; private: - /// Compute the cumulative lower-bound cost-to-go for each slot. - void computeHeuristic(); + /// Compute the cumulative lower-bound cost-to-go after each slot is filled. + /// @return the estimated lower bound of the cost for *all* slots. + double computeHeuristic(); - std::vector slots_; + double lowerBound_; ///< Lower bound on the cost-to-go for the entire search. + std::vector slots_; ///< The slots to fill in the search. }; } // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 7e715ca62f..0f424bd453 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -18,8 +18,6 @@ #include #include -#include -#include #include #include "AsiaExample.h" @@ -35,12 +33,9 @@ static const DiscreteBayesNet bayesNet = createAsiaExample(); static const DiscreteFactorGraph factorGraph(bayesNet); static const DiscreteValues mpe = factorGraph.optimize(); -// Create junction tree +// Create ordering static const Ordering ordering{D, X, B, E, L, T, S, A}; -static const DiscreteEliminationTree etree(factorGraph, ordering); -static const DiscreteJunctionTree junctionTree(etree); - // Create Bayes tree static const DiscreteBayesTree bayesTree = *factorGraph.eliminateMultifrontal(ordering); @@ -48,9 +43,7 @@ static const DiscreteBayesTree bayesTree = /* ************************************************************************* */ TEST(DiscreteBayesNet, AsiaFactorGraphKBest) { - GTSAM_PRINT(asia::etree); - GTSAM_PRINT(asia::junctionTree); - DiscreteSearch search(asia::factorGraph); + DiscreteSearch search(asia::factorGraph, asia::ordering); } /* ************************************************************************* */ From 0bc566f69227c5ae3f198daaf510e331c677e7ef Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 19:06:01 -0500 Subject: [PATCH 05/12] Working etree and jtree versions --- gtsam/discrete/DiscreteSearch.cpp | 75 ++++++++++++++++++++++++------- gtsam/discrete/DiscreteSearch.h | 54 +++++++++++++++++++--- 2 files changed, 107 insertions(+), 22 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 439331fa1b..4270dfbf76 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -150,21 +150,58 @@ class Solutions { } }; -DiscreteSearch::DiscreteSearch(const DiscreteFactorGraph& factorGraph, - const Ordering& ordering, - bool buildJunctionTree) { - const DiscreteEliminationTree etree(factorGraph, ordering); - const DiscreteJunctionTree junctionTree(etree); +DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { + using NodePtr = std::shared_ptr; + auto visitor = [this](const NodePtr& node, int data) { + const auto& factors = node->factors; + const auto factor = factors.size() == 1 + ? factors.back() + : DiscreteFactorGraph(factors).product(); + const size_t cardinality = factor->cardinality(node->key); + std::vector> pairs{{node->key, cardinality}}; + slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + return data + 1; + }; - // GTSAM_PRINT(asia::etree); - // GTSAM_PRINT(asia::junctionTree); - slots_.reserve(factorGraph.size()); - for (auto& factor : factorGraph) { - slots_.emplace_back(factor, std::vector{}, 0.0); - } + const int data = 0; // unused + treeTraversal::DepthFirstForest(etree, data, visitor); + std::reverse(slots_.begin(), slots_.end()); // reverse slots + lowerBound_ = computeHeuristic(); +} + +DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { + using NodePtr = std::shared_ptr; + auto visitor = [this](const NodePtr& cluster, int data) { + const auto& factors = cluster->factors; + const auto factor = factors.size() == 1 + ? factors.back() + : DiscreteFactorGraph(factors).product(); + std::vector> pairs; + for (Key key : cluster->orderedFrontalKeys) { + pairs.emplace_back(key, factor->cardinality(key)); + } + slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + return data + 1; + }; + + const int data = 0; // unused + treeTraversal::DepthFirstForest(junctionTree, data, visitor); + std::reverse(slots_.begin(), slots_.end()); // reverse slots lowerBound_ = computeHeuristic(); } +DiscreteSearch DiscreteSearch::FromFactorGraph( + const DiscreteFactorGraph& factorGraph, const Ordering& ordering, + bool buildJunctionTree) { + const DiscreteEliminationTree etree(factorGraph, ordering); + if (buildJunctionTree) { + const DiscreteJunctionTree junctionTree(etree); + return DiscreteSearch(junctionTree); + } else { + return DiscreteSearch(etree); + } +} + DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { slots_.reserve(bayesNet.size()); for (auto& conditional : bayesNet) { @@ -188,6 +225,14 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { lowerBound_ = computeHeuristic(); } +void DiscreteSearch::print(const std::string& name, + const KeyFormatter& formatter) const { + std::cout << name << " with " << slots_.size() << " slots:\n"; + for (size_t i = 0; i < slots_.size(); ++i) { + std::cout << i << ": " << slots_[i] << std::endl; + } +} + struct SearchNodeQueue : public std::priority_queue, SearchNode::Compare> { @@ -215,8 +260,6 @@ struct SearchNodeQueue } }; -#define DISCRETE_SEARCH_DEBUG - std::vector DiscreteSearch::run(size_t K) const { Solutions solutions(K); SearchNodeQueue expansions; @@ -252,9 +295,9 @@ std::vector DiscreteSearch::run(size_t K) const { // a lower-bound on the cost-to-go for each slot, *not* including this factor. // For the first slot, this is 0.0, as this is the last slot to be filled, so // the cost after that is zero. For the second slot, it is h0 = -// -log(max(factor[0])), because after we assign slot[1] we still need to assign -// slot[0], which will cost *at least* h0. -// We return the estimated lower bound of the cost for *all* slots. +// -log(max(factor[0])), because after we assign slot[1] we still need to +// assign slot[0], which will cost *at least* h0. We return the estimated +// lower bound of the cost for *all* slots. double DiscreteSearch::computeHeuristic() { double error = 0.0; for (size_t i = 0; i < slots_.size(); ++i) { diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 3d5ee1d50c..44e605e349 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -10,7 +10,11 @@ * -------------------------------------------------------------------------- */ /* - * DiscreteSearch.cpp + * @file DiscreteSearch.h + * @brief Defines the DiscreteSearch class for discrete search algorithms. + * + * @details This file contains the definition of the DiscreteSearch class, which + * is used in discrete search algorithms to find the K best solutions. * * @date January, 2025 * @author Frank Dellaert @@ -43,6 +47,13 @@ class GTSAM_EXPORT DiscreteSearch { /// A lower bound on the cost-to-go for each slot, e.g., /// [-log(max_B P(B|A)), -log(max_A P(A))] double heuristic; + + friend std::ostream& operator<<(std::ostream& os, const Slot& slot) { + os << "Slot with " << slot.assignments.size() + << " assignments, heuristic=" << slot.heuristic; + os << ", factor:\n" << slot.factor->markdown() << std::endl; + return os; + } }; /// A solution is then a set of assignments, covering all the slots. @@ -59,6 +70,9 @@ class GTSAM_EXPORT DiscreteSearch { }; public: + /// @name Standard Constructors + /// @{ + /** * Construct from a DiscreteFactorGraph. * @@ -68,12 +82,26 @@ class GTSAM_EXPORT DiscreteSearch { * fine-grained (more slots). * * @param factorGraph The factor graph to search over. - * @param ordering The ordering of the variables to search over. - * @param buildJunctionTree Whether to build a junction tree for the factor - * graph. + * @param ordering The ordering used to create etree (and maybe jtree). + * @param buildJunctionTree Whether to build a junction tree or not. + */ + static DiscreteSearch FromFactorGraph(const DiscreteFactorGraph& factorGraph, + const Ordering& ordering, + bool buildJunctionTree = false); + + /** + * @brief Constructor from a DiscreteEliminationTree. + * + * @param etree The DiscreteEliminationTree to initialize from. */ - DiscreteSearch(const DiscreteFactorGraph& factorGraph, - const Ordering& ordering, bool buildJunctionTree = false); + DiscreteSearch(const DiscreteEliminationTree& etree); + + /** + * @brief Constructor from a DiscreteJunctionTree. + * + * @param junctionTree The DiscreteJunctionTree to initialize from. + */ + DiscreteSearch(const DiscreteJunctionTree& junctionTree); /** * Construct from a DiscreteBayesNet. @@ -85,6 +113,18 @@ class GTSAM_EXPORT DiscreteSearch { */ DiscreteSearch(const DiscreteBayesTree& bayesTree); + /// @} + /// @name Testable + /// @{ + + /** Print the tree to cout */ + void print(const std::string& name = "DiscreteSearch: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const; + + /// @} + /// @name Standard API + /// @{ + /** * @brief Search for the K best solutions. * @@ -97,6 +137,8 @@ class GTSAM_EXPORT DiscreteSearch { */ std::vector run(size_t K = 1) const; + /// @} + private: /// Compute the cumulative lower-bound cost-to-go after each slot is filled. /// @return the estimated lower bound of the cost for *all* slots. From 8d6b6055fe3254104ac69b0027db527d0d513e5a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 19:06:09 -0500 Subject: [PATCH 06/12] Loop over all variants --- gtsam/discrete/tests/testDiscreteSearch.cpp | 76 ++++++++------------- 1 file changed, 28 insertions(+), 48 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index 0f424bd453..dd53895586 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -41,11 +41,6 @@ static const DiscreteBayesTree bayesTree = *factorGraph.eliminateMultifrontal(ordering); } // namespace asia -/* ************************************************************************* */ -TEST(DiscreteBayesNet, AsiaFactorGraphKBest) { - DiscreteSearch search(asia::factorGraph, asia::ordering); -} - /* ************************************************************************* */ TEST(DiscreteBayesNet, EmptyKBest) { DiscreteBayesNet net; // no factors @@ -56,29 +51,6 @@ TEST(DiscreteBayesNet, EmptyKBest) { EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); } -/* ************************************************************************* */ -TEST(DiscreteBayesNet, AsiaKBest) { - const DiscreteSearch search(asia::bayesNet); - - // Ask for the MPE - auto mpe = search.run(); - - EXPECT_LONGS_EQUAL(1, mpe.size()); - // Regression test: check the MPE solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); - - // Check it is equal to MPE via inference - EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); - - // Ask for top 4 solutions - auto solutions = search.run(4); - - EXPECT_LONGS_EQUAL(4, solutions.size()); - // Regression test: check the first and last solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); - EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); -} - /* ************************************************************************* */ TEST(DiscreteBayesTree, EmptyTree) { DiscreteBayesTree bt; @@ -92,26 +64,34 @@ TEST(DiscreteBayesTree, EmptyTree) { } /* ************************************************************************* */ -TEST(DiscreteBayesTree, AsiaTreeKBest) { - DiscreteSearch search(asia::bayesTree); - - // Ask for MPE - auto mpe = search.run(); - - EXPECT_LONGS_EQUAL(1, mpe.size()); - // Regression test: check the MPE solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); - - // Check it is equal to MPE via inference - EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); - - // Ask for top 4 solutions - auto solutions = search.run(4); - - EXPECT_LONGS_EQUAL(4, solutions.size()); - // Regression test: check the first and last solution - EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); - EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); +TEST(DiscreteBayesNet, AsiaKBest) { + auto fromETree = + DiscreteSearch::FromFactorGraph(asia::factorGraph, asia::ordering); + auto fromJunctionTree = + DiscreteSearch::FromFactorGraph(asia::factorGraph, asia::ordering, true); + const DiscreteSearch fromBayesNet(asia::bayesNet); + const DiscreteSearch fromBayesTree(asia::bayesTree); + + for (auto& search : + {fromETree, fromJunctionTree, fromBayesNet, fromBayesTree}) { + // Ask for the MPE + auto mpe = search.run(); + + EXPECT_LONGS_EQUAL(1, mpe.size()); + // Regression test: check the MPE solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); + + // Check it is equal to MPE via inference + EXPECT(assert_equal(asia::mpe, mpe[0].assignment)); + + // Ask for top 4 solutions + auto solutions = search.run(4); + + EXPECT_LONGS_EQUAL(4, solutions.size()); + // Regression test: check the first and last solution + EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); + EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5); + } } /* ************************************************************************* */ From af07409c10f9b3a1c6742a341f47d83c81cb802e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 23:35:31 -0500 Subject: [PATCH 07/12] Fix some CI issues --- gtsam/discrete/DiscreteJunctionTree.cpp | 2 +- gtsam/discrete/DiscreteSearch.cpp | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/gtsam/discrete/DiscreteJunctionTree.cpp b/gtsam/discrete/DiscreteJunctionTree.cpp index b4657ec8dc..0c9eb10efa 100644 --- a/gtsam/discrete/DiscreteJunctionTree.cpp +++ b/gtsam/discrete/DiscreteJunctionTree.cpp @@ -51,7 +51,7 @@ struct PrintForestVisitorPre { void DiscreteJunctionTree::print(const std::string& s, const KeyFormatter& keyFormatter) const { PrintForestVisitorPre visitor(keyFormatter); - treeTraversal::DepthFirstForest(*this, s, visitor); + treeTraversal::DepthFirstForest(*this, std::string(s), visitor); } } // namespace gtsam diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 4270dfbf76..71485795e5 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -159,7 +159,8 @@ DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { : DiscreteFactorGraph(factors).product(); const size_t cardinality = factor->cardinality(node->key); std::vector> pairs{{node->key, cardinality}}; - slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + const Slot slot(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + slots_.emplace_back(std::move(slot)); return data + 1; }; @@ -180,7 +181,8 @@ DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { for (Key key : cluster->orderedFrontalKeys) { pairs.emplace_back(key, factor->cardinality(key)); } - slots_.emplace_back(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + const Slot slot(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + slots_.emplace_back(std::move(slot)); return data + 1; }; @@ -205,7 +207,8 @@ DiscreteSearch DiscreteSearch::FromFactorGraph( DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { slots_.reserve(bayesNet.size()); for (auto& conditional : bayesNet) { - slots_.emplace_back(conditional, conditional->frontalAssignments(), 0.0); + const Slot slot(conditional, conditional->frontalAssignments(), 0.0); + slots_.emplace_back(std::move(slot)); } lowerBound_ = computeHeuristic(); } @@ -216,8 +219,8 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { if (!clique) return; for (const auto& child : clique->children) collectConditionals(child); auto conditional = clique->conditional(); - slots_.emplace_back(conditional, conditional->frontalAssignments(), - 0.0); + const Slot slot(conditional, conditional->frontalAssignments(), 0.0); + slots_.emplace_back(std::move(slot)); }; slots_.reserve(bayesTree.size()); @@ -300,11 +303,10 @@ std::vector DiscreteSearch::run(size_t K) const { // lower bound of the cost for *all* slots. double DiscreteSearch::computeHeuristic() { double error = 0.0; - for (size_t i = 0; i < slots_.size(); ++i) { - slots_[i].heuristic = error; - const auto& factor = slots_[i].factor; - Ordering ordering(factor->begin(), factor->end()); - auto maxx = factor->max(ordering); + for (auto& slot : slots_) { + slot.heuristic = error; + Ordering ordering(slot.factor->begin(), slot.factor->end()); + auto maxx = slot.factor->max(ordering); error -= std::log(maxx->evaluate({})); } return error; From 460a9a958e2655c70eeb04bcfd775ffc4cce5ec3 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 23:41:35 -0500 Subject: [PATCH 08/12] Fix compilation issue --- gtsam/discrete/DiscreteJunctionTree.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteJunctionTree.cpp b/gtsam/discrete/DiscreteJunctionTree.cpp index 0c9eb10efa..e1fc2af113 100644 --- a/gtsam/discrete/DiscreteJunctionTree.cpp +++ b/gtsam/discrete/DiscreteJunctionTree.cpp @@ -51,7 +51,8 @@ struct PrintForestVisitorPre { void DiscreteJunctionTree::print(const std::string& s, const KeyFormatter& keyFormatter) const { PrintForestVisitorPre visitor(keyFormatter); - treeTraversal::DepthFirstForest(*this, std::string(s), visitor); + std::string parentString = s; + treeTraversal::DepthFirstForest(*this, parentString, visitor); } } // namespace gtsam From b8f265d69f041f2fab32da147cc8f32358852802 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 27 Jan 2025 23:50:08 -0500 Subject: [PATCH 09/12] Use brace init --- gtsam/discrete/DiscreteSearch.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 71485795e5..569887d7ff 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -159,7 +159,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { : DiscreteFactorGraph(factors).product(); const size_t cardinality = factor->cardinality(node->key); std::vector> pairs{{node->key, cardinality}}; - const Slot slot(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; slots_.emplace_back(std::move(slot)); return data + 1; }; @@ -181,7 +181,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { for (Key key : cluster->orderedFrontalKeys) { pairs.emplace_back(key, factor->cardinality(key)); } - const Slot slot(factor, DiscreteValues::CartesianProduct(pairs), 0.0); + const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; slots_.emplace_back(std::move(slot)); return data + 1; }; @@ -207,7 +207,7 @@ DiscreteSearch DiscreteSearch::FromFactorGraph( DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { slots_.reserve(bayesNet.size()); for (auto& conditional : bayesNet) { - const Slot slot(conditional, conditional->frontalAssignments(), 0.0); + const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; slots_.emplace_back(std::move(slot)); } lowerBound_ = computeHeuristic(); @@ -219,7 +219,7 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { if (!clique) return; for (const auto& child : clique->children) collectConditionals(child); auto conditional = clique->conditional(); - const Slot slot(conditional, conditional->frontalAssignments(), 0.0); + const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; slots_.emplace_back(std::move(slot)); }; From 9e98b805d6cfb07373dc8fa2d6a4ebd2de7492aa Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 00:17:45 -0500 Subject: [PATCH 10/12] Reversed slots so we start from zero --- gtsam/discrete/DiscreteSearch.cpp | 168 ++++++++------------ gtsam/discrete/DiscreteSearch.h | 6 + gtsam/discrete/tests/testDiscreteSearch.cpp | 11 ++ 3 files changed, 87 insertions(+), 98 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 569887d7ff..43f321d4ac 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -33,16 +33,16 @@ using Solution = DiscreteSearch::Solution; * conditional to be assigned. */ struct SearchNode { - DiscreteValues assignment; ///< Partial assignment of discrete variables. - double error; ///< Current error for the partial assignment. - double bound; ///< Lower bound on the final error for unassigned variables. - int nextConditional; ///< Index of the next conditional to be assigned. + DiscreteValues assignment; ///< Partial assignment of discrete variables. + double error; ///< Current error for the partial assignment. + double bound; ///< Lower bound on the final error + std::optional next; ///< Index of the next factor to be assigned. /** * @brief Construct the root node for the search. */ static SearchNode Root(size_t numSlots, double bound) { - return {DiscreteValues(), 0.0, bound, static_cast(numSlots) - 1}; + return {DiscreteValues(), 0.0, bound, 0}; } struct Compare { @@ -51,38 +51,22 @@ struct SearchNode { } }; - /** - * @brief Checks if the node represents a complete assignment. - * - * @return True if all variables have been assigned, false otherwise. - */ - inline bool isComplete() const { return nextConditional < 0; } + /// Checks if the node represents a complete assignment. + inline bool isComplete() const { return !next; } - /** - * @brief Expands the node by assigning the next variable. - * - * @param slot The slot to be filled. - * @param fa The frontal assignment for the next variable. - * @return A new SearchNode representing the expanded state. - */ - SearchNode expand(const Slot& slot, const DiscreteValues& fa) const { + /// Expands the node by assigning the next variable(s). + SearchNode expand(const DiscreteValues& fa, const Slot& slot, + std::optional nextSlot) const { // Combine the new frontal assignment with the current partial assignment DiscreteValues newAssignment = assignment; for (auto& [key, value] : fa) { newAssignment[key] = value; } double errorSoFar = error + slot.factor->error(newAssignment); - return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, - nextConditional - 1}; + return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot}; } - /** - * @brief Prints the SearchNode to an output stream. - * - * @param os The output stream. - * @param node The SearchNode to be printed. - * @return The output stream. - */ + /// Prints the SearchNode to an output stream. friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; return os; @@ -150,13 +134,18 @@ class Solutions { } }; +/// @brief Get the factor associated with a node, possibly product of factors. +template +static auto getFactor(const NodeType& node) { + const auto& factors = node->factors; + return factors.size() == 1 ? factors.back() + : DiscreteFactorGraph(factors).product(); +} + DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { using NodePtr = std::shared_ptr; auto visitor = [this](const NodePtr& node, int data) { - const auto& factors = node->factors; - const auto factor = factors.size() == 1 - ? factors.back() - : DiscreteFactorGraph(factors).product(); + const auto factor = getFactor(node); const size_t cardinality = factor->cardinality(node->key); std::vector> pairs{{node->key, cardinality}}; const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; @@ -164,19 +153,15 @@ DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { return data + 1; }; - const int data = 0; // unused + int data = 0; // unused treeTraversal::DepthFirstForest(etree, data, visitor); - std::reverse(slots_.begin(), slots_.end()); // reverse slots lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { using NodePtr = std::shared_ptr; auto visitor = [this](const NodePtr& cluster, int data) { - const auto& factors = cluster->factors; - const auto factor = factors.size() == 1 - ? factors.back() - : DiscreteFactorGraph(factors).product(); + const auto factor = getFactor(cluster); std::vector> pairs; for (Key key : cluster->orderedFrontalKeys) { pairs.emplace_back(key, factor->cardinality(key)); @@ -186,9 +171,8 @@ DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) { return data + 1; }; - const int data = 0; // unused + int data = 0; // unused treeTraversal::DepthFirstForest(junctionTree, data, visitor); - std::reverse(slots_.begin(), slots_.end()); // reverse slots lowerBound_ = computeHeuristic(); } @@ -210,21 +194,21 @@ DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) { const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; slots_.emplace_back(std::move(slot)); } + std::reverse(slots_.begin(), slots_.end()); lowerBound_ = computeHeuristic(); } DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) { - std::function - collectConditionals = [&](const auto& clique) { - if (!clique) return; - for (const auto& child : clique->children) collectConditionals(child); - auto conditional = clique->conditional(); - const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; - slots_.emplace_back(std::move(slot)); - }; - - slots_.reserve(bayesTree.size()); - for (const auto& root : bayesTree.roots()) collectConditionals(root); + using NodePtr = DiscreteBayesTree::sharedClique; + auto visitor = [this](const NodePtr& clique, int data) { + auto conditional = clique->conditional(); + const Slot slot{conditional, conditional->frontalAssignments(), 0.0}; + slots_.emplace_back(std::move(slot)); + return data + 1; + }; + + int data = 0; // unused + treeTraversal::DepthFirstForest(bayesTree, data, visitor); lowerBound_ = computeHeuristic(); } @@ -236,59 +220,48 @@ void DiscreteSearch::print(const std::string& name, } } -struct SearchNodeQueue - : public std::priority_queue, - SearchNode::Compare> { - void expandNextNode(const SearchNode& current, const Slot& slot, - Solutions* solutions) { - // If we already have K solutions, prune if we cannot beat the worst one. - if (solutions->prune(current.bound)) { - return; - } - - // Check if we have a complete assignment - if (current.isComplete()) { - solutions->maybeAdd(current.error, current.assignment); - return; - } - - for (auto& fa : slot.assignments) { - auto childNode = current.expand(slot, fa); +using SearchNodeQueue = std::priority_queue, + SearchNode::Compare>; - // Again, prune if we cannot beat the worst solution - if (!solutions->prune(childNode.bound)) { - emplace(childNode); - } - } +std::vector DiscreteSearch::run(size_t K) const { + if (slots_.empty()) { + return {Solution(0.0, DiscreteValues())}; } -}; -std::vector DiscreteSearch::run(size_t K) const { Solutions solutions(K); SearchNodeQueue expansions; expansions.push(SearchNode::Root(slots_.size(), lowerBound_)); -#ifdef DISCRETE_SEARCH_DEBUG - size_t numExpansions = 0; -#endif - // Perform the search while (!expansions.empty()) { // Pop the partial assignment with the smallest bound SearchNode current = expansions.top(); expansions.pop(); + // If we already have K solutions, prune if we cannot beat the worst one. + if (solutions.prune(current.bound)) { + continue; + } + + // Check if we have a complete assignment + if (current.isComplete()) { + solutions.maybeAdd(current.error, current.assignment); + continue; + } + // Get the next slot to expand - const auto& slot = slots_[current.nextConditional]; - expansions.expandNextNode(current, slot, &solutions); -#ifdef DISCRETE_SEARCH_DEBUG - ++numExpansions; -#endif - } + const auto& slot = slots_[*current.next]; + std::optional nextSlot = *current.next + 1; + if (nextSlot == slots_.size()) nextSlot.reset(); + for (auto& fa : slot.assignments) { + auto childNode = current.expand(fa, slot, nextSlot); -#ifdef DISCRETE_SEARCH_DEBUG - std::cout << "Number of expansions: " << numExpansions << std::endl; -#endif + // Again, prune if we cannot beat the worst solution + if (!solutions.prune(childNode.bound)) { + expansions.emplace(childNode); + } + } + } // Extract solutions from bestSolutions in ascending order of error return solutions.extractSolutions(); @@ -296,17 +269,16 @@ std::vector DiscreteSearch::run(size_t K) const { // We have a number of factors, each with a max value, and we want to compute // a lower-bound on the cost-to-go for each slot, *not* including this factor. -// For the first slot, this is 0.0, as this is the last slot to be filled, so -// the cost after that is zero. For the second slot, it is h0 = -// -log(max(factor[0])), because after we assign slot[1] we still need to -// assign slot[0], which will cost *at least* h0. We return the estimated -// lower bound of the cost for *all* slots. +// For the last slot, this is 0.0, as the cost after that is zero. +// For the second-to-last slot, it is -log(max(factor[0])), because after we +// assign slot[1] we still need to assign slot[0], which will cost *at least* +// h0. We return the estimated lower bound of the cost for *all* slots. double DiscreteSearch::computeHeuristic() { double error = 0.0; - for (auto& slot : slots_) { - slot.heuristic = error; - Ordering ordering(slot.factor->begin(), slot.factor->end()); - auto maxx = slot.factor->max(ordering); + for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) { + it->heuristic = error; + Ordering ordering(it->factor->begin(), it->factor->end()); + auto maxx = it->factor->max(ordering); error -= std::log(maxx->evaluate({})); } return error; diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 44e605e349..700e41392f 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -125,6 +125,12 @@ class GTSAM_EXPORT DiscreteSearch { /// @name Standard API /// @{ + /// Return lower bound on the cost-to-go for the entire search + double lowerBound() const { return lowerBound_; } + + /// Read access to the slots + const std::vector& slots() const { return slots_; } + /** * @brief Search for the K best solutions. * diff --git a/gtsam/discrete/tests/testDiscreteSearch.cpp b/gtsam/discrete/tests/testDiscreteSearch.cpp index dd53895586..cebddfe8de 100644 --- a/gtsam/discrete/tests/testDiscreteSearch.cpp +++ b/gtsam/discrete/tests/testDiscreteSearch.cpp @@ -77,6 +77,17 @@ TEST(DiscreteBayesNet, AsiaKBest) { // Ask for the MPE auto mpe = search.run(); + // Regression on error lower bound + EXPECT_DOUBLES_EQUAL(1.205536, search.lowerBound(), 1e-5); + + // Check that the cost-to-go heuristic decreases from there + auto slots = search.slots(); + double previousHeuristic = search.lowerBound(); + for (auto&& slot : slots) { + EXPECT(slot.heuristic <= previousHeuristic); + previousHeuristic = slot.heuristic; + } + EXPECT_LONGS_EQUAL(1, mpe.size()); // Regression test: check the MPE solution EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5); From 8c7e75bb25d0be16efada3c5ac7733ca12e58367 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 14:50:11 -0500 Subject: [PATCH 11/12] Review comments --- gtsam/discrete/DiscreteJunctionTree.cpp | 26 +++++++++---------------- gtsam/discrete/tests/AsiaExample.h | 2 +- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/gtsam/discrete/DiscreteJunctionTree.cpp b/gtsam/discrete/DiscreteJunctionTree.cpp index e1fc2af113..bf9f9fe18f 100644 --- a/gtsam/discrete/DiscreteJunctionTree.cpp +++ b/gtsam/discrete/DiscreteJunctionTree.cpp @@ -31,26 +31,18 @@ DiscreteJunctionTree::DiscreteJunctionTree( const DiscreteEliminationTree& eliminationTree) : Base(eliminationTree) {} /* ************************************************************************* */ -namespace { -struct PrintForestVisitorPre { - const KeyFormatter& formatter; - PrintForestVisitorPre(const KeyFormatter& formatter) : formatter(formatter) {} - std::string operator()( - const std::shared_ptr& node, - const std::string& parentString) { - // Print the current node - node->print(parentString + "-", formatter); - node->factors.print(parentString + "-", formatter); - std::cout << std::endl; - // Increment the indentation - return parentString + "| "; - } -}; -} // namespace void DiscreteJunctionTree::print(const std::string& s, const KeyFormatter& keyFormatter) const { - PrintForestVisitorPre visitor(keyFormatter); + auto visitor = [&keyFormatter]( + const std::shared_ptr& node, + const std::string& parentString) { + // Print the current node + node->print(parentString + "-", keyFormatter); + node->factors.print(parentString + "-", keyFormatter); + std::cout << std::endl; + return parentString + "| "; // Increment the indentation + }; std::string parentString = s; treeTraversal::DepthFirstForest(*this, parentString, visitor); } diff --git a/gtsam/discrete/tests/AsiaExample.h b/gtsam/discrete/tests/AsiaExample.h index 6c327daec8..ff6c4ea990 100644 --- a/gtsam/discrete/tests/AsiaExample.h +++ b/gtsam/discrete/tests/AsiaExample.h @@ -58,4 +58,4 @@ DiscreteBayesNet createAsiaExample() { return asia; } } // namespace asia_example -} // namespace gtsam \ No newline at end of file +} // namespace gtsam From 1afb0891437898a7fbace791a60792100402fa1c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Jan 2025 14:50:39 -0500 Subject: [PATCH 12/12] Better and more consistent documentation. --- gtsam/discrete/DiscreteSearch.cpp | 75 ++++++++++++++++--------------- gtsam/discrete/DiscreteSearch.h | 74 ++++++++++++++++-------------- 2 files changed, 79 insertions(+), 70 deletions(-) diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 43f321d4ac..c046f508f9 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -9,7 +9,7 @@ * -------------------------------------------------------------------------- */ -/* +/** * DiscreteSearch.cpp * * @date January, 2025 @@ -25,22 +25,19 @@ namespace gtsam { using Slot = DiscreteSearch::Slot; using Solution = DiscreteSearch::Solution; -/** - * @brief Represents a node in the search tree for discrete search algorithms. - * - * @details Each SearchNode contains a partial assignment of discrete variables, - * the current error, a bound on the final error, and the index of the next - * conditional to be assigned. +/* + * A SearchNode represents a node in the search tree for the search algorithm. + * Each SearchNode contains a partial assignment of discrete variables, the + * current error, a bound on the final error, and the index of the next + * slot to be assigned. */ struct SearchNode { - DiscreteValues assignment; ///< Partial assignment of discrete variables. - double error; ///< Current error for the partial assignment. - double bound; ///< Lower bound on the final error - std::optional next; ///< Index of the next factor to be assigned. - - /** - * @brief Construct the root node for the search. - */ + DiscreteValues assignment; // Partial assignment of discrete variables. + double error; // Current error for the partial assignment. + double bound; // Lower bound on the final error + std::optional next; // Index of the next slot to be assigned. + + // Construct the root node for the search. static SearchNode Root(size_t numSlots, double bound) { return {DiscreteValues(), 0.0, bound, 0}; } @@ -51,10 +48,10 @@ struct SearchNode { } }; - /// Checks if the node represents a complete assignment. + // Checks if the node represents a complete assignment. inline bool isComplete() const { return !next; } - /// Expands the node by assigning the next variable(s). + // Expands the node by assigning the next variable(s). SearchNode expand(const DiscreteValues& fa, const Slot& slot, std::optional nextSlot) const { // Combine the new frontal assignment with the current partial assignment @@ -66,7 +63,7 @@ struct SearchNode { return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot}; } - /// Prints the SearchNode to an output stream. + // Prints the SearchNode to an output stream. friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; return os; @@ -79,17 +76,20 @@ struct CompareSolution { } }; -// Define the Solutions class +/* + * A Solutions object maintains a priority queue of the best solutions found + * during the search. The priority queue is limited to a maximum size, and + * solutions are only added if they are better than the worst solution. + */ class Solutions { - private: - size_t maxSize_; + size_t maxSize_; // Maximum number of solutions to keep std::priority_queue, CompareSolution> pq_; public: Solutions(size_t maxSize) : maxSize_(maxSize) {} - /// Add a solution to the priority queue, possibly evicting the worst one. - /// Return true if we added the solution. + // Add a solution to the priority queue, possibly evicting the worst one. + // Return true if we added the solution. bool maybeAdd(double error, const DiscreteValues& assignment) { const bool full = pq_.size() == maxSize_; if (full && error >= pq_.top().error) return false; @@ -98,7 +98,7 @@ class Solutions { return true; } - /// Check if we have any solutions + // Check if we have any solutions bool empty() const { return pq_.empty(); } // Method to print all solutions @@ -112,9 +112,9 @@ class Solutions { return os; } - /// Check if (partial) solution with given bound can be pruned. If we have - /// room, we never prune. Otherwise, prune if lower bound on error is worse - /// than our current worst error. + // Check if (partial) solution with given bound can be pruned. If we have + // room, we never prune. Otherwise, prune if lower bound on error is worse + // than our current worst error. bool prune(double bound) const { if (pq_.size() < maxSize_) return false; return bound >= pq_.top().error; @@ -134,9 +134,9 @@ class Solutions { } }; -/// @brief Get the factor associated with a node, possibly product of factors. +// Get the factor associated with a node, possibly product of factors. template -static auto getFactor(const NodeType& node) { +static DiscreteFactor::shared_ptr getFactor(const NodeType& node) { const auto& factors = node->factors; return factors.size() == 1 ? factors.back() : DiscreteFactorGraph(factors).product(); @@ -145,7 +145,7 @@ static auto getFactor(const NodeType& node) { DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { using NodePtr = std::shared_ptr; auto visitor = [this](const NodePtr& node, int data) { - const auto factor = getFactor(node); + const DiscreteFactor::shared_ptr factor = getFactor(node); const size_t cardinality = factor->cardinality(node->key); std::vector> pairs{{node->key, cardinality}}; const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; @@ -266,13 +266,14 @@ std::vector DiscreteSearch::run(size_t K) const { // Extract solutions from bestSolutions in ascending order of error return solutions.extractSolutions(); } - -// We have a number of factors, each with a max value, and we want to compute -// a lower-bound on the cost-to-go for each slot, *not* including this factor. -// For the last slot, this is 0.0, as the cost after that is zero. -// For the second-to-last slot, it is -log(max(factor[0])), because after we -// assign slot[1] we still need to assign slot[0], which will cost *at least* -// h0. We return the estimated lower bound of the cost for *all* slots. +/* + * We have a number of factors, each with a max value, and we want to compute + * a lower-bound on the cost-to-go for each slot, *not* including this factor. + * For the last slot[n-1], this is 0.0, as the cost after that is zero. + * For the second-to-last slot, it is h = -log(max(factor[n-1])), because after + * we assign slot[n-2] we still need to assign slot[n-1], which will cost *at + * least* h. We return the estimated lower bound of the cost for *all* slots. + */ double DiscreteSearch::computeHeuristic() { double error = 0.0; for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) { diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 700e41392f..b610955b29 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -9,7 +9,7 @@ * -------------------------------------------------------------------------- */ -/* +/** * @file DiscreteSearch.h * @brief Defines the DiscreteSearch class for discrete search algorithms. * @@ -28,24 +28,40 @@ namespace gtsam { /** - * DiscreteSearch: Search for the K best solutions. + * @brief DiscreteSearch: Search for the K best solutions. + * + * This class is used to search for the K best solutions in a DiscreteBayesNet. + * This is implemented with a modified A* search algorithm that uses a priority + * queue to manage the search nodes. That machinery is defined in the .cpp file. + * The heuristic we use is the sum of the log-probabilities of the + * maximum-probability assignments for each slot, for all slots to the right of + * the current slot. + * + * TODO: The heuristic could be refined by using the partial assignment in + * search node to refine the max-probability assignment for the remaining slots. + * This would incur more computation but will lead to fewer expansions. */ class GTSAM_EXPORT DiscreteSearch { public: - /// We structure the search as a set of slots, each with a factor and - /// a set of variable assignments that need to be chosen. In addition, each - /// slot has a heuristic associated with it. + /** + * We structure the search as a set of slots, each with a factor and + * a set of variable assignments that need to be chosen. In addition, each + * slot has a heuristic associated with it. + * + * Example: + * The factors in the search problem (always parents before descendents!): + * [P(A), P(B|A), P(C|A,B)] + * The assignments for each factor. + * [[A0,A1], [B0,B1], [C0,C1,C2]] + * A lower bound on the cost-to-go after each slot, e.g., + * [-log(max_B P(B|A)) -log(max_C P(C|A,B)), -log(max_C P(C|A,B)), 0.0] + * Note that these decrease as we move from right to left. + * We keep the global lower bound as lowerBound_. In the example, it is: + * -log(max_B P(B|A)) -log(max_C P(C|A,B)) -log(max_C P(C|A,B)) + */ struct Slot { - /// The factors in the search problem, - /// e.g., [P(B|A),P(A)] DiscreteFactor::shared_ptr factor; - - /// The assignments for each factor, - /// e.g., [[B0,B1] [A0,A1]] std::vector assignments; - - /// A lower bound on the cost-to-go for each slot, e.g., - /// [-log(max_B P(B|A)), -log(max_A P(A))] double heuristic; friend std::ostream& operator<<(std::ostream& os, const Slot& slot) { @@ -56,8 +72,10 @@ class GTSAM_EXPORT DiscreteSearch { } }; - /// A solution is then a set of assignments, covering all the slots. - /// as well as an associated error = -log(probability) + /** + * A solution is a set of assignments, covering all the slots. + * as well as an associated error = -log(probability) + */ struct Solution { double error; DiscreteValues assignment; @@ -89,28 +107,16 @@ class GTSAM_EXPORT DiscreteSearch { const Ordering& ordering, bool buildJunctionTree = false); - /** - * @brief Constructor from a DiscreteEliminationTree. - * - * @param etree The DiscreteEliminationTree to initialize from. - */ + /// Construct from a DiscreteEliminationTree. DiscreteSearch(const DiscreteEliminationTree& etree); - /** - * @brief Constructor from a DiscreteJunctionTree. - * - * @param junctionTree The DiscreteJunctionTree to initialize from. - */ + /// Construct from a DiscreteJunctionTree. DiscreteSearch(const DiscreteJunctionTree& junctionTree); - /** - * Construct from a DiscreteBayesNet. - */ + //// Construct from a DiscreteBayesNet. DiscreteSearch(const DiscreteBayesNet& bayesNet); - /** - * Construct from a DiscreteBayesTree. - */ + /// Construct from a DiscreteBayesTree. DiscreteSearch(const DiscreteBayesTree& bayesTree); /// @} @@ -146,8 +152,10 @@ class GTSAM_EXPORT DiscreteSearch { /// @} private: - /// Compute the cumulative lower-bound cost-to-go after each slot is filled. - /// @return the estimated lower bound of the cost for *all* slots. + /** + * Compute the cumulative lower-bound cost-to-go after each slot is filled. + * @return the estimated lower bound of the cost for *all* slots. + */ double computeHeuristic(); double lowerBound_; ///< Lower bound on the cost-to-go for the entire search.