Linen Module attributes and arrays/non hashable #1111
Unanswered
PhilipVinc
asked this question in
Q&A
Replies: 1 comment 4 replies
-
You always make me wonder if I should move back to quantum physics. It all looks exciting but I have no clue what you are doing :P Modules are not registered as a pytree because they are not simple stateless containers of JAX values. class SymmDense(nn.Module):
features: int
symm_init: Callable
dtype
@compact
def __call__(self, x)
x = jnp.asarray(inputs, x)
kernel = self.param(
"kernel", self.kernel_init, (inputs.shape[-1], self.features), self.dtype
)
kernel = jnp.asarray(kernel, dtype)
symm_shape = kernel.shape * 2 # I guess?
symmetrizer = self.variable('constant', 'symmetrizer', self.symm_init, symm_shape) # assuming symm_init has a shape argument but it does not need to have one...
symm_kernel = jnp.einsum("ijkl,kl->ij", symmetrizer.value, kernel)
x = lax.dot_general(
x,
kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision,
)
return x This way you could also enable more advanced things in the feature like using |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
A thing i noticed is that linen modules can be compared (==) and hashed.
This is useful to me because if I define the function
And this will not be re-jitted if I call it sub-sequently with different modules with identical parameters.
However, I now need to add an array as a field to the module (a matrix representing some symmetrical invariances to a dense layer).
But the code above will now fail, because
module
is no longer hashable if it contains an attribute that is a jax array.This is, I believe because modules declare as 'static' (
struct.field(pytree_node=False)
) every attribute. Am i right?What I want to do is build some 'DenseSymm' layer that multiplies the kernel and bias by a symmetrization tensor/matrix.
The Symmetrization tensor will not change and should not be optimised and ideally should be passed in at module construction.
I thought it would be a good fit for being an attribute but apparently not... Is this an use case for model_state?
Beta Was this translation helpful? Give feedback.
All reactions