Skip to content

Commit

Permalink
Merge pull request #139 from ucl-bug/equinox
Browse files Browse the repository at this point in the history
modules moved to equinox
  • Loading branch information
astanziola authored Nov 24, 2023
2 parents ee5ba44 + c124817 commit a15d053
Show file tree
Hide file tree
Showing 47 changed files with 2,816 additions and 2,451 deletions.
20 changes: 13 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,30 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased]
### Changed
- The Quickstart tutorial has been upgdated.
- The `@new_discretization` decorator has been renamed `@discretization`
- The property `Field.ndim` has now been moved into `Field.domain.ndim`, as it is fundamentally a property of the domain
- Before, `OnGrid` fields were able to automatically add an extra dimension if needed at initialization. This however can easily clash with some of the internal operations of jax during compliation. This is now not possible, use `.from_grid` instead, which implements the same functionality.
- The `init_params` function now will inherit the default parameters from its operator, to remove any source of ambiguity. This means that it should not have any default values, and an error is raised if it does.

### Removed
- The `__about__` file has been removed, as it is redundant
- The function `params_map` is removed, use `jax.tree_util.tree_map` instead.
- Operators are now expected to return only their outputs, and not parameters. If you need to get the parameters of an operator use its `default_params` method. To minimize problems for packages relying on `jaxdf`, in this release the outputs of an `operator` are filtered to keep only the first one. This will change soon to allow the user to return arbitrary PyTrees.

### Added
- The new `operator.abstract` decorator can be used to define an unimplemented operator, with the goal of specifying input arguments and docstrings.
- `Linear` fields are now defined as equal if they have the same set of parameters.
- `Ongrid` fields now have the property `.add_dim`, which adds an extra tailing dimension to its parameters. The method returns a new field.
- JaxDF `Field`s are now based on [equinox](https://github.com/patrick-kidger/equinox). In theory, this should allow to use `jaxdf` with all the [scientific libraries for the jax ecosystem](https://github.com/patrick-kidger/equinox#see-also-other-libraries-in-the-jax-ecosystem). In practice, please raise an issue when you encounter one of the inevitable bugs :)
- The new `operator.abstract` decorator can be used to define an unimplemented operator, for specifying input arguments and docstrings.
- `Linear` fields are now defined as equal if they have the same set of parameters and the same `Domain`.
- `Ongrid` fields now have the method `.add_dim()`, which adds an extra tailing dimension to its parameters. This **is not** an in-place update: the method returns a new field.
- The function `jaxdf.util.get_implemented` is now exposed to the user.
- Added `laplacian` operator for `FiniteDifferences` fields.
- JaxDF now uses standard Python logging. To set the logging level, use `jaxdf.logger.set_logging_level`, for example `jaxdf.logger.set_logging_level("DEBUG")`. The default level is `INFO`.
- Fields have now a handy property `` which is an alias for `.params`
- `Continuous` and `Linear` fields now have the `.is_complex` property
- `Field` and `Domain` are now `JaxDFModules`s, which are based on from `equinox.Module`. They are entirely equivalent to `equinox.Module`, but have the extra `.replace` method that is used to update a single field.

### Deprecated
- The property `.is_field_complex` is now deprecated in favor of `.is_complex`. Same argument for `.is_real`
- `Field.get_field` is now deprecated in favor of the `__call__` metho.
- The property `.is_field_complex` is now deprecated in favor of `.is_complex`. Same goes for `.is_real`.
- `Field.get_field` is now deprecated in favor of the `__call__` method.
- The `@discretization` decorator is deprecated, as now `Fields` are `equinox` modules. It is just not needed now, and until removed it will act as a simple pass-trough

### Fixed
- `OnGrid.from_grid` now automatically adds a dimension at the end of the array for scalar fields, if needed
Expand Down
10 changes: 0 additions & 10 deletions docs/exceptions.md

This file was deleted.

303 changes: 164 additions & 139 deletions docs/notebooks/api_discretization.ipynb

Large diffs are not rendered by default.

18 changes: 5 additions & 13 deletions docs/notebooks/example_1_paper.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions docs/notebooks/pinn_burgers.ipynb

Large diffs are not rendered by default.

169 changes: 119 additions & 50 deletions docs/notebooks/quickstart.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/operators/differential.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ $$
$$

{{ implementations('jaxdf.operators.differential', 'diag_jacobian') }}

## `gradient`

Given a field $u$, it returns the vector field
Expand Down
76 changes: 0 additions & 76 deletions docs/operators/magic.md

This file was deleted.

140 changes: 69 additions & 71 deletions jaxdf/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def reflection_conv(kernel: jnp.ndarray,
array: jnp.ndarray,
reverse: bool = True) -> jnp.ndarray:
r"""Convolves an array with a kernel, using reflection padding.
r"""Convolves an array with a kernel, using reflection padding.
The kernel is supposed to have the same number of dimensions as the array.
Args:
Expand All @@ -20,21 +20,21 @@ def reflection_conv(kernel: jnp.ndarray,
Returns:
The convolved array.
"""
# Reflection padding the array appropriately
pad = [((x - 1) // 2, (x - 1) // 2) for x in kernel.shape]
f = jnp.pad(array, pad, mode="wrap")
# Reflection padding the array appropriately
pad = [((x - 1) // 2, (x - 1) // 2) for x in kernel.shape]
f = jnp.pad(array, pad, mode="wrap")

if reverse:
# Reverse the kernel over all axes
kernel = jnp.flip(kernel, axis=tuple(range(kernel.ndim)))
if reverse:
# Reverse the kernel over all axes
kernel = jnp.flip(kernel, axis=tuple(range(kernel.ndim)))

# Apply kernel
return jsp.signal.convolve(f, kernel, mode="valid")
# Apply kernel
return jsp.signal.convolve(f, kernel, mode="valid")


def bubble_sort_abs_value(
points_list: List[Union[float, int]]) -> List[Union[float, int]]:
r"""Sorts a sequence of grid points by their absolute value.
points_list: List[Union[float, int]]) -> List[Union[float, int]]:
r"""Sorts a sequence of grid points by their absolute value.
Sorting is done __in place__. This function is written with numpy, so it can't
be transformed by JAX.
Expand All @@ -54,26 +54,25 @@ def bubble_sort_abs_value(
The sorted grid points.
"""

for i in range(len(points_list)):
for j in range(0, len(points_list) - i - 1):
magnitude_condition = abs(points_list[j]) > abs(points_list[j + 1])
same_mag_condition = abs(points_list[j]) == abs(points_list[j + 1])
sign_condition = np.sign(points_list[j]) < np.sign(points_list[j +
1])
if magnitude_condition or (same_mag_condition and sign_condition):
temp = points_list[j]
points_list[j] = points_list[j + 1]
points_list[j + 1] = temp
for i in range(len(points_list)):
for j in range(0, len(points_list) - i - 1):
magnitude_condition = abs(points_list[j]) > abs(points_list[j + 1])
same_mag_condition = abs(points_list[j]) == abs(points_list[j + 1])
sign_condition = np.sign(points_list[j]) < np.sign(points_list[j + 1])
if magnitude_condition or (same_mag_condition and sign_condition):
temp = points_list[j]
points_list[j] = points_list[j + 1]
points_list[j + 1] = temp

return points_list
return points_list


# TODO (astanziola): This fails on mypy for some reason, but can't work out how to fix.
@no_type_check
def fd_coefficients_fornberg(
order: int, grid_points: List[Union[float, int]],
x0: Union[float, int]) -> Tuple[List[None], List[Union[float, int]]]:
r"""Generate finite difference stencils for a given order and grid points, using
order: int, grid_points: List[Union[float, int]],
x0: Union[float, int]) -> Tuple[List[None], List[Union[float, int]]]:
r"""Generate finite difference stencils for a given order and grid points, using
the Fornberg algorithm described in [[Fornberg, 2018]](https://web.njit.edu/~jiang/math712/fornberg.pdf).
The grid points can be placed in any order, can be at arbitrary locations (for example, to implemente staggered
Expand All @@ -99,49 +98,48 @@ def fd_coefficients_fornberg(
Returns:
The stencil and the grid points where the stencil is evaluated.
"""
# from Generation of Finite Difference Formulas on Arbitrarily Spaced Grids
# Bengt Fornberg, 1998
# https://web.njit.edu/~jiang/math712/fornberg.pdf
M = order
N = len(grid_points) - 1

# Sort the grid points
alpha = bubble_sort_abs_value(grid_points)
delta = dict() # key: (m,n,v)
delta[(0, 0, 0)] = 1.0
c1 = 1.0

for n in range(1, N + 1):
c2 = 1.0
for v in range(n):
c3 = alpha[n] - alpha[v]
c2 = c2 * c3
if n < M:
delta[(n, n - 1, v)] = 0.0
for m in range(min([n, M]) + 1):
d1 = delta[(m, n - 1, v)] if (m, n - 1,
v) in delta.keys() else 0.0
d2 = (delta[(m - 1, n - 1, v)] if
(m - 1, n - 1, v) in delta.keys() else 0.0)
delta[(m, n, v)] = ((alpha[n] - x0) * d1 - m * d2) / c3

for m in range(min([n, M]) + 1):
d1 = (delta[(m - 1, n - 1, n - 1)] if
(m - 1, n - 1, n - 1) in delta.keys() else 0.0)
d2 = delta[(m, n - 1, n - 1)] if (m, n - 1,
n - 1) in delta.keys() else 0.0
delta[(m, n, n)] = (c1 / c2) * (m * d1 - (alpha[n - 1] - x0) * d2)
c1 = c2

# Extract the delta with m = M and n = N
coeffs = [None] * (N + 1)
for key in delta:
if key[0] == M and key[1] == N:
coeffs[key[2]] = delta[key]

# sort coefficeient and alpha by alpha
idx = np.argsort(alpha)
alpha = np.take_along_axis(np.asarray(alpha), idx, axis=-1)
coeffs = np.take_along_axis(np.asarray(coeffs), idx, axis=-1)

return coeffs, alpha
# from Generation of Finite Difference Formulas on Arbitrarily Spaced Grids
# Bengt Fornberg, 1998
# https://web.njit.edu/~jiang/math712/fornberg.pdf
M = order
N = len(grid_points) - 1

# Sort the grid points
alpha = bubble_sort_abs_value(grid_points)
delta = dict() # key: (m,n,v)
delta[(0, 0, 0)] = 1.0
c1 = 1.0

for n in range(1, N + 1):
c2 = 1.0
for v in range(n):
c3 = alpha[n] - alpha[v]
c2 = c2 * c3
if n < M:
delta[(n, n - 1, v)] = 0.0
for m in range(min([n, M]) + 1):
d1 = delta[(m, n - 1, v)] if (m, n - 1, v) in delta.keys() else 0.0
d2 = (delta[(m - 1, n - 1, v)] if
(m - 1, n - 1, v) in delta.keys() else 0.0)
delta[(m, n, v)] = ((alpha[n] - x0) * d1 - m * d2) / c3

for m in range(min([n, M]) + 1):
d1 = (delta[(m - 1, n - 1, n - 1)] if
(m - 1, n - 1, n - 1) in delta.keys() else 0.0)
d2 = delta[(m, n - 1, n - 1)] if (m, n - 1,
n - 1) in delta.keys() else 0.0
delta[(m, n, n)] = (c1 / c2) * (m * d1 - (alpha[n - 1] - x0) * d2)
c1 = c2

# Extract the delta with m = M and n = N
coeffs = [None] * (N + 1)
for key in delta:
if key[0] == M and key[1] == N:
coeffs[key[2]] = delta[key]

# sort coefficeient and alpha by alpha
idx = np.argsort(alpha)
alpha = np.take_along_axis(np.asarray(alpha), idx, axis=-1)
coeffs = np.take_along_axis(np.asarray(coeffs), idx, axis=-1)

return coeffs, alpha
Loading

0 comments on commit a15d053

Please sign in to comment.