Skip to content

Commit

Permalink
Fix math, results look kinda good now
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed May 30, 2024
1 parent 599f895 commit bc6b982
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 139 deletions.
193 changes: 89 additions & 104 deletions desc/compute/_neoclassical.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,85 +116,71 @@ def _poloidal_average(grid, f, name=""):


@register_compute_fun(
name="V_psi(r)*L",
label="\\int_{0}^{L} d \\ell / \\vert B \\vert",
units="m^3 / Wb",
units_long="Cubic meters per Weber",
description=(
"Volume enclosed by flux surfaces, derivative with respect to toroidal flux, "
"computed along field line, scaled by dimensionless length of field line"
),
dim=1,
name="L|r,a",
label="\\int_{\\zeta_{\\mathrm{min}}}^{\\zeta_{\\mathrm{max}}"
" \\frac{d\\zeta}{B^{\\zeta} \\vert B \\vert}",
units="m",
units_long="meters",
description="Length along field line",
dim=2,
params=[],
transforms={"grid": []},
profiles=[],
coordinates="r",
coordinates="ra",
data=["B^zeta", "|B|"],
grid_requirement=[
"source_grid",
lambda grid: grid.source_grid.coordinates == "raz"
and grid.source_grid.is_meshgrid,
],
)
def _V_psi_L(data, transforms, profiles, **kwargs):
def _L_ra(data, transforms, profiles, **kwargs):
g = transforms["grid"].source_grid
shape = (g.num_rho, g.num_alpha, g.num_zeta)
V_psi_L = _poloidal_average(
g,
quadax.simpson(
jnp.reshape(1 / (data["B^zeta"] * data["|B|"]), shape),
jnp.reshape(g.nodes[:, 2], shape),
axis=-1,
),
name="V_psi(r)*L",
data["L|r,a"] = quadax.simpson(
jnp.reshape(1 / (data["B^zeta"] * data["|B|"]), shape),
jnp.reshape(g.nodes[:, 2], shape),
axis=-1,
)
data["V_psi(r)*L"] = g.expand(V_psi_L)
return data


@register_compute_fun(
name="S(r)*L",
label="\\int_{0}^{L} d \\ell \\vert \\nabla \\psi \\vert / \\vert B \\vert",
units="m^2",
units_long="Square meters",
description=(
"Surface area of flux surfaces, computed along field line, "
"scaled by dimensionless length of field line."
),
dim=1,
name="G|r,a",
label="\\int_{\\zeta_{\\mathrm{min}}}^{\\zeta_{\\mathrm{max}}"
" \\frac{d\\zeta}{B^{\\zeta} \\vert B \\vert \\sqrt g}",
units="m^{-2}",
units_long="inverse meters squared",
description="Length over volume along field line",
dim=2,
params=[],
transforms={"grid": []},
profiles=[],
coordinates="r",
data=["B^zeta", "|B|", "|grad(psi)|"],
coordinates="ra",
data=["B^zeta", "|B|", "sqrt(g)"],
grid_requirement=[
"source_grid",
lambda grid: grid.source_grid.coordinates == "raz"
and grid.source_grid.is_meshgrid,
],
)
def _S_L(data, transforms, profiles, **kwargs):
def _G_ra(data, transforms, profiles, **kwargs):
g = transforms["grid"].source_grid
shape = (g.num_rho, g.num_alpha, g.num_zeta)
S_L = _poloidal_average(
g,
quadax.simpson(
jnp.reshape(data["|grad(psi)|"] / (data["B^zeta"] * data["|B|"]), shape),
jnp.reshape(g.nodes[:, 2], shape),
axis=-1,
),
name="S(r)*L",
)
data["S(r)*L"] = g.expand(S_L)
data["G|r,a"] = quadax.simpson(
jnp.reshape(1 / (data["B^zeta"] * data["|B|"] * data["sqrt(g)"]), shape),
jnp.reshape(g.nodes[:, 2], shape),
axis=-1,
) / (4 * jnp.pi**2)
return data


@register_compute_fun(
name="effective ripple raw",
label="-∫ dλ λ⁻² ∑ⱼ Hⱼ² / Iⱼ",
units="Wb / m",
units_long="Webers per meter",
description="Effective ripple modulation amplitude, not normalized",
label="∫ dλ λ⁻² \\langle ∑ⱼ Hⱼ²/Iⱼ \\rangle",
units="T^2",
units_long="Tesla squared",
description="Effective ripple modulation amplitude, not dimensionless",
dim=1,
params=[],
transforms={"grid": []},
Expand All @@ -204,10 +190,11 @@ def _S_L(data, transforms, profiles, **kwargs):
"min_tz |B|",
"max_tz |B|",
"B^zeta",
"|B|_z|r,a",
"|B|",
"|B|_z|r,a",
"|grad(psi)|",
"kappa_g",
"L|r,a",
],
grid_requirement=[
"source_grid",
Expand Down Expand Up @@ -243,6 +230,7 @@ def _effective_ripple_raw(params, transforms, profiles, data, **kwargs):
pitch_endpoint = 1 / jnp.stack([max_B, min_B])

def dH(grad_psi_norm, kappa_g, B, pitch, Z):
# absorbed 1/λ into H
return jnp.sqrt(1 / pitch - B) * (4 / B - pitch) * grad_psi_norm * kappa_g / B

def dI(B, pitch, Z):
Expand All @@ -254,7 +242,7 @@ def dI(B, pitch, Z):
)

def d_ripple(pitch):
"""Return λ⁻² ∑ⱼ Hⱼ² / Iⱼ evaluated at pitch.
"""Return λ⁻² ∑ⱼ Hⱼ²/Iⱼ evaluated at λ = pitch.
Parameters
----------
Expand All @@ -264,10 +252,9 @@ def d_ripple(pitch):
Returns
-------
d_ripple : Array, shape(pitch.shape)
λ⁻² ∑ⱼ Hⱼ² / Iⱼ
λ⁻² ∑ⱼ Hⱼ²/Iⱼ
"""
# absorbed 1/λ into H
H = bounce_integrate(
dH, [data["|grad(psi)|"], data["kappa_g"]], pitch, batch=batch
)
Expand All @@ -283,7 +270,6 @@ def d_ripple(pitch):
# Use adaptive quadrature.

def d_ripple(pitch, B_sup_z, B, B_z_ra, grad_psi_norm, kappa_g):
# Quadax requires scalar integration interval, so we need to return scalar.
bounce_integrate, _ = _bounce_integral(B_sup_z, B, B_z_ra, knots)
H = bounce_integrate(dH, [grad_psi_norm, kappa_g], pitch, batch=batch)
I = bounce_integrate(dI, [], pitch, batch=batch)
Expand All @@ -304,15 +290,15 @@ def d_ripple(pitch, B_sup_z, B, B_z_ra, grad_psi_norm, kappa_g):
ripple = quad(d_ripple, pitch_endpoint, *args)

ripple = _poloidal_average(
g, ripple.reshape(g.num_rho, g.num_alpha), name="effective ripple raw"
g, ripple.reshape(g.num_rho, g.num_alpha) / data["L|r,a"]
)
data["effective ripple raw"] = g.expand(ripple)
return data


@register_compute_fun(
name="effective ripple",
label="\\epsilon_{\\text{eff}}",
name="effective ripple", # this is ε¹ᐧ⁵
label="π/(8√2) (R₀(∂_ψ V)/S)² ∫ dλ λ⁻² \\langle ∑ⱼ Hⱼ²/Iⱼ \\rangle",
units="~",
units_long="None",
description="Effective ripple modulation amplitude",
Expand All @@ -321,7 +307,7 @@ def d_ripple(pitch, B_sup_z, B, B_z_ra, grad_psi_norm, kappa_g):
transforms={},
profiles=[],
coordinates="r",
data=["effective ripple raw", "R0", "V_psi(r)*L", "S(r)*L"],
data=["R0", "V_r(r)", "psi_r", "S(r)", "effective ripple raw"],
)
def _effective_ripple(params, transforms, profiles, data, **kwargs):
"""Evaluation of 1/ν neoclassical transport in stellarators.
Expand All @@ -332,21 +318,38 @@ def _effective_ripple(params, transforms, profiles, data, **kwargs):
"""
data["effective ripple"] = (
jnp.pi
* data["R0"] ** 2
/ (8 * 2**0.5)
* data["V_psi(r)*L"]
/ data["S(r)*L"] ** 2
* (data["R0"] * data["V_r(r)"] / data["psi_r"] / data["S(r)"]) ** 2
* data["effective ripple raw"]
) ** (2 / 3)
)
return data


@register_compute_fun(
name="Gamma_c raw",
label="-∫ dλ ∑ⱼ (γ_c² ∂I/∂(λ⁻¹) λ⁻²)ⱼ",
units="m^3 / Wb",
units_long="Cubic meters per Weber",
description="Energetic ion confinement proxy, not normalized",
name="Gamma_c",
# When comparing Velasco's Γ_c: https://doi.org/10.1088/1741-4326/ac2994
# with Nemov's Γ_c: https://doi.org/10.1063/1.2912456,
# note that
# dλ v τ_b = 8 (∂I/∂b) db = -8 ∂I/∂(λ⁻¹) dλ λ⁻²
# and
# 4π² ∂Ψₜ/∂V = lim{L → ∞} ( [∫₀ᴸ ds/(B √g)] / [∫₀ᴸ ds/B] )
# where the integrals are along an irrational field line with
# ds given by dζ / B^ζ
# √g the (Ψ, θ, ζ)-coordinate Jacobian
# There is also the dimensionless difference between Nemov's and Velascos's γ_c,
# mentioned in Velasco's footnote 4. If this difference is γ_c ignored, and the
# (missing?) √g factor is pushed into the integral over alpha in Velasco eq. 18
# then we have that
# Velasco Γ_c = Nemov Γ_c * lim{L → ∞} ∫₀ᴸ ds/(B √g) / (2π).
# In particular, Velasco Γ_c grows with the length of the field line
# which isn't what we want.
# TODO:
# We currently implement Nemov Γ_c with Velasco γ_c.
# Switch to Nemov γ_c too.
label="π/(2√2) ∫ dλ λ⁻² \\langle ∑ⱼ [γ_c² ∂I/∂(λ⁻¹)]ⱼ \\rangle",
units="~",
units_long="None",
description="Nemov's energetic ion confinement proxy",
dim=1,
params=[],
transforms={"grid": []},
Expand All @@ -356,10 +359,11 @@ def _effective_ripple(params, transforms, profiles, data, **kwargs):
"min_tz |B|",
"max_tz |B|",
"B^zeta",
"|B|_z|r,a",
"|B|",
"|B|_z|r,a",
"cvdrift0",
"gbdrift",
"L|r,a",
],
grid_requirement=[
"source_grid",
Expand All @@ -378,7 +382,13 @@ def _effective_ripple(params, transforms, profiles, data, **kwargs):
),
quad_res="int : Resolution for quadrature over velocity coordinate.",
)
def _Gamma_c_raw(params, transforms, profiles, data, **kwargs):
def _Gamma_c(params, transforms, profiles, data, **kwargs):
"""Poloidal motion of trapped particle orbits in real-space coordinates.
V. V. Nemov, S. V. Kasilov, W. Kernbichler, G. O. Leitold.
Phys. Plasmas 1 May 2008; 15 (5): 052501.
https://doi.org/10.1063/1.2912456.
"""
g = transforms["grid"].source_grid
knots = g.compress(g.nodes[:, 2], surface_label="zeta")
_bounce_integral = kwargs.get("bounce_integral", bounce_integral)
Expand All @@ -404,8 +414,8 @@ def dK(B, pitch, Z):
data["B^zeta"], data["|B|"], data["|B|_z|r,a"], knots
)

def d_Gamma_c_raw(pitch):
"""Return ∑ⱼ (γ_c² ∂I/∂(λ⁻¹) λ⁻²)ⱼ evaluated at pitch.
def d_Gamma_c(pitch):
"""Return ∑ⱼ [γ_c² ∂I/∂(λ⁻¹)]ⱼ evaluated at λ = pitch.
Parameters
----------
Expand All @@ -414,13 +424,13 @@ def d_Gamma_c_raw(pitch):
Returns
-------
d_Gamma_c_raw : Array, shape(pitch.shape)
∑ⱼ (γ_c² ∂I/∂(λ⁻¹) λ⁻²)
d_Gamma_c : Array, shape(pitch.shape)
∑ⱼ [γ_c² ∂I/∂(λ⁻¹)]
"""
# TODO: Currently we have implemented Velasco's Gamma_c.
# If we add a 1/|grad(psi)| into the arctan of little
# gamma_c, we implement Nemov's Gamma_c.
# gamma_c, we implement Nemov's Gamma_c. (Check this again).
# This will affect the gamma_c profile since |grad(psi)|
# depends on alpha.
gamma_c = (
Expand All @@ -431,20 +441,19 @@ def d_Gamma_c_raw(pitch):
/ bounce_integrate(d_gamma_c, data["gbdrift"], pitch, batch=batch)
)
)
K = bounce_integrate(dK, [], pitch, batch=batch) # ∂I/∂(λ⁻¹) λ⁻²
K = bounce_integrate(dK, [], pitch, batch=batch) # λ⁻² ∂I/∂(λ⁻¹)
return jnp.nansum(gamma_c**2 * K, axis=-1)

