-
Hi guys, I was wondering if anyone could provide an example for capturing the intermediate layer outputs when using the Linen API. Does |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Note that you can "capture" any intermediate values by adding them to a mutable collection: from typing import Sequence
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
class Model(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, inputs):
x = inputs
for i, features in enumerate(self.features):
x = nn.relu(nn.Dense(features)(x))
self.variable('debug', f'dense_relu_{i}', lambda: x).value = x
return x
model = Model([2, 3])
inputs = jnp.ones([2, 3])
variables = model.init(jax.random.PRNGKey(0), inputs)
model.apply(variables, inputs, mutable=['debug']) Would output
|
Beta Was this translation helpful? Give feedback.
-
Please check out https://flax.readthedocs.io/en/latest/howtos/extracting_intermediates.html for different Linen patterns to extract intermediate values. |
Beta Was this translation helpful? Give feedback.
Please check out https://flax.readthedocs.io/en/latest/howtos/extracting_intermediates.html for different Linen patterns to extract intermediate values.