Skip to content

Commit

Permalink
Revisiting linesearches and LBFGS.
Browse files Browse the repository at this point in the history
For backtracking linesearch:
- Add debugging option for backtracking_linesearch.
- Add info entry in BacktrackingLinesearchState to potentially help debugging by looking at outputs (could be useful for example in vmap setting, and mimics the setup for the zoom linesearch).
- Adding mechanism to prevent the linesearch to make a step if that would end up getting NaNs or infinite values in the function.

For zoom_linesearch:
- Simplifies a bit the debugging information for the zoom linesearch and added prints of some relevant values for debugging.
- Added a note in the zoom linesearch that using curv_tol=inf, would let this method make an efficient alternative to the backtracking linesearch using polynomial interpolation strategies.
- Most importantly, added an option to define the initial guess for the linesearch. Looking up Nocedal and Wright, this initial guess should always be one for Newton or quasi-Newton methods. Could be refined for other methods (for now, for such other methods like gradient descent, we may simply keep the previous learning rate). This largely improved the performance in the public notebook.

For lbfgs:
- Use clipped gradient step for the very first step (when scale_init_precond=True). The scale of the preconditioner for the very first iteration is not detailed anywhere in the literature I've seen. But using such clipped gradient step ensures to capture approximately the right scale. This made for example one of the tests pass without any further modifications of the default hyperparameters of the objective.
- Revised the notebook in view fo these changes. Added some tips and an example of benchmark.

PiperOrigin-RevId: 694183028
  • Loading branch information
