Official code for the paper Weight Predictor Network with Feature Selection for Small Sample Tabular Biomedical Data accepted at AAAI Conference on Artificial Intelligence 2023
by Andrei Margeloiu, Nikola Simidjievski, Pietro Lio, Mateja Jamnik
TL;DR: WPFS is a general framework for learning neural networks from high-dimensional and small-sample data by reducing the number of learnable parameters, and performing global feature selection. In addition to the predictor network, WPFS combines two small auxiliary networks: a weight predictor network that outputs the weight matrix of the first layer, and a feature-selection network that serves as an additional mechanism for regularisation.
Paper abstract: Tabular biomedical data is often high-dimensional but with a very small number of samples. Although recent work showed that well-regularised simple neural networks could outperform more sophisticated architectures on tabular data, they are still prone to overfitting on tiny datasets with many potentially irrelevant features. To combat these issues, we propose Weight Predictor Network with Feature Selection (WPFS) for learning neural networks from high-dimensional and small sample data by reducing the number of learnable parameters and simultaneously performing feature selection. In addition to the classification network, WPFS uses two small auxiliary networks that together output the weights of the first layer of the classification model. We evaluate on nine real-world biomedical datasets and demonstrate that WPFS outperforms other standard as well as more recent methods typically applied to tabular data. Furthermore, we investigate the proposed feature selection mechanism and show that it improves performance while providing useful insights into the learning task.
For attribution in academic contexts, please cite this work as
@inproceedings{margeloiu2023weights,
title={Weight Predictor Network with Feature Selection for Small Sample Tabular Biomedical Data},
author={Margeloiu, Andrei and Simidjievski, Nikola and Lio, Pietro and Jamnik, Mateja},
booktitle={37th AAAI Conference on Artificial Intelligence},
year={2023}
}
src
main.py
: code for parsing arguments, and starting experiment- def parse_arguments - include all command-line arguments
- def train - start training model
- important command-line arguments
- dataset
- model
- feature_extractor_dims - the size of the hidden layers in the dnn
- max_steps - maximum training iterations
- batchnorm, dropout_rate
- lr, batch_size, patience_early_stopping
- lr_scheduler - learning rate scheduler
dataset.py
: loading the datasetsmodels.py
: neural network architectures: WPFS, DietNetworks, FsNet and Concrete Autoencodersweights_predictor_network.py
- defines the Weight Predictor Networks (WPN)sparsity_network.py
- defines the Sparsity Network (SPN)
data
- cll, lung, prostate, smk, toxicity
Requirement: All project dependencies are included in requirements.txt
. We assume you have conda installed.
Installing WPFS
conda create python=3.7.9 --name WPFS
conda activate WPFS
pip install -r requirements.txt
Optional: Change BASE_DIR
from /src/_config.py
to point to the project directory on your machine.
Step 1: Run the script run_experiment.sh
Step 2: Analyze the results in the notebook analyze_experiments.ipynb
Adding a new dataset is straightforward:. Search your_custom_dataset
in the codebase and replace it with your dataset.