Skip to content

Latest commit

 

History

History
199 lines (131 loc) · 8.4 KB

README.md

File metadata and controls

199 lines (131 loc) · 8.4 KB

Easy metric learning

Simple framework for metric learning training. Just set up the configuration file and start training. This framework uses Hydra for configurations management and Accelerate for distributed training.

Models

Backbones

It's possible to use any model from timm, openclip or unicom libraries as a backbone

Margins

Install requirements

Conda environment

Easiest way to work with this repo is to install conda environment from file. First you need to install conda. And then run the following:

conda env create -f environment.yml

This will install a new conda environment with all the required libraries.

Manually

If you want to install all required libraries without conda you can instal them manually. First install pytorch. Next it's possible to install all required libraries using requirements.txt file. Just run:

pip install -r requirements.txt

Prepare dataset

Open source datasets

To download and prepare one of the following open source datasets:

You can just run:

python data/prepare_dataset.py --dataset {dataset_name} --save_path {dataset save path}

--dataset - should be one of ['sop', 'cars', 'cub', 'inshop', 'aliproducts', 'rp2k', 'products10k', 'met', 'hnm', 'finefood', 'shopee', 'inaturalist_2021']

Custom dataset

Easiest way to prepare custom dataset is to orginize dataset into the following structure:

dataset_folder
│   
│   └───class_1
│       │   image1.jpg
│       │   image2.jpg
│       │   ...
│   └───class_2
│       |   image1.jpg
│       │   image2.jpg
│       │   ...
│
│      ...
│
│   └───class_N
│       │   ...

After that you can just run:

python data/generate_dataset_info.py --dataset_path {path to the dataset_folder}

Optional arguments:

--hashes - generate image hashes to use them for duplicates filtering

(Optional) Dataset filtering

If you want to remove duplicates from your dataset you need to generate dataset_info.csv file from previous step with --hashes argument, next run:

python data/filter_dataset_info.py --dataset_info {path to the dataset_info.csv file} --dedup

The script will generate dataset_info_filtered.csv file which you can use in next steps.

Optional arguments:

--min_size - minimal image size. Removes too small images. If image size is less than min_size will remove it from filtered dataset_info file.

--max_size - maximal image size. Removes too large images. If image size is more than max_size will remove it from filtered dataset_info file.

--threshold - threshold for duplicates search indicating the maximum amount of hamming distance that can exist between the key image and a candidate image so that the candidate image can be considered as a duplicate of the key image. Should be an int between 0 and 64. Default value is 10.

(Optional) K-fold split If you want to make stratified k-fold spit on custom dataset you can run:
python data/get_kfold_split.py --dataset_info {path to the dataset_info.csv file}

The script will generate folds.csv file with 'fold' column. It will also generate folds_train_only.csv and folds_test_only.csv to use dataset only for training or only for testing.

Optional arguments:

--k - number of folds (default: 5)

--random_seed - random seed for reproducibility

--save_name - save file name (default: folds)

(Optional) Split dataset on train and test

If you want to use part of the classes for testing and rest for training just run:

python data/get_kfold_split.py --dataset_info {path to the dataset_info.csv file} --split_type {name of split type}

There are several ways to split dataset on train and test:

  1. based on min and max number of samples (split_type = minmax). Classes with number of samples in range [min_n_samples, max_n_samples] will be used for training and rest for testing
  2. based on proportion and frequency (split_type = freq)

Optional script arguments:

--min_n_samples - min number of samples to select class for training (used when split_type == minmax, default: 3)

--max_n_samples - max number of samples to select class for training (used when split_type == minmax, default: 50)

--test_ratio - test classes ratio (used when split_type == freq, default: 0.1)

--min_freq - min number of samples in frequency bin to split bin, if less will add whole bin in training set (used when split_type == freq, default: 10)

--random_seed - random seed for reproducibility

Training

Simple example

In the configs folder you will find main configuration file config.yaml. It contains default training parameters. For example defaul backbone set to efficientnetv2_b1 and default dataset is cars196. If you want to change any of the parameters you can enter it as an argument. Let's say you want to train a model for 10 epochs with backbone openclip_vit_b32.yaml , margin arcface on dataset inshop and evaluate model on products10k with batch size 32. You can change the values in the configuration file or set what you want as an argument and the remaining parameters will be default parameters from config.yaml. You can use the following command:

python tools/train.py backbone=openclip_vit_b32 dataset=inshop evaluation/data=products10k batch_size=32 epochs=10

That's all, the result will be saved in the folder work_dirs.

Configuration files are hierarchical, so you can create and configure separate configurations for individual modules. For example, you can create a new configuration file for a new backbone or for a new loss function. All configurations can be found in the configs folder. Feel free to modify existing configurations or create new ones.

Distributed training

If you want to use MultiGPU, MultiCPU, MultiNPU or TPU training or if you want to use several machines (nodes), you need to set up Accelerate framework. Just run:

accelerate config

All Accelerate lirary configuration parameters could be found here.

Once you have configured Accelerate library you can run distributed training using the following command:

accelerate launch tools/train.py backbone=openclip_vit_l {and other configuration parameters}