In this Git repository a new way of knowlegde distillation for data with sequence-level input and binary output on an inbalanced dataset was implemented. This work was inspired by the papers [1] and [2].
The idea behind this particullar distillation approach is for the student model to learn the distribution of the teacher model whilst learning from the original data itself. This is done using an online learning approach (teacher and student are trained simultaneously).
In order for the student model to learn from the teacher model a combined loss for training was used. The first part of the loss contains the original data and the loss for unlabled positive data (implemented in the loss class, the algorithm is discribed in [3]). The second part is a cross-entropy loss on data labled by the teacher model. These two losses are combined in a convex combination with the hyperparameter
With the above discribed loss the student model is trained. The teacher model is trained using the loss for unlabled positive data.
Input:
training data (subset of original data);
hyperparameters for loss:
hyperparameters for epochs: meta_epoch, teacher_epoch, student_epoch;
models: student model S, teacher model T (both untrained)
Output:
trained models S and T
- For each meta_epoch do:
-
For each teacher_epoch do:
-
Train teacher model with training data
-
For each student_epoch do:
-
Shuffle data and take a batch for training iteration.
-
Split batch into two disjoint data sets
$data_s$ and$data_t$ with$n_{data_s} = \beta * n_{data}$ and$n_{data_t} = (1-\beta) * n_{data}$ -
Make predictions with T for
$data_t$ -
Train S with both data sets (use predictions from T for
$data_t$ and true labels for$data_s$ ) using a combined loss weighted with$\alpha$ for$L_t$ and$(1- \alpha)$ for$L_s$ - Save S and T
In the folder src one can find the folders distillation, loss, models and visualization.
All losses are implemented in the folder loss.
The training and the distillation algorithm is located in the folder distillation.
The models used for this work can be found in the folder models.
The Config file (hyperparameters.yml) can be found in Config.
Results and graphics can be found in the Wiki part of this Github repository.
In order to reproduce our results, adjust the file run_file_google_colab.py in the folder 'notebooks' with your own data path and git key, and run it on GoogleColab. If you are using another device than GoogleColab, please execute:
-
!python3 -m pip install -r requirements.txt
-
# Adjust your path to data (e.g. connect to google drive)
-
os.chdir('./src')
-
from main import main
-
main({ 'config_path' :'/Team09AppliedDL/config/hyperparameters.yml', 'data_path' : 'path to data', 'wandb' : True})
[1] Geoffrey Hinton, Oriol Vinyals, Jeff Dean, 2015. Distilling the Knowledge in a Neural Network. https://arxiv.org/abs/1503.02531
[2] Jianping Gou, Baosheng Yu, Stephen J. Maybank, Dacheng Tao, 2021. Knowledge Distillation: A Survey. https://arxiv.org/abs/2006.05525v7
[3] Guangxin Su, Weitong Chen, Miao Xu, 2021. Positive-Unlabeled Learning from Imbalanced Data. https://www.ijcai.org/proceedings/2021/412
├── config
│ ├── hyperparameters.yml <- YML-File for hyperparameters and model specification.
│
├── notebooks <- Jupyter notebooks.
│ ├── run_file_google_colab.ipynb <- Notebook for running the code on GoogleColab
│
├── reports <- folder with images containing reported final results
│ ├── figures <- folder with figures of the results with seed 123
│ │ ├── auc_student_test.png
│ │ ├── auc_student_train.png
│ │ ├── auc_teacher_test.png
│ │ ├── auc_teacher_test.png
│ │
│ ├── tables <- folder with results of seed=123
│ │ ├── adl_seed_123.txt
│
├── src <- Source code to use in this project
│ │
│ ├── data <- Scripts to preprocess data
│ │ ├── Dataset.py <- Script for data preparation (read in and one hot encoding of the original dataset) + preparation of data for DNABert
│ │ ├── make_dataset.py <- Script for generating random test data
│ │
│ ├── distillation <- Script for the distillation class with evaluation and train loop
│ │ ├── Distillation.py
│ │ ├── Train.py
│
│ ├── loss
│ │ ├── DistillationLoss.py <- Script for our distillation loss
│ │ ├── ImbalancedLoss.py <- Script for the imbalanced loss
│
│ ├── models <- Scripts for teacher and student models
│ │ ├── Students.py <- transformer and mlp's
│ │ ├── Teachers.py <- mlp's
│
│ ├── ConfigReader.py <- Script to read configurations
│ ├── Logger.py <- Script for including mlflow and wandb
│ ├── main.py <- Launch script
│
├── tests
│ ├── test_ImbalancedLoss.py <- testing the imbalanced loss for correct parameters, properties, type and shape of outouts and for correct behaviour
│
├── LICENSE
├── Makefile <- Makefile with commands like `make data` or `make train`
├── README.md <- The top-level README for developers using this project.
├── requirements.txt <- The requirements file for reproducing the analysis environment and installing all required packages
├── setup.py <- makes project pip installable (pip install -e .) so src can be imported
├── test_environment.py <- Test for correct python version
├── tox.ini <- Tox file. Run for tests and linting
Project based on the cookiecutter data science project template. #cookiecutterdatascience