Skip to content

Commit

Permalink
faster-root-solver
Browse files Browse the repository at this point in the history
  • Loading branch information
rahulgaur104 committed Aug 22, 2022
1 parent ac91628 commit c5b09ba
Showing 1 changed file with 9 additions and 149 deletions.
158 changes: 9 additions & 149 deletions src/simsopt/mhd/vmec_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import numpy as np
from scipy.interpolate import interp1d, InterpolatedUnivariateSpline
from scipy.optimize import root_scalar
from scipy.optimize import newton

from .vmec import Vmec
from .._core.util import Struct
Expand Down Expand Up @@ -980,160 +980,21 @@ def vmec_fieldlines(vs, s, alpha, theta1d=None, phi1d=None, phi_center=0, plot=F
theta_pest[js, :, :] = theta1d[None, :]
phi[js, :, :] = phi_center + (theta1d[None, :] - alpha[:, None]) / iota[js]


def residual(theta_v, phi0, theta_p_target, jradius):
"""
This function is used for computing the value of theta_vmec that
gives a desired theta_pest.
"""
return theta_p_target - (theta_v + np.sum(lmns[js, :, None] * np.sin(xm[:, None] * theta_v - xn[:, None] * phi0), axis=0))

"""
theta_p = theta_v
for jmn in range(len(xn)):
angle = xm[jmn] * theta_v - xn[jmn] * phi0
theta_p += lmns[jradius, jmn] * np.sin(angle)
return theta_p_target - theta_p
"""
return theta_p_target - (theta_v + np.sum(lmns[jradius, :] * np.sin(xm * theta_v - xn * phi0)))

# A helper function that returns 0 when theta_vmec_trys are close to the
# proper toroidal angle in magnetic coordinates
# vmec_fieldlines new residual routine (written by Alan Goodman)
def fzero_residuals_function(theta_vmec_trys, theta_pest_targets, phis, xm, xn, lmns, lmnc, lasym):
angle = xm * theta_vmec_trys - xn * phis
fzero_residual = theta_pest_targets - (theta_vmec_trys + np.reshape(np.sum(lmns * np.sin(angle), axis=1), (-1, 1)))
return fzero_residual

# A helper function that numerically finds values that minimize fzero_residuals_function
# This is the non-linear root finder that turns out to be ~8X faster than scipy root_scalar
# written by Alan Goodman
def get_roots(a0, b0, theta_pest_targets, phis, xm, xn, lmns, lmnc, lasym):
converged = False
nconv = 0
a = 1 * a0
b = 1 * b0
roots = np.zeros(a.shape)
toteval = len(a.flatten())
itmax_bracket = 10
itmax_root = 100
tol = 1e-10

fa = fzero_residuals_function(a, theta_pest_targets, phis, xm, xn, lmns, lmnc, lasym)
fb = fzero_residuals_function(b, theta_pest_targets, phis, xm, xn, lmns, lmnc, lasym)
for it in range(itmax_bracket):
eps = np.finfo(a[0, 0]).eps
inds = (((fa > 0.0) * (fb > 0.0)) + ((fa < 0.0) * (fb < 0.0)))
if np.sum(inds) != 0:
a[inds] = a[inds] - 0.3 * np.ones(a[inds].shape)
b[inds] = b[inds] + 0.3 * np.ones(a[inds].shape)
fa = fzero_residuals_function(a, theta_pest_targets, phis, xm, xn, lmns, lmnc, lasym)
fb = fzero_residuals_function(b, theta_pest_targets, phis, xm, xn, lmns, lmnc, lasym)
else:
break

c = 1 * b
fc = 1 * fb
d = 0 * fb
e = 0 * fb
convds = False * np.ones(b.shape)

for it in range(itmax_root):
inds = (((fb > 0.0) * (fc > 0.0)) + ((fb < 0.0) * (fc < 0.0)))
c[inds] = a[inds]
fc[inds] = fa[inds]
d[inds] = b[inds] - a[inds]
e[inds] = d[inds]

inds = (np.abs(fc) < np.abs(fb))
a[inds] = b[inds]
b[inds] = c[inds]
c[inds] = a[inds]
fa[inds] = fb[inds]
fb[inds] = fc[inds]
fc[inds] = fa[inds]

tol1 = 2.0 * eps * np.abs(b) + 0.5 * tol
Xm = 0.5 * (c - b)
indsC = ((np.abs(Xm) <= tol1) + (fb == 0.0))

roots[indsC] = b[indsC]
convds[indsC] = True
nconv = np.sum(indsC)
if nconv == toteval:
converged = True
return roots, converged, nconv

p = 0 * a
s = 0 * p
q = 0 * p
r = 0 * p

inds1 = (((np.abs(e) >= tol1) * (np.abs(fa) > np.abs(fb))))
s[inds1] = (fb[inds1])/(fa[inds1])

inds2 = (a == c)
inds = inds2*inds1
s[inds] = (fb[inds])/(fa[inds])
p[inds] = 2.0 * Xm[inds] * s[inds]
q[inds] = 1.0 - s[inds]

inds = np.invert(inds2)*inds1
q[inds] = fa[inds] / fc[inds]
r[inds] = fb[inds] / fc[inds]
p[inds] = s[inds] * (2.0 * Xm[inds] * q[inds] * (q[inds] - r[inds]) - (b[inds] - a[inds]) * (r[inds] - 1.0))
q[inds] = (q[inds] - 1.0) * (r[inds] - 1.0) * (s[inds] - 1.0)

inds = (p > 0.0)*inds1
q[inds] = -q[inds]
p[inds1] = np.abs(p[inds1])

inds2 = (2.0 * p < np.minimum(3.0 * Xm * q - np.abs(tol1 * q), np.abs(e * q)))
inds = inds2*inds1
e[inds] = d[inds]
d[inds] = (p[inds])/(q[inds])

inds = np.invert(inds2)*inds1
d[inds] = Xm[inds]
e[inds] = d[inds]

inds = np.invert(inds1)
d[inds] = Xm[inds]
e[inds] = d[inds]

a = 1 * b
fa = 1 * fb

inds = np.where((np.abs(d) > tol1))
b[inds] += d[inds]

inds = np.where((np.abs(d) <= tol1))
b[inds] += np.copysign(tol1[inds], Xm[inds])
fb = fzero_residuals_function(b, theta_pest_targets, phis, xm, xn, lmns, lmnc, lasym)
return roots, converged, nconv

#theta_vmec = np.zeros((ns, nalpha, nl))
#for js in range(ns):
# for jalpha in range(nalpha):
# for jl in range(nl):
# theta_guess = theta_pest[js, jalpha, jl]
# solution = root_scalar(residual,args=(phi[js, jalpha, jl], theta_pest[js, jalpha, jl], js),
# bracket=(theta_guess - 1.0, theta_guess + 1.0))
# theta_vmec[js, jalpha, jl] = solution.root
## Solve for theta_vmec corresponding to theta_pest (new routine)
## Does the same calculation as the commented code above but faster
theta_vmec = np.zeros((ns, nalpha, nl))
for js in range(ns):
for jalpha in range(nalpha):
theta_guess = np.reshape(theta_pest[js, jalpha, :], (-1, 1))
theta_vmec_mins = theta_guess - 0.3*np.ones(theta_guess.shape)
theta_vmec_maxs = theta_guess + 0.3*np.ones(theta_guess.shape)
theta_test, converged, nconv = get_roots(theta_vmec_mins, theta_vmec_maxs, theta_guess, np.reshape(phi[js, jalpha, :], (-1, 1)), xm, xn, lmns[js], lmnc[js], 'False')
theta_vmec[js, jalpha] = np.reshape(theta_test, (-1, ))

if converged == False:
print("* Error! Conversion from theta_pest to theta_vmec failed to converge")
print("===========================================================================")
print("")
err = True
theta_guess = theta_pest[js, jalpha, :]
solution = newton(residual, x0=theta_guess - 1.0, args=(phi[js, jalpha, :], theta_pest[js, jalpha, :], js), x1=theta_guess + 1.0)
theta_vmec[js, jalpha, :] = solution

# Now that we know theta_vmec, compute all the geometric quantities
angle = xm[:, None, None, None] * theta_vmec[None, :, :, :] - xn[:, None, None, None] * phi[None, :, :, :]
Expand Down Expand Up @@ -1307,14 +1168,13 @@ def get_roots(a0, b0, theta_pest_targets, phis, xm, xn, lmns, lmnc, lasym):
gds22 = grad_psi_dot_grad_psi * shat[:, None, None] * shat[:, None, None] / (L_reference * L_reference * B_reference * B_reference * s[:, None, None])

# temporary fix. Please see issue #238 and the discussion therein
gbdrift = -1. * 2 * B_reference * L_reference * L_reference * sqrt_s[:, None, None] * B_cross_grad_B_dot_grad_alpha \
/ (modB * modB * modB) * toroidal_flux_sign
gbdrift = -1 * 2 * B_reference * L_reference * L_reference * sqrt_s[:, None, None] * B_cross_grad_B_dot_grad_alpha / (modB * modB * modB) * toroidal_flux_sign

gbdrift0 = B_cross_grad_B_dot_grad_psi * 2 * shat[:, None, None] / (modB * modB * modB * sqrt_s[:, None, None]) * toroidal_flux_sign

# temporary fix. Please see issue #238 and the discussion therein
cvdrift = gbdrift + -1 * 2 * B_reference * L_reference * L_reference * sqrt_s[:, None, None] * mu_0 * d_pressure_d_s[:, None, None] * toroidal_flux_sign / (edge_toroidal_flux_over_2pi * modB * modB)
cvdrift = gbdrift - 2 * B_reference * L_reference * L_reference * sqrt_s[:, None, None] * mu_0 * d_pressure_d_s[:, None, None] * toroidal_flux_sign / (edge_toroidal_flux_over_2pi * modB * modB)

cvdrift0 = gbdrift0

# Package results into a structure to return:
Expand Down

0 comments on commit c5b09ba

Please sign in to comment.