Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 4th order derivative transforms #586

Merged
merged 22 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8922d94
Increase max matrix size of transform
unalmis Jul 17, 2023
3399eaf
Add fourth order zernike radial
unalmis Jul 18, 2023
cdefd9f
Merge branch 'master' into higher_order_deriv
unalmis Jul 18, 2023
8f3ca37
Lower bound transform matrix order to 1
unalmis Jul 18, 2023
e5eee12
Update transform class for deriv order
unalmis Jul 18, 2023
abfda65
Add back dropped line of code
unalmis Jul 18, 2023
3ff4ebf
Add back dummy _set_up method
unalmis Jul 19, 2023
15f816a
Document use of _set_up function and fix test to raise ValueError
unalmis Jul 19, 2023
d78e4a5
Remove _set_up method now that io save stuff was...
unalmis Jul 19, 2023
0882ba4
Merge branch 'master' into higher_order_deriv
f0uriest Jul 20, 2023
07bed02
Merge branch 'master' into higher_order_deriv
f0uriest Jul 20, 2023
e6a84a0
Tests for Zernike polynomial radial derivative
unalmis Jul 20, 2023
7574d8a
Merge branch 'master' into higher_order_deriv
unalmis Jul 20, 2023
b8c877f
Ignore division by zero warnings in compute funs
unalmis Jul 21, 2023
60a5bb0
Fix test_zernike_radial
unalmis Jul 25, 2023
8683ba2
git checkout add_all_limits desc/compute/_core.py
unalmis Jul 25, 2023
1a63fb9
Merge branch 'master' into higher_order_deriv
unalmis Jul 25, 2023
73fab8f
Sort _core using script
unalmis Jul 25, 2023
73e4c57
Merge branch 'master' into higher_order_deriv
unalmis Jul 27, 2023
b021ef2
Merge branch 'master' into higher_order_deriv
f0uriest Jul 28, 2023
9e18d49
Merge branch 'master' into higher_order_deriv
f0uriest Jul 28, 2023
d76d9dd
Merge branch 'master' into higher_order_deriv
f0uriest Jul 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

def _set_up(self):
"""Do things after loading or changing resolution."""
# Also recreates any attributes not in _io_attrs on load from input file.
# See IOAble class docstring for more info.
self._enforce_symmetry()
self._sort_modes()
self._create_idx()
Expand Down Expand Up @@ -608,7 +610,7 @@
For L>0, the indexing scheme defines order of the basis functions:

``'ansi'``: ANSI indexing fills in the pyramid with triangles of
decreasing size, ending in a triagle shape. For L == M,
decreasing size, ending in a triangle shape. For L == M,
the traditional ANSI pyramid indexing is recovered. For L>M, adds rows
to the bottom of the pyramid, increasing L while keeping M constant,
giving a "house" shape.
Expand Down Expand Up @@ -651,7 +653,7 @@
For L>0, the indexing scheme defines order of the basis functions:

``'ansi'``: ANSI indexing fills in the pyramid with triangles of
decreasing size, ending in a triagle shape. For L == M,
decreasing size, ending in a triangle shape. For L == M,
the traditional ANSI pyramid indexing is recovered. For L>M, adds rows
to the bottom of the pyramid, increasing L while keeping M constant,
giving a "house" shape.
Expand Down Expand Up @@ -830,7 +832,7 @@
For L>0, the indexing scheme defines order of the basis functions:

``'ansi'``: ANSI indexing fills in the pyramid with triangles of
decreasing size, ending in a triagle shape. For L == M,
decreasing size, ending in a triangle shape. For L == M,
the traditional ANSI pyramid indexing is recovered. For L>M, adds rows
to the bottom of the pyramid, increasing L while keeping M constant,
giving a "house" shape.
Expand Down Expand Up @@ -875,7 +877,7 @@
For L>0, the indexing scheme defines order of the basis functions:

``'ansi'``: ANSI indexing fills in the pyramid with triangles of
decreasing size, ending in a triagle shape. For L == M,
decreasing size, ending in a triangle shape. For L == M,
the traditional ANSI pyramid indexing is recovered. For L>M, adds rows
to the bottom of the pyramid, increasing L while keeping M constant,
giving a "house" shape.
Expand Down Expand Up @@ -1252,37 +1254,49 @@
beta = 0
n = (l - m) // 2
s = (-1) ** n
jacobi_arg = 1 - 2 * r**2
if dr == 0:
out = r**m * _jacobi(n, alpha, beta, 1 - 2 * r**2, 0)
out = r**m * _jacobi(n, alpha, beta, jacobi_arg, 0)
elif dr == 1:
f = _jacobi(n, alpha, beta, 1 - 2 * r**2, 0)
df = _jacobi(n, alpha, beta, 1 - 2 * r**2, 1)
f = _jacobi(n, alpha, beta, jacobi_arg, 0)
df = _jacobi(n, alpha, beta, jacobi_arg, 1)
out = m * r ** jnp.maximum(m - 1, 0) * f - 4 * r ** (m + 1) * df
elif dr == 2:
f = _jacobi(n, alpha, beta, 1 - 2 * r**2, 0)
df = _jacobi(n, alpha, beta, 1 - 2 * r**2, 1)
d2f = _jacobi(n, alpha, beta, 1 - 2 * r**2, 2)
f = _jacobi(n, alpha, beta, jacobi_arg, 0)
df = _jacobi(n, alpha, beta, jacobi_arg, 1)
d2f = _jacobi(n, alpha, beta, jacobi_arg, 2)
out = (
m * (m - 1) * r ** jnp.maximum((m - 2), 0) * f
- 2 * 4 * m * r**m * df
+ r**m * (16 * r**2 * d2f - 4 * df)
(m - 1) * m * r ** jnp.maximum(m - 2, 0) * f
- 4 * (2 * m + 1) * r**m * df
+ 16 * r ** (m + 2) * d2f
)
elif dr == 3:
f = _jacobi(n, alpha, beta, 1 - 2 * r**2, 0)
df = _jacobi(n, alpha, beta, 1 - 2 * r**2, 1)
d2f = _jacobi(n, alpha, beta, 1 - 2 * r**2, 2)
d3f = _jacobi(n, alpha, beta, 1 - 2 * r**2, 3)
f = _jacobi(n, alpha, beta, jacobi_arg, 0)
df = _jacobi(n, alpha, beta, jacobi_arg, 1)
d2f = _jacobi(n, alpha, beta, jacobi_arg, 2)
d3f = _jacobi(n, alpha, beta, jacobi_arg, 3)

Check warning on line 1277 in desc/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/basis.py#L1274-L1277

Added lines #L1274 - L1277 were not covered by tests
out = (
(m - 2) * (m - 1) * m * r ** jnp.maximum(m - 3, 0) * f
- 12 * (m - 1) * m * r ** jnp.maximum(m - 1, 0) * df
+ 48 * r ** (m + 1) * d2f
- 12 * m**2 * r ** jnp.maximum(m - 1, 0) * df
+ 48 * (m + 1) * r ** (m + 1) * d2f
dpanici marked this conversation as resolved.
Show resolved Hide resolved
- 64 * r ** (m + 3) * d3f
+ 48 * m * r ** (m + 1) * d2f
- 12 * m * r ** jnp.maximum(m - 1, 0) * df
)
elif dr == 4:
f = _jacobi(n, alpha, beta, jacobi_arg, 0)
df = _jacobi(n, alpha, beta, jacobi_arg, 1)
d2f = _jacobi(n, alpha, beta, jacobi_arg, 2)
d3f = _jacobi(n, alpha, beta, jacobi_arg, 3)
d4f = _jacobi(n, alpha, beta, jacobi_arg, 4)
out = (

Check warning on line 1290 in desc/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/basis.py#L1284-L1290

Added lines #L1284 - L1290 were not covered by tests
(m - 3) * (m - 2) * (m - 1) * m * r ** jnp.maximum(m - 4, 0) * f
- 8 * m * (2 * m**2 - 3 * m + 1) * r ** jnp.maximum(m - 2, 0) * df
+ 48 * (2 * m**2 + 2 * m + 1) * r**m * d2f
- 128 * (2 * m + 3) * r ** (m + 2) * d3f
+ 256 * r ** (m + 4) * d4f
)
else:
raise NotImplementedError(
"Analytic radial derivatives of zernike polynomials for order>3 "
"Analytic radial derivatives of zernike polynomials for order>4 "
+ "have not been implemented"
)
return s * jnp.where((l - m) % 2 == 0, out, 0)
Expand Down
2 changes: 1 addition & 1 deletion desc/io/equilibrium_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def load(load_from, file_format=None):
)
else:
raise ValueError("Unknown file format: {}".format(file_format))
# to set other secondary stuff that wasnt saved possibly:
# to set other secondary stuff that wasn't saved possibly:
if hasattr(obj, "_set_up"):
obj._set_up()
return obj
Expand Down
81 changes: 35 additions & 46 deletions desc/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@
self._rcond = rcond if rcond is not None else "auto"

if (
not np.all(self.grid.nodes[:, 2] == 0)
and not (self.grid.NFP == self.basis.NFP)
np.any(self.grid.nodes[:, 2] != 0)
and self.grid.NFP != self.basis.NFP
and grid.node_pattern != "custom"
):
warnings.warn(
Expand All @@ -71,29 +71,19 @@
)
)

self._built = False
self._built_pinv = False
self._derivatives = self._get_derivatives(derivs)
self._sort_derivatives()
self._method = method

self._built = False
self._built_pinv = False
self._set_up()
# assign according to logic in setter function
self.method = method
self._matrices = self._get_matrices()
if build:
self.build()
if build_pinv:
self.build_pinv()

def _set_up(self):

self.method = self._method
self._matrices = {
"direct1": {
i: {j: {k: {} for k in range(4)} for j in range(4)} for i in range(4)
},
"fft": {i: {j: {} for j in range(4)} for i in range(4)},
"direct2": {i: {} for i in range(4)},
}

def _get_derivatives(self, derivs):
"""Get array of derivatives needed for calculating objective function.

Expand Down Expand Up @@ -138,6 +128,18 @@
)
self._derivatives = self.derivatives[sort_idx]

def _get_matrices(self):
"""Get matrices to compute all derivatives."""
n = np.amax(self.derivatives) + 1
matrices = {
"direct1": {
i: {j: {k: {} for k in range(n)} for j in range(n)} for i in range(n)
},
"fft": {i: {j: {} for j in range(n)} for i in range(n)},
"direct2": {i: {} for i in range(n)},
}
return matrices

def _check_inputs_fft(self, grid, basis):
"""Check that inputs are formatted correctly for fft method."""
if grid.num_nodes == 0 or basis.num_modes == 0:
Expand Down Expand Up @@ -370,7 +372,7 @@

if self.method == "direct1":
for d in self.derivatives:
self._matrices["direct1"][d[0]][d[1]][d[2]] = self.basis.evaluate(
self.matrices["direct1"][d[0]][d[1]][d[2]] = self.basis.evaluate(
self.grid.nodes, d, unique=True
)

Expand Down Expand Up @@ -404,7 +406,7 @@
rcond = None if self.rcond == "auto" else self.rcond
if self.method == "direct1":
A = self.basis.evaluate(self.grid.nodes, np.array([0, 0, 0]))
self._matrices["pinv"] = (
self.matrices["pinv"] = (

Check warning on line 409 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L409

Added line #L409 was not covered by tests
scipy.linalg.pinv(A, rcond=rcond) if A.size else np.zeros_like(A.T)
)
elif self.method == "direct2":
Expand Down Expand Up @@ -471,41 +473,36 @@
return np.zeros(self.grid.num_nodes)

if self.method == "direct1":
A = self.matrices["direct1"][dr][dt][dz]
A = self.matrices["direct1"].get(dr, {}).get(dt, {}).get(dz, {})
if isinstance(A, dict):
raise ValueError(
colored("Derivative orders are out of initialized bounds", "red")
)
return jnp.matmul(A, c)
return A @ c

elif self.method == "direct2":
A = self.matrices["fft"][dr][dt]
B = self.matrices["direct2"][dz]

A = self.matrices["fft"].get(dr, {}).get(dt, {})
B = self.matrices["direct2"].get(dz, {})
if isinstance(A, dict) or isinstance(B, dict):
raise ValueError(
colored("Derivative orders are out of initialized bounds", "red")
)
c_mtrx = jnp.zeros((self.num_lm_modes * self.num_n_modes,))
c_mtrx = put(c_mtrx, self.fft_index, c).reshape((-1, self.num_n_modes))

cc = jnp.matmul(A, c_mtrx)
return jnp.matmul(cc, B.T).flatten(order="F")
cc = A @ c_mtrx
return (cc @ B.T).flatten(order="F")

elif self.method == "fft":
A = self.matrices["fft"][dr][dt]
A = self.matrices["fft"].get(dr, {}).get(dt, {})
if isinstance(A, dict):
raise ValueError(
colored("Derivative orders are out of initialized bounds", "red")
)

# reshape coefficients
c_mtrx = jnp.zeros((self.num_lm_modes * self.num_n_modes,))
c_mtrx = put(c_mtrx, self.fft_index, c).reshape((-1, self.num_n_modes))

# differentiate
c_diff = c_mtrx[:, :: (-1) ** dz] * self.dk**dz * (-1) ** (dz > 1)

# re-format in complex notation
c_real = jnp.pad(
(self.num_z_nodes / 2)
Expand All @@ -520,10 +517,9 @@
jnp.fliplr(jnp.conj(c_real)),
)
)

# transform coefficients
c_fft = jnp.real(jnp.fft.ifft(c_cplx))
return jnp.matmul(A, c_fft).flatten(order="F")
return (A @ c_fft).flatten(order="F")

def fit(self, x):
"""Transform from physical domain to spectral using weighted least squares fit.
Expand Down Expand Up @@ -730,27 +726,20 @@
self._sort_derivatives()

if len(derivs_to_add):
# if we actually added derivatives and didn't build them, then its not built
# if we actually added derivatives and didn't build them, then it's not
# built
self._built = False
if build:
# we don't update self._built here because it is still built from before
# we don't update self._built here because it is still built from before,
# but it still might have unbuilt matrices from new derivatives
self.build()

@property
def matrices(self):
"""dict: transform matrices such that x=A*c."""
return self.__dict__.setdefault(
"_matrices",
{
"direct1": {
i: {j: {k: {} for k in range(4)} for j in range(4)}
for i in range(4)
},
"fft": {i: {j: {} for j in range(4)} for i in range(4)},
"direct2": {i: {} for i in range(4)},
},
)
if not hasattr(self, "_matrices"):
self._matrices = self._get_matrices()
return self._matrices

@property
def num_nodes(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def test_misc(self):
def test_asserts(self):
"""Test error checking when creating FourierXYZCurve."""
c = FourierXYZCurve()
with pytest.raises(KeyError):
with pytest.raises(ValueError):
c.compute_coordinates(dt=4)
with pytest.raises(TypeError):
c.grid = [1, 2, 3]
Expand Down
Loading