pitch = composite_linspace(pitch_endpoint, quad_res)
pitch = jnp.broadcast_to(
pitch[..., jnp.newaxis], (pitch.shape[0], g.num_rho, g.num_alpha)
).reshape(pitch.shape[0], g.num_rho * g.num_alpha)
Gamma_c_raw = quad(d_Gamma_c_raw(pitch), pitch, axis=0)
Gamma_c = quad(d_Gamma_c(pitch), pitch, axis=0)
else:
# Use adaptive quadrature.

def d_Gamma_c_raw(pitch, B_sup_z, B, B_z_ra, cvdrift0, gbdrift):
# Quadax requires scalar integration interval, so we need to return scalar.
bounce_int, _ = _bounce_integral(B_sup_z, B, B_z_ra, knots)
def d_Gamma_c(pitch, B_sup_z, B, B_z_ra, cvdrift0, gbdrift):
bounce_integrate, _ = _bounce_integral(B_sup_z, B, B_z_ra, knots)
gamma_c = (
2
/ jnp.pi
Expand All @@ -468,34 +477,10 @@ def d_Gamma_c_raw(pitch, B_sup_z, B, B_z_ra, cvdrift0, gbdrift):
data["gbdrift"],
]
]
Gamma_c_raw = quad(d_Gamma_c_raw, pitch_endpoint, *args)
Gamma_c = quad(d_Gamma_c, pitch_endpoint, *args)

