Most of the code is taken from here for "TabNet: Attentive Interpretable Tabular Learning" by Sercan O. Arik and Tomas Pfister (paper: https://arxiv.org/abs/1908.07442).
The modified model, reduced TabNet, is defined in model/tabnet_reduced.py
. There are two modifications:
- there is now 1 shared feature transformer and 1 decision-dependent feature transformer (from 2 and 2 before respectively), and
- the SparseMax mask for feature selection has been replaced by EntMax 1.5 (implementation in TensorFlow from here).
The combination of these modifications has improved the performance of TabNet with fewer parameters, particularly with a sharper mask for feature selection.
As in the original repository, this repository contains an example implementation of TabNet on the Forest Covertype dataset (https://archive.ics.uci.edu/ml/datasets/covertype).
To run the script, run run.sh
. Otherwise, a manual approach can be taken as follows.
First, run python download_prepare_covertype.py
to download and prepare the Forest Covertype dataset.
This command creates train.csv
, val.csv
, and test.csv
files under the data/
directory (will create the directory if it does not exist).
To run the pipeline for training and evaluation, simply use python train_classifier.py
. Note that Tensorboard logs are written in tflog/
.
For simplicity, the hyperparameters for both the reduced TabNet and TabNet model are kept the same. These can be found in config/covertype.py
. To set training to reduced TabNet,
set REDUCED = True
, else set REDUCED = False
.
To modify the experiment to other tabular datasets:
- Substitute the
train.csv
,val.csv
, andtest.csv
files underdata/
directory, - Create a new config in
config/
by copyingconfig/covertype.py
for the numerical and categorical features of the new dataset and hyperparameters, - Reoptimize the TabNet hyperparameters for the new dataset in your config,
- Import the parameters in
train_classifier.py
, - Select the reduced TabNet architecture by setting
REDUCED = True
, and - Change
MODEL_NAME
in your config to a name you desire.