using flax sow inside a function with scanning #3727
Unanswered
TalKachman
asked this question in
Q&A
Replies: 1 comment
-
You'll need to specify For example: import flax.linen as nn
import jax
import jax.numpy as jnp
class Block(nn.Module):
width: int
@nn.compact
def __call__(self, carry, unused_inputs):
carry = nn.Dense(self.width)(carry)
carry = nn.relu(carry)
self.sow('intermediates', 'carry', carry)
return carry, None
class MLP(nn.Module):
width: int
depth: int
@nn.compact
def __call__(self, x):
self.sow('intermediates', 'x', x)
carry, unused_outputs = nn.scan(
Block,
# If 'intermediates' is not listed in `variable_axes` below, then
# `self.sow('intermediates', ...)` will not work inside the `nn.scan()`.
variable_axes={'params': 0, 'intermediates': 0},
split_rngs={'params': True},
length=self.depth,
)(
width=self.width,
)(
x, None,
)
return carry
model = MLP(width=2, depth=3)
x = jnp.zeros([1, 2])
variables = model.init(jax.random.PRNGKey(0), x)
out, state = model.apply(variables, x, mutable=['intermediates'])
jax.tree.map(lambda x: x.shape, dict(out=out, state=state, variables=variables)) will output
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi Everyone!
Is there a way to use the sow method within a scan function? ideally I would love to save some metadata and the ability to look at states.
Small example of what I want to try:
this gives an empty state. how for example can I access the output within the scan ?
Many thanks!
Beta Was this translation helpful? Give feedback.
All reactions