Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve MESA's linesearch #761

Open
pmocz opened this issue Jan 7, 2025 · 2 comments
Open

Improve MESA's linesearch #761

pmocz opened this issue Jan 7, 2025 · 2 comments
Assignees

Comments

@pmocz
Copy link
Member

pmocz commented Jan 7, 2025

MESA uses a Newton-Raphson scheme with line search to solve for the zero point of a linearized system and find the solution to the stellar structure. This involves inverting a matrix (with the bcyclic solver with pivoting and equilibration) and iteratively updating the solution until errors fall below prescribed tolerances.

See Section 6.3 of https://arxiv.org/abs/1009.1622

It's a known issue that MESA's solver can struggle in certain parts of parameter space where the system is stiff.

@Debraheem created a(/another) reproducer in #753 where @Debraheem and @pmocz are investigating precisely why the solver fails.

So far, we've found that when things break down the Jacobian has high condition number (>10^15), but the bcyclic solver is still able to solve the system with errors of O(10^-11).

There are indications the problem is in the line search algorithm:

star/private/star_solver.f90: subroutine adjust_correction

The function being minimized is oscillatory and the line search is not able to take the right sized steps to find the minimum. At present, the line research uses some heuristics and hard-coded parameters.

So an open research question is to improve the robustness of the linesearch.

@matthiasfabry
Copy link
Contributor

I've run into similar issues with high-mass stellar models with a strong stellar wind.
I've tried a very hacky solution by detecting when such oscillations are happening, and then cutting the correction in half to hopefully land in the valley of the solution. I call it "loop detection." This has some, but mostly a marginal effect only.
Below is my do_solver routine (modded from a r22.11.1 release version) that accomplishes this. I also designed a "compressed" solver message that is not so wide.

subroutine do_solver( &
            s, nvar, AF1, ldAF, neq, skip_global_corr_coeff_limit, &
            gold_tolerances_level, tol_max_correction, tol_correction_norm, &
            work, lwork, iwork, liwork, &
            convergence_failure, ierr)

         ...

         ! loop protection
         real(dp), dimension(:, :), allocatable:: results
         real(dp), parameter :: loop_tol = 1d-2
         integer, parameter :: loop_lim = 5
         integer :: loop_cnt, ii
         logical :: close, do_loop_action, loop_protection = .true.

        ...

         allocate(results(loop_lim + 1, 6))
         loop_cnt = 0
         do_loop_action = .false.

         
      iter_loop: do while (.not. passed_tol_tests)

            ...

            ! compute size of scaled correction B
            call sizeB(s, nvar, soln, &
                  max_correction, correction_norm, max_corr_k, max_corr_j, ierr)
            if (ierr /= 0) then
               call oops('correction rejected by sizeB')
               exit iter_loop
            end if

            correction_norm = abs(correction_norm)
            max_abs_correction = abs(max_correction)
            corr_norm_min = min(correction_norm, corr_norm_min)
            max_corr_min = min(max_abs_correction, max_corr_min)

            if (is_bad_num(correction_norm) .or. is_bad_num(max_abs_correction)) then
               ! bad news -- bogus correction
               call oops('bad result from sizeB -- correction info either NaN or Inf')
               if (s% stop_for_bad_nums) then
                  write(*,1) 'correction_norm', correction_norm
                  write(*,1) 'max_correction', max_correction
                  call mesa_error(__FILE__,__LINE__,'solver')
               end if
               exit iter_loop
            end if

            if (.not. s% ignore_too_large_correction) then
               if ((correction_norm > s% corr_param_factor*s% scale_correction_norm) .and. &
                     .not. s% doing_first_model_of_run) then
                  call oops('avg corr too large')
                  exit iter_loop
               endif
            end if

            ! shrink the correction if it is too large
            correction_factor = 1d0
            temp_correction_factor = 1d0

            if (correction_norm*correction_factor > s% scale_correction_norm) then
               correction_factor = min(correction_factor,s% scale_correction_norm/correction_norm)
            end if

            if (max_abs_correction*correction_factor > s% scale_max_correction) then
               temp_correction_factor = s% scale_max_correction/max_abs_correction
            end if

            if (iter > s% solver_itermin_until_reduce_min_corr_coeff) then
               if (min_corr_coeff == 1d0 .and. &
                  s% solver_reduced_min_corr_coeff < 1d0) then
                     min_corr_coeff = s% solver_reduced_min_corr_coeff
               end if
            end if

            correction_factor = max(min_corr_coeff, correction_factor)
            if (.not. s% ignore_min_corr_coeff_for_scale_max_correction) then
               temp_correction_factor = max(min_corr_coeff, temp_correction_factor)
            end if
            correction_factor = min(correction_factor, temp_correction_factor)

            if (do_loop_action) then
               correction_factor = correction_factor / 2d0
            end if

            ! fix B if out of definition domain
            call Bdomain(s, nvar, soln, correction_factor, ierr)
            if (ierr /= 0) then ! correction cannot be fixed
               call oops('correction rejected by Bdomain')
               exit iter_loop
            end if

