Skip to content

Commit

Permalink
soft_one_hot_linspace
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed May 24, 2022
1 parent 24bcd8f commit b53f418
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 16 deletions.
1 change: 1 addition & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Changed
- use `dataclasses.dataclass` instead of custom `dataclass`
- Get Clebsch-Gordan coefficients from qutip and a change of basis
- Add `start_zero` and `end_zero` arguments to function `soft_one_hot_linspace`

## [0.4.3] - 2022-03-26
### Added
Expand Down
73 changes: 57 additions & 16 deletions e3nn_jax/_soft_one_hot_linspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ def sus(x):


def soft_one_hot_linspace(
input: jnp.ndarray, *, start: float, end: float, number: int, basis: str = None, cutoff: bool = None
input: jnp.ndarray,
*,
start: float,
end: float,
number: int,
basis: str = None,
cutoff: bool = None,
start_zero: bool = None,
end_zero: bool = None,
):
r"""Projection on a basis of functions
Expand Down Expand Up @@ -52,8 +60,11 @@ def soft_one_hot_linspace(
basis : {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'}
choice of basis family; note that due to the :math:`1/x` term, ``bessel`` basis does not satisfy the normalization of other basis choices
cutoff : bool
if ``cutoff=True`` then for all :math:`x` outside of the interval defined by ``(start, end)``, :math:`\forall i, \; f_i(x) \approx 0`
start_zero : bool
if ``True``, the first basis function is forced to be zero (or close) at ``start``
end_zero : bool
if ``True``, the last basis function is forced to be zero (or close) at ``end``
Returns
-------
Expand Down Expand Up @@ -82,7 +93,10 @@ def soft_one_hot_linspace(
for axs, b in zip(axss, bases):
for ax, c in zip(axs, [True, False]):
plt.sca(ax)
plt.plot(x, soft_one_hot_linspace(x, start=-0.5, end=1.5, number=4, basis=b, cutoff=c))
try:
plt.plot(x, soft_one_hot_linspace(x, start=-0.5, end=1.5, number=4, basis=b, cutoff=c))
except NotImplementedError:
pass
plt.plot([-0.5]*2, [-2, 2], 'k-.')
plt.plot([1.5]*2, [-2, 2], 'k-.')
plt.title(f"{b}" + (" with cutoff" if c else ""))
Expand All @@ -97,25 +111,50 @@ def soft_one_hot_linspace(
for axs, b in zip(axss, bases):
for ax, c in zip(axs, [True, False]):
plt.sca(ax)
plt.plot(x, soft_one_hot_linspace(x, start=-0.5, end=1.5, number=4, basis=b, cutoff=c).pow(2).sum(1))
try:
plt.plot(x, soft_one_hot_linspace(x, start=-0.5, end=1.5, number=4, basis=b, cutoff=c).pow(2).sum(1))
except NotImplementedError:
pass
plt.plot([-0.5]*2, [-2, 2], 'k-.')
plt.plot([1.5]*2, [-2, 2], 'k-.')
plt.title(f"{b}" + (" with cutoff" if c else ""))
plt.ylim(0, 2)
plt.tight_layout()
"""
if cutoff not in [True, False]:
raise ValueError("cutoff must be specified")
if cutoff is not None:
assert start_zero is None
assert end_zero is None
start_zero = cutoff
end_zero = cutoff

if not cutoff:
values = jnp.linspace(start, end, number)
step = values[1] - values[0]
else:
del cutoff

if start_zero not in [True, False]:
raise ValueError("start_zero must be specified")

if end_zero not in [True, False]:
raise ValueError("end_zero must be specified")

if start_zero and end_zero:
values = jnp.linspace(start, end, number + 2)
step = values[1] - values[0]
values = values[1:-1]

if start_zero and not end_zero:
values = jnp.linspace(start, end, number + 1)
step = values[1] - values[0]
values = values[1:]

if not start_zero and end_zero:
values = jnp.linspace(start, end, number + 1)
step = values[1] - values[0]
values = values[:-1]

if not start_zero and not end_zero:
values = jnp.linspace(start, end, number)
step = values[1] - values[0]

diff = (input[..., None] - values) / step

if basis == "gaussian":
Expand All @@ -129,22 +168,24 @@ def soft_one_hot_linspace(

if basis == "fourier":
x = (input[..., None] - start) / (end - start)
if not cutoff:
if start_zero and end_zero:
i = jnp.arange(1, number + 1)
return jnp.where((0.0 < x) & (x < 1.0), jnp.sin(jnp.pi * i * x) / jnp.sqrt(0.25 + number / 2), 0.0)
elif not start_zero and not end_zero:
i = jnp.arange(0, number)
return jnp.cos(jnp.pi * i * x) / jnp.sqrt(0.25 + number / 2)
else:
i = jnp.arange(1, number + 1)
return jnp.where((0.0 < x) & (x < 1.0), jnp.sin(jnp.pi * i * x) / jnp.sqrt(0.25 + number / 2), 0.0)
raise NotImplementedError

if basis == "bessel":
x = input[..., None] - start
c = end - start
bessel_roots = jnp.arange(1, number + 1) * jnp.pi
out = jnp.sqrt(2 / c) * jnp.sin(bessel_roots * x / c) / x

if not cutoff:
if not start_zero and not end_zero:
return out
else:
return out * ((x / c) < 1) * (0 < x)
raise NotImplementedError

raise ValueError(f'basis="{basis}" is not a valid entry')

0 comments on commit b53f418

Please sign in to comment.