diff --git a/e3nn_jax/_irreps.py b/e3nn_jax/_irreps.py index 2ff0ad8e..40b8a313 100644 --- a/e3nn_jax/_irreps.py +++ b/e3nn_jax/_irreps.py @@ -764,6 +764,10 @@ def __repr__(self): def shape(self): return self.contiguous.shape[:-1] + @property + def ndim(self): + return len(self.shape) + def reshape(self, shape) -> "IrrepsData": list = [ None if x is None else x.reshape(shape + (mul, ir.dim))