Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DiscreteFactor::operator() a common base method #1925

Merged
merged 5 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ namespace gtsam {
// Construct unordered_map with values
std::vector<std::pair<DiscreteValues, double>> result;
for (const auto& assignment : assignments) {
result.emplace_back(assignment, operator()(assignment));
result.emplace_back(assignment, evaluate(assignment));
}
return result;
}
Expand Down
10 changes: 4 additions & 6 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,14 @@ namespace gtsam {
/// @name Standard Interface
/// @{

/// Calculate probability for given values `x`,
/// Calculate probability for given values,
/// is just look up in AlgebraicDecisionTree.
double evaluate(const Assignment<Key>& values) const {
virtual double evaluate(const Assignment<Key>& values) const override {
return ADT::operator()(values);
}

/// Evaluate probability distribution, sugar.
double operator()(const DiscreteValues& values) const override {
return ADT::operator()(values);
}
/// Disambiguate to use DiscreteFactor version. Mainly for wrapper
using DiscreteFactor::operator();

/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
Expand Down
4 changes: 2 additions & 2 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ class GTSAM_EXPORT DiscreteConditional
}

/// Evaluate, just look up in AlgebraicDecisionTree
double evaluate(const DiscreteValues& values) const {
virtual double evaluate(const Assignment<Key>& values) const override {
return ADT::operator()(values);
}

using DecisionTreeFactor::error; ///< DiscreteValues version
using DecisionTreeFactor::operator(); ///< DiscreteValues version
using DiscreteFactor::operator(); ///< DiscreteValues version

/**
* @brief restrict to given *parent* values.
Expand Down
15 changes: 14 additions & 1 deletion gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,21 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {

size_t cardinality(Key j) const { return cardinalities_.at(j); }

/**
* @brief Calculate probability for given values.
* Calls specialized evaluation under the hood.
*
* Note: Uses Assignment<Key> as it is the base class of DiscreteValues.
*
* @param values Discrete assignment.
* @return double
*/
virtual double evaluate(const Assignment<Key>& values) const = 0;

/// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0;
double operator()(const DiscreteValues& values) const {
return evaluate(values);
}

/// Error is just -log(value)
virtual double error(const DiscreteValues& values) const;
Expand Down
3 changes: 2 additions & 1 deletion gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
}

/* ************************************************************************ */
double TableFactor::operator()(const DiscreteValues& values) const {
double TableFactor::evaluate(const Assignment<Key>& values) const {
// a b c d => D * (C * (B * (a) + b) + c) + d
uint64_t idx = 0, card = 1;
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
Expand Down Expand Up @@ -180,6 +180,7 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
}
// NOTE(Varun): This constructor is really expensive!!
DecisionTreeFactor f(dkeys, table);
return f;
}
Expand Down
10 changes: 2 additions & 8 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
// /// @name Standard Interface
// /// @{

/// Calculate probability for given values `x`,
/// is just look up in TableFactor.
double evaluate(const DiscreteValues& values) const {
return operator()(values);
}

/// Evaluate probability distribution, sugar.
double operator()(const DiscreteValues& values) const override;
/// Evaluate probability distribution, is just look up in TableFactor.
double evaluate(const Assignment<Key>& values) const override;

/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
Expand Down
4 changes: 2 additions & 2 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);

DecisionTreeFactor(const gtsam::DiscreteConditional& c);

void print(string s = "DecisionTreeFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;

size_t cardinality(gtsam::Key j) const;

double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
size_t cardinality(gtsam::Key j) const;
Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/AllDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
}

/* ************************************************************************* */
double AllDiff::operator()(const DiscreteValues& values) const {
double AllDiff::evaluate(const Assignment<Key>& values) const {
std::set<size_t> taken; // record values taken by keys
for (Key dkey : keys_) {
size_t value = values.at(dkey); // get the value for that key
Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/AllDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
}

/// Calculate value = expensive !
double operator()(const DiscreteValues& values) const override;
double evaluate(const Assignment<Key>& values) const override;

/// Convert into a decisiontree, can be *very* expensive !
DecisionTreeFactor toDecisionTreeFactor() const override;
Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/BinaryAllDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint {
}

/// Calculate value
double operator()(const DiscreteValues& values) const override {
double evaluate(const Assignment<Key>& values) const override {
return (double)(values.at(keys_[0]) != values.at(keys_[1]));
}

Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/Domain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ string Domain::base1Str() const {
}

/* ************************************************************************* */
double Domain::operator()(const DiscreteValues& values) const {
double Domain::evaluate(const Assignment<Key>& values) const {
return contains(values.at(key()));
}

Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/Domain.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
bool contains(size_t value) const { return values_.count(value) > 0; }

/// Calculate value
double operator()(const DiscreteValues& values) const override;
double evaluate(const Assignment<Key>& values) const override;

/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/SingleValue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const {
}

/* ************************************************************************* */
double SingleValue::operator()(const DiscreteValues& values) const {
double SingleValue::evaluate(const Assignment<Key>& values) const {
return (double)(values.at(keys_[0]) == value_);
}

Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/SingleValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
}

/// Calculate value
double operator()(const DiscreteValues& values) const override;
double evaluate(const Assignment<Key>& values) const override;

/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
Expand Down
3 changes: 2 additions & 1 deletion python/gtsam/tests/test_DecisionTreeFactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

import unittest

from gtsam.utils.test_case import GtsamTestCase

from gtsam import (DecisionTreeFactor, DiscreteDistribution, DiscreteValues,
Ordering)
from gtsam.utils.test_case import GtsamTestCase


class TestDecisionTreeFactor(GtsamTestCase):
Expand Down
4 changes: 2 additions & 2 deletions python/gtsam/tests/test_DiscreteBayesTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

import gtsam
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
DiscreteConditional, DiscreteFactorGraph,
DiscreteValues, Ordering)
DiscreteConditional, DiscreteFactorGraph, DiscreteValues,
Ordering)


class TestDiscreteBayesNet(GtsamTestCase):
Expand Down
3 changes: 2 additions & 1 deletion python/gtsam/tests/test_DiscreteConditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

import unittest

from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
from gtsam.utils.test_case import GtsamTestCase

from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys

# Some DiscreteKeys for binary variables:
A = 0, 2
B = 1, 2
Expand Down
5 changes: 4 additions & 1 deletion python/gtsam/tests/test_DiscreteFactorGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import unittest

import numpy as np
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
from gtsam.utils.test_case import GtsamTestCase

from gtsam import (DecisionTreeFactor, DiscreteConditional,
DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering,
Symbol)

OrderingType = Ordering.OrderingType


Expand Down
Loading