Setup | Running NodePert | Paper
What algorithms underlie goal directed learning in the brain? Backpropagation is the standard credit assignment algorithm used in machine learning research, but it's considered biologically implausible. Recently, biologically plausible alternatives, such as feedback alignment, target propagation, and perturbation algorithms, have been explored. The node perturbation algorithm applies random perturbations to neuron activity, monitors performance, and adjusts weights accordingly. This approach is simple and may be utilized by the brain.
This repository contains the accompanying code for the paper, An empirical study of perturbation methods for training deep networks. It offers a efficient and scalable implementation of perturbation algorithms, allowing for large-scale experiments with node perturbation on modern convolutional architectures on a GPU. Our results provide insights into the diverse credit assignment algorithms used by the brain. The code was written by Yash Mehta and Timothy Lillicrap using JAX
in conjunction with Tensorflow Datasets
for data loading. Reach out to yashsmehta95[at]gmail.com or timothy.lillicrap[at]gmail.com with queries or feedback.
-
Clone the repository like the git wizard you know you are:
git clone https://github.com/silverpaths/nodepert.git cd nodepert
-
Create a new virtual environment using
venv
orconda
. Note,venv
comes inbuilt with python but we recommend usingconda
, especially if you want to run it on a GPU.conda
conda create -n nodepert python=3.11 conda activate nodepert
venv
python -m venv venv source venv/bin/activate
-
Install JAX and the nodepert package:
a. CPU only
pip install --upgrade "jax[cpu]" pip install -e .
b. GPU
conda
conda install -c nvidia cuda-toolkit pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install -e .
venv
Based on your CUDA version, check if you need to use "jax[cuda11_pip]" or "jax[cuda12_pip]"pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install -e .
-
To ensure JAX is working properly, run a basic experiment on a fully connected network comparing node perturbation and SGD on MNIST data. This saves a learning plot, and should less than 2m to run.
python example.py
Run into any JAX installation snafus? Check out their official install guide for a helping hand.
You can customize the entire training process by passing different arguments to a single file, main.py
. An example of argparse parameters include:
- dataset:
mnist
,fmnist
,cifar10
- network:
fc
,linfc
,conv
,conv-large
- update rule:
np
,sgd
For a full list of parameters and default values, refer to the parse_args()
function in utils.py
. To see an example of how to run the training process with your desired arguments, you can use the main.py
file.
python nodepert/main.py -network fc -dataset mnist -log_expdata True -n_hl 2 -hl_size 500 -lr 5e-3 -batchsize 100 -num_epochs 10 -update_rule np
Inside the experiments folder, you'll find example code for a variety of experiments utilizing node perturbation, for example:
- Understanding network crashes during training. See
crash-dynamics.py
,crash_timing.py
,grad_dynamics.py
- Relative change in the loss with different learning rates. See
linesearch.py
,linesearch_utils.py
- Adam-like update for NP gradients. See
adam_update.py
- Visualizing the loss landscape. See
loss_landscape.py
And for all you neural network aficionados, take a gander at model/conv.py
or model/fc.py
. The exact nodepert update can be found in optim.py
.
You can directly run multiple configurations with ease by simply specifying values in a dictionary in cluster_scripts/scheduler.py
. It schedules all combinations of hyperparameters specified in the dictionary, along with multiple seeds of your experiments simultaneously. This is extremely useful for GPU clusters that have resource allocation managers like SLURM.
bash slurm-scripts/scheduler.py
If you use this code in your own work, please use the following bibtex entry:
@misc{nodepert-2023,
title={NodePert: An empircal study of perturbation methods for training deep networks},
author={Mehta, Yash and Hiratani, Naoki and Humphreys, Peter and Latham, Peter and Lillicrap, Timothy},
year={2023}, publisher={GitHub},
howpublished={\url{https://github.com/countzerozzz/nodepert}} }