Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gauss-Newton and Levenberg-Marquardt #920

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

gbruno16
Copy link
Contributor

@gbruno16 gbruno16 commented Apr 9, 2024

In this fork I'm trying to implement the Gauss-Newton and the Levenberg-Marquardt methods for the Optax library.
The primary objective is to provide a flexible Gauss-Newton transformation that offers options for selecting the damping parameter, the solver, and whether to consider the normal equations. Additionally, this transformation enables solving least squares problems by just providing the jvp of the residuals function and can handle compositional problems by specifying the hvp of the outer function.

A simple usage example for the Gauss-Newton optimizer:

import jax
import optax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

def f(x):
  return jnp.sqrt(2) * jnp.array([10 * (x[1] - x[0]**2), (1 - x[0])])

params = jnp.array([-1.2, 1])
print('Initial objective function: ', 0.5*jnp.sum(f(params)**2))

solver = optax.gauss_newton()
opt_state = solver.init(params)

for _ in range(5):
    residuals, inner_jvp = jax.linearize(f, params)
    updates, opt_state = solver.update(residuals, opt_state, params, inner_jvp=inner_jvp)
    params = optax.apply_updates(params, updates)
    print('Objective function: {:.2E}'.format(0.5*jnp.sum(f(params)**2)))

The Gauss-Newton transformation could serve as building block for constructing more sophisticated optimization solvers. As an illustration, I have incorporated the trust region algorithm implemented in Jaxopt (algorithm 6.18 in “Introduction to Optimization and Data Fitting”, K. Madsen & H. B. Nielsen) into the scale_by_madsen_trust_region transformation. As a consquence we can seamlessly obtain the Levenberg-Marquardt method by composing it with the Gauss-Newton transformation described earlier.

The previous example becomes:

import jax
import optax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

def f(x):
  return jnp.sqrt(2) * jnp.array([10 * (x[1] - x[0]**2), (1 - x[0])])

params = jnp.array([-1.2, 1])
print('Initial objective function: ', 0.5*jnp.sum(f(params)**2))

solver = optax.levenberg_marquardt(init_damping_parameter=1.0)
opt_state = solver.init(params)

for _ in range(15):
    updates, opt_state = solver.update(opt_state, params, residuals_fn=f)
    params = optax.apply_updates(params, updates)
    print('Objective function: {:.2E}'.format(0.5*jnp.sum(f(params)**2)))

This is still a draft and will require more time, but feedbacks and suggestions for improvement are greatly appreciated. Please feel free to share your thoughts on the implementation and suggest any enhancements or modifications.

@vroulet vroulet self-assigned this Apr 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants