Experiments with neural networks and Jax
The flax_neural_network_sin.py script demonstrates how to build and train a simple Multi-Layer Perceptron (MLP) using the Flax library to approximate a sine wave function.
Key steps in the script:
- Data Generation: It generates
N_SAMPLES(200) data points wherexvalues are uniformly distributed between 0 and 2π. The correspondingyvalues aresin(x)with some added Gaussian noise. - Model Definition: A
SimpleMLPis defined usingflax.linen.Module. This network consists of an input layer, three hidden layers (each with 10 neurons and sigmoid activation), and an output layer (1 neuron). - Initialization: The model's parameters are initialized.
- Training Setup:
- An Adam optimizer is chosen with a specified
LEARNING_RATE. - A
TrainStateobject from Flax is used to manage the model's parameters, apply function, and optimizer state. - A mean squared error loss function (
loss_fn) is defined to measure the difference between the model's predictions and the trueyvalues.
- An Adam optimizer is chosen with a specified
- Training Loop:
- The model is trained for
N_EPOCHS(30,000). - In each epoch, the
train_stepfunction (JIT-compiled for performance) calculates the loss and gradients, then updates the model's parameters. - The loss is recorded at each epoch.
- The model is trained for
- Results Visualization:
- A plot of the loss history over epochs is generated and saved.
- A scatter plot comparing the original noisy sine wave data with the trained model's predictions is generated and saved.
Loss History:
Inference - Model Approximation of Sine Wave:

