diff --git a/README.md b/README.md index f255a31d..5ad6b268 100644 --- a/README.md +++ b/README.md @@ -78,23 +78,30 @@ out.list The two fields `contiguous` and `list` contain the same information under different forms. This is not a performence issue, we rely on `jax.jit` to optimize the code and get rid of the unused operations. -## FullyConnectedTensorProduct +## Complete example -`torch` version ([e3nn](github.com/e3nn/e3nn) repo): +Usage of `FullyConnectedTensorProduct` in the `torch` version ([e3nn](github.com/e3nn/e3nn) repo): ```python -f = o3.FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) +from e3nn import o3 -f(x, y) +irreps_in1 = o3.Irreps("1e") +irreps_in2 = o3.Irreps("1e") +irreps_out = o3.Irreps("1e") + +tp = o3.FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) + +x1 = irreps_in1.randn(10, -1) +x2 = irreps_in2.randn(10, -1) + +out = tp(x1, x2) ``` -`jax` version (this repo): +and in the `jax` version (this repo): ```python import jax import e3nn_jax as e3nn import haiku as hk -irreps_in1 = e3nn.Irreps("1e") -irreps_in2 = e3nn.Irreps("1e") irreps_out = e3nn.Irreps("1e") @hk.without_apply_rng @@ -102,9 +109,15 @@ irreps_out = e3nn.Irreps("1e") def tp(x1, x2): return e3nn.FullyConnectedTensorProduct(irreps_out)(x1, x2) +irreps_in1 = e3nn.Irreps("1e") +irreps_in2 = e3nn.Irreps("1e") + x1 = e3nn.IrrepsData.randn(irreps_in1, jax.random.PRNGKey(0), (10,)) x2 = e3nn.IrrepsData.randn(irreps_in2, jax.random.PRNGKey(1), (10,)) w = tp.init(jax.random.PRNGKey(2), x1, x2) out = tp.apply(w, x1, x2) ``` + +The `jax` version require more boiler-plate (haiku) and more verbose code (with the random keys). +However note that the input irreps does not need to be provided to `FullyConnectedTensorProduct` because it will obtain it from its inputs.