Skip to content

Commit 2c7d7cd

Browse files
author
Flax Authors
committed
Merge pull request #3876 from google:nnx-v0.1
PiperOrigin-RevId: 628712571
2 parents ea3bcab + 08b1336 commit 2c7d7cd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+3718
-4645
lines changed

.readthedocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ version: 2
88
build:
99
os: ubuntu-22.04
1010
tools:
11-
python: "3.9"
11+
python: "3.10"
1212

1313
# Build documentation in the docs/ directory with Sphinx
1414
sphinx:

docs/api_reference/flax.experimental.nnx/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ Experimental API. See the `NNX page <https://flax.readthedocs.io/en/latest/exper
1414
transforms
1515
variables
1616
helpers
17+
visualization
1718

docs/api_reference/flax.experimental.nnx/module.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ module
77
.. autoclass:: Module
88
:members:
99

10-
.. autofunction:: merge
10+
.. automethod:: sow
11+
.. automethod:: iter_modules

docs/api_reference/flax.experimental.nnx/variables.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ variables
1414
:members:
1515
.. autoclass:: Param
1616
:members:
17-
.. autoclass:: Rng
18-
:members:
1917
.. autoclass:: Variable
2018
:members:
2119
.. autoclass:: VariableMetadata
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
visualization
2+
------------------------
3+
4+
.. automodule:: flax.experimental.nnx
5+
.. currentmodule:: flax.experimental.nnx
6+
7+
.. autofunction:: display

docs/experimental/nnx/index.rst

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,6 @@ Features
6868
NNX prioritizes simplicity for common use cases, building upon lessons learned from Linen
6969
to provide a streamlined experience.
7070

71-
72-
Installation
73-
^^^^^^^^^^^^
74-
NNX is under active development, we recommend using the latest version from Flax's GitHub repository:
75-
76-
.. code-block:: bash
77-
78-
pip install git+https://github.com/google/flax.git
79-
80-
8171
Basic usage
8272
^^^^^^^^^^^^
8373

@@ -89,22 +79,43 @@ Basic usage
8979
.. testcode::
9080

9181
from flax.experimental import nnx
82+
import optax
83+
84+
85+
class Model(nnx.Module):
86+
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
87+
self.linear = nnx.Linear(din, dmid, rngs=rngs)
88+
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
89+
self.dropout = nnx.Dropout(0.2, rngs=rngs)
90+
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
9291

93-
class Linear(nnx.Module):
94-
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
95-
key = rngs() # get a unique random key
96-
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
97-
self.b = nnx.Param(jnp.zeros((dout,))) # initialize parameters
98-
self.din, self.dout = din, dout
92+
def __call__(self, x):
93+
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
94+
return self.linear_out(x)
9995

100-
def __call__(self, x: jax.Array):
101-
return x @ self.w.value + self.b.value
96+
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
97+
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
10298

103-
rngs = nnx.Rngs(0) # explicit RNG handling
104-
model = Linear(din=2, dout=3, rngs=rngs) # initialize the model
99+
@nnx.jit # automatic state management
100+
def train_step(model, optimizer, x, y):
101+
def loss_fn(model):
102+
y_pred = model(x) # call methods directly
103+
return ((y_pred - y) ** 2).mean()
104+
105+
loss, grads = nnx.value_and_grad(loss_fn)(model)
106+
optimizer.update(grads) # inplace updates
107+
108+
return loss
109+
110+
111+
Installation
112+
^^^^^^^^^^^^
113+
NNX is under active development, we recommend using the latest version from Flax's GitHub repository:
114+
115+
.. code-block:: bash
116+
117+
pip install git+https://github.com/google/flax.git
105118
106-
x = jnp.empty((1, 2)) # generate random data
107-
y = model(x) # forward pass
108119
109120
----
110121

docs/experimental/nnx/mnist_tutorial.ipynb

Lines changed: 225 additions & 155 deletions
Large diffs are not rendered by default.

docs/experimental/nnx/mnist_tutorial.md

Lines changed: 50 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,14 @@ Since NNX is under active development, we recommend using the latest version fro
2828
```{code-cell} ipython3
2929
:tags: [skip-execution]
3030
31-
# TODO: Fix text descriptions in this tutorial
32-
!pip install git+https://github.com/google/flax.git
31+
# !pip install git+https://github.com/google/flax.git
3332
```
3433

3534
## 2. Load the MNIST Dataset
3635

37-
We'll use TensorFlow Datasets (TFDS) for loading and preparing the MNIST dataset:
36+
First, the MNIST dataset is loaded and prepared for training and testing using
37+
Tensorflow Datasets. Image values are normalized, the data is shuffled and divided
38+
into batches, and samples are prefetched to enhance performance.
3839

3940
```{code-cell} ipython3
4041
import tensorflow_datasets as tfds # TFDS for MNIST
@@ -77,40 +78,28 @@ Create a convolutional neural network with NNX by subclassing `nnx.Module`.
7778

