Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger authored Apr 18, 2022
1 parent 25132ac commit 24bcd8f
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,33 +78,46 @@ out.list
The two fields `contiguous` and `list` contain the same information under different forms.
This is not a performence issue, we rely on `jax.jit` to optimize the code and get rid of the unused operations.

## FullyConnectedTensorProduct
## Complete example

`torch` version ([e3nn](github.com/e3nn/e3nn) repo):
Usage of `FullyConnectedTensorProduct` in the `torch` version ([e3nn](github.com/e3nn/e3nn) repo):
```python
f = o3.FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)
from e3nn import o3

f(x, y)
irreps_in1 = o3.Irreps("1e")
irreps_in2 = o3.Irreps("1e")
irreps_out = o3.Irreps("1e")

tp = o3.FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)

x1 = irreps_in1.randn(10, -1)
x2 = irreps_in2.randn(10, -1)

out = tp(x1, x2)
```

`jax` version (this repo):
and in the `jax` version (this repo):
```python
import jax
import e3nn_jax as e3nn
import haiku as hk

irreps_in1 = e3nn.Irreps("1e")
irreps_in2 = e3nn.Irreps("1e")
irreps_out = e3nn.Irreps("1e")

@hk.without_apply_rng
@hk.transform
def tp(x1, x2):
return e3nn.FullyConnectedTensorProduct(irreps_out)(x1, x2)

irreps_in1 = e3nn.Irreps("1e")
irreps_in2 = e3nn.Irreps("1e")

x1 = e3nn.IrrepsData.randn(irreps_in1, jax.random.PRNGKey(0), (10,))
x2 = e3nn.IrrepsData.randn(irreps_in2, jax.random.PRNGKey(1), (10,))
w = tp.init(jax.random.PRNGKey(2), x1, x2)

out = tp.apply(w, x1, x2)
```

The `jax` version require more boiler-plate (haiku) and more verbose code (with the random keys).
However note that the input irreps does not need to be provided to `FullyConnectedTensorProduct` because it will obtain it from its inputs.

0 comments on commit 24bcd8f

Please sign in to comment.