Skip to content

Commit

Permalink
Merge branch 'master' into higher_order_deriv
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest authored Jul 29, 2023
2 parents 9e18d49 + d896b0a commit d76d9dd
Show file tree
Hide file tree
Showing 22 changed files with 587 additions and 294 deletions.
58 changes: 30 additions & 28 deletions desc/compute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,36 +55,38 @@
# compute something, it's easier to just do it once for all quantities when we first
# import the compute module.
def _build_data_index():
for key in data_index.keys():
full = {
"data": get_data_deps(key, has_axis=False),
"transforms": get_derivs(key, has_axis=False),
"params": get_params(key, has_axis=False),
"profiles": get_profiles(key, has_axis=False),
}
data_index[key]["full_dependencies"] = full

full_with_axis_data = get_data_deps(key, has_axis=True)
if len(full["data"]) >= len(full_with_axis_data):
# Then this quantity and all its dependencies do not need anything
# extra to evaluate its limit at the magnetic axis.
# The dependencies in the `full` dictionary and the `full_with_axis`
# dictionary will be identical, so we assign the same reference to
# avoid storing a copy.
full_with_axis = full
else:
full_with_axis = {
"data": full_with_axis_data,
"transforms": get_derivs(key, has_axis=True),
"params": get_params(key, has_axis=True),
"profiles": get_profiles(key, has_axis=True),
for p in data_index.keys():
for key in data_index[p].keys():
full = {
"data": get_data_deps(key, p, has_axis=False),
"transforms": get_derivs(key, p, has_axis=False),
"params": get_params(key, p, has_axis=False),
"profiles": get_profiles(key, p, has_axis=False),
}
for _key, val in full_with_axis.items():
if full[_key] == val:
# Nothing extra was needed to evaluate this quantity's limit.
# One is a copy of the other; dereference to save memory.
full_with_axis[_key] = full[_key]
data_index[key]["full_with_axis_dependencies"] = full_with_axis
data_index[p][key]["full_dependencies"] = full

full_with_axis_data = get_data_deps(key, p, has_axis=True)
if len(full["data"]) >= len(full_with_axis_data):
# Then this quantity and all its dependencies do not need anything
# extra to evaluate its limit at the magnetic axis.
# The dependencies in the `full` dictionary and the `full_with_axis`
# dictionary will be identical, so we assign the same reference to
# avoid storing a copy.
full_with_axis = full
else:
full_with_axis = {
"data": full_with_axis_data,
"transforms": get_derivs(key, p, has_axis=True),
"params": get_params(key, p, has_axis=True),
"profiles": get_profiles(key, p, has_axis=True),
}
for _key, val in full_with_axis.items():
if full[_key] == val:
# Nothing extra was needed to evaluate this quantity's limit.
# One is a copy of the other; dereference to save memory.
full_with_axis[_key] = full[_key]
data_index[p][key]["full_with_axis_dependencies"] = full_with_axis


_build_data_index()
12 changes: 0 additions & 12 deletions desc/compute/_basis_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,10 +666,6 @@ def _b(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="rtz",
data=["e_theta", "e_zeta", "|e_theta x e_zeta|"],
parameterization=[
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.core.Surface",
],
)
def _n_rho(params, transforms, profiles, data, **kwargs):
# equal to e^rho / |e^rho| but works correctly for surfaces as well that don't have
Expand All @@ -692,10 +688,6 @@ def _n_rho(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="rtz",
data=["e_rho", "e_zeta", "|e_zeta x e_rho|"],
parameterization=[
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.core.Surface",
],
)
def _n_theta(params, transforms, profiles, data, **kwargs):
data["n_theta"] = (
Expand All @@ -716,10 +708,6 @@ def _n_theta(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="rtz",
data=["e_rho", "e_theta", "|e_rho x e_theta|"],
parameterization=[
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.core.Surface",
],
)
def _n_zeta(params, transforms, profiles, data, **kwargs):
data["n_zeta"] = (
Expand Down
61 changes: 57 additions & 4 deletions desc/compute/data_index.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""data_index contains all the quantities calculated by the compute functions."""

data_index = {}


def register_compute_fun(
name,
Expand All @@ -15,8 +13,9 @@ def register_compute_fun(
profiles,
coordinates,
data,
parameterization="desc.equilibrium.equilibrium.Equilibrium",
axis_limit_data=None,
**kwargs
**kwargs,
):
"""Decorator to wrap a function and add it to the list of things we can compute.
Expand Down Expand Up @@ -47,6 +46,9 @@ def register_compute_fun(
a flux function, etc.
data : list of str
Names of other items in the data index needed to compute qty.
parameterization: str
Name of desc types the method is valid for. eg 'desc.geometry.FourierXYZCurve'
or `desc.equilibrium.Equilibrium`.
axis_limit_data : list of str
Names of other items in the data index needed to compute axis limit of qty.
Expand All @@ -55,6 +57,9 @@ def register_compute_fun(
Should only list *direct* dependencies. The full dependencies will be built
recursively at runtime using each quantity's direct dependencies.
"""
if not isinstance(parameterization, (tuple, list)):
parameterization = [parameterization]

deps = {
"params": params,
"transforms": transforms,
Expand All @@ -75,7 +80,55 @@ def _decorator(func):
"coordinates": coordinates,
"dependencies": deps,
}
data_index[name] = d
for p in parameterization:
flag = False
for base_class, superclasses in _class_inheritance.items():
if p in superclasses or p == base_class:
data_index[base_class][name] = d.copy()
flag = True
if not flag:
raise ValueError(
f"Can't register function with unknown parameterization: {p}"
)
return func

return _decorator


# This allows us to handle subclasses whos data_index stuff should inherit
# from parent classes.
# This is the least bad solution I've found, since everything else requires
# crazy circular imports
# could maybe make this fancier with a registry of compute-able objects?
_class_inheritance = {
"desc.equilibrium.equilibrium.Equilibrium": [],
"desc.geometry.curve.FourierRZCurve": [
"desc.geometry.core.Curve",
],
"desc.geometry.curve.FourierXYZCurve": [
"desc.geometry.core.Curve",
],
"desc.geometry.curve.FourierPlanarCurve": [
"desc.geometry.core.Curve",
],
"desc.geometry.surface.FourierRZToroidalSurface": [
"desc.geometry.core.Surface",
],
"desc.geometry.surface.ZernikeRZToroidalSection": [
"desc.geometry.core.Surface",
],
"desc.coils.FourierRZCoil": [
"desc.geometry.curve.FourierRZCurve",
"desc.geometry.core.Curve",
],
"desc.coils.FourierXYZCoil": [
"desc.geometry.curve.FourierXYZCurve",
"desc.geometry.core.Curve",
],
"desc.coils.FourierPlanarCoil": [
"desc.geometry.curve.FourierPlanarCurve",
"desc.geometry.core.Curve",
],
}

data_index = {p: {} for p in _class_inheritance.keys()}
Loading

0 comments on commit d76d9dd

Please sign in to comment.