This repository provides the implementation for "Replica-Based Federated Learning with Heterogeneous Architectures for Graph Super-Resolution", which has been accepted at the International Workshop on Machine Learning in Medical Imaging (MLMI), held in conjuction with MICCAI 2023.
Replica-Based Federated Learning with Heterogeneous Architectures for Graph Super-Resolution
1BASIRA Lab, Imperial-X and Department of Computing, Imperial College London, London, UK
Abstract: Having access to brain connectomes at various resolutions is important for clinicians, as they can reveal vital information about brain anatomy and function. However, the process of deriving the graphs from magnetic resonance imaging (MRI) is computationally expensive and error-prone. Furthermore, an existing challenge in the medical domain is the small amount of data that is available, as well as privacy concerns. In this work, we propose a new federated learning framework, named RepFL. At its core, RepFL is a replica-based federated learning approach for heterogeneous models, which creates replicas of each participating client by copying its model architecture and perturbing its local training dataset. This solution enables learning from limited data with a small number of participating clients by aggregating multiple local models and diversifying the data distributions of the clients. Specifically, we apply the framework for graph super-resolution using heterogeneous model architectures. In addition, to the best of our knowledge, this is the first federated multi-resolution graph generation approach. Our experiments prove that the method outperforms other federated learning methods on the task of brain graph super-resolution. Our RepFL code is available at https://github.com/basiralab/RepFL.
The code for the project was implemented using Python 3.8.16 and PyTorch 2.0.0 on macOS (Apple M1) and Python 3.10.12 and PyTorch 2.0.1 on Linux (Nvidia Tesla A30, Nvidia Tesla T4 and Nvidia GeForce GTX Titan Xp).
- Create a new Anaconda environment and activate it
conda create --name pyg python=3.8
conda activate pyg
- Install arm64 compilers (needed for using torch-geometric on M1 GPU)
conda install -y clang_osx-arm64 clangxx_osx-arm64 gfortran_osx-arm64
- Install the dependencies (PyTorch)
Make sure to replace MACOSX_DEPLOYMENT_TARGET=13.0 with the OS version that you have, and the torch version you need.
MACOSX_DEPLOYMENT_TARGET=13.0 CC=clang CXX=clang++ python -m pip --no-cache-dir install torch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1
You can check the PyTorch version by running the following command:
python -c "import torch; print(torch.__version__)"
It will display something like this:
2.0.0
- Install the dependencies (PyTorch Geometric)
MACOSX_DEPLOYMENT_TARGET=13.0 CC=clang CXX=clang++ python -m pip --no-cache-dir install torch-scatter -f https://data.pyg.org/whl/torch-1.12.1+${cpu}.html
MACOSX_DEPLOYMENT_TARGET=13.0 CC=clang CXX=clang++ python -m pip --no-cache-dir install torch-sparse -f https://data.pyg.org/whl/torch-1.12.1+${cpu}.html
MACOSX_DEPLOYMENT_TARGET=13.0 CC=clang CXX=clang++ python -m pip --no-cache-dir install torch-cluster -f https://data.pyg.org/whl/torch-1.12.1+${cpu}.html
MACOSX_DEPLOYMENT_TARGET=13.0 CC=clang CXX=clang++ python -m pip --no-cache-dir install torch-geometric
- Install other dependencies
conda install -c conda-forge matplotlib
conda install -c anaconda networkx
conda install -c anaconda seaborn
- Create a new Python environment and activate it
python3 -m venv pyg
source pyg/bin/activate
- Install the dependencies (PyTorch)
Make sure to have the appropriate CUDA version in the link. The available versions can be found at https://data.pyg.org/whl/.
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
You can check if the PyTorch and CUDA versions match by running the following command:
python -c "import torch; print(torch.__version__)"
It will display something like this:
2.0.1+cu117
- Install the dependencies (PyTorch Geometric)
pip install torch-sparse==0.6.17 -f https://pytorch-geometric.com/whl/torch- 2.0.1+cu117.html
pip install torch-scatter==2.1.1 -f https://pytorch-geometric.com/whl/torch- 2.0.1+cu117.html
pip install torch-cluster==1.6.1 -f https://pytorch-geometric.com/whl/torch- 2.0.1+cu117.html
pip install torch-geometric
- Install other dependencies
pip install networkx
pip install matplotlib
pip install seaborn
We provide simulated data of 279 samples, each having three graphs: one graph with resolution 35, one graph with resulution 160, and one graph with resolution 268. If you want to use your own dataset, please modify the method utils/utils/read_and_preprocess_files
such that it returns the desired data.
python train/train_slim.py --alg {algorithm} --seed {seed} --le {local_epochs} --r {rounds} --batch {batch_size} --folds {folds} --global_run_name {global_run_name} --run_name {run_name} --perturbed {no_perturbed_samples} --replicas {no_replicas}
Parameters
Parameter | Values | Definition |
---|---|---|
algorithm | baseline, fedavg, feddyn, feddc, repfl | The federated learning algorithm to run |
seed | int | Seed for random |
local_epochs | int | Number of local epochs |
rounds | int | Number of rounds |
batch_size | int | Batch size |
folds | int | Number of folds for K-fold cross validation |
global_run_name | string | Name of directory where the plots will be saved |
run_name | string | Name of directory where the models will be saved |
no_perturbed_samples | int | Percentage of perturbed samples in replicas datasets |
no_replicas | int | Number of replicas for each client (anchor) |
Examples
python train/train_slim.py --alg fedavg --le 10 --r 10 --batch 5 --folds 5 --global_run_name global_fedavg --run_name fedavg
python train/train_slim.py --alg repfl --le 10 --r 10 --batch 5 --folds 5 --global_run_name global_repfl --run_name repfl --perturbed 30 --replicas 3
python test/test_slim.py --alg {algorithm} --global_run_name {global_run_name} --run_name {run_name}
Parameters
Parameter | Values | Definition |
---|---|---|
algorithm | baseline, fedavg, feddyn, feddc, repfl | The federated learning algorithm to run |
global_run_name | string | Name of directory where the plots will be saved |
run_name | string | Name of directory where the models were saved and where test results will be saved |
Examples
python test/test_slim.py --alg fedavg --global_run_name global_fedavg --run_name fedavg
python test/test_slim.py --alg repfl --global_run_name global_repfl --run_name repfl
@inproceedings{ghilea2023replica,
title={Replica-Based Federated Learning with Heterogeneous Architectures for Graph Super-Resolution},
author={Ghilea, Ramona and Rekik, Islem},
booktitle={International Workshop on Machine Learning in Medical Imaging},
pages={273--282},
year={2023},
organization={Springer}
}
For questions regarding the code, please contact ramonaghilea9@gmail.com.