From b2dd416d23b0e881aefc5f14844a00c11fe42023 Mon Sep 17 00:00:00 2001 From: patwie Date: Fri, 30 May 2025 17:52:58 +0000 Subject: [PATCH] Fix off-by-one error in L-BFGS Solver Implementation Accessing the memory did skip one entry. --- include/cppoptlib/solver/lbfgs.h | 71 ++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/include/cppoptlib/solver/lbfgs.h b/include/cppoptlib/solver/lbfgs.h index fd8fce4..d2bd39a 100644 --- a/include/cppoptlib/solver/lbfgs.h +++ b/include/cppoptlib/solver/lbfgs.h @@ -33,7 +33,7 @@ #include "../linesearch/more_thuente.h" #include "Eigen/Core" -#include "solver.h" // NOLINT +#include "solver.h" // NOLINT namespace cppoptlib::solver { @@ -49,7 +49,7 @@ class Lbfgs cppoptlib::function::DifferentiabilityMode::Second, "L-BFGS only supports first- or second-order differentiable functions"); - private: +private: using StateType = typename cppoptlib::function::FunctionState< typename FunctionType::ScalarType, FunctionType::Dimension>; using Superclass = Solver; @@ -63,7 +63,7 @@ class Lbfgs using memory_MatrixType = Eigen::Matrix; using memory_VectorType = Eigen::Matrix; - public: +public: EIGEN_MAKE_ALIGNED_OPERATOR_NEW using Superclass::Superclass; @@ -107,43 +107,52 @@ class Lbfgs // Start with the preconditioned gradient as the initial search direction. VectorType search_direction = grad_precond; - // Determine the number of corrections available for the two-loop recursion. - // We exclude the most recent correction (which was just computed) from use. - int k = (mem_count_ > 0 ? static_cast(mem_count_) - 1 : 0); + // Determine the actual number of stored corrections to use + const int k = static_cast(mem_count_); - // --- First Loop (Backward Pass) --- - // Iterate over stored corrections in reverse chronological order. + // First loop: computes q = q - alpha_i * y_i + // Iterates from the newest correction (k-1) to the oldest (k-m_actual) + // conceptual_idx refers to the chronological order: 0=oldest, + // num_valid_corrections-1=newest for (int i = k - 1; i >= 0; i--) { // Compute the index in chronological order. // When mem_count_ < m, corrections are stored in order [0 ... // mem_count_-1]. When full, they are stored cyclically starting at // mem_pos_ (oldest) up to (mem_pos_ + m - 1) mod m. - int idx = (mem_count_ < m ? i : ((mem_pos_ + i) % m)); - const ScalarType denom = - x_diff_memory_.col(idx).dot(grad_diff_memory_.col(idx)); - if (std::abs(denom) < eps) { + const int idx = (mem_count_ < m) ? i : (mem_pos_ + i) % m; + + const VectorType &s_col = x_diff_memory_.col(idx); + const VectorType &y_col = grad_diff_memory_.col(idx); + + const ScalarType s_dot_y = s_col.dot(y_col); + if (std::abs(s_dot_y) < eps) { // Avoid division by zero or near-zero continue; } - const ScalarType rho = 1.0 / denom; - alpha(i) = rho * x_diff_memory_.col(idx).dot(search_direction); - search_direction -= alpha(i) * grad_diff_memory_.col(idx); + const ScalarType rho_val = static_cast(1.0) / s_dot_y; + alpha(i) = rho_val * s_col.dot(search_direction); + search_direction -= alpha(i) * y_col; } - // Apply the initial Hessian approximation. + // Apply the initial Hessian approximation H_k^0 = gamma_k * I + // gamma_k = s_{k-1}^T y_{k-1} / (y_{k-1}^T y_{k-1}) + // Here, scaling_factor_ is this gamma_k from the *previous* iteration. search_direction *= scaling_factor_; - // --- Second Loop (Forward Pass) --- + // Second loop: computes r = r + s_i * (alpha_i - beta_i) + // Iterates from the oldest correction (k-m_actual) to the newest (k-1) for (int i = 0; i < k; i++) { - int idx = (mem_count_ < m ? i : ((mem_pos_ + i) % m)); - const ScalarType denom = - x_diff_memory_.col(idx).dot(grad_diff_memory_.col(idx)); - if (std::abs(denom) < eps) { + const int idx = (mem_count_ < m) ? i : (mem_pos_ + i) % m; + + const VectorType &s_col = x_diff_memory_.col(idx); + const VectorType &y_col = grad_diff_memory_.col(idx); + + const ScalarType s_dot_y = s_col.dot(y_col); + if (std::abs(s_dot_y) < eps) { continue; } - const ScalarType rho = 1.0 / denom; - const ScalarType beta = - rho * grad_diff_memory_.col(idx).dot(search_direction); - search_direction += x_diff_memory_.col(idx) * (alpha(i) - beta); + const ScalarType rho_val = static_cast(1.0) / s_dot_y; + const ScalarType beta = rho_val * y_col.dot(search_direction); + search_direction += s_col * (alpha(i) - beta); } // Check descent direction validity. @@ -210,21 +219,21 @@ class Lbfgs return next; } - private: +private: memory_MatrixType x_diff_memory_; memory_MatrixType grad_diff_memory_; // Circular buffer state: - size_t mem_count_ = 0; // Number of corrections stored so far (max m). - size_t mem_pos_ = 0; // Index of the oldest correction in the buffer. + size_t mem_count_ = 0; // Number of corrections stored so far (max m). + size_t mem_pos_ = 0; // Index of the oldest correction in the buffer. memory_VectorType - alpha; // Storage for the coefficients in the two-loop recursion. + alpha; // Storage for the coefficients in the two-loop recursion. ScalarType scaling_factor_ = 1; // Cautious factor to determine whether to accept a new correction pair. // You may want to expose this parameter or adjust its default value. ScalarType cautious_factor_ = 1e-6; }; -} // namespace cppoptlib::solver +} // namespace cppoptlib::solver -#endif // INCLUDE_CPPOPTLIB_SOLVER_LBFGS_H_ +#endif // INCLUDE_CPPOPTLIB_SOLVER_LBFGS_H_