Simple implementation of split learning/split inference. A model is split at specified "cut" layer, and trained by a set of clients (holding input data), and a parameter server.
We provide a few example implementations based on the following scenarios:
Type | Network Communication | Client Framework | Server Framework |
---|---|---|---|
Split Inference | Websockets | Onnx Runtime | PyTorch |
Split Learning | Websockets | PyTorch | PyTorch |
Split Learning | MPI | PyTorch | PyTorch |
The core implementation is in Python (3.11). To install we can create a virtual environment using Conda:
conda create -f environment.yml
conda activate split-learning-demo
To install the webapp:
cd apps/web
pnpm install
Demos scripts can be found in the scripts
directory.
To run a simple client/server split leraning setup:
python scripts/server.py --learning-rate=0.01
In a different terminal:
python scripts/client.py --learning-rate=0.01
Make sure the webapp is installed (see above)
python scripts/server.py --learning-rate=0.01
In a different terminal:
cd apps/web
pnpm run dev
Navigate to http://localhost:5173
in your browser. The websocket server will default to ws://127.0.0.1:8000/ws
.
To run the MPI demo with 1 server and 1 client:
mpirun -n 2 python scripts/mpi.py --leanring-rate=0.01
- Add a simple local baseline model for comparisons
- Add split inference model
- Introduce an adversarial attack
- Add a serform a simple defence
- Show and address SplitNN communication overheads with compression
@article{vepakomma2018split,
title={Split learning for health: Distributed deep learning without sharing raw patient data},
author={Vepakomma, Praneeth and Gupta, Otkrist and Swedish, Tristan and Raskar, Ramesh},
journal={arXiv preprint arXiv:1812.00564},
year={2018}
}