vroulet authored and OptaxDev committed Nov 12, 2024
1 parent db6db9f commit c0693b4
Show file tree
Hide file tree
Showing 5 changed files with 369 additions and 121 deletions.
171 changes: 150 additions & 21 deletions examples/lbfgs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"L-BFGS is a classical optimization method that uses past gradients and parameters information to iteratively refine a solution to a minimization problem. In this notebook, we illustrate\n",
"1. how to use L-BFGS as a simple gradient transformation,\n",
"2. how to wrap L-BFGS in a solver, and how linesearches are incorporated,\n",
"3. how to debug the solver if needed,\n"
"3. how to debug the solver if needed.\n"
]
},
{
Expand Down Expand Up @@ -146,7 +146,7 @@
"\n",
"where $c_1$ is some constant set to $10^{-4}$ by default. Consider for example the update direction to be $u_k = -g_k$, i.e., moving along the negative gradient direction. In that case the criterion above reduces to $f(w_k - \\eta_k g_k) \\leq f(w_k) - c_1 \\eta_k ||g_k||_2^2$. The criterion amounts then to choosing the stepsize such that it decreases the objective by an amount proportional to the squared gradient norm.\n",
"\n",
"As long as the update direction is a *descent direction*, that is, $\\langle u_k, g_k\\rangle < 0$ the above criterion is guaranteed to be satisfied by some sufficiently small stepsize.\n",
"As long as the update direction is a *descent direction*, that is, $\\langle u_k, g_k\\rangle \u003c 0$ the above criterion is guaranteed to be satisfied by some sufficiently small stepsize.\n",
"A simple linesearch technique to ensure a sufficient decrease is then to decrease a candidate stepsize by a constant factor up until the criterion is satisfied. This amounts to the backtracking linesearch implemented in {py:func}`optax.scale_by_backtracking_linesearch` and briefly reviewed below.\n",
"\n",
"#### Small curvature (Strong wolfe criterion)\n",
Expand Down Expand Up @@ -286,7 +286,7 @@
},
"outputs": [],
"source": [
"def run_lbfgs(init_params, fun, opt, max_iter, tol):\n",
"def run_opt(init_params, fun, opt, max_iter, tol):\n",
" value_and_grad_fun = optax.value_and_grad_from_state(fun)\n",
"\n",
" def step(carry):\n",
Expand All @@ -303,7 +303,7 @@
" iter_num = otu.tree_get(state, 'count')\n",
" grad = otu.tree_get(state, 'grad')\n",
" err = otu.tree_l2_norm(grad)\n",
" return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))\n",
" return (iter_num == 0) | ((iter_num \u003c max_iter) \u0026 (err \u003e= tol))\n",
"\n",
" init_carry = (init_params, opt.init(init_params))\n",
" final_params, final_state = jax.lax.while_loop(\n",
Expand Down Expand Up @@ -338,7 +338,7 @@
" f'Initial value: {fun(init_params):.2e} '\n",
" f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'\n",
")\n",
"final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)\n",
"final_params, _ = run_opt(init_params, fun, opt, max_iter=100, tol=1e-3)\n",
"print(\n",
" f'Final value: {fun(final_params):.2e}, '\n",
" f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'\n",
Expand Down Expand Up @@ -395,7 +395,7 @@
" f'Initial value: {fun(init_params):.2e} '\n",
" f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'\n",
")\n",
"final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)\n",
"final_params, _ = run_opt(init_params, fun, opt, max_iter=100, tol=1e-3)\n",
"print(\n",
" f'Final value: {fun(final_params):.2e}, '\n",
" f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'\n",
Expand All @@ -408,9 +408,8 @@
"id": "KZIu7UDveO6D"
},
"source": [
"## Debugging solver\n",
"\n",
"In some cases, L-BFGS with a linesearch as a solver will fail. Most of the times, the culprit goes down to the linesearch. To debug the solver in such cases, we provide a `verbose` option to the `optax.scale_by_zoom_linesearch`. We show below how to proceed."
"## Debugging\n",
"\n"
]
},
{
Expand All @@ -419,7 +418,11 @@
"id": "LV8CslWpoDDq"
},
"source": [
"First we try to minimize the [Zakharov function](https://www.sfu.ca/~ssurjano/zakharov.html) without any changes. You'll observe that the final value is larger than the initial value which points out that the solver failed, and probably because the linesearch did not find a stepsize that ensured a sufficient decrease."
"### Accessing debug information\n",
"\n",
"In some cases, L-BFGS with a linesearch as a solver will fail. Most of the times, the culprit goes down to the linesearch. To debug the solver in such cases, we provide a `verbose` option to the `optax.scale_by_zoom_linesearch`. We show below how to proceed.\n",
"\n",
"To demonstrate such bug, we try to minimize the [Zakharov function](https://www.sfu.ca/~ssurjano/zakharov.html) and set the `scale_init_precond` option to `False` (by choosing the default option `scale_init_precond=True`, the algorithm would actually run fine, we just want to showcase the possibility to use debugging in the linesearch here). You'll observe that the final value is is the same as the initial value which points out that the solver failed."
]
},
{
Expand All @@ -436,14 +439,14 @@
" sum2 = (0.5 * ii * w).sum()\n",
" return sum1 + sum2**2 + sum2**4\n",
"\n",
"opt = optax.chain(print_info(), optax.lbfgs())\n",
"opt = optax.lbfgs(scale_init_precond=False)\n",
"\n",
"init_params = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])\n",
"print(\n",
" f'Initial value: {fun(init_params)} '\n",
" f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params))}'\n",
")\n",
"final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)\n",
"final_params, _ = run_opt(init_params, fun, opt, max_iter=50, tol=1e-3)\n",
"print(\n",
" f'Final value: {fun(final_params)}, '\n",
" f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params))}'\n",
Expand All @@ -456,7 +459,11 @@
"id": "uwcbY5UXohZB"
},
"source": [
"We can change the linesearch used in lbfgs as part of its arguments. Here we keep the default number of linesearch steps (15) and set the verbose option to `True`."
"The default implementation of the linesearch in the code is\n",
"```\n",
"scale_by_zoom_linesearch(max_linesearch_steps=20, initial_guess_strategy='one')\n",
"```\n",
"To debug we can set the verbose option of the linesearch to `True`."
]
},
{
Expand All @@ -467,9 +474,9 @@
},
"outputs": [],
"source": [
"opt = optax.chain(print_info(), optax.lbfgs(\n",
"opt = optax.chain(print_info(), optax.lbfgs(scale_init_precond=False,\n",
" linesearch=optax.scale_by_zoom_linesearch(\n",
" max_linesearch_steps=15, verbose=True\n",
" max_linesearch_steps=20, verbose=True, initial_guess_strategy='one'\n",
" )\n",
"))\n",
"\n",
Expand All @@ -478,7 +485,7 @@
" f'Initial value: {fun(init_params):.2e} '\n",
" f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'\n",
")\n",
"final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)\n",
"final_params, _ = run_opt(init_params, fun, opt, max_iter=100, tol=1e-3)\n",
"print(\n",
" f'Final value: {fun(final_params):.2e}, '\n",
" f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'\n",
Expand All @@ -491,7 +498,7 @@
"id": "nCgpjzCbo7p9"
},
"source": [
"As expected, the linesearch failed at the very first step taking a stepsize that did not ensure a sufficient decrease. Multiple information is displayed. For example, the slope (derivative along the update direction) at the first step si extremely large which explains the difficulties to find an appropriate stepsize. As pointed out in the log above, the first thing to try is to use a larger number of linesearch steps."
"As expected, the linesearch failed at the very first step taking a stepsize that did not ensure a sufficient decrease. Multiple information is displayed. For example, the slope (derivative along the update direction) at the first step is extremely large which explains the difficulties to find an appropriate stepsize. As pointed out in the log above, the first thing to try is to use a larger number of linesearch steps."
]
},
{
Expand All @@ -502,9 +509,9 @@
},
"outputs": [],
"source": [
"opt = optax.chain(print_info(), optax.lbfgs(\n",
"opt = optax.chain(print_info(), optax.lbfgs(scale_init_precond=False,\n",
" linesearch=optax.scale_by_zoom_linesearch(\n",
" max_linesearch_steps=50, verbose=True\n",
" max_linesearch_steps=50, verbose=True, initial_guess_strategy='one'\n",
" )\n",
"))\n",
"\n",
Expand All @@ -513,7 +520,7 @@
" f'Initial value: {fun(init_params):.2e} '\n",
" f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'\n",
")\n",
"final_params, _ = run_lbfgs(init_params, fun, opt, max_iter=50, tol=1e-3)\n",
"final_params, _ = run_opt(init_params, fun, opt, max_iter=100, tol=1e-3)\n",
"print(\n",
" f'Final value: {fun(final_params):.2e}, '\n",
" f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'\n",
Expand All @@ -526,9 +533,131 @@
"id": "na-7s1Q2o1Rc"
},
"source": [
"By simply taking a maximum of 50 steps of the linesearch instead of 15, we ensured that the first stepsize taken provided a sufficient decrease and the solver worked well.\n",
"By simply taking a maximum of 50 steps of the linesearch instead of 20, we ensured that the first stepsize taken provided a sufficient decrease and the solver worked well.\n",
"Additional debugging information can be found in the source code accessible from the docs of {py:func}`optax.scale_by_zoom_linesearch`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "74ZbgzcKoJ0J"
},
"source": [
"### Tips\n",
"\n",
"- **LBFGS**\n",
" - Selecting a higher `memory_size` in lbfgs may improve performance at a memory and computational cost. No real gains may be perceived after some value.\n",
" - `scale_init_precond=True` is standard. It captures a similar scale as other well-known optimization methods like Barzilai Borwein.\n",
"\n",
"- **Zoom linesearch**\n",
" - Remember there are two conditions to be met (sufficient decrease and small curvature). If the algorithm takes too many linesearch steps, you may try\n",
" setting `curv_rtol = jnp.inf`, effectively ignoring the small curvature condition. The resulting algorithm will essentially perform a backtracking linesearch where a valid stepsize is searched by minmizing a quadratic or cubic approximation of the objective (so that would be a potentially faster algorithm than the current implementation of `scale_by_backtracking_linesearch`).\n",
" - As pointed above, if the solver gets stuck, try using a larger number of linesearch steps and print debugging information.\n",
"\n",
"You may run the solver in double precision by setting `jax.config.update(\"jax_enable_x64\", True)`. If you use double precision, consider augmenting the number of linesearch steps to reach the machine precision (like using `max_linesearch_steps=55`).\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T-oGa3P2sCbH"
},
"source": [
"## Contributing and benchmarking\n",
"\n",
"Numerous other linesearch could be implemented, as well as other solvers for medium scale problems without stochasticity. Contributions are welcome.\n",
"\n",
"If you want to contribute a new solver for medium scale problems like LBFGS, benchmarks would be highly appreciated. We provide below an example of benchmark (which could also be used if you want to test some hyperparameters of the algorithm). We take here the classical Rosenbroke function, but it could be better to expand such benchmarks to e.g. the set of test functions given by [Andrei, 2008](https://camo.ici.ro/journal/vol10/v10a10.pdf)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MagDCuGjsB5x"
},
"outputs": [],
"source": [
"import time\n",
"num_fun_calls = 0\n",
"\n",
"def register_call():\n",
" global num_fun_calls\n",
" num_fun_calls += 1\n",
"\n",
"def test_hparams(lbfgs_hparams, linesearch_hparams, dimension=512):\n",
" global num_fun_calls\n",
" num_fun_calls = 0\n",
"\n",
" def fun(x):\n",
" jax.debug.callback(register_call)\n",
" return jnp.sum((x[1:] - x[:-1] ** 2) ** 2 + (1.0 - x[:-1]) ** 2)\n",
"\n",
" opt = optax.chain(optax.lbfgs(**lbfgs_hparams,\n",
" linesearch=optax.scale_by_zoom_linesearch(**linesearch_hparams)\n",
" )\n",
" )\n",
"\n",
" init_params = jnp.arange(dimension, dtype=jnp.float32)\n",
"\n",
" tic = time.time()\n",
" final_params, _ = run_opt(\n",
" init_params, fun, opt, max_iter=500, tol=5*1e-5\n",
" )\n",
" final_params = jax.block_until_ready(final_params)\n",
" time_run = time.time() - tic\n",
"\n",
" final_value = fun(final_params)\n",
" final_grad_norm = otu.tree_l2_norm(jax.grad(fun)(final_params))\n",
" return final_value, final_grad_norm, num_fun_calls, time_run\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7CXMxWsztGf5"
},
"outputs": [],
"source": [
"import copy\n",
"import matplotlib.pyplot as plt\n",
"\n",
"default_lbfgs_hparams = {'memory_size': 15, 'scale_init_precond': True}\n",
"default_linesearch_hparams = {\n",
" 'max_linesearch_steps': 15,\n",
" 'initial_guess_strategy': 'one'\n",
"}\n",
"\n",
"memory_sizes = [int(2**i) for i in range(7)]\n",
"times = []\n",
"calls = []\n",
"values = []\n",
"grad_norms = []\n",
"for m in memory_sizes:\n",
" lbfgs_hparams = copy.deepcopy(default_lbfgs_hparams)\n",
" lbfgs_hparams['memory_size'] = m\n",
" v, g, n, t = test_hparams(lbfgs_hparams, default_linesearch_hparams, dimension=1024)\n",
" values.append(v)\n",
" grad_norms.append(g)\n",
" calls.append(n)\n",
" times.append(t)\n",
"\n",
"fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n",
"axs[0].plot(memory_sizes, values)\n",
"axs[0].set_ylabel('Final values')\n",
"axs[0].set_yscale('log')\n",
"axs[1].plot(memory_sizes, grad_norms)\n",
"axs[1].set_ylabel('Final gradient norms')\n",
"axs[1].set_yscale('log')\n",
"axs[2].plot(memory_sizes, calls)\n",
"axs[2].set_ylabel('Number of function calls')\n",
"axs[3].plot(memory_sizes, times)\n",
"axs[3].set_ylabel('Run times')\n",
"for i in range(4):\n",
" axs[i].set_xlabel('Memory size')\n",
"plt.tight_layout()"
]
}
],
"metadata": {
Expand Down
20 changes: 15 additions & 5 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,7 +2393,9 @@ def lbfgs(
scale_init_precond: bool = True,
linesearch: Optional[
base.GradientTransformationExtraArgs
] = _linesearch.scale_by_zoom_linesearch(max_linesearch_steps=15),
] = _linesearch.scale_by_zoom_linesearch(
max_linesearch_steps=20, initial_guess_strategy='one'
),
) -> base.GradientTransformationExtraArgs:
r"""L-BFGS optimizer.
Expand Down Expand Up @@ -2453,7 +2455,7 @@ def lbfgs(
memory_size: number of past updates to keep in memory to approximate the
Hessian inverse.
scale_init_precond: whether to use a scaled identity as the initial
preconditioner, see formula above.
preconditioner, see formula of :math:`\gamma_k` above.
linesearch: an instance of :class:`optax.GradientTransformationExtraArgs`
such as :func:`optax.scale_by_zoom_linesearch` that computes a
learning rate, a.k.a. stepsize, to satisfy some criterion such as a
Expand All @@ -2480,9 +2482,9 @@ def lbfgs(
... )
... params = optax.apply_updates(params, updates)
... print('Objective function: ', f(params))
Objective function: 0.0
Objective function: 0.0
Objective function: 0.0
Objective function: 7.5166864
Objective function: 7.460699e-14
Objective function: 2.6505726e-28
Objective function: 0.0
Objective function: 0.0
Expand All @@ -2504,6 +2506,14 @@ def lbfgs(
zoom linesearch). See example above for best use in a non-stochastic
setting, where we can recycle gradients computed by the linesearch using
:func:`optax.value_and_grad_from_state`.
.. note::
We initialize the scaling of the identity as a capped reciprocal of the
gradient norm. This avoids wasting linesearch iterations for the first step
by taking into account the magnitude of the gradients. In other words, we
constrain the trust-region of the first step to an Euclidean ball of radius
1 at the first iteration. The choice of :math:`\gamma_0` is not detailed in
the references above, so this is a heuristic choice.
"""
if learning_rate is None:
base_scaling = transform.scale(-1.0)
Expand Down
Loading

0 comments on commit c0693b4

Please sign in to comment.