...

            f = 0d0
            call adjust_correction( &
               min_corr_coeff, correction_factor, grad_f1, f, slope, coeff, err_msg, ierr)
            if (ierr /= 0) then
               call oops(err_msg)
               exit iter_loop
            end if
            s% solver_adjust_iter = 0

...

            residual_norm = abs(residual_norm)
            max_residual = abs(max_residual)
            s% residual_norm = residual_norm
            s% max_residual = max_residual
            resid_norm_min = min(residual_norm, resid_norm_min)
            max_resid_min = min(max_residual, max_resid_min)

            call loop_detection

...

         end do iter_loop

         if (max_residual > s% warning_limit_for_max_residual .and. .not. convergence_failure) &
            write(*,2) 'WARNING: max_residual > warning_limit_for_max_residual', &
               s% model_number, max_residual, s% warning_limit_for_max_residual

         deallocate(results)

      contains

         subroutine loop_detection

            loop_cnt = loop_cnt + 1
            call set_loop_info()
            if (loop_cnt > 2) then
               call search_for_loops()
            else
               do_loop_action = .false.
            end if

         end subroutine loop_detection


         subroutine search_for_loops
            logical :: loop_found
            integer :: n

            loop_found = .false.
            ! check if results numbers are "different"
            do n = 1, loop_cnt - 1
               if (residual_norm > tol_residual_norm .and. &
                     abs(results(loop_cnt, 1) - results(loop_cnt - n, 1)) / results(loop_cnt, 1) < loop_tol) then
                  loop_found = .true.
               end if
               if (max_residual > tol_max_residual .and. &
                     (abs(results(loop_cnt, 2) - results(loop_cnt - n, 2)) / results(loop_cnt, 2) < loop_tol .and. &
                        results(loop_cnt, 5) == results(loop_cnt - n, 5))) then
                  loop_found = .true.
               end if
               if (correction_norm > tol_correction_norm*coeff .and. &
                     abs(results(loop_cnt, 3) - results(loop_cnt - n, 3)) / results(loop_cnt, 3) < loop_tol) then
                  loop_found = .true.
               end if
               if (max_abs_correction > tol_max_correction*coeff .and. &
                     (abs((results(loop_cnt, 4) - results(loop_cnt - n, 4)) / results(loop_cnt, 4)) < loop_tol .and. &
                        results(loop_cnt, 6) == results(loop_cnt - n, 6))) then
                  loop_found = .true.
               end if
            end do


            if (loop_found) then  ! loop found, do action and reset loop_cnt to 0
               do_loop_action = .true.
               loop_cnt = 0
            else  ! loop still alive
               do_loop_action = .false.
               if (loop_cnt == loop_lim) then  ! no loop found within loop_lim iterations, move results over and continue
                  do n = 2, loop_cnt
                     results(n-1, :) = results(n, :)
                  end do
                  loop_cnt = loop_lim - 1
               end if
            end if

         end subroutine search_for_loops


         subroutine set_loop_info

            results(loop_cnt, 1) = residual_norm
            results(loop_cnt, 2) = max_residual
            results(loop_cnt, 3) = correction_norm
            results(loop_cnt, 4) = max_correction
            results(loop_cnt, 5) = max_resid_k
            results(loop_cnt, 6) = max_corr_k

         end subroutine set_loop_info


