Skip to content

Commit

Permalink
make the Unary and Binary ops common
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Dec 9, 2024
1 parent 88b36da commit e0fedda
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 15 deletions.
12 changes: 6 additions & 6 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ namespace gtsam {
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const {
DecisionTreeFactor DecisionTreeFactor::apply(Unary op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
DecisionTreeFactor DecisionTreeFactor::apply(UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
Expand All @@ -100,7 +100,7 @@ namespace gtsam {

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const {
Binary op) const {
map<Key, size_t> cs; // new cardinalities
// make unique key-cardinality map
for (Key j : keys()) cs[j] = cardinality(j);
Expand All @@ -118,8 +118,8 @@ namespace gtsam {
}

/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const {
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
Binary op) const {
if (nrFrontals > size()) {
throw invalid_argument(
"DecisionTreeFactor::combine: invalid number of frontal "
Expand All @@ -146,7 +146,7 @@ namespace gtsam {

/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
const Ordering& frontalKeys, ADT::Binary op) const {
const Ordering& frontalKeys, Binary op) const {
if (frontalKeys.size() > size()) {
throw invalid_argument(
"DecisionTreeFactor::combine: invalid number of frontal "
Expand Down
15 changes: 10 additions & 5 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ namespace gtsam {
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;

// Needed since we have definitions in both DiscreteFactor and DecisionTree
using Base::Binary;
using Base::Unary;
using Base::UnaryAssignment;

/// @name Standard Constructors
/// @{

Expand Down Expand Up @@ -182,37 +187,37 @@ namespace gtsam {
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree
*/
DecisionTreeFactor apply(ADT::Unary op) const;
DecisionTreeFactor apply(Unary op) const;

/**
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree. Takes
* both the assignment and the value.
*/
DecisionTreeFactor apply(ADT::UnaryAssignment op) const;
DecisionTreeFactor apply(UnaryAssignment op) const;

/**
* Apply binary operator (*this) "op" f
* @param f the second argument for op
* @param op a binary operator that operates on AlgebraicDecisionTree
*/
DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const;
DecisionTreeFactor apply(const DecisionTreeFactor& f, Binary op) const;

/**
* Combine frontal variables using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on AlgebraicDecisionTree
* @return shared pointer to newly created DecisionTreeFactor
*/
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
shared_ptr combine(size_t nrFrontals, Binary op) const;

/**
* Combine frontal variables in an Ordering using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on AlgebraicDecisionTree
* @return shared pointer to newly created DecisionTreeFactor
*/
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
shared_ptr combine(const Ordering& keys, Binary op) const;

/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
Expand Down
5 changes: 5 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {

using Values = DiscreteValues; ///< backwards compatibility

using Unary = std::function<double(const double&)>;
using UnaryAssignment =
std::function<double(const Assignment<Key>&, const double&)>;
using Binary = std::function<double(const double, const double)>;

protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
Expand Down
4 changes: 0 additions & 4 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
typedef std::shared_ptr<TableFactor> shared_ptr;
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
using Unary = std::function<double(const double&)>;
using UnaryAssignment =
std::function<double(const Assignment<Key>&, const double&)>;
using Binary = std::function<double(const double, const double)>;

public:
/// @name Standard Constructors
Expand Down

0 comments on commit e0fedda

Please sign in to comment.