Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 4th order derivative transforms #586

Merged
merged 22 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8922d94
Increase max matrix size of transform
unalmis Jul 17, 2023
3399eaf
Add fourth order zernike radial
unalmis Jul 18, 2023
cdefd9f
Merge branch 'master' into higher_order_deriv
unalmis Jul 18, 2023
8f3ca37
Lower bound transform matrix order to 1
unalmis Jul 18, 2023
e5eee12
Update transform class for deriv order
unalmis Jul 18, 2023
abfda65
Add back dropped line of code
unalmis Jul 18, 2023
3ff4ebf
Add back dummy _set_up method
unalmis Jul 19, 2023
15f816a
Document use of _set_up function and fix test to raise ValueError
unalmis Jul 19, 2023
d78e4a5
Remove _set_up method now that io save stuff was...
unalmis Jul 19, 2023
0882ba4
Merge branch 'master' into higher_order_deriv
f0uriest Jul 20, 2023
07bed02
Merge branch 'master' into higher_order_deriv
f0uriest Jul 20, 2023
e6a84a0
Tests for Zernike polynomial radial derivative
unalmis Jul 20, 2023
7574d8a
Merge branch 'master' into higher_order_deriv
unalmis Jul 20, 2023
b8c877f
Ignore division by zero warnings in compute funs
unalmis Jul 21, 2023
60a5bb0
Fix test_zernike_radial
unalmis Jul 25, 2023
8683ba2
git checkout add_all_limits desc/compute/_core.py
unalmis Jul 25, 2023
1a63fb9
Merge branch 'master' into higher_order_deriv
unalmis Jul 25, 2023
73fab8f
Sort _core using script
unalmis Jul 25, 2023
73e4c57
Merge branch 'master' into higher_order_deriv
unalmis Jul 27, 2023
b021ef2
Merge branch 'master' into higher_order_deriv
f0uriest Jul 28, 2023
9e18d49
Merge branch 'master' into higher_order_deriv
f0uriest Jul 28, 2023
d76d9dd
Merge branch 'master' into higher_order_deriv
f0uriest Jul 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@
For L>0, the indexing scheme defines order of the basis functions:

``'ansi'``: ANSI indexing fills in the pyramid with triangles of
decreasing size, ending in a triagle shape. For L == M,
decreasing size, ending in a triangle shape. For L == M,
the traditional ANSI pyramid indexing is recovered. For L>M, adds rows
to the bottom of the pyramid, increasing L while keeping M constant,
giving a "house" shape.
Expand Down Expand Up @@ -651,7 +651,7 @@
For L>0, the indexing scheme defines order of the basis functions:

``'ansi'``: ANSI indexing fills in the pyramid with triangles of
decreasing size, ending in a triagle shape. For L == M,
decreasing size, ending in a triangle shape. For L == M,
the traditional ANSI pyramid indexing is recovered. For L>M, adds rows
to the bottom of the pyramid, increasing L while keeping M constant,
giving a "house" shape.
Expand Down Expand Up @@ -830,7 +830,7 @@
For L>0, the indexing scheme defines order of the basis functions:

``'ansi'``: ANSI indexing fills in the pyramid with triangles of
decreasing size, ending in a triagle shape. For L == M,
decreasing size, ending in a triangle shape. For L == M,
the traditional ANSI pyramid indexing is recovered. For L>M, adds rows
to the bottom of the pyramid, increasing L while keeping M constant,
giving a "house" shape.
Expand Down Expand Up @@ -875,7 +875,7 @@
For L>0, the indexing scheme defines order of the basis functions:

``'ansi'``: ANSI indexing fills in the pyramid with triangles of
decreasing size, ending in a triagle shape. For L == M,
decreasing size, ending in a triangle shape. For L == M,
the traditional ANSI pyramid indexing is recovered. For L>M, adds rows
to the bottom of the pyramid, increasing L while keeping M constant,
giving a "house" shape.
Expand Down Expand Up @@ -1252,37 +1252,49 @@
beta = 0
n = (l - m) // 2
s = (-1) ** n
jacobi_arg = 1 - 2 * r**2

Check warning on line 1255 in desc/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/basis.py#L1255

Added line #L1255 was not covered by tests
if dr == 0:
out = r**m * _jacobi(n, alpha, beta, 1 - 2 * r**2, 0)
out = r**m * _jacobi(n, alpha, beta, jacobi_arg, 0)

Check warning on line 1257 in desc/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/basis.py#L1257

Added line #L1257 was not covered by tests
elif dr == 1:
f = _jacobi(n, alpha, beta, 1 - 2 * r**2, 0)
df = _jacobi(n, alpha, beta, 1 - 2 * r**2, 1)
f = _jacobi(n, alpha, beta, jacobi_arg, 0)
df = _jacobi(n, alpha, beta, jacobi_arg, 1)

Check warning on line 1260 in desc/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/basis.py#L1259-L1260

Added lines #L1259 - L1260 were not covered by tests
out = m * r ** jnp.maximum(m - 1, 0) * f - 4 * r ** (m + 1) * df
elif dr == 2:
f = _jacobi(n, alpha, beta, 1 - 2 * r**2, 0)
df = _jacobi(n, alpha, beta, 1 - 2 * r**2, 1)
d2f = _jacobi(n, alpha, beta, 1 - 2 * r**2, 2)
f = _jacobi(n, alpha, beta, jacobi_arg, 0)
df = _jacobi(n, alpha, beta, jacobi_arg, 1)
d2f = _jacobi(n, alpha, beta, jacobi_arg, 2)

Check warning on line 1265 in desc/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/basis.py#L1263-L1265

Added lines #L1263 - L1265 were not covered by tests
out = (
m * (m - 1) * r ** jnp.maximum((m - 2), 0) * f
- 2 * 4 * m * r**m * df
+ r**m * (16 * r**2 * d2f - 4 * df)
(m - 1) * m * r ** jnp.maximum(m - 2, 0) * f
- 4 * (2 * m + 1) * r**m * df
+ 16 * r ** (m + 2) * d2f
)
elif dr == 3:
f = _jacobi(n, alpha, beta, 1 - 2 * r**2, 0)
df = _jacobi(n, alpha, beta, 1 - 2 * r**2, 1)
d2f = _jacobi(n, alpha, beta, 1 - 2 * r**2, 2)
d3f = _jacobi(n, alpha, beta, 1 - 2 * r**2, 3)
f = _jacobi(n, alpha, beta, jacobi_arg, 0)
df = _jacobi(n, alpha, beta, jacobi_arg, 1)
d2f = _jacobi(n, alpha, beta, jacobi_arg, 2)
d3f = _jacobi(n, alpha, beta, jacobi_arg, 3)

Check warning on line 1275 in desc/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/basis.py#L1272-L1275

Added lines #L1272 - L1275 were not covered by tests
out = (
(m - 2) * (m - 1) * m * r ** jnp.maximum(m - 3, 0) * f
- 12 * (m - 1) * m * r ** jnp.maximum(m - 1, 0) * df
+ 48 * r ** (m + 1) * d2f
- 12 * m**2 * r ** jnp.maximum(m - 1, 0) * df
+ 48 * (m + 1) * r ** (m + 1) * d2f
dpanici marked this conversation as resolved.
Show resolved Hide resolved
- 64 * r ** (m + 3) * d3f
+ 48 * m * r ** (m + 1) * d2f
- 12 * m * r ** jnp.maximum(m - 1, 0) * df
)
elif dr == 4:
f = _jacobi(n, alpha, beta, jacobi_arg, 0)
df = _jacobi(n, alpha, beta, jacobi_arg, 1)
d2f = _jacobi(n, alpha, beta, jacobi_arg, 2)
d3f = _jacobi(n, alpha, beta, jacobi_arg, 3)
d4f = _jacobi(n, alpha, beta, jacobi_arg, 4)
out = (

Check warning on line 1288 in desc/basis.py

View check run for this annotation

Codecov / codecov/patch

desc/basis.py#L1282-L1288

Added lines #L1282 - L1288 were not covered by tests
(m - 3) * (m - 2) * (m - 1) * m * r ** jnp.maximum(m - 4, 0) * f
- 8 * m * (2 * m**2 - 3 * m + 1) * r ** jnp.maximum(m - 2, 0) * df
+ 48 * (2 * m**2 + 2 * m + 1) * r**m * d2f
- 128 * (2 * m + 3) * r ** (m + 2) * d3f
+ 256 * r ** (m + 4) * d4f
)
else:
raise NotImplementedError(
"Analytic radial derivatives of zernike polynomials for order>3 "
"Analytic radial derivatives of zernike polynomials for order>4 "
+ "have not been implemented"
)
return s * jnp.where((l - m) % 2 == 0, out, 0)
Expand Down
69 changes: 27 additions & 42 deletions desc/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@
self._rcond = rcond if rcond is not None else "auto"

if (
not np.all(self.grid.nodes[:, 2] == 0)
and not (self.grid.NFP == self.basis.NFP)
np.any(self.grid.nodes[:, 2] != 0)
and self.grid.NFP != self.basis.NFP
and grid.node_pattern != "custom"
):
warnings.warn(
Expand All @@ -71,29 +71,19 @@
)
)

self._derivatives = self._get_derivatives(derivs)
self._sort_derivatives()
self._method = method

self._built = False
self._built_pinv = False
self._set_up()
self._derivatives = self._get_derivatives(derivs)
self._sort_derivatives()

Check warning on line 77 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L76-L77

Added lines #L76 - L77 were not covered by tests
# assign according to logic in setter function
self.method = method

Check warning on line 79 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L79

