You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
98
-
x = self.conv2(x)
99
-
x = nnx.relu(x)
100
-
x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
101
-
x = x.reshape((x.shape[0], -1)) # flatten
102
-
x = self.linear1(x)
103
-
x = nnx.relu(x)
94
+
x = self.avg_pool(nnx.relu(self.conv1(x)))
95
+
x = self.avg_pool(nnx.relu(self.conv2(x)))
96
+
x = x.reshape(x.shape[0], -1) # flatten
97
+
x = nnx.relu(self.linear1(x))
104
98
x = self.linear2(x)
105
99
return x
106
100
107
-
108
101
model = CNN(rngs=nnx.Rngs(0))
109
-
110
-
print(f'model = {model}'[:500] + '\n...\n') # print a part of the model
111
-
print(
112
-
f'{model.conv1.kernel.value.shape = }'
113
-
) # inspect the shape of the kernel of the first convolutional layer
102
+
nnx.display(model)
114
103
```
115
104
116
105
### Run model
@@ -123,84 +112,71 @@ Let's put our model to the test! We'll perform a forward pass with arbitrary da
123
112
import jax.numpy as jnp # JAX NumPy
124
113
125
114
y = model(jnp.ones((1, 28, 28, 1)))
126
-
y
115
+
nnx.display(y)
127
116
```
128
117
129
-
## 4. Create the `TrainState`
130
-
131
-
In Flax, a common practice is to use a dataclass to encapsulate the entire training state, which would allow you to simply pass only two arguments (the train state and batched data) to functions like `train_step`. The training state would typically contain an [`nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/training/optimizer.html#flax.experimental.nnx.optimizer.Optimizer) (which contains the step number, model and optimizer state) and an `nnx.Module` (for easier access to the model from the top-level of the train state). The training state can also be easily extended to add training and test metrics, as you will see in this tutorial (see [`nnx.metrics`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/training/metrics.html#module-flax.experimental.nnx.metrics) for more detail on NNX's metric classes).
132
-
133
-
```{code-cell} ipython3
134
-
import dataclasses
135
-
136
-
@dataclasses.dataclass
137
-
class TrainState(nnx.GraphNode):
138
-
optimizer: nnx.Optimizer
139
-
model: CNN
140
-
metrics: nnx.MultiMetric
141
-
```
118
+
## 4. Create Optimizer and Metrics
142
119
143
-
We use `optax`to create an optimizer ([`adamw`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adamw)) and initialize the `nnx.Optimizer`. We use `nnx.MultiMetric` to keep track of both the accuracy and average loss for both training and test batches.
120
+
In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model parameters and an `optax` optimizer that will define the update rules. Additionally, we'll define a `MultiMetric`object to keep track of the `Accuracy` and the `Average` loss.
We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing.
164
140
165
141
```{code-cell} ipython3
166
-
def loss_fn(model, batch):
142
+
def loss_fn(model: CNN, batch):
167
143
logits = model(batch['image'])
168
144
loss = optax.softmax_cross_entropy_with_integer_labels(
169
145
logits=logits, labels=batch['label']
170
146
).mean()
171
147
return loss, logits
172
148
```
173
149
174
-
Next, we create the training step function. This function takes the `state` and a data `batch` and does the following:
150
+
Next, we create the training step function. This function takes the `model` and a data `batch` and does the following:
175
151
176
152
* Computes the loss, logits and gradients with respect to the loss function using `nnx.value_and_grad`.
177
-
* Updates the training loss using the loss and updates the training accuracy using the logits and batch labels
178
-
* Updates model parameters and optimizer state by applying the gradient pytree to the optimizer.
153
+
* Updates training accuracy using the loss, logits, and batch labels.
154
+
* Updates model parameters via the optimizer by applying the gradient updates.
The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/transforms.html#flax.experimental.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with
191
167
[XLA](https://www.tensorflow.org/xla), optimizing performance on
192
168
hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),
193
-
except it can decorate functions that make stateful updates to NNX classes.
169
+
except it can transforms functions that contain NNX objects as inputs and outputs.
194
170
195
-
## 6. Metric Computation
171
+
## 6. Evaluation step
196
172
197
173
Create a separate function to calculate loss and accuracy metrics for the test batch, since this will be outside the `train_step` function. Loss is determined using the `optax.softmax_cross_entropy_with_integer_labels` function, since we're reusing the loss function defined earlier.
0 commit comments