7879
```{code-cell} ipython3
7980
from flax.experimental import nnx # NNX API
81+
from functools import partial
8082
8183
class CNN(nnx.Module):
8284
"""A simple CNN model."""
8385
8486
def __init__(self, *, rngs: nnx.Rngs):
85-
self.conv1 = nnx.Conv(
86-
in_features=1, out_features=32, kernel_size=(3, 3), rngs=rngs
87-
)
88-
self.conv2 = nnx.Conv(
89-
in_features=32, out_features=64, kernel_size=(3, 3), rngs=rngs
90-
)
91-
self.linear1 = nnx.Linear(in_features=3136, out_features=256, rngs=rngs)
92-
self.linear2 = nnx.Linear(in_features=256, out_features=10, rngs=rngs)
87+
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
88+
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
89+
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
90+
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
91+
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
9392
9493
def __call__(self, x):
95-
x = self.conv1(x)
96-
x = nnx.relu(x)
97-
x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
98-
x = self.conv2(x)
99-
x = nnx.relu(x)
100-
x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
101-
x = x.reshape((x.shape[0], -1)) # flatten
102-
x = self.linear1(x)
103-
x = nnx.relu(x)
94+
x = self.avg_pool(nnx.relu(self.conv1(x)))
95+
x = self.avg_pool(nnx.relu(self.conv2(x)))
96+
x = x.reshape(x.shape[0], -1) # flatten
97+
x = nnx.relu(self.linear1(x))
10498
x = self.linear2(x)
10599
return x
106100
107-
108101
model = CNN(rngs=nnx.Rngs(0))
109-
110-
print(f'model = {model}'[:500] + '\n...\n') # print a part of the model
111-
print(
112-
f'{model.conv1.kernel.value.shape = }'
113-
) # inspect the shape of the kernel of the first convolutional layer
102+
nnx.display(model)
114103
```
115104

116105
### Run model
@@ -123,84 +112,71 @@ Let's put our model to the test! We'll perform a forward pass with arbitrary da
123112
import jax.numpy as jnp # JAX NumPy
124113
125114
y = model(jnp.ones((1, 28, 28, 1)))
126-
y
115+
nnx.display(y)
127116
```
128117

129-
## 4. Create the `TrainState`
130-
131-
In Flax, a common practice is to use a dataclass to encapsulate the entire training state, which would allow you to simply pass only two arguments (the train state and batched data) to functions like `train_step`. The training state would typically contain an [`nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/training/optimizer.html#flax.experimental.nnx.optimizer.Optimizer) (which contains the step number, model and optimizer state) and an `nnx.Module` (for easier access to the model from the top-level of the train state). The training state can also be easily extended to add training and test metrics, as you will see in this tutorial (see [`nnx.metrics`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/training/metrics.html#module-flax.experimental.nnx.metrics) for more detail on NNX's metric classes).
132-
133-
```{code-cell} ipython3
134-
import dataclasses
135-
136-
@dataclasses.dataclass
137-
class TrainState(nnx.GraphNode):
138-
optimizer: nnx.Optimizer
139-
model: CNN
140-
metrics: nnx.MultiMetric
141-
```
118+
## 4. Create Optimizer and Metrics
142119

143-
We use `optax` to create an optimizer ([`adamw`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adamw)) and initialize the `nnx.Optimizer`. We use `nnx.MultiMetric` to keep track of both the accuracy and average loss for both training and test batches.
120+
In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model parameters and an `optax` optimizer that will define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss.
144121

145122
```{code-cell} ipython3
146123
import optax
147124
148125
learning_rate = 0.005
149126
momentum = 0.9
150-
tx = optax.adamw(learning_rate, momentum)
151-
152-
state = TrainState(
153-
optimizer=nnx.Optimizer(model=model, tx=tx),
154-
model=model,
155-
metrics=nnx.MultiMetric(
156-
accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average()
157-
),
127+
128+
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
129+
metrics = nnx.MultiMetric(
130+
accuracy=nnx.metrics.Accuracy(),
131+
loss=nnx.metrics.Average('loss'),
158132
)
133+
134+
nnx.display(optimizer)
159135
```
160136

161137
## 5. Training step
162138

163139
We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing.
164140

165141
```{code-cell} ipython3
166-
def loss_fn(model, batch):
142+
def loss_fn(model: CNN, batch):
167143
logits = model(batch['image'])
168144
loss = optax.softmax_cross_entropy_with_integer_labels(
169145
logits=logits, labels=batch['label']
170146
).mean()
171147
return loss, logits
172148
```
173149

174-
Next, we create the training step function. This function takes the `state` and a data `batch` and does the following:
150+
Next, we create the training step function. This function takes the `model` and a data `batch` and does the following:
175151

176152
* Computes the loss, logits and gradients with respect to the loss function using `nnx.value_and_grad`.
177-
* Updates the training loss using the loss and updates the training accuracy using the logits and batch labels
178-
* Updates model parameters and optimizer state by applying the gradient pytree to the optimizer.
153+
* Updates training accuracy using the loss, logits, and batch labels.
154+
* Updates model parameters via the optimizer by applying the gradient updates.
179155