...

         subroutine write_msg(msg)
            use const_def, only: secyer
            character(*)  :: msg
            logical :: short_msg = .true.
            
            integer :: k
            character (len=64) :: max_resid_str, max_corr_str
            character (len=5) :: max_resid_mix_type_str, max_corr_mix_type_str
            character (len=10) :: integer_string
            include 'formats'
            
            if (.not. dbg_msg) return
            
            if (max_resid_j < 0) then
               call sizequ(s, nvar, residual_norm, max_residual, max_resid_k, max_resid_j, ierr)
            end if
            
            if (max_resid_j > 0) then
               if (short_msg) then
                  write(max_resid_str,*) 'max_r ' // trim(s% nameofequ(max_resid_j))
               else
                  write(max_resid_str,*) 'max resid ' // trim(s% nameofequ(max_resid_j))
               end if

            else
               max_resid_str = ''
            end if
            
            if (max_corr_j < 0) then
               call sizeB(s, nvar, B, &
                  max_correction, correction_norm, max_corr_k, max_corr_j, ierr)
            end if
            
            if (max_corr_j > 0) then
               if (short_msg) then
                  write(max_corr_str,*) 'max_c ' // trim(s% nameofvar(max_corr_j))
               else
                  write(max_corr_str,*) 'max corr ' // trim(s% nameofvar(max_corr_j))
               end if

            else
               max_corr_str = ''
            end if
            
            integer_string = '0123456789'
            k = max_corr_k
            call store_mix_type_str(max_corr_mix_type_str, integer_string, 1, k-2)
            call store_mix_type_str(max_corr_mix_type_str, integer_string, 2, k-1)
            call store_mix_type_str(max_corr_mix_type_str, integer_string, 3, k)
            call store_mix_type_str(max_corr_mix_type_str, integer_string, 4, k+1)
            call store_mix_type_str(max_corr_mix_type_str, integer_string, 5, k+2)
            
            k = max_resid_k
            call store_mix_type_str(max_resid_mix_type_str, integer_string, 1, k-2)
            call store_mix_type_str(max_resid_mix_type_str, integer_string, 2, k-1)
            call store_mix_type_str(max_resid_mix_type_str, integer_string, 3, k)
            call store_mix_type_str(max_resid_mix_type_str, integer_string, 4, k+1)
            call store_mix_type_str(max_resid_mix_type_str, integer_string, 5, k+2)

  111       format(i6, i3, 2x, a, f7.4, &
               1x, a, 1x, e10.3, 2x, a18, 1x, i5, e13.5, a, &
               1x, a, 1x, e10.3, 2x, a16, 1x, i5, e13.5, a, &
               1x, a, 1x, i1, 1x,&
               a)
  112       format(i5, i3, 1x, a, f7.4, &
               1x, a, 1x, e10.3, 1x, a12, 1x, i5, e13.5, &
               1x, a, 1x, e10.3, 1x, a10, 1x, i5, e13.5, &
               1x, a, 1x, i1, 1x,&
               a)

            if (short_msg) then
               write(*,112) &
                  s% model_number, iter, &
                  'coeff', coeff,  &
                  'av_r', residual_norm,  &
   !               '   avg resid', residual_norm,  &
                  trim(max_resid_str), max_resid_k, max_residual, &
                  'av_c', correction_norm,  &
   !               'mix type ' // trim(max_resid_mix_type_str),  &
   !               '   avg corr', correction_norm,  &
                  trim(max_corr_str), max_corr_k, max_correction,  &
                  'lp_ct', loop_cnt, &
                  ' ' // trim(msg)
   !               'mix type ' // trim(max_corr_mix_type_str),  &
   !               '   ' // trim(msg)
            else
               write(*,111) &
                  s% model_number, iter, &
                  'coeff', coeff,  &
                  'avg resid', residual_norm,  &
   !               '   avg resid', residual_norm,  &
                  trim(max_resid_str), max_resid_k, max_residual, &
                  ' mix type ' // trim(max_resid_mix_type_str),  &
                  'avg corr', correction_norm,  &
   !               'mix type ' // trim(max_resid_mix_type_str),  &
   !               '   avg corr', correction_norm,  &
                  trim(max_corr_str), max_corr_k, max_correction,  &
                  ' mix type ' // trim(max_corr_mix_type_str),  &
                  'loop_cnt', loop_cnt, &
                  ' ' // trim(msg)
   !               'mix type ' // trim(max_corr_mix_type_str),  &
   !               '   ' // trim(msg)
            end if
               
            if (is_bad(slope)) call mesa_error(__FILE__,__LINE__,'write_msg')

         end subroutine write_msg
...
      end subroutine do_solver

@pmocz
Copy link
Member Author

pmocz commented Jan 21, 2025

Thanks for sharing this insight, @matthiasfabry as we figure out how to make this more robust -- I'll try out your ideas! We'll want to test new approaches on your strong stellar wind problems as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants