Releases: e3nn/e3nn-jax
Releases · e3nn/e3nn-jax
2022-01-20
e3nn.reduced_symmetric_tensor_product_basis("0e + 1o + 2e + 3o + 4e + 5o", 3, keep_ir="0e")
Only takes 10 seconds.
Added
e3nn.s2grid_vectors
ande3nn.pad_to_plot_on_s2grid
to help plotting signals on the spheree3nn.util.assert_output_dtype
to check the output dtype of a functione3nn.s2_irreps
is a function to create the irreps of the coefficients of a signal on the spheree3nn.reduced_antisymmetric_tensor_product_basis
to compute the basis of the reduced antisymmetric tensor productIrrepsArray * scalar
is supported if the number of scalars matches the number of irreps
Changed
- Optimize the
reduced_symmetric_tensor_product
. It is now up to 100x faster than the previous implementation. e3nn.from_s2grid
ande3nn.to_s2grid
are now more flexible with input and output irreps, you can skip some l's and have them in any order- [BREAKING]
e3nn.from_s2grid
requires andirreps
argument instead of almax
argument
Fixed
- Increase robusteness of
e3nn.spherical_harmonics
towardsnan
whennormalize=True
Full Changelog: 0.14.0...0.15.0
2022-12-16
Added
IrrepsArray.astype
to cast the underlying arraye3nn.flax.MultiLayerPerceptron
ande3nn.haiku.MultiLayerPerceptron
e3nn.IrrepsArray.from_list(..., dtype)
- Add sparse tensor product as an option in
e3nn.tensor_product
and related functions. It sparsify the clebsch gordan coefficients. It has more inpact whenfused=True
. It is disabled by default because no improvement was observed in the benchmarks. - Add
log_coordinates
along the other parameterizations of SO(3).e3nn.log_coordinates_to_matrix
,e3nn.rand_log_coordinates
, etc.
Fixed
- set dtype for all
jnp.zeros(..., dtype)
calls in the codebase - set dtype for all
jnp.ones(..., dtype)
calls in the codebase
Removed
- [BREAKING]
e3nn.full_tensor_product
in favor ofe3nn.tensor_product
- [BREAKING]
e3nn.FunctionalTensorSquare
in favor ofe3nn.tensor_square
- [BREAKING]
e3nn.TensorSquare
in favor ofe3nn.tensor_square
- [BREAKING]
e3nn.IrrepsArray.cat
in favor ofe3nn.concatenate
- [BREAKING]
e3nn.IrrepsArray.randn
in favor ofe3nn.normal
- [BREAKING]
e3nn.Irreps.randn
in favor ofe3nn.normal
- [BREAKING]
e3nn.Irreps.transform_by_*
in favor ofe3nn.IrrepsArray.transform_by_*
Changed
- moves
BatchNorm
andDropout
toe3nn.haiku
submodule, will remove them from the main module in the future. - move
e3nn.haiku.FullyConnectedTensorProduct
inhaiku
submodule. Undeprecate it because it's faster thane3nn.tensor_product
followed bye3nn.Linear
. This is becauseopteinsum
optimizes the contraction of the two operations.
2022-12-14
Introduce flax
and haiku
submodules.
- port
Linear
- port all modules (Dropout, Batchnorm, ...) to the submodules
Example with Linear
in flax
input = e3nn.normal("2x0e + 3x1e")
linear = e3nn.flax.Linear("3x0e + 1e")
w = linear.init(jax.random.PRNGKey(0), input)
linear.apply(w, input)
Added
e3nn.scatter_sum
to replacee3nn.index_add
.e3nn.index_add
is deprecated.- add
flax
andhaiku
submodules. Plan to migrate all modules toflax
andhaiku
in the future. - Implement
e3nn.flax.Linear
and movee3nn.Linear
ine3nn.haiku.Linear
.
2022-12-07
import e3nn_jax as e3nn
irreps = e3nn.Irreps("0e + 1o")
print(3 * irreps) # prints 3x0e+3x1o
Changed
- [BREAKING]
3 * e3nn.Irreps("0e + 1o")
now returns3x0e + 3x1o
instead of1x0e + 1x1o + 1x0e + 1x1o + 1x0e + 1x1o
- [BREAKING] in Linear, renamed
num_weights
tonum_indexed_weights
because it was confusing.
Added
e3nn.Irreps("3x0e + 6x1o") // 3
returns1x0e + 2x1o
Fixed
s2grid
is now jitable
2022-11-16
New method regroup
aim to replace .sort
and .simplify
most of the time.
Irreps("1e + 0e + 1e + 0x2e").regroup() # 1x0e+2x1e
New default behavior of tensor_product
e3nn.tensor_product("0e + 1o + 0e", "1o + 1o") # version<0.12.0
# 1x0e+1x0e+1x1o+1x1o+1x1o+1x1o+1x1e+1x1e+1x2e+1x2e
e3nn.tensor_product("0e + 1o + 0e", "1o + 1o") # version==0.12.0
# 2x0e+4x1o+2x1e+2x2e
e3nn.tensor_product("0e + 1o + 0e", "1o + 1o", regroup_output=False) # version==0.12.0
# 1x0e+1x0e+1x1o+1x1o+1x1o+1x1o+1x1e+1x1e+1x2e+1x2e
Added
e3nn.Irreps.regroup
ande3nn.IrrepsArray.regroup
to regroup irreps. Equivalent tosort
followed bysimplify
.- add
regroup_output
parameter toe3nn.tensor_product
ande3nn.tensor_square
to regroup the output irreps.
Changed
e3nn.IrrepsArray.convert
is now private (e3nn.IrrepsArray._convert
) because it's recommended to other methods instead.- breaking change use
input.regroup()
ine3nn.Linear
which can change the structure of the parameters dictionary. - breaking change
regroup_output
isTrue
by default ine3nn.tensor_product
ande3nn.tensor_square
. - To facilitate debugging, if not
key
is provided toe3nn.normal
it will use the hash of the irreps. - breaking change changed normalization of
e3nn.tensor_square
in the case ofnormalized_input=True
Removed
- Deprecate
e3nn.TensorSquare
2022-11-13
e3nn.Linear
can create different weights that are then selected by an index.
z = jnp.array([0, 0, 1, 3]) # [num_nodes]
x = e3nn.IrrepsArray("8x0e + 8x1o", _) # [num_nodes, irreps]
e3nn.Linear("16x0e + 16x1o", num_weights=4)(z, x)
Added
e3nn.Linear
now supports integer "weights" inputs.e3nn.Linear
now supportsname
argument.- Add
.dtype
toIrrepsArray
to get the dtype of the underlying array.
Changed
e3nn.MultiLayerPerceptron
names its layerslinear_0
,linear_1
, etc.
2022-11-08
This release allow e3nn.Linear
to get some scalars as input that are blent with the parameters.
This allow a clean and strait forward implementation of depth-wise message passing (uvu
convolution).
def message_passing_convolution(
node_feats: e3nn.IrrepsArray, # [n_nodes, irreps]
edge_attrs: e3nn.IrrepsArray, # [n_edges, irreps]
edge_feats: e3nn.IrrepsArray, # [n_edges, irreps]
...
target_irreps: e3nn.Irreps,
) -> e3nn.IrrepsArray:
messages = e3nn.Linear(target_irreps)(
e3nn.MultiLayerPerceptron(3 * [64], activation)(edge_feats), # [n_edges, 64]
e3nn.tensor_product(node_feats[senders], edge_attrs), # [n_edges, irreps]
) # [n_edges, irreps]
zeros = e3nn.IrrepsArray.zeros(messages.irreps, (node_feats.shape[0],))
node_feats = zeros.at[receivers].add(messages) / jnp.sqrt(
avg_num_neighbors
) # [n_nodes, irreps]
return node_feats
def depthwise_convolution(
node_feats: e3nn.IrrepsArray, # [n_nodes, channel_in, irreps]
...
channel_out: int,
target_irreps: e3nn.Irreps,
) -> e3nn.IrrepsArray:
node_feats = e3nn.Linear(node_feats.irreps, channel_out)(node_feats)
node_feats = hk.vmap(
lambda x: message_passing_convolution(
x,
edge_attrs,
edge_feats,
senders,
receivers,
avg_num_neighbors,
target_irreps,
activation,
),
in_axes=1,
out_axes=1,
split_rng=False,
)(node_feats)
node_feats = e3nn.Linear(target_irreps, channel_out)(node_feats)
return node_feats
Added
- s2grid:
e3nn.from_s2grid
ande3nn.to_s2grid
thanks to @songk42 for the contribution - argument
max_order: int
to functionreduced_tensor_product_basis
to be able to limit the polynomial order of the basis MultiLayerPerceptron
acceptsIrrepsArray
as input and outpute3nn.Linear
accepts optional weights as arguments that will be internally mixed with the free parameters. Very usefyul to implement the depthwise convolution
Changed
- breaking change
e3nn.normal
has a new argument to get normalized vectors. - breaking change
e3nn.tensor_square
now distinguishes betweennormalization=norm
andnormalized_input=True
.
2022-10-24
Added
e3nn.SymmetricTensorProduct
operation: a parameterized version ofx + x^2 + x^3 + ...
.e3nn.soft_envelope
a smoothC^inf
envelope radial function.e3nn.tensor_square
2022-10-05
Added
Irrep.generators
andIrreps.generators
functions to get the generators of the representations.e3nn.bessel
functionslice_by_mul
,slice_by_dim
andslice_by_chunk
functions toIrreps
andIrrepsArray
Changed
- breaking change
e3nn.soft_one_hot_linspace
does not supportbessel
anymore. Usee3nn.bessel
instead. e3nn.gate
is now more flexible of the input format, see examples in the docstring.
Removed
- breaking change
IrrepsArray.split
2022-09-29
Fixed
- fix
IrrepsArray.zeros().at[...].add