Skip to content

Gradient routing experiments that isolate neural network capabilities to specific regions, demonstrated through MNIST autoencoder implementations.

Notifications You must be signed in to change notification settings

zroe1/gradient-routing

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

50 Commits
 
 
 
 
 
 

Repository files navigation

Gradient Routing Replication

This repository contains an implementation of the paper "Gradient Routing: Masking Gradients to Localize Computation in Neural Networks". The code specifically replicates the MNIST autoencoder experiments described in Section 4.1 of the paper. Some training details differ, but this should give a solid start for building on the experiments in the paper.

Repository Status: The respository is currently being updated with more experiments and replications. The plan is to eventually use this respository as starter code for research projects related to gradient routing for UChicago's AI safety club. Pull requests are welcome!

Overview

Gradient routing is a training method that isolates capabilities to specific regions of a neural network by setting gradients to zero in certain regions of the network during backpropagation. The forces the model to not learn certain data in specific parts of the network. The implementation in this repository demonstrates how gradient routing can split MNIST digit representations into distinct halves of an autoencoder's 32 dimension latent space.

If we train a encoder/decoder architechure on the MNIST handwritten digit dataset, we can route gradients to localize representations of digits 0-4 in the bottom half of the latent space and digits 5-9 in the top. We can confirm that this works by training decoders with only access to one half of the latent space. The image below shows original images (from the validation set) on the top and images generated from a decoder with only access to the top half of the latent space on the bottom. As you can see, performance is reasonable for these digits.

top_cert_high_digits

Above, the images generated by the decoder (bottom row) look fairly similar to the original data (top row). However, when you test this decoder on digits 0-4 (from the validation set) the performance is much worse. This is because we routed gradients that learn from these digits away from the top half of the latent space:

top_cert_low_digits

You can see a similar effect when you train a decoder on only the bottom half of the encoding:

bottom_cert_low_digits

bottom_cert_high_digits

Overall, this shows that we have sucessfully isolated representations of certain features to one side of the latent space with representations of other features to the other. This can be shown more robustly through comparing measurements of MAE losses for both the top and bottom decoder on each digit in the validation set:

Training

The process to train the encoder/decoders should be almost the same as outlined in the original paper. Correlation loss is described in the appendix of the paper but is not included in the repository. Note that I also measure an L1 loss for the output of the decoder (as described in the original paper). In the paper, the authors train for 200 epochs while I trained for 400 epochs to get the results shown above. I will note that returns after epoch 200 are minimal.

(Note that the purple line in image above represents the decoder loss for the bottom decoder. The reason that there is no line for the top decoder is because their losses are roughly the same and the lines would overlap.)

Theory of Change

There are two main reasons why I find this research direction personally interesting:

  1. Gradient routing can make models more interpretable. When we localize computation related to certain features, we will know where to find them later we want to understand more about a model's internals.
  2. Related to #1, there are early results indicating that when we localize computation for one subject, the model routes realted concepts to the same area. This indicates that gradient routing is scalable to domains where there is limited labeled data.
  3. It is a glimpse into a world where we have modules in models that we can turn on and off. If we are concerned that an area of a model is related to dangerous bahavior, we could shut it off.

In my opinion, extending #1 and trying to test #2 to practically increase safety in LLMs of at least 7 billion parameters are the most promising research directions.

Acknowledgments

If you use any of the ideas from this repository in your own work please cite the original paper:

@article{cloud2024gradient,
	title={Gradient Routing: Masking Gradients to Localize Computation in Neural Networks},
	url={https://arxiv.org/abs/2410.04332v1},
	journal={arXiv.org},
	author={Cloud, Alex and Goldman-Wetzler, Jacob and Wybitul, Evžen and Miller, Joseph and Turner, Alexander Matt},
	year={2024},
}

Made with ❤️ and PyTorch

About

Gradient routing experiments that isolate neural network capabilities to specific regions, demonstrated through MNIST autoencoder implementations.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published