-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Effective ripple ε #1003
base: master
Are you sure you want to change the base?
Effective ripple ε #1003
Changes from all commits
8090374
6ac2251
c02306b
bcaece0
f156083
19c6413
f26e80d
452416a
f7f68d4
de9bc71
447c682
14ac6a7
ce1fe11
cd5094e
71de921
a879376
f17c773
406d07d
558f623
bebc40b
b13d7ad
7c82470
6074f7b
169556f
5044624
a647856
974244d
ad020f2
a7c43d7
9978728
5670703
75fb132
e848f91
263a15b
93ff35b
eb0088c
540a69e
de71838
f2f187c
35ada77
0fa255d
b955d7e
fa404b2
468baf6
8a02352
01f8a75
e5358b5
1a49329
f448bf0
b4ca7e3
6c0581c
65994a5
b08b3c0
a4637ea
490da1f
061f75b
d707531
91a76e0
0cc5c72
229ab46
4dc1fa1
0901d96
95ce834
fe0b099
34ff754
e6183c5
c5b5d69
c333a53
c264cff
599f895
bc6b982
beade78
808806c
b87ca1a
169af3a
f2687b6
50a0810
a9feffd
ae700f0
b45c0dc
45493ae
bc17c5d
a2903cb
9d58c4e
c5b1fbf
57d6f7f
2737194
104b3b4
4454500
db8b7be
66b24a5
bf981dd
3c2aabd
4db2468
7b5b7c0
2dbfeb0
2d01ef4
cce1018
12b39c9
2e730bd
d1c7d4a
9c68b41
f0ac159
6a3965c
fc1be1c
0319dca
9dc81b3
38d82a7
381543f
54dab1a
4b3983e
3f75c09
ee437ea
96ac3e9
6827fcd
0c995d1
8a8d9de
97e081f
9f1371b
ce1c280
a28dade
ea6e0dc
5409b40
10ac679
f54abb4
5d61f58
3e49a8b
84d1134
68efde6
a8bbbc9
5fb62bf
c1c6c16
a603df2
70a3e43
661e6cd
a3e1ba4
fb13f31
cbeda22
5b0393e
36428d3
d952252
4e9ebb9
9ff79da
04a0d52
b383358
654a5ff
9e80037
a93a6da
0f319e0
8d5f3d3
3b1ffc0
3b5441f
fd914e8
f2c26d9
9d981ff
456d1a8
fc8b393
598009f
d266505
9a968e6
b4151d9
8b656f2
3a93117
52adba9
8d9b605
5bc4b35
6146dc4
0366137
6a46efb
bd68679
06d5061
53fd368
454bf3b
41f3727
6e3b0e7
e2b58c7
df0590b
4756c38
2073547
e5d150c
1edd349
850001a
0ae5c9a
8cb2a28
21235b4
098db02
eb3370b
56a88a8
9e88270
8840c9c
2dffde2
73a6b48
6c1cd12
5b4456d
7d88eac
c819e32
157e57b
883b34d
1b3b6f2
a2a3d71
3e4c0ab
998c830
499a6c7
71b42bd
e5bfd38
ae939e9
569dcdc
59235b0
78fa50d
c2971d1
d3d239e
0dd088b
47744c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
_field, | ||
_geometry, | ||
_metric, | ||
_neoclassical, | ||
_omnigenity, | ||
_profiles, | ||
_stability, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
"""Compute functions for neoclassical transport. | ||
|
||
Notes | ||
----- | ||
Some quantities require additional work to compute at the magnetic axis. | ||
A Python lambda function is used to lazily compute the magnetic axis limits | ||
of these quantities. These lambda functions are evaluated only when the | ||
computational grid has a node on the magnetic axis to avoid potentially | ||
expensive computations. | ||
""" | ||
|
||
from functools import partial | ||
|
||
from quadax import simpson | ||
|
||
from desc.backend import imap, jit, jnp | ||
|
||
from ..integrals.bounce_integral import Bounce1D | ||
from ..integrals.bounce_utils import get_pitch_inv_quad | ||
from ..integrals.quad_utils import chebgauss2 | ||
from ..utils import safediv | ||
from .data_index import register_compute_fun | ||
|
||
_bounce_doc = { | ||
"quad": ( | ||
"tuple[jnp.ndarray] : Quadrature points and weights for bounce integrals. " | ||
"Default option is well tested." | ||
), | ||
"num_quad": ( | ||
"int : Resolution for quadrature of bounce integrals. " | ||
"Default is 32. This option is ignored if given ``quad``." | ||
), | ||
"num_pitch": "int : Resolution for quadrature over velocity coordinate.", | ||
"num_well": ( | ||
"int : Maximum number of wells to detect for each pitch and field line. " | ||
"Default is to detect all wells, but due to limitations in JAX this option " | ||
"may consume more memory. Specifying a number that tightly upper bounds " | ||
"the number of wells will increase performance." | ||
), | ||
"batch": "bool : Whether to vectorize part of the computation. Default is true.", | ||
} | ||
|
||
|
||
def _alpha_mean(f): | ||
"""Simple mean over field lines. | ||
|
||
Simple mean rather than integrating over α and dividing by 2π | ||
(i.e. f.T.dot(dα) / dα.sum()), because when the toroidal angle extends | ||
beyond one transit we need to weight all field lines uniformly, regardless | ||
of their spacing wrt α. | ||
""" | ||
return f.mean(axis=0) | ||
|
||
|
||
def _compute(fun, interp_data, data, grid, num_pitch, reduce=True): | ||
"""Compute ``fun`` for each α and ρ value iteratively to reduce memory usage. | ||
|
||
Parameters | ||
---------- | ||
fun : callable | ||
Function to compute. | ||
interp_data : dict[str, jnp.ndarray] | ||
Data to provide to ``fun``. | ||
Names in ``Bounce1D.required_names`` will be overridden. | ||
Reshaped automatically. | ||
data : dict[str, jnp.ndarray] | ||
DESC data dict. | ||
reduce : bool | ||
Whether to compute mean over α and expand to grid. | ||
Default is true. | ||
|
||
""" | ||
pitch_inv, pitch_inv_weight = get_pitch_inv_quad( | ||
grid.compress(data["min_tz |B|"]), | ||
grid.compress(data["max_tz |B|"]), | ||
num_pitch, | ||
) | ||
|
||
def for_each_rho(x): | ||
# using same λ values for every field line α on flux surface ρ | ||
x["pitch_inv"] = pitch_inv | ||
x["pitch_inv weight"] = pitch_inv_weight | ||
return imap(fun, x) | ||
|
||
for name in Bounce1D.required_names: | ||
interp_data[name] = data[name] | ||
interp_data = dict( | ||
zip(interp_data.keys(), Bounce1D.reshape_data(grid, *interp_data.values())) | ||
) | ||
out = imap(for_each_rho, interp_data) | ||
return grid.expand(_alpha_mean(out)) if reduce else out | ||
|
||
|
||
@register_compute_fun( | ||
name="fieldline length", | ||
label="\\int_{\\zeta_{\\mathrm{min}}}^{\\zeta_{\\mathrm{max}}}" | ||
" \\frac{d\\zeta}{|B^{\\zeta}|}", | ||
units="m / T", | ||
units_long="Meter / tesla", | ||
description="(Mean) proper length of field line(s)", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a non-rational field line is formally infinite in length, is this taking the "length" over 1 toroidal turn? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the length of your sample of the field line. Yes, this converges to infinity. |
||
dim=1, | ||
params=[], | ||
transforms={"grid": []}, | ||
profiles=[], | ||
coordinates="r", | ||
data=["B^zeta"], | ||
resolution_requirement="z", | ||
source_grid_requirement={"coordinates": "raz", "is_meshgrid": True}, | ||
) | ||
def _fieldline_length(data, transforms, profiles, **kwargs): | ||
grid = transforms["grid"].source_grid | ||
L_ra = simpson( | ||
y=grid.meshgrid_reshape(1 / data["B^zeta"], "arz"), | ||
x=grid.compress(grid.nodes[:, 2], surface_label="zeta"), | ||
axis=-1, | ||
) | ||
data["fieldline length"] = grid.expand(jnp.abs(_alpha_mean(L_ra))) | ||
return data | ||
|
||
|
||
@register_compute_fun( | ||
name="fieldline length/volume", | ||
label="\\int_{\\zeta_{\\mathrm{min}}}^{\\zeta_{\\mathrm{max}}}" | ||
" \\frac{d\\zeta}{|B^{\\zeta} \\sqrt g|}", | ||
units="1 / Wb", | ||
units_long="Inverse webers", | ||
description="(Mean) proper length over volume of field line(s)", | ||
dim=1, | ||
params=[], | ||
transforms={"grid": []}, | ||
profiles=[], | ||
coordinates="r", | ||
data=["B^zeta", "sqrt(g)"], | ||
resolution_requirement="z", | ||
source_grid_requirement={"coordinates": "raz", "is_meshgrid": True}, | ||
) | ||
def _fieldline_length_over_volume(data, transforms, profiles, **kwargs): | ||
grid = transforms["grid"].source_grid | ||
G_ra = simpson( | ||
y=grid.meshgrid_reshape(1 / (data["B^zeta"] * data["sqrt(g)"]), "arz"), | ||
x=grid.compress(grid.nodes[:, 2], surface_label="zeta"), | ||
axis=-1, | ||
) | ||
data["fieldline length/volume"] = grid.expand(jnp.abs(_alpha_mean(G_ra))) | ||
return data | ||
|
||
|
||
@register_compute_fun( | ||
name="effective ripple 3/2", | ||
label=( | ||
# ε¹ᐧ⁵ = π/(8√2) R₀²〈|∇ψ|〉⁻² B₀⁻¹ ∫dλ λ⁻² 〈 ∑ⱼ Hⱼ²/Iⱼ 〉 | ||
"\\epsilon_{\\mathrm{eff}}^{3/2} = \\frac{\\pi}{8 \\sqrt{2}} " | ||
"R_0^2 \\langle \\vert\\nabla \\psi\\vert \\rangle^{-2} " | ||
"B_0^{-1} \\int d\\lambda \\lambda^{-2} " | ||
"\\langle \\sum_j H_j^2 / I_j \\rangle" | ||
), | ||
units="~", | ||
units_long="None", | ||
description="Effective ripple modulation amplitude to 3/2 power", | ||
dim=1, | ||
params=[], | ||
transforms={"grid": []}, | ||
profiles=[], | ||
coordinates="r", | ||
data=[ | ||
"min_tz |B|", | ||
"max_tz |B|", | ||
"kappa_g", | ||
"R0", | ||
"|grad(rho)|", | ||
"<|grad(rho)|>", | ||
"fieldline length", | ||
] | ||
+ Bounce1D.required_names, | ||
resolution_requirement="z", | ||
source_grid_requirement={"coordinates": "raz", "is_meshgrid": True}, | ||
**_bounce_doc, | ||
# Some notes on choosing the resolution hyperparameters: | ||
# The default settings were chosen such that the effective ripple profile on | ||
# the W7-X stellarator looks similar to the profile computed at higher resolution, | ||
f0uriest marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# indicating convergence. The parameters ``num_transit`` and ``knots_per_transit`` | ||
# have a stronger effect on the result. As a reference for W7-X, when computing the | ||
# effective ripple by tracing a single field line on each flux surface, a density of | ||
# 100 knots per toroidal transit accurately reconstructs the ripples along the field | ||
# line. After 10 toroidal transits convergence is apparent (after 15 the returns | ||
# diminish). Dips in the resulting profile indicates insufficient ``num_transit``. | ||
# Unreasonably high values indicates insufficient ``knots_per_transit``. | ||
# One can plot the field line with ``Bounce1D.plot`` to see if the number of knots | ||
# was sufficient to reconstruct the field line. | ||
# TODO: Improve performance... see GitHub issue #1045. | ||
# Need more efficient function approximation of |B|(α, ζ). | ||
) | ||
@partial(jit, static_argnames=["num_quad", "num_pitch", "num_well", "batch"]) | ||
def _epsilon_32(params, transforms, profiles, data, **kwargs): | ||
"""https://doi.org/10.1063/1.873749. | ||
|
||
Evaluation of 1/ν neoclassical transport in stellarators. | ||
V. V. Nemov, S. V. Kasilov, W. Kernbichler, M. F. Heyn. | ||
Phys. Plasmas 1 December 1999; 6 (12): 4622–4632. | ||
""" | ||
# noqa: unused dependency | ||
if "quad" in kwargs: | ||
quad = kwargs["quad"] | ||
else: | ||
quad = chebgauss2(kwargs.get("num_quad", 32)) | ||
num_well = kwargs.get("num_well", None) | ||
batch = kwargs.get("batch", True) | ||
grid = transforms["grid"].source_grid | ||
|
||
def dH(grad_rho_norm_kappa_g, B, pitch): | ||
# Integrand of Nemov eq. 30 with |∂ψ/∂ρ| (λB₀)¹ᐧ⁵ removed. | ||
f0uriest marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return ( | ||
jnp.sqrt(jnp.abs(1 - pitch * B)) | ||
* (4 / (pitch * B) - 1) | ||
* grad_rho_norm_kappa_g | ||
/ B | ||
) | ||
|
||
def dI(B, pitch): | ||
# Integrand of Nemov eq. 31. | ||
return jnp.sqrt(jnp.abs(1 - pitch * B)) / B | ||
f0uriest marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def eps_32(data): | ||
"""(∂ψ/∂ρ)⁻² B₀⁻² ∫ dλ λ⁻² ∑ⱼ Hⱼ²/Iⱼ.""" | ||
# B₀ has units of λ⁻¹. | ||
# Nemov's ∑ⱼ Hⱼ²/Iⱼ = (∂ψ/∂ρ)² (λB₀)³ ``(H**2 / I).sum(axis=-1)``. | ||
# (λB₀)³ d(λB₀)⁻¹ = B₀² λ³ d(λ⁻¹) = -B₀² λ dλ. | ||
bounce = Bounce1D(grid, data, quad, automorphism=None, is_reshaped=True) | ||
points = bounce.points(data["pitch_inv"], num_well=num_well) | ||
H = bounce.integrate( | ||
dH, | ||
data["pitch_inv"], | ||
data["|grad(rho)|*kappa_g"], | ||
points=points, | ||
batch=batch, | ||
) | ||
I = bounce.integrate(dI, data["pitch_inv"], points=points, batch=batch) | ||
return ( | ||
safediv(H**2, I).sum(axis=-1) | ||
* data["pitch_inv"] ** (-3) | ||
* data["pitch_inv weight"] | ||
).sum(axis=-1) | ||
|
||
# Interpolate |∇ρ| κ_g since it is smoother than κ_g alone. | ||
interp_data = {"|grad(rho)|*kappa_g": data["|grad(rho)|"] * data["kappa_g"]} | ||
B0 = data["max_tz |B|"] | ||
data["effective ripple 3/2"] = ( | ||
jnp.pi | ||
/ (8 * 2**0.5) | ||
* (B0 * data["R0"] / data["<|grad(rho)|>"]) ** 2 | ||
* _compute(eps_32, interp_data, data, grid, kwargs.get("num_pitch", 50)) | ||
/ data["fieldline length"] | ||
) | ||
return data | ||
|
||
|
||
@register_compute_fun( | ||
name="effective ripple", | ||
label="\\epsilon_{\\mathrm{eff}}", | ||
units="~", | ||
units_long="None", | ||
description="Effective ripple modulation amplitude", | ||
dim=1, | ||
params=[], | ||
transforms={}, | ||
profiles=[], | ||
coordinates="r", | ||
data=["effective ripple 3/2"], | ||
) | ||
def _effective_ripple(params, transforms, profiles, data, **kwargs): | ||
data["effective ripple"] = data["effective ripple 3/2"] ** (2 / 3) | ||
return data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we missing a factor to make this have units of length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has the desired units. dhaeseleer calls this the proper length, (p.g. VIII and chapter 12)