Gamma_c_raw = _poloidal_average(
g, Gamma_c_raw.reshape(g.num_rho, g.num_alpha), name="Gamma_c raw"
Gamma_c = _poloidal_average(
g, Gamma_c.reshape(g.num_rho, g.num_alpha) / data["L|r,a"]
)
data["Gamma_c raw"] = g.expand(Gamma_c_raw)
return data


@register_compute_fun(
name="Gamma_c",
label="\\Gamma_{c}",
units="~",
units_long="None",
description="Energetic ion confinement proxy",
dim=1,
params=[],
transforms={},
profiles=[],
coordinates="r",
data=["Gamma_c raw", "V_psi(r)*L"],
)
def _Gamma_c(params, transforms, profiles, data, **kwargs):
"""Poloidal motion of trapped particle orbits in real-space coordinates.
V. V. Nemov, S. V. Kasilov, W. Kernbichler, G. O. Leitold.
Phys. Plasmas 1 May 2008; 15 (5): 052501.
https://doi.org/10.1063/1.2912456.
"""
data["Gamma_c"] = jnp.pi * data["Gamma_c raw"] / (2**1.5 * data["V_psi(r)*L"])
data["Gamma_c"] = g.expand(jnp.pi / 2**1.5 * Gamma_c)
return data
Loading

2 comments on commit bc6b982

@a957924278
Copy link

@a957924278 a957924278 commented on bc6b982 May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi unalmis, it reports error like below when I run test_effective_ripple() with the newest version of branch ripple :

DESC version 0+untagged.1.gbc6b982,using JAX backend, jax version=0.4.28, jaxlib version=0.4.28, dtype=float64
Using device: CPU, with 248.09 GB available memory
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/mnt/d/chl/我的坚果云/desc_code/test_neoclassical.py", line 109, in <module>
    test_effective_ripple()
  File "/mnt/d/chl/我的坚果云/desc_code/test_neoclassical.py", line 69, in test_effective_ripple
    data = eq.compute(
  File "/home/chl/miniconda3/envs/desc-ripple0530/lib/python3.10/site-packages/desc/equilibrium/equilibrium.py", line 1003, in compute
    data = compute_fun(
  File "/home/chl/miniconda3/envs/desc-ripple0530/lib/python3.10/site-packages/desc/compute/utils.py", line 85, in compute
    data = _compute(
  File "/home/chl/miniconda3/envs/desc-ripple0530/lib/python3.10/site-packages/desc/compute/utils.py", line 156, in _compute
    data = data_index[parameterization][name]["fun"](
  File "/home/chl/miniconda3/envs/desc-ripple0530/lib/python3.10/site-packages/desc/compute/_neoclassical.py", line 240, in _effective_ripple_raw
    bounce_integrate, _ = _bounce_integral(
  File "/home/chl/miniconda3/envs/desc-ripple0530/lib/python3.10/site-packages/desc/compute/bounce_integral.py", line 1251, in bounce_integral
    else CubicHermiteSpline(knots, B, B_z_ra, axis=-1, check=check).c
  File "/home/chl/miniconda3/envs/desc-ripple0530/lib/python3.10/site-packages/equinox/_better_abstract.py", line 226, in __call__
    self = super().__call__(*args, **kwargs)
  File "/home/chl/miniconda3/envs/desc-ripple0530/lib/python3.10/site-packages/interpax/_ppoly.py", line 565, in __init__
    super().__init__(c, x, extrapolate=extrapolate, axis=axis)
  File "/home/chl/miniconda3/envs/desc-ripple0530/lib/python3.10/site-packages/interpax/_ppoly.py", line 139, in __init__
    errorif(
  File "/home/chl/miniconda3/envs/desc-ripple0530/lib/python3.10/site-packages/interpax/utils.py", line 19, in errorif
    if cond:
  File "/home/chl/miniconda3/envs/desc-ripple0530/lib/python3.10/site-packages/jax/_src/errors.py", line 522, in __init__
    f"{tracer._origin_msg()}")
IndexError: list index out of range

Do you have any idea about the error?

@unalmis
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's the issue I mention in the description of the pull request. The PR relies on interpax package, which fixed this bug in f0uriest/interpax#27. The bugfix has yet to be pushed to interpax's PyPI package. Until then you'll need to manually edit your interpax installation with the changes in the liked pull request.

Please sign in to comment.