Skip to content

Commit

Permalink
add ceres-solver fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael McCrackan committed Oct 28, 2024
1 parent 78ef713 commit cd3753b
Show file tree
Hide file tree
Showing 7 changed files with 630 additions and 1 deletion.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ find_package(PythonInterp 3)
find_package(PythonLibs 3)
find_package(FLAC)
find_package(GSL)
find_package(Ceres)

find_package(OpenMP)
if(OPENMP_FOUND)
Expand Down Expand Up @@ -67,6 +68,7 @@ add_library(so3g SHARED
src/so_linterp.cxx
src/exceptions.cxx
src/array_ops.cxx
src/fitting_ops.cxx
)

# We could disable the lib prefix on the output library... but let's not.
Expand All @@ -87,6 +89,9 @@ target_link_libraries(so3g spt3g::core)
# Link GSL
target_include_directories(so3g PRIVATE ${GSL_INCLUDE_DIR})
target_link_libraries(so3g ${GSL_LIBRARIES})
# Link Ceres
target_include_directories(so3g PRIVATE ${CERES_INCLUDE_DIRS})
target_link_libraries(so3g ${CERES_LIBRARIES} ${CERES_DEPENDENCIES})

# You probably want to select openblas, so pass -DBLA_VENDOR=OpenBLAS
find_package(BLAS REQUIRED)
Expand Down
13 changes: 12 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,22 @@ RUN apt update && apt install -y \
gfortran \
libopenblas-dev \
libbz2-dev \
python-is-python3
python-is-python3 \
libgoogle-glog-dev \
libgflags-dev \
libeigen3-dev \

# Set the working directory
WORKDIR /app_lib/so3g

# Fetch and install ceres-solver
RUN git clone https://ceres-solver.googlesource.com/ceres-solver
WORKDIR /app_lib/so3g/ceres-solver
RUN mkdir build && cd build && cmake .. && make && make install

# Set the working directory back to so3g
WORKDIR /app_lib/so3g

# Copy the current directory contents into the container
ADD . /app_lib/so3g

Expand Down
39 changes: 39 additions & 0 deletions cmake/FindCeres.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Try to find the Ceres Solver library

# Look for the Ceres package
find_path(CERES_INCLUDE_DIR ceres/ceres.h
PATH_SUFFIXES ceres
)

find_library(CERES_LIBRARY ceres)

# Check if Eigen is needed and available
find_package(Eigen3 REQUIRED)

# Check if Ceres is found
if (CERES_INCLUDE_DIR AND CERES_LIBRARY AND TARGET Eigen3::Eigen)
# Ceres and Eigen were found
set(CERES_FOUND TRUE)
set(CERES_LIBRARIES ${CERES_LIBRARY})
set(CERES_INCLUDE_DIRS ${CERES_INCLUDE_DIR} ${EIGEN3_INCLUDE_DIRS})

# Optionally, find other dependencies (like gflags and glog, which Ceres uses)
find_package(Glog REQUIRED)
find_package(Gflags REQUIRED)

set(CERES_DEPENDENCIES ${GLOG_LIBRARIES} ${GFLAGS_LIBRARIES} Eigen3::Eigen)
else()
# Ceres or Eigen was not found
set(CERES_FOUND FALSE)
endif()

# Set the results so they can be used by the project
mark_as_advanced(CERES_INCLUDE_DIR CERES_LIBRARY)

# Provide an interface for usage
if (CERES_FOUND)
message(STATUS "Found Ceres: ${CERES_INCLUDE_DIR}")
message(STATUS "Found Eigen3: ${EIGEN3_INCLUDE_DIR}")
else()
message(WARNING "Could not find Ceres Solver or Eigen3")
endif()
6 changes: 6 additions & 0 deletions include/array_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pragma once

int get_dtype(const bp::object &);

template <typename T>
T _calculate_median(const T*, const int);
137 changes: 137 additions & 0 deletions include/fitting_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#pragma once

#include <ceres/ceres.h>

