What is possible with neural network stitching?
You can install necessary dependencies using the provided environment file:
conda env create -f environment.yml
conda activate stitchHowever, many users will need to install PyTorch manually, based on their specific system configuration. In that case,
- Create an environment (using your preferred Python version):
conda create -n stitch python=3.11 - Activate:
conda activate stitch - Install PyTorch and Torchvision first.
- Manually install the rest of the packages listed in environment.yml. Install
condapackages beforepippackages.- Note: The
wandbpackage comes from theconda-forgechannel:conda install wandb -c conda-forge
- Note: The
For convenience, you may consider setting up a symlink to the folder that contains your datasets. Otherwise you must
specify the --data-path when you run. For instance:
cd stitching
ln -s ~/datasets ./dataConfiguration and output of all experiments will live in the experiments/ folder.
For now, each experiment will consist of the stitching of two networks. For initial experiments, instead of stitching two separate networks we will first knock out some layer(s) of a single network and replace them with new stitching layer(s). Organization will be as follows:
experiments/<project name>/<experiment name>/config.ymltraj.pkl
Where config.yml is the experiment configuration and traj.pkl is a pickled Pandas dataframe describing the
stitch training trajectory.
- Load a configured set of subnets using
utils.subgraphs.create_sub_network(). - Construct a network with configured stitching modules in between each subnet.
- Train the stitching module(s) for a configured number of epochs using a configured optimizer.
- Write the training trajectory to a dataframe on disk.