Skip to content

2022-11-08

Compare
Choose a tag to compare
@mariogeiger mariogeiger released this 08 Nov 21:48
· 473 commits to main since this release

This release allow e3nn.Linear to get some scalars as input that are blent with the parameters.
This allow a clean and strait forward implementation of depth-wise message passing (uvu convolution).

def message_passing_convolution(
    node_feats: e3nn.IrrepsArray,  # [n_nodes, irreps]
    edge_attrs: e3nn.IrrepsArray,  # [n_edges, irreps]
    edge_feats: e3nn.IrrepsArray,  # [n_edges, irreps]
    ...
    target_irreps: e3nn.Irreps,
) -> e3nn.IrrepsArray:
    messages = e3nn.Linear(target_irreps)(
        e3nn.MultiLayerPerceptron(3 * [64], activation)(edge_feats),  # [n_edges, 64]
        e3nn.tensor_product(node_feats[senders], edge_attrs),  # [n_edges, irreps]
    )  # [n_edges, irreps]

    zeros = e3nn.IrrepsArray.zeros(messages.irreps, (node_feats.shape[0],))
    node_feats = zeros.at[receivers].add(messages) / jnp.sqrt(
        avg_num_neighbors
    )  # [n_nodes, irreps]

    return node_feats

def depthwise_convolution(
    node_feats: e3nn.IrrepsArray,  # [n_nodes, channel_in, irreps]
    ...
    channel_out: int,
    target_irreps: e3nn.Irreps,
) -> e3nn.IrrepsArray:
    node_feats = e3nn.Linear(node_feats.irreps, channel_out)(node_feats)

    node_feats = hk.vmap(
        lambda x: message_passing_convolution(
            x,
            edge_attrs,
            edge_feats,
            senders,
            receivers,
            avg_num_neighbors,
            target_irreps,
            activation,
        ),
        in_axes=1,
        out_axes=1,
        split_rng=False,
    )(node_feats)

    node_feats = e3nn.Linear(target_irreps, channel_out)(node_feats)
    return node_feats

Added

  • s2grid: e3nn.from_s2grid and e3nn.to_s2grid thanks to @songk42 for the contribution
  • argument max_order: int to function reduced_tensor_product_basis to be able to limit the polynomial order of the basis
  • MultiLayerPerceptron accepts IrrepsArray as input and output
  • e3nn.Linear accepts optional weights as arguments that will be internally mixed with the free parameters. Very usefyul to implement the depthwise convolution

Changed

  • breaking change e3nn.normal has a new argument to get normalized vectors.
  • breaking change e3nn.tensor_square now distinguishes between normalization=norm and normalized_input=True.