180156
```{code-cell} ipython3
181157
@nnx.jit
182-
def train_step(state: TrainState, batch):
158+
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
183159
"""Train for a single step."""
184160
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
185-
(loss, logits), grads = grad_fn(state.model, batch)
186-
state.metrics.update(values=loss, logits=logits, labels=batch['label'])
187-
state.optimizer.update(grads=grads)
161+
(loss, logits), grads = grad_fn(model, batch)
162+
metrics.update(loss=loss, logits=logits, labels=batch['label'])
163+
optimizer.update(grads)
188164
```
189165

190166
The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/transforms.html#flax.experimental.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with
191167
[XLA](https://www.tensorflow.org/xla), optimizing performance on
192168
hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),
193-
except it can decorate functions that make stateful updates to NNX classes.
169+
except it can transforms functions that contain NNX objects as inputs and outputs.
194170

195-
## 6. Metric Computation
171+
## 6. Evaluation step
196172

197173
Create a separate function to calculate loss and accuracy metrics for the test batch, since this will be outside the `train_step` function. Loss is determined using the `optax.softmax_cross_entropy_with_integer_labels` function, since we're reusing the loss function defined earlier.
198174

199175
```{code-cell} ipython3
200176
@nnx.jit
201-
def compute_test_metrics(*, state: TrainState, batch):
202-
loss, logits = loss_fn(state.model, batch)
203-
state.metrics.update(values=loss, logits=logits, labels=batch['label'])
177+
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
178+
loss, logits = loss_fn(model, batch)
179+
metrics.update(loss=loss, logits=logits, labels=batch['label'])
204180
```
205181

206182
## 7. Seed randomness
@@ -213,20 +189,9 @@ tf.random.set_seed(0)
213189

214190
## 8. Train and Evaluate
215191

216-
**Dataset Preparation:** create a "shuffled" dataset
217-
- Repeat the dataset for the desired number of training epochs.
218-
- Establish a 1024-sample buffer (holding the dataset's initial 1024 samples).
219-
Randomly draw batches from this buffer.
220-
- As samples are drawn, replenish the buffer with subsequent dataset samples.
221-
222-
**Training Loop:** Iterate through epochs
223-
- Sample batches randomly from the dataset.
224-
- Execute an optimization step for each training batch.
225-
- Calculate mean training metrics across batches within the epoch.
226-
- With updated parameters, compute metrics on the test set.
227-
- Log train and test metrics for visualization.
228-
229-
After 10 training and testing epochs, your model should reach approximately 99% accuracy.
192+
Now we train a model using batches of data for 10 epochs, evaluate its performance
193+
on the test set after each epoch, and log the training and testing metrics (loss and
194+
accuracy) throughout the process. Typically this leads to a model with around 99% accuracy.
230195

231196
```{code-cell} ipython3
232197
:outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87
@@ -245,22 +210,22 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
245210
# - the train state's model parameters
246211
# - the optimizer state
247212
# - the training loss and accuracy batch metrics
248-
train_step(state, batch)
213+
train_step(model, optimizer, metrics, batch)
249214
250215
if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed
251216
# Log training metrics
252-
for metric, value in state.metrics.compute().items(): # compute metrics
217+
for metric, value in metrics.compute().items(): # compute metrics
253218
metrics_history[f'train_{metric}'].append(value) # record metrics
254-
state.metrics.reset() # reset metrics for test set
219+
metrics.reset() # reset metrics for test set
255220
256221
# Compute metrics on the test set after each training epoch
257222
for test_batch in test_ds.as_numpy_iterator():
258-
compute_test_metrics(state=state, batch=test_batch)
223+
eval_step(model, metrics, test_batch)
259224
260225
# Log test metrics
261-
for metric, value in state.metrics.compute().items():
226+
for metric, value in metrics.compute().items():
262227
metrics_history[f'test_{metric}'].append(value)
263-
state.metrics.reset() # reset metrics for next training epoch
228+
metrics.reset() # reset metrics for next training epoch
264229
265230
print(
266231
f"train epoch: {(step+1) // num_steps_per_epoch}, "
@@ -293,7 +258,6 @@ for dataset in ('train', 'test'):
293258
ax1.legend()
294259
ax2.legend()
295260
plt.show()
296-
plt.clf()
297261
```
298262

299263
## 10. Perform inference on test set
@@ -302,16 +266,16 @@ Define a jitted inference function, `pred_step`, to generate predictions on the
302266

303267
```{code-cell} ipython3
304268
@nnx.jit
305-
def pred_step(state: TrainState, batch):
306-
logits = state.model(batch['image'])
269+
def pred_step(model: CNN, batch):
270+
logits = model(batch['image'])
307271
return logits.argmax(axis=1)
308272
```
309273

310274
```{code-cell} ipython3
311275
:outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e
312276
313277
test_batch = test_ds.as_numpy_iterator().next()
314-
pred = pred_step(state, test_batch)
278+
pred = pred_step(model, test_batch)
315279
316280
fig, axs = plt.subplots(5, 5, figsize=(12, 12))
317281
for i, ax in enumerate(axs.flatten()):

0 commit comments

Comments
 (0)