Code for the paper on robust data pruning.
Requires Python 3+.
- Create a conda environment:
conda env create -f environment.yml
, - Activate the environment:
conda activate environment
.
The project implements both active learning (AL, --strategy 0
) and data pruning (DP, --strategy 1
).
The command line flag --auto_config
fills in the appropriate hyperparameters based on the model specification and is recommended. The general flow of an experiment is as follows:
- Trains a query model (possibly across multiple initializations) and retrieves sample scores;
- Acquires (for AL) or deletes (for DP) samples based on scores and other factors (e.g., class-wise quotas);
- Potentially repeats steps 1-2 across multiple iterations (
--iterations
, common for AL); - Once the ultimate dataset is determined, trains the final model and saves its metrics in a json format.
Here are a few simple usage examples. The commands should be executed from a parent directory of the project's folder.
- Prune 30% of CIFAR-10 using VGG-16 and EL2N scorer:
python -m fair-data-pruning.main --auto_config --use_gpu --strategy 1 --final_frac 0.7 --model_name VGG16 --scorer_name EL2N
- Randomly prune 30% of CIFAR-10 using VGG-16 and MetriQ class-wise ratios with query retrained 5 times:
python -m fair-data-pruning.main --auto_config --use_gpu --strategy 1 --final_frac 0.7 --model_name VGG16 --scorer_name Random --quoter_name MetriQ --num_inits 5
- Prune 30% of CIFAR-10 using VGG-16 and Forgetting, and train the final model with a cost-sensitive optimization algorithm CDB-W :
python -m fair-data-pruning.main --auto_config --use_gpu --cdbw_final --strategy 1 --final_frac 0.7 --model_name VGG16 --scorer_name Random
Coming soon.