template <int Degree>
struct PolynomialModel
{
// Ceres requires number of params
// to be known at compile time
static constexpr int nparams = Degree + 1;

template <typename T>
static T eval(T x, const T* params)
{
const T p0 = params[0];
T result = p0;
for (int i = 1; i < nparams; ++i) {
const T p = params[i];
result += p * ceres::pow(x, T(i));
}

return result;
}
// Not needed for least squares as ceres
// supports boundaries
template <typename T>
static bool check_bounds(const T* params)
{
return true;
}
};

struct NoiseModel
{
// Ceres requires number of params
// to be known at compile time
static constexpr int nparams = 3;

template <typename T>
static T eval(T f, const T* params)
{
const T fknee = params[0];
const T w = params[1];
const T alpha = params[2];

return w * (1.0 + ceres::pow(fknee / f, alpha));
}

// Slightly hacky way of bounds checking but is
// suggested by Ceres to ensure it never goes
// out of bounds
template <typename T>
static bool check_bounds(const T* params)
{
const T w = params[1];
if (w <= 0.0) {
return false;
}
return true;
}
};

// Model independent cost function for least-squares fitting
template <typename Model>
struct CostFunction
{
using model = Model;

CostFunction(int n, const double* x_data, const double* y_data)
: n_pts(n), x(x_data), y(y_data) {}

template <typename T>
bool operator()(const T* const params, T* residual) const {
for (int i = 0; i < n_pts; ++i) {
T model = Model::eval(T(x[i]), params);
residual[i] = T(y[i]) - model;
}
return true;
}

static ceres::Problem create(const int n, const double* xx, const double* yy, double* p)
{
ceres::Problem problem;

problem.AddResidualBlock(
new ceres::AutoDiffCostFunction<CostFunction<Model>,
ceres::DYNAMIC, Model::nparams>(
new CostFunction<Model>(n, xx, yy), n), nullptr, p);

return problem;
}

private:
const int n_pts;
const double* x;
const double* y;
};

// Model independent Negative Log Likelihood for generalized
// unconstrained minimization
template <typename Model>
struct NegLogLikelihood
{
using model = Model;

NegLogLikelihood(int n, const double* x_data, const double* y_data)
: n_pts(n), x(x_data), y(y_data) {}

template <typename T>
bool operator()(const T* const params, T* cost) const
{
// Check bounds (saves a lot of time)
if (!model::check_bounds(params)) {
return false;
}

cost[0] = T(0.);
for (int i = 0; i < n_pts; ++i) {
T model = Model::eval(T(x[i]), params);
cost[0] += ceres::log(model) + T(y[i]) / model;
}

return true;
}

static ceres::FirstOrderFunction* create(int n, const double* xx, const double* yy)
{
// Ceres takes ownership of pointers so no cleanup is required
return new ceres::AutoDiffFirstOrderFunction<NegLogLikelihood<Model>,
Model::nparams>(new NegLogLikelihood<Model>(n, xx, yy));
}

private:
const int n_pts;
const double* x;
const double* y;
};
18 changes: 18 additions & 0 deletions src/array_ops.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ extern "C" {
#include "so3g_numpy.h"
#include "numpy_assist.h"
#include "Ranges.h"
#include "array_ops.h"

// TODO: Generalize to double precision too.
// This implements Jon's noise model for ACT. It takes in
Expand Down Expand Up @@ -993,6 +994,23 @@ void interp1d_linear(const bp::object & x, const bp::object & y,
}
}

template <typename T>
T _calculate_median(const T* data, const int n)
{
// Copy to prevent overwriting input with gsl median
// Explicitly cast to double here due to gsl
std::vector<double> data_copy(n);
std::transform(data, data + n, data_copy.begin(), [](double val) {
return static_cast<double>(val);
});

// GSL is much faster than a naive std::sort implementation
return gsl_stats_median(data_copy.data(), 1, n);
}

template double _calculate_median<double>(const double* arr, int size);
template float _calculate_median<float>(const float* arr, int size);


PYBINDINGS("so3g")
{
Expand Down
Loading

0 comments on commit cd3753b

Please sign in to comment.