This repository contains an implementation of the paper "Gradient Routing: Masking Gradients to Localize Computation in Neural Networks". The code specifically replicates the MNIST autoencoder experiments described in Section 4.1 of the paper. Some training details differ, but this should give a solid start for building on the experiments in the paper.
Repository Status: The respository is currently being updated with more experiments and replications. The plan is to eventually use this respository as starter code for research projects related to gradient routing for UChicago's AI safety club. Pull requests are welcome!
Gradient routing is a training method that isolates capabilities to specific regions of a neural network by setting gradients to zero in certain regions of the network during backpropagation. The forces the model to not learn certain data in specific parts of the network. The implementation in this repository demonstrates how gradient routing can split MNIST digit representations into distinct halves of an autoencoder's 32 dimension latent space.
If we train a encoder/decoder architechure on the MNIST handwritten digit dataset, we can route gradients to localize representations of digits 0-4 in the bottom half of the latent space and digits 5-9 in the top. We can confirm that this works by training decoders with only access to one half of the latent space. The image below shows original images (from the validation set) on the top and images generated from a decoder with only access to the top half of the latent space on the bottom. As you can see, performance is reasonable for these digits.
Above, the images generated by the decoder (bottom row) look fairly similar to the original data (top row). However, when you test this decoder on digits 0-4 (from the validation set) the performance is much worse. This is because we routed gradients that learn from these digits away from the top half of the latent space:
You can see a similar effect when you train a decoder on only the bottom half of the encoding:
Overall, this shows that we have sucessfully isolated representations of certain features to one side of the latent space with representations of other features to the other. This can be shown more robustly through comparing measurements of MAE losses for both the top and bottom decoder on each digit in the validation set:
The process to train the encoder/decoders should be almost the same as outlined in the original paper. Correlation loss is described in the appendix of the paper but is not included in the repository. Note that I also measure an L1 loss for the output of the decoder (as described in the original paper). In the paper, the authors train for 200 epochs while I trained for 400 epochs to get the results shown above. I will note that returns after epoch 200 are minimal.
(Note that the purple line in image above represents the decoder loss for the bottom decoder. The reason that there is no line for the top decoder is because their losses are roughly the same and the lines would overlap.)
There are two main reasons why I find this research direction personally interesting:
- Gradient routing can make models more interpretable. When we localize computation related to certain features, we will know where to find them later we want to understand more about a model's internals.
- Related to #1, there are early results indicating that when we localize computation for one subject, the model routes realted concepts to the same area. This indicates that gradient routing is scalable to domains where there is limited labeled data.
- It is a glimpse into a world where we have modules in models that we can turn on and off. If we are concerned that an area of a model is related to dangerous bahavior, we could shut it off.
In my opinion, extending #1 and trying to test #2 to practically increase safety in LLMs of at least 7 billion parameters are the most promising research directions.
If you use any of the ideas from this repository in your own work please cite the original paper:
@article{cloud2024gradient,
title={Gradient Routing: Masking Gradients to Localize Computation in Neural Networks},
url={https://arxiv.org/abs/2410.04332v1},
journal={arXiv.org},
author={Cloud, Alex and Goldman-Wetzler, Jacob and Wybitul, Evžen and Miller, Joseph and Turner, Alexander Matt},
year={2024},
}
Made with ❤️ and PyTorch