Skip to content

Commit

Permalink
Adds Trust Region Reflective algorithm (#2)
Browse files Browse the repository at this point in the history
* Initial TRR implementation

* Kinda works

* Added regularization

* Small refactoring

* Somehow it works

* Finished TRR

* Adding TRR documentation
  • Loading branch information
lucaspellegrinelli authored Jun 24, 2024
1 parent 4543e89 commit dc96ff2
Show file tree
Hide file tree
Showing 8 changed files with 406 additions and 11 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,23 @@ library from Elixir under the hood to perform matrix operations.

## Which method should I use?

The library provides three functions for curve fitting: `least_squares`, `gauss_newton` and `levenberg_marquardt`.
The library provides four functions for curve fitting: `least_squares`, `gauss_newton`, `levenberg_marquardt` and `trust_region_reflective`.

### Least Squares

The `least_squares` function is just an alias for the `levenberg_marquardt` function.

### Gauss-Newton
## Levenberg-Marquardt

The `gauss_newton` function is best for least squares problems with good initial guesses and small residuals.
It is less computationally intensive and thus can be faster than the Levenberg-Marquardt method but can be unstable
with poor initial guesses or large residuals.
Ideal for non-linear least squares problems, particularly when the initial guess is far from the solution. It combines the benefits of the Gauss-Newton method and gradient descent, making it robust and efficient for various scenarios. However, it requires careful tuning of the damping parameter to balance convergence speed and stability.

### Levenberg-Marquardt
## Trust-Region Reflective

The `levenberg_marquardt` function is robust for nonlinear least squares problems, handling large residuals and poor
initial guesses effectively. It is more computationally intensive but provides reliable convergence for a wider range
of problems, especially in challenging or ill-conditioned cases.
Best suited for large-scale problems or those with constraints, this method ensures that each iteration stays within a predefined "trust region," preventing large, unstable steps. It is reliable and effective for challenging optimization problems but can be computationally intensive.

## Gauss-Newton

Efficient for problems where residuals are small and the initial guess is close to the true solution. It approximates the Hessian matrix, leading to faster convergence for well-behaved problems. However, it may struggle with highly non-linear problems or poor initial guesses, as it lacks the robustness of the Levenberg-Marquardt and trust-region reflective methods.

## Installation

Expand Down
60 changes: 60 additions & 0 deletions src/gleastsq.gleam
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gleastsq/internal/methods/gauss_newton as gn
import gleastsq/internal/methods/levenberg_marquardt as lm
import gleastsq/internal/methods/trust_region_reflective as trr
import gleastsq/internal/params.{decode_params}
import gleastsq/options.{type LeastSquareOptions}

Expand Down Expand Up @@ -134,3 +135,62 @@ pub fn gauss_newton(
) {
gn.gauss_newton(x, y, func, initial_params, decode_params(opts))
}

/// The `trust_region_reflective` function performs a least squares optimization using the Trust Region Reflective algorithm.
/// It is used to find the best-fit parameters for a given model function to a set of data points.
/// This function takes as input the data points, the model function, and several optional parameters to control the optimization process.
///
/// # Parameters
/// - `x` (List(Float))
/// A list of x-values of the data points.
/// - `y` (List(Float))
/// A list of y-values of the data points.
/// - `func` (fn(Float, List(Float)) -> Float)
/// The model function that takes an x-value and a list of parameters, and returns the corresponding y-value.
/// - `initial_params` (List(Float))
/// A list of initial guesses for the parameters of the model function.
/// - `opts` (List(LeastSquareOptions))
/// A list of optional parameters to control the optimization process.
/// The available options are:
/// - `Iterations(Int)`: The maximum number of iterations to perform. Default is 100.
/// - `Epsilon(Float)`: A small value to change x when calculating the derivatives for the function. Default is 0.0001.
/// - `Tolerance(Float)`: The convergence tolerance. Default is 0.0001.
/// - `Damping(Float)`: The value of the damping parameter. Default is 0.001.
///
/// # Example
/// ```gleam
/// import gleam/io
/// import gleastsq
/// import gleastsq/options.{Iterations, Tolerance}
///
/// fn parabola(x: Float, params: List(Float)) -> Float {
/// let assert [a, b, c] = params
/// a *. x *. x +. b *. x +. c
/// }
///
/// pub fn main() {
/// let x = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
/// let y = [0.0, 1.0, 4.0, 9.0, 16.0, 25.0]
/// let initial_guess = [1.0, 1.0, 1.0]
///
/// let assert Ok(result) =
/// gleastsq.trust_region_reflective(
/// x,
/// y,
/// parabola,
/// initial_guess,
/// opts: [Iterations(1000), Tolerance(0.001)]
/// )
///
/// io.debug(result) // [1.0, 0.0, 0.0] (within numerical error)
/// }
/// ```
pub fn trust_region_reflective(
x: List(Float),
y: List(Float),
func: fn(Float, List(Float)) -> Float,
initial_params: List(Float),
opts opts: List(LeastSquareOptions),
) {
trr.trust_region_reflective(x, y, func, initial_params, decode_params(opts))
}
184 changes: 184 additions & 0 deletions src/gleastsq/internal/methods/trust_region_reflective.gleam
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import gleam/bool
import gleam/float
import gleam/list
import gleam/option
import gleam/result
import gleastsq/errors.{
type FitErrors, JacobianTaskError, NonConverged, WrongParameters,
}
import gleastsq/internal/jacobian.{jacobian}
import gleastsq/internal/nx.{type NxTensor}
import gleastsq/internal/params.{type FitParams}

/// The `trust_region_reflective` function performs the Trust Region Reflective optimization algorithm.
/// It is used to solve non-linear least squares problems. This function takes as input the data points,
/// the model function, and several optional parameters to control the optimization process.
///
/// # Parameters
/// - `x` (List(Float))
/// A list of x-values of the data points.
/// - `y` (List(Float))
/// A list of y-values of the data points.
/// - `func` (fn(Float, List(Float)) -> Float)
/// The model function that takes an x-value and a list of parameters, and returns the corresponding y-value.
/// - `initial_params` (List(Float))
/// A list of initial guesses for the parameters of the model function.
/// - `opts` (FitParams)
/// A record with the following fields:
/// - `iterations` (Option(Int))
/// The maximum number of iterations to perform. Default is 100.
/// - `epsilon` (Option(Float))
/// The step size used to calculate the numerical gradient. Default is 0.0001.
/// - `tolerance` (Option(Float))
/// The tolerance used to stop the optimization. Default is 0.00001.
/// - `damping` (Option(Float))
/// The damping factor used to stabilize the optimization. Default is 0.0001.
pub fn trust_region_reflective(
x: List(Float),
y: List(Float),
func: fn(Float, List(Float)) -> Float,
initial_params: List(Float),
opts opts: FitParams,
) -> Result(List(Float), FitErrors) {
use <- bool.guard(
list.length(x) != list.length(y),
Error(WrongParameters("x and y must have the same length")),
)

let x = nx.tensor(x) |> nx.to_list_1d
let y = nx.tensor(y)
let iter = option.unwrap(opts.iterations, 100)
let eps = option.unwrap(opts.epsilon, 0.0001)
let tol = option.unwrap(opts.tolerance, 0.00001)
let reg = option.unwrap(opts.damping, 0.0001)
let delta = 1.0

use fitted <- result.try(do_trust_region_reflective(
x,
y,
func,
initial_params,
iter,
eps,
tol,
delta,
reg,
))
Ok(fitted)
}

fn ternary(cond: Bool, a: a, b: a) -> a {
bool.guard(cond, a, fn() { b })
}

fn dogleg(j: NxTensor, g: NxTensor, b: NxTensor, delta: Float) -> NxTensor {
let jt = nx.transpose(j)
let jg = nx.dot(j, g)
let p_u_numerator = nx.negate(nx.dot(g, g))
let p_u_denominator = nx.dot(g, nx.dot(jt, jg))
let p_u = nx.multiply_mat(nx.divide_mat(p_u_numerator, p_u_denominator), g)

let p_b = nx.negate(nx.solve(b, g))

let p_b_norm = nx.norm(p_b) |> nx.to_number
let p_u_norm = nx.norm(p_u) |> nx.to_number

use <- bool.guard(p_b_norm <=. delta, p_b)
use <- bool.guard(p_u_norm >=. delta, nx.multiply(p_u, delta /. p_u_norm))

let p_b_u = nx.subtract(p_b, p_u)
let assert Ok(delta_sq) = float.power(delta, 2.0)
let assert Ok(u_norm_sq) = float.power(p_u_norm, 2.0)
let assert Ok(d_pu_sqrt) = float.square_root(delta_sq -. u_norm_sq)
let pb_u_norm = nx.norm(p_b_u) |> nx.to_number
let pc_factor = d_pu_sqrt /. pb_u_norm
let p_c = nx.add(p_u, nx.multiply(p_b_u, pc_factor))
p_c
}

pub fn rho(
x: List(Float),
y: NxTensor,
func: fn(Float, List(Float)) -> Float,
params: List(Float),
p: NxTensor,
g: NxTensor,
) -> Float {
let fx = list.map(x, func(_, params)) |> nx.tensor

let offset_params =
list.zip(params, nx.to_list_1d(p))
|> list.map(fn(p) { p.0 +. p.1 })

let fx_p = list.map(x, func(_, offset_params)) |> nx.tensor

let fx_diff = nx.pow(nx.subtract(fx, y), 2.0)
let fxp_diff = nx.pow(nx.subtract(fx_p, y), 2.0)

let actual_reduction_sum =
nx.sum(nx.subtract(fx_diff, fxp_diff)) |> nx.to_number
let actual_reduction = 0.5 *. actual_reduction_sum

let g_dot_p = nx.dot(g, p) |> nx.to_number
let predicted_reduction = -0.5 *. g_dot_p

actual_reduction /. predicted_reduction
}

fn do_trust_region_reflective(
x: List(Float),
y: NxTensor,
func: fn(Float, List(Float)) -> Float,
params: List(Float),
iterations: Int,
epsilon: Float,
tolerance: Float,
delta: Float,
lambda_reg: Float,
) {
use <- bool.guard(iterations == 0, Error(NonConverged))
let m = list.length(params)

let f = list.map(x, func(_, params)) |> nx.tensor
let r = nx.subtract(f, y)
use j <- result.try(result.replace_error(
jacobian(x, f, func, params, epsilon),
JacobianTaskError,
))

let lambda_eye = nx.eye(m) |> nx.multiply(lambda_reg)
let jt = nx.transpose(j)
let b = nx.add(nx.dot(jt, j), lambda_eye)
let g = nx.dot(jt, r)

let g_norm = nx.norm(g) |> nx.to_number
use <- bool.guard(g_norm <. tolerance, Ok(params))

let p = dogleg(j, g, b, delta)
let rho = rho(x, y, func, params, p, g)

let p_norm = nx.norm(p) |> nx.to_number
use <- bool.guard(p_norm <. tolerance, Ok(params))

let new_delta = case rho {
x if x >. 0.75 -> float.max(delta, 2.0 *. nx.to_number(nx.norm(p)))
x if x <. 0.25 -> delta *. 0.5
_ -> delta
}

let new_params =
list.zip(params, nx.to_list_1d(p))
|> list.map(fn(p) { p.0 +. p.1 })

do_trust_region_reflective(
x,
y,
func,
ternary(rho >. 0.0, new_params, params),
iterations - 1,
epsilon,
tolerance,
new_delta,
lambda_reg,
)
}
12 changes: 12 additions & 0 deletions src/gleastsq/internal/nx.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ pub fn dot(a: NxTensor, b: NxTensor) -> NxTensor
@external(erlang, "Elixir.Nx", "multiply")
pub fn multiply(a: NxTensor, b: Float) -> NxTensor

@external(erlang, "Elixir.Nx", "multiply")
pub fn multiply_mat(a: NxTensor, b: NxTensor) -> NxTensor

@external(erlang, "Elixir.Nx", "add")
pub fn add(a: NxTensor, b: NxTensor) -> NxTensor

Expand All @@ -32,6 +35,9 @@ pub fn subtract(a: NxTensor, b: NxTensor) -> NxTensor
@external(erlang, "Elixir.Nx", "transpose")
pub fn transpose(a: NxTensor) -> NxTensor

@external(erlang, "Elixir.Nx", "negate")
pub fn negate(a: NxTensor) -> NxTensor

@external(erlang, "Elixir.Nx", "shape")
pub fn shape(a: NxTensor) -> #(Int)

Expand All @@ -44,11 +50,17 @@ pub fn to_number(a: NxTensor) -> Float
@external(erlang, "Elixir.Nx.LinAlg", "solve")
pub fn solve(a: NxTensor, b: NxTensor) -> NxTensor

@external(erlang, "Elixir.Nx.LinAlg", "norm")
pub fn norm(a: NxTensor) -> NxTensor

@external(erlang, "Elixir.Nx", "new_axis")
pub fn new_axis(a: NxTensor, axis: Int) -> NxTensor

@external(erlang, "Elixir.Nx", "divide")
pub fn divide(a: NxTensor, b: Float) -> NxTensor

@external(erlang, "Elixir.Nx", "divide")
pub fn divide_mat(a: NxTensor, b: NxTensor) -> NxTensor

@external(erlang, "Elixir.Nx", "concatenate")
pub fn concatenate(a: List(NxTensor), opts opts: List(NxOpts)) -> NxTensor
22 changes: 21 additions & 1 deletion test/gn_test.gleam
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import gleastsq
import gleeunit/should
import utils/curves.{double_gaussian, exponential, gaussian, parabola}
import utils/curves.{
double_gaussian, exponential, gaussian, parabola, triple_gaussian,
}
import utils/helpers.{are_fits_equivalent, fit_to_curve, generate_x_axis}

pub fn gn(
Expand Down Expand Up @@ -89,6 +91,24 @@ pub fn noisy_double_gaussian_fit_test() {
}
}

pub fn noisy_triple_gaussian_fit_test() {
// Gauss-Newton will not generally converge on this function
let x = generate_x_axis(-3, 7, 100)
let params = [1.2, 0.3, 0.5, 2.5, 2.0, 1.0, 1.0, -2.0, 0.1]
let result = fit_to_curve(x, triple_gaussian, params, gn, noisy: True)
case result {
Ok(result) -> {
// If it converges, it should be a bad fit
are_fits_equivalent(x, triple_gaussian, params, result)
|> should.be_false
}
Error(_) -> {
// We expect it to not converge
should.be_true(True)
}
}
}

pub fn should_error_when_x_y_different_sizes_test() {
gn([0.0], [], parabola, []) |> should.be_error
}
13 changes: 12 additions & 1 deletion test/lm_test.gleam
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import gleastsq
import gleeunit/should
import utils/curves.{double_gaussian, exponential, gaussian, parabola}
import utils/curves.{
double_gaussian, exponential, gaussian, parabola, triple_gaussian,
}
import utils/helpers.{are_fits_equivalent, fit_to_curve, generate_x_axis}

pub fn lm(
Expand Down Expand Up @@ -80,6 +82,15 @@ pub fn noisy_double_gaussian_fit_test() {
|> should.be_true
}

pub fn noisy_triple_gaussian_fit_test() {
let x = generate_x_axis(-3, 7, 100)
let params = [1.2, 0.3, 0.5, 2.5, 2.0, 1.0, 1.0, -2.0, 0.1]
let assert Ok(result) =
fit_to_curve(x, triple_gaussian, params, lm, noisy: True)
are_fits_equivalent(x, triple_gaussian, params, result)
|> should.be_true
}

pub fn should_error_when_x_y_different_sizes_test() {
lm([0.0], [], parabola, []) |> should.be_error
}
Loading

0 comments on commit dc96ff2

Please sign in to comment.