Added line #L79 was not covered by tests
# assign according to logic in property function
self._matrices = self.matrices

Check warning on line 81 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L81

Added line #L81 was not covered by tests
unalmis marked this conversation as resolved.
Show resolved Hide resolved
if build:
self.build()
if build_pinv:
self.build_pinv()

def _set_up(self):

self.method = self._method
self._matrices = {
"direct1": {
i: {j: {k: {} for k in range(4)} for j in range(4)} for i in range(4)
},
"fft": {i: {j: {} for j in range(4)} for i in range(4)},
"direct2": {i: {} for i in range(4)},
}

def _get_derivatives(self, derivs):
"""Get array of derivatives needed for calculating objective function.

Expand Down Expand Up @@ -471,41 +461,36 @@
return np.zeros(self.grid.num_nodes)

if self.method == "direct1":
A = self.matrices["direct1"][dr][dt][dz]
A = self.matrices["direct1"].get(dr, {}).get(dt, {}).get(dz, {})

Check warning on line 464 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L464

Added line #L464 was not covered by tests
if isinstance(A, dict):
raise ValueError(
colored("Derivative orders are out of initialized bounds", "red")
)
return jnp.matmul(A, c)
return A @ c

Check warning on line 469 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L469

Added line #L469 was not covered by tests

elif self.method == "direct2":
A = self.matrices["fft"][dr][dt]
B = self.matrices["direct2"][dz]

A = self.matrices["fft"].get(dr, {}).get(dt, {})
B = self.matrices["direct2"].get(dz, {})

Check warning on line 473 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L472-L473

Added lines #L472 - L473 were not covered by tests
if isinstance(A, dict) or isinstance(B, dict):
raise ValueError(
colored("Derivative orders are out of initialized bounds", "red")
)
c_mtrx = jnp.zeros((self.num_lm_modes * self.num_n_modes,))
c_mtrx = put(c_mtrx, self.fft_index, c).reshape((-1, self.num_n_modes))

cc = jnp.matmul(A, c_mtrx)
return jnp.matmul(cc, B.T).flatten(order="F")
cc = A @ c_mtrx
return (cc @ B.T).flatten(order="F")

Check warning on line 481 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L480-L481

Added lines #L480 - L481 were not covered by tests

elif self.method == "fft":
A = self.matrices["fft"][dr][dt]
A = self.matrices["fft"].get(dr, {}).get(dt, {})

Check warning on line 484 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L484

Added line #L484 was not covered by tests
if isinstance(A, dict):
raise ValueError(
colored("Derivative orders are out of initialized bounds", "red")
)

# reshape coefficients
c_mtrx = jnp.zeros((self.num_lm_modes * self.num_n_modes,))
c_mtrx = put(c_mtrx, self.fft_index, c).reshape((-1, self.num_n_modes))

# differentiate
c_diff = c_mtrx[:, :: (-1) ** dz] * self.dk**dz * (-1) ** (dz > 1)

# re-format in complex notation
c_real = jnp.pad(
(self.num_z_nodes / 2)
Expand All @@ -520,10 +505,9 @@
jnp.fliplr(jnp.conj(c_real)),
)
)

# transform coefficients
c_fft = jnp.real(jnp.fft.ifft(c_cplx))
return jnp.matmul(A, c_fft).flatten(order="F")
return (A @ c_fft).flatten(order="F")

Check warning on line 510 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L510

Added line #L510 was not covered by tests

def fit(self, x):
"""Transform from physical domain to spectral using weighted least squares fit.
Expand Down Expand Up @@ -733,24 +717,25 @@
# if we actually added derivatives and didn't build them, then its not built
self._built = False
if build:
# we don't update self._built here because it is still built from before
# we don't update self._built here because it is still built from before,
# but it still might have unbuilt matrices from new derivatives
self.build()

@property
def matrices(self):
"""dict: transform matrices such that x=A*c."""
return self.__dict__.setdefault(
"_matrices",
{
if not hasattr(self, "_matrices"):

Check warning on line 727 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L727

Added line #L727 was not covered by tests
# to allow computing of highest order derivative
n = np.amax(self.derivatives) + 1
self._matrices = {

Check warning on line 730 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L729-L730

Added lines #L729 - L730 were not covered by tests
"direct1": {
i: {j: {k: {} for k in range(4)} for j in range(4)}
for i in range(4)
i: {j: {k: {} for k in range(n)} for j in range(n)}
for i in range(n)
},
"fft": {i: {j: {} for j in range(4)} for i in range(4)},
"direct2": {i: {} for i in range(4)},
},
)
"fft": {i: {j: {} for j in range(n)} for i in range(n)},
"direct2": {i: {} for i in range(n)},
}
return self._matrices

Check warning on line 738 in desc/transform.py

View check run for this annotation

Codecov / codecov/patch

desc/transform.py#L738

Added line #L738 was not covered by tests

@property
def num_nodes(self):
Expand Down
Loading