Creating a running average with self.variable #1005
-
Hi all, I tried to create a running average for a module. However, I am getting this error message: Could someone point me to what I am doing wrong? Thanks! import jax.numpy as jnp
from jax import random
import flax.linen as nn
class Net(nn.Module):
@nn.compact
def __call__(self, x):
is_initialized = self.has_variable('moving_stats', 'mean')
mean = self.variable('moving_stats', 'mean', jnp.zeros, [3])
if is_initialized:
mean.value = 0.9 * mean.value + 0.1 * x
return mean.value
key = random.PRNGKey(0)
x = random.normal(key, shape=(2, 3))
net = Net()
params = net.init(key, x)
y = net.apply(params, x) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
By default all variables are immutable during |
Beta Was this translation helpful? Give feedback.
By default all variables are immutable during
apply
. This is to avoid accidental side effects in otherwise stateless code.Here you should use
y, new_state = net.apply(params, x, mutable=['moving_stats'])
.The
new_state
will be a dict containing 'moving_stats' with the updated batch statistics.