From e0fedda712808b557ef229b2ec4cda50f25fb743 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 9 Dec 2024 17:22:51 -0500 Subject: [PATCH] make the Unary and Binary ops common --- gtsam/discrete/DecisionTreeFactor.cpp | 12 ++++++------ gtsam/discrete/DecisionTreeFactor.h | 15 ++++++++++----- gtsam/discrete/DiscreteFactor.h | 5 +++++ gtsam/discrete/TableFactor.h | 4 ---- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 9ec3b0ac53..776d4bd904 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -83,7 +83,7 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const { + DecisionTreeFactor DecisionTreeFactor::apply(Unary op) const { // apply operand ADT result = ADT::apply(op); // Make a new factor @@ -91,7 +91,7 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const { + DecisionTreeFactor DecisionTreeFactor::apply(UnaryAssignment op) const { // apply operand ADT result = ADT::apply(op); // Make a new factor @@ -100,7 +100,7 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, - ADT::Binary op) const { + Binary op) const { map cs; // new cardinalities // make unique key-cardinality map for (Key j : keys()) cs[j] = cardinality(j); @@ -118,8 +118,8 @@ namespace gtsam { } /* ************************************************************************ */ - DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( - size_t nrFrontals, ADT::Binary op) const { + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals, + Binary op) const { if (nrFrontals > size()) { throw invalid_argument( "DecisionTreeFactor::combine: invalid number of frontal " @@ -146,7 +146,7 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( - const Ordering& frontalKeys, ADT::Binary op) const { + const Ordering& frontalKeys, Binary op) const { if (frontalKeys.size() > size()) { throw invalid_argument( "DecisionTreeFactor::combine: invalid number of frontal " diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 8f3dd85a67..9172180b10 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -51,6 +51,11 @@ namespace gtsam { typedef std::shared_ptr shared_ptr; typedef AlgebraicDecisionTree ADT; + // Needed since we have definitions in both DiscreteFactor and DecisionTree + using Base::Binary; + using Base::Unary; + using Base::UnaryAssignment; + /// @name Standard Constructors /// @{ @@ -182,21 +187,21 @@ namespace gtsam { * Apply unary operator (*this) "op" f * @param op a unary operator that operates on AlgebraicDecisionTree */ - DecisionTreeFactor apply(ADT::Unary op) const; + DecisionTreeFactor apply(Unary op) const; /** * Apply unary operator (*this) "op" f * @param op a unary operator that operates on AlgebraicDecisionTree. Takes * both the assignment and the value. */ - DecisionTreeFactor apply(ADT::UnaryAssignment op) const; + DecisionTreeFactor apply(UnaryAssignment op) const; /** * Apply binary operator (*this) "op" f * @param f the second argument for op * @param op a binary operator that operates on AlgebraicDecisionTree */ - DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const; + DecisionTreeFactor apply(const DecisionTreeFactor& f, Binary op) const; /** * Combine frontal variables using binary operator "op" @@ -204,7 +209,7 @@ namespace gtsam { * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ - shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; + shared_ptr combine(size_t nrFrontals, Binary op) const; /** * Combine frontal variables in an Ordering using binary operator "op" @@ -212,7 +217,7 @@ namespace gtsam { * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ - shared_ptr combine(const Ordering& keys, ADT::Binary op) const; + shared_ptr combine(const Ordering& keys, Binary op) const; /// Enumerate all values into a map from values to double. std::vector> enumerate() const; diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 5bc3c463f5..7175f7608a 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -46,6 +46,11 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { using Values = DiscreteValues; ///< backwards compatibility + using Unary = std::function; + using UnaryAssignment = + std::function&, const double&)>; + using Binary = std::function; + protected: /// Map of Keys and their cardinalities. std::map cardinalities_; diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 53495d0781..e7ed2294f3 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -94,10 +94,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { typedef std::shared_ptr shared_ptr; typedef Eigen::SparseVector::InnerIterator SparseIt; typedef std::vector> AssignValList; - using Unary = std::function; - using UnaryAssignment = - std::function&, const double&)>; - using Binary = std::function; public: /// @name Standard Constructors