-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Michael McCrackan
committed
Oct 28, 2024
1 parent
78ef713
commit cd3753b
Showing
7 changed files
with
630 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Try to find the Ceres Solver library | ||
|
||
# Look for the Ceres package | ||
find_path(CERES_INCLUDE_DIR ceres/ceres.h | ||
PATH_SUFFIXES ceres | ||
) | ||
|
||
find_library(CERES_LIBRARY ceres) | ||
|
||
# Check if Eigen is needed and available | ||
find_package(Eigen3 REQUIRED) | ||
|
||
# Check if Ceres is found | ||
if (CERES_INCLUDE_DIR AND CERES_LIBRARY AND TARGET Eigen3::Eigen) | ||
# Ceres and Eigen were found | ||
set(CERES_FOUND TRUE) | ||
set(CERES_LIBRARIES ${CERES_LIBRARY}) | ||
set(CERES_INCLUDE_DIRS ${CERES_INCLUDE_DIR} ${EIGEN3_INCLUDE_DIRS}) | ||
|
||
# Optionally, find other dependencies (like gflags and glog, which Ceres uses) | ||
find_package(Glog REQUIRED) | ||
find_package(Gflags REQUIRED) | ||
|
||
set(CERES_DEPENDENCIES ${GLOG_LIBRARIES} ${GFLAGS_LIBRARIES} Eigen3::Eigen) | ||
else() | ||
# Ceres or Eigen was not found | ||
set(CERES_FOUND FALSE) | ||
endif() | ||
|
||
# Set the results so they can be used by the project | ||
mark_as_advanced(CERES_INCLUDE_DIR CERES_LIBRARY) | ||
|
||
# Provide an interface for usage | ||
if (CERES_FOUND) | ||
message(STATUS "Found Ceres: ${CERES_INCLUDE_DIR}") | ||
message(STATUS "Found Eigen3: ${EIGEN3_INCLUDE_DIR}") | ||
else() | ||
message(WARNING "Could not find Ceres Solver or Eigen3") | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#pragma once | ||
|
||
int get_dtype(const bp::object &); | ||
|
||
template <typename T> | ||
T _calculate_median(const T*, const int); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
#pragma once | ||
|
||
#include <ceres/ceres.h> | ||
|
||
template <int Degree> | ||
struct PolynomialModel | ||
{ | ||
// Ceres requires number of params | ||
// to be known at compile time | ||
static constexpr int nparams = Degree + 1; | ||
|
||
template <typename T> | ||
static T eval(T x, const T* params) | ||
{ | ||
const T p0 = params[0]; | ||
T result = p0; | ||
for (int i = 1; i < nparams; ++i) { | ||
const T p = params[i]; | ||
result += p * ceres::pow(x, T(i)); | ||
} | ||
|
||
return result; | ||
} | ||
// Not needed for least squares as ceres | ||
// supports boundaries | ||
template <typename T> | ||
static bool check_bounds(const T* params) | ||
{ | ||
return true; | ||
} | ||
}; | ||
|
||
struct NoiseModel | ||
{ | ||
// Ceres requires number of params | ||
// to be known at compile time | ||
static constexpr int nparams = 3; | ||
|
||
template <typename T> | ||
static T eval(T f, const T* params) | ||
{ | ||
const T fknee = params[0]; | ||
const T w = params[1]; | ||
const T alpha = params[2]; | ||
|
||
return w * (1.0 + ceres::pow(fknee / f, alpha)); | ||
} | ||
|
||
// Slightly hacky way of bounds checking but is | ||
// suggested by Ceres to ensure it never goes | ||
// out of bounds | ||
template <typename T> | ||
static bool check_bounds(const T* params) | ||
{ | ||
const T w = params[1]; | ||
if (w <= 0.0) { | ||
return false; | ||
} | ||
return true; | ||
} | ||
}; | ||
|
||
// Model independent cost function for least-squares fitting | ||
template <typename Model> | ||
struct CostFunction | ||
{ | ||
using model = Model; | ||
|
||
CostFunction(int n, const double* x_data, const double* y_data) | ||
: n_pts(n), x(x_data), y(y_data) {} | ||
|
||
template <typename T> | ||
bool operator()(const T* const params, T* residual) const { | ||
for (int i = 0; i < n_pts; ++i) { | ||
T model = Model::eval(T(x[i]), params); | ||
residual[i] = T(y[i]) - model; | ||
} | ||
return true; | ||
} | ||
|
||
static ceres::Problem create(const int n, const double* xx, const double* yy, double* p) | ||
{ | ||
ceres::Problem problem; | ||
|
||
problem.AddResidualBlock( | ||
new ceres::AutoDiffCostFunction<CostFunction<Model>, | ||
ceres::DYNAMIC, Model::nparams>( | ||
new CostFunction<Model>(n, xx, yy), n), nullptr, p); | ||
|
||
return problem; | ||
} | ||
|
||
private: | ||
const int n_pts; | ||
const double* x; | ||
const double* y; | ||
}; | ||
|
||
// Model independent Negative Log Likelihood for generalized | ||
// unconstrained minimization | ||
template <typename Model> | ||
struct NegLogLikelihood | ||
{ | ||
using model = Model; | ||
|
||
NegLogLikelihood(int n, const double* x_data, const double* y_data) | ||
: n_pts(n), x(x_data), y(y_data) {} | ||
|
||
template <typename T> | ||
bool operator()(const T* const params, T* cost) const | ||
{ | ||
// Check bounds (saves a lot of time) | ||
if (!model::check_bounds(params)) { | ||
return false; | ||
} | ||
|
||
cost[0] = T(0.); | ||
for (int i = 0; i < n_pts; ++i) { | ||
T model = Model::eval(T(x[i]), params); | ||
cost[0] += ceres::log(model) + T(y[i]) / model; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
static ceres::FirstOrderFunction* create(int n, const double* xx, const double* yy) | ||
{ | ||
// Ceres takes ownership of pointers so no cleanup is required | ||
return new ceres::AutoDiffFirstOrderFunction<NegLogLikelihood<Model>, | ||
Model::nparams>(new NegLogLikelihood<Model>(n, xx, yy)); | ||
} | ||
|
||
private: | ||
const int n_pts; | ||
const double* x; | ||
const double* y; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.