2022-11-08
This release allow e3nn.Linear
to get some scalars as input that are blent with the parameters.
This allow a clean and strait forward implementation of depth-wise message passing (uvu
convolution).
def message_passing_convolution(
node_feats: e3nn.IrrepsArray, # [n_nodes, irreps]
edge_attrs: e3nn.IrrepsArray, # [n_edges, irreps]
edge_feats: e3nn.IrrepsArray, # [n_edges, irreps]
...
target_irreps: e3nn.Irreps,
) -> e3nn.IrrepsArray:
messages = e3nn.Linear(target_irreps)(
e3nn.MultiLayerPerceptron(3 * [64], activation)(edge_feats), # [n_edges, 64]
e3nn.tensor_product(node_feats[senders], edge_attrs), # [n_edges, irreps]
) # [n_edges, irreps]
zeros = e3nn.IrrepsArray.zeros(messages.irreps, (node_feats.shape[0],))
node_feats = zeros.at[receivers].add(messages) / jnp.sqrt(
avg_num_neighbors
) # [n_nodes, irreps]
return node_feats
def depthwise_convolution(
node_feats: e3nn.IrrepsArray, # [n_nodes, channel_in, irreps]
...
channel_out: int,
target_irreps: e3nn.Irreps,
) -> e3nn.IrrepsArray:
node_feats = e3nn.Linear(node_feats.irreps, channel_out)(node_feats)
node_feats = hk.vmap(
lambda x: message_passing_convolution(
x,
edge_attrs,
edge_feats,
senders,
receivers,
avg_num_neighbors,
target_irreps,
activation,
),
in_axes=1,
out_axes=1,
split_rng=False,
)(node_feats)
node_feats = e3nn.Linear(target_irreps, channel_out)(node_feats)
return node_feats
Added
- s2grid:
e3nn.from_s2grid
ande3nn.to_s2grid
thanks to @songk42 for the contribution - argument
max_order: int
to functionreduced_tensor_product_basis
to be able to limit the polynomial order of the basis MultiLayerPerceptron
acceptsIrrepsArray
as input and outpute3nn.Linear
accepts optional weights as arguments that will be internally mixed with the free parameters. Very usefyul to implement the depthwise convolution
Changed
- breaking change
e3nn.normal
has a new argument to get normalized vectors. - breaking change
e3nn.tensor_square
now distinguishes betweennormalization=norm
andnormalized_input=True
.