Skip to content

Releases: e3nn/e3nn-jax

2022-01-20

21 Jan 03:04
Compare
Choose a tag to compare
e3nn.reduced_symmetric_tensor_product_basis("0e + 1o + 2e + 3o + 4e + 5o", 3, keep_ir="0e")

Only takes 10 seconds.

Added

  • e3nn.s2grid_vectors and e3nn.pad_to_plot_on_s2grid to help plotting signals on the sphere
  • e3nn.util.assert_output_dtype to check the output dtype of a function
  • e3nn.s2_irreps is a function to create the irreps of the coefficients of a signal on the sphere
  • e3nn.reduced_antisymmetric_tensor_product_basis to compute the basis of the reduced antisymmetric tensor product
  • IrrepsArray * scalar is supported if the number of scalars matches the number of irreps

Changed

  • Optimize the reduced_symmetric_tensor_product. It is now up to 100x faster than the previous implementation.
  • e3nn.from_s2grid and e3nn.to_s2grid are now more flexible with input and output irreps, you can skip some l's and have them in any order
  • [BREAKING] e3nn.from_s2grid requires and irreps argument instead of a lmax argument

Fixed

  • Increase robusteness of e3nn.spherical_harmonics towards nan when normalize=True

Full Changelog: 0.14.0...0.15.0

2022-12-16

16 Dec 22:24
Compare
Choose a tag to compare

Added

  • IrrepsArray.astype to cast the underlying array
  • e3nn.flax.MultiLayerPerceptron and e3nn.haiku.MultiLayerPerceptron
  • e3nn.IrrepsArray.from_list(..., dtype)
  • Add sparse tensor product as an option in e3nn.tensor_product and related functions. It sparsify the clebsch gordan coefficients. It has more inpact when fused=True. It is disabled by default because no improvement was observed in the benchmarks.
  • Add log_coordinates along the other parameterizations of SO(3). e3nn.log_coordinates_to_matrix, e3nn.rand_log_coordinates, etc.

Fixed

  • set dtype for all jnp.zeros(..., dtype) calls in the codebase
  • set dtype for all jnp.ones(..., dtype) calls in the codebase

Removed

  • [BREAKING] e3nn.full_tensor_product in favor of e3nn.tensor_product
  • [BREAKING] e3nn.FunctionalTensorSquare in favor of e3nn.tensor_square
  • [BREAKING] e3nn.TensorSquare in favor of e3nn.tensor_square
  • [BREAKING] e3nn.IrrepsArray.cat in favor of e3nn.concatenate
  • [BREAKING] e3nn.IrrepsArray.randn in favor of e3nn.normal
  • [BREAKING] e3nn.Irreps.randn in favor of e3nn.normal
  • [BREAKING] e3nn.Irreps.transform_by_* in favor of e3nn.IrrepsArray.transform_by_*

Changed

  • moves BatchNorm and Dropout to e3nn.haiku submodule, will remove them from the main module in the future.
  • move e3nn.haiku.FullyConnectedTensorProduct in haiku submodule. Undeprecate it because it's faster than e3nn.tensor_product followed by e3nn.Linear. This is because opteinsum optimizes the contraction of the two operations.

2022-12-14

14 Dec 15:22
Compare
Choose a tag to compare

Introduce flax and haiku submodules.

  • port Linear
  • port all modules (Dropout, Batchnorm, ...) to the submodules

Example with Linear in flax

input = e3nn.normal("2x0e + 3x1e")
linear = e3nn.flax.Linear("3x0e + 1e")
w = linear.init(jax.random.PRNGKey(0), input)
linear.apply(w, input)

Added

  • e3nn.scatter_sum to replace e3nn.index_add. e3nn.index_add is deprecated.
  • add flax and haiku submodules. Plan to migrate all modules to flax and haiku in the future.
  • Implement e3nn.flax.Linear and move e3nn.Linear in e3nn.haiku.Linear.

2022-12-07

07 Dec 23:49
Compare
Choose a tag to compare
import e3nn_jax as e3nn

irreps = e3nn.Irreps("0e + 1o")

print(3 * irreps)  # prints 3x0e+3x1o

Changed

  • [BREAKING] 3 * e3nn.Irreps("0e + 1o") now returns 3x0e + 3x1o instead of 1x0e + 1x1o + 1x0e + 1x1o + 1x0e + 1x1o
  • [BREAKING] in Linear, renamed num_weights to num_indexed_weights because it was confusing.

Added

  • e3nn.Irreps("3x0e + 6x1o") // 3 returns 1x0e + 2x1o

Fixed

  • s2grid is now jitable

2022-11-16

16 Nov 23:14
Compare
Choose a tag to compare

New method regroup aim to replace .sort and .simplify most of the time.

Irreps("1e + 0e + 1e + 0x2e").regroup()  # 1x0e+2x1e

New default behavior of tensor_product

e3nn.tensor_product("0e + 1o + 0e", "1o + 1o")  # version<0.12.0
# 1x0e+1x0e+1x1o+1x1o+1x1o+1x1o+1x1e+1x1e+1x2e+1x2e

e3nn.tensor_product("0e + 1o + 0e", "1o + 1o")  # version==0.12.0
# 2x0e+4x1o+2x1e+2x2e

e3nn.tensor_product("0e + 1o + 0e", "1o + 1o", regroup_output=False)  # version==0.12.0
# 1x0e+1x0e+1x1o+1x1o+1x1o+1x1o+1x1e+1x1e+1x2e+1x2e

Added

  • e3nn.Irreps.regroup and e3nn.IrrepsArray.regroup to regroup irreps. Equivalent to sort followed by simplify.
  • add regroup_output parameter to e3nn.tensor_product and e3nn.tensor_square to regroup the output irreps.

Changed

  • e3nn.IrrepsArray.convert is now private (e3nn.IrrepsArray._convert) because it's recommended to other methods instead.
  • breaking change use input.regroup() in e3nn.Linear which can change the structure of the parameters dictionary.
  • breaking change regroup_output is True by default in e3nn.tensor_product and e3nn.tensor_square.
  • To facilitate debugging, if not key is provided to e3nn.normal it will use the hash of the irreps.
  • breaking change changed normalization of e3nn.tensor_square in the case of normalized_input=True

Removed

  • Deprecate e3nn.TensorSquare

2022-11-13

13 Nov 18:05
Compare
Choose a tag to compare

e3nn.Linear can create different weights that are then selected by an index.

z = jnp.array([0, 0, 1, 3])  # [num_nodes]
x = e3nn.IrrepsArray("8x0e + 8x1o", _)  # [num_nodes, irreps]
e3nn.Linear("16x0e + 16x1o", num_weights=4)(z, x)

Added

  • e3nn.Linear now supports integer "weights" inputs.
  • e3nn.Linear now supports name argument.
  • Add .dtype to IrrepsArray to get the dtype of the underlying array.

Changed

  • e3nn.MultiLayerPerceptron names its layers linear_0, linear_1, etc.

2022-11-08

08 Nov 21:48
Compare
Choose a tag to compare

This release allow e3nn.Linear to get some scalars as input that are blent with the parameters.
This allow a clean and strait forward implementation of depth-wise message passing (uvu convolution).

def message_passing_convolution(
    node_feats: e3nn.IrrepsArray,  # [n_nodes, irreps]
    edge_attrs: e3nn.IrrepsArray,  # [n_edges, irreps]
    edge_feats: e3nn.IrrepsArray,  # [n_edges, irreps]
    ...
    target_irreps: e3nn.Irreps,
) -> e3nn.IrrepsArray:
    messages = e3nn.Linear(target_irreps)(
        e3nn.MultiLayerPerceptron(3 * [64], activation)(edge_feats),  # [n_edges, 64]
        e3nn.tensor_product(node_feats[senders], edge_attrs),  # [n_edges, irreps]
    )  # [n_edges, irreps]

    zeros = e3nn.IrrepsArray.zeros(messages.irreps, (node_feats.shape[0],))
    node_feats = zeros.at[receivers].add(messages) / jnp.sqrt(
        avg_num_neighbors
    )  # [n_nodes, irreps]

    return node_feats

def depthwise_convolution(
    node_feats: e3nn.IrrepsArray,  # [n_nodes, channel_in, irreps]
    ...
    channel_out: int,
    target_irreps: e3nn.Irreps,
) -> e3nn.IrrepsArray:
    node_feats = e3nn.Linear(node_feats.irreps, channel_out)(node_feats)

    node_feats = hk.vmap(
        lambda x: message_passing_convolution(
            x,
            edge_attrs,
            edge_feats,
            senders,
            receivers,
            avg_num_neighbors,
            target_irreps,
            activation,
        ),
        in_axes=1,
        out_axes=1,
        split_rng=False,
    )(node_feats)

    node_feats = e3nn.Linear(target_irreps, channel_out)(node_feats)
    return node_feats

Added

  • s2grid: e3nn.from_s2grid and e3nn.to_s2grid thanks to @songk42 for the contribution
  • argument max_order: int to function reduced_tensor_product_basis to be able to limit the polynomial order of the basis
  • MultiLayerPerceptron accepts IrrepsArray as input and output
  • e3nn.Linear accepts optional weights as arguments that will be internally mixed with the free parameters. Very usefyul to implement the depthwise convolution

Changed

  • breaking change e3nn.normal has a new argument to get normalized vectors.
  • breaking change e3nn.tensor_square now distinguishes between normalization=norm and normalized_input=True.

2022-10-24

24 Oct 22:22
Compare
Choose a tag to compare

Added

  • e3nn.SymmetricTensorProduct operation: a parameterized version of x + x^2 + x^3 + ....
  • e3nn.soft_envelope a smooth C^inf envelope radial function.
  • e3nn.tensor_square

2022-10-05

05 Oct 14:49
Compare
Choose a tag to compare

Added

  • Irrep.generators and Irreps.generators functions to get the generators of the representations.
  • e3nn.bessel function
  • slice_by_mul, slice_by_dim and slice_by_chunk functions to Irreps and IrrepsArray

Changed

  • breaking change e3nn.soft_one_hot_linspace does not support bessel anymore. Use e3nn.bessel instead.
  • e3nn.gate is now more flexible of the input format, see examples in the docstring.

Removed

  • breaking change IrrepsArray.split

2022-09-29

29 Sep 22:54
Compare
Choose a tag to compare

Fixed

  • fix IrrepsArray.zeros().at[...].add