From b53f4181727088a7d2512700fb6f730e9b521c3e Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 24 May 2022 18:20:45 -0400 Subject: [PATCH] soft_one_hot_linspace --- ChangeLog.md | 1 + e3nn_jax/_soft_one_hot_linspace.py | 73 +++++++++++++++++++++++------- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/ChangeLog.md b/ChangeLog.md index 03ee1e01..b819f77b 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Changed - use `dataclasses.dataclass` instead of custom `dataclass` - Get Clebsch-Gordan coefficients from qutip and a change of basis +- Add `start_zero` and `end_zero` arguments to function `soft_one_hot_linspace` ## [0.4.3] - 2022-03-26 ### Added diff --git a/e3nn_jax/_soft_one_hot_linspace.py b/e3nn_jax/_soft_one_hot_linspace.py index 17de6340..9bf86b7c 100644 --- a/e3nn_jax/_soft_one_hot_linspace.py +++ b/e3nn_jax/_soft_one_hot_linspace.py @@ -15,7 +15,15 @@ def sus(x): def soft_one_hot_linspace( - input: jnp.ndarray, *, start: float, end: float, number: int, basis: str = None, cutoff: bool = None + input: jnp.ndarray, + *, + start: float, + end: float, + number: int, + basis: str = None, + cutoff: bool = None, + start_zero: bool = None, + end_zero: bool = None, ): r"""Projection on a basis of functions @@ -52,8 +60,11 @@ def soft_one_hot_linspace( basis : {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'} choice of basis family; note that due to the :math:`1/x` term, ``bessel`` basis does not satisfy the normalization of other basis choices - cutoff : bool - if ``cutoff=True`` then for all :math:`x` outside of the interval defined by ``(start, end)``, :math:`\forall i, \; f_i(x) \approx 0` + start_zero : bool + if ``True``, the first basis function is forced to be zero (or close) at ``start`` + + end_zero : bool + if ``True``, the last basis function is forced to be zero (or close) at ``end`` Returns ------- @@ -82,7 +93,10 @@ def soft_one_hot_linspace( for axs, b in zip(axss, bases): for ax, c in zip(axs, [True, False]): plt.sca(ax) - plt.plot(x, soft_one_hot_linspace(x, start=-0.5, end=1.5, number=4, basis=b, cutoff=c)) + try: + plt.plot(x, soft_one_hot_linspace(x, start=-0.5, end=1.5, number=4, basis=b, cutoff=c)) + except NotImplementedError: + pass plt.plot([-0.5]*2, [-2, 2], 'k-.') plt.plot([1.5]*2, [-2, 2], 'k-.') plt.title(f"{b}" + (" with cutoff" if c else "")) @@ -97,7 +111,10 @@ def soft_one_hot_linspace( for axs, b in zip(axss, bases): for ax, c in zip(axs, [True, False]): plt.sca(ax) - plt.plot(x, soft_one_hot_linspace(x, start=-0.5, end=1.5, number=4, basis=b, cutoff=c).pow(2).sum(1)) + try: + plt.plot(x, soft_one_hot_linspace(x, start=-0.5, end=1.5, number=4, basis=b, cutoff=c).pow(2).sum(1)) + except NotImplementedError: + pass plt.plot([-0.5]*2, [-2, 2], 'k-.') plt.plot([1.5]*2, [-2, 2], 'k-.') plt.title(f"{b}" + (" with cutoff" if c else "")) @@ -105,17 +122,39 @@ def soft_one_hot_linspace( plt.ylim(0, 2) plt.tight_layout() """ - if cutoff not in [True, False]: - raise ValueError("cutoff must be specified") + if cutoff is not None: + assert start_zero is None + assert end_zero is None + start_zero = cutoff + end_zero = cutoff - if not cutoff: - values = jnp.linspace(start, end, number) - step = values[1] - values[0] - else: + del cutoff + + if start_zero not in [True, False]: + raise ValueError("start_zero must be specified") + + if end_zero not in [True, False]: + raise ValueError("end_zero must be specified") + + if start_zero and end_zero: values = jnp.linspace(start, end, number + 2) step = values[1] - values[0] values = values[1:-1] + if start_zero and not end_zero: + values = jnp.linspace(start, end, number + 1) + step = values[1] - values[0] + values = values[1:] + + if not start_zero and end_zero: + values = jnp.linspace(start, end, number + 1) + step = values[1] - values[0] + values = values[:-1] + + if not start_zero and not end_zero: + values = jnp.linspace(start, end, number) + step = values[1] - values[0] + diff = (input[..., None] - values) / step if basis == "gaussian": @@ -129,12 +168,14 @@ def soft_one_hot_linspace( if basis == "fourier": x = (input[..., None] - start) / (end - start) - if not cutoff: + if start_zero and end_zero: + i = jnp.arange(1, number + 1) + return jnp.where((0.0 < x) & (x < 1.0), jnp.sin(jnp.pi * i * x) / jnp.sqrt(0.25 + number / 2), 0.0) + elif not start_zero and not end_zero: i = jnp.arange(0, number) return jnp.cos(jnp.pi * i * x) / jnp.sqrt(0.25 + number / 2) else: - i = jnp.arange(1, number + 1) - return jnp.where((0.0 < x) & (x < 1.0), jnp.sin(jnp.pi * i * x) / jnp.sqrt(0.25 + number / 2), 0.0) + raise NotImplementedError if basis == "bessel": x = input[..., None] - start @@ -142,9 +183,9 @@ def soft_one_hot_linspace( bessel_roots = jnp.arange(1, number + 1) * jnp.pi out = jnp.sqrt(2 / c) * jnp.sin(bessel_roots * x / c) / x - if not cutoff: + if not start_zero and not end_zero: return out else: - return out * ((x / c) < 1) * (0 < x) + raise NotImplementedError raise ValueError(f'basis="{basis}" is not a valid entry')