Skip to content

Latest commit

 

History

History
181 lines (102 loc) · 6.94 KB

README.md

File metadata and controls

181 lines (102 loc) · 6.94 KB

Text Classification

PyTorch re-implementation of some text classificaiton models.

 

Supported Models

Train the following models by editing model_name item in config files (here are some example config files). Click the link of each for details.

 

Requirements

First, make sure your environment is installed with:

  • Python >= 3.5

Then install requirements:

pip install -r requirements.txt

 

Dataset

Currently, the following datasets proposed in this paper are supported:

  • AG News
  • DBpedia
  • Yelp Review Polarity
  • Yelp Review Full
  • Yahoo Answers
  • Amazon Review Full
  • Amazon Review Polarity

All of them can be download here (Google Drive). Click here for details of these datasets.

You should download and unzip them first, then set their path (dataset_path) in your config files. If you would like to use other datasets, they may have to be stored in the same format as the above mentioned datasets.

 

Pre-trained Word Embeddings

If you would like to use pre-trained word embeddings (like GloVe), just set emb_pretrain to True and specify the path to pre-trained vectors (emb_folder and emb_filename) in your config files. You could also choose to fine-tune word embeddings or not with by editing fine_tune_embeddings item.

Or if you want to randomly initialize the embedding layer's weights, set emb_pretrain to False and specify the embedding size (embed_size).

 

Preprocess

Although torchtext can be used to preprocess data easily, it loads all data in one go and occupies too much memory and slows down the training speed, expecially when the dataset is big.

Therefore, here I preprocess the data manually and store them locally first (where configs/test.yaml is the path to your config file):

python preprocess.py --config configs/example.yaml 

Then I load data dynamically using PyTorch's Dataloader when training (see datasets/dataloader.py).

The preprocessing including encoding and padding sentences and building word2ix map. This may takes a little time, but in this way, the training can occupy less memory (which means we can have a large batch size) and take less time. For example, I need 4.6 minutes (on RTX 2080 Ti) to train a fastText model on Yahoo Answers dataset for an epoch using torchtext, but only 41 seconds using Dataloader.

torchtext.py is the script for loading data via torchtext, you can try it if you have interests.

 

Train

To train a model, just run:

python train.py --config configs/example.yaml

If you have enabled the tensorboard (tensorboard: True in config files), you can visualize the losses and accuracies during training by:

tensorboard --logdir=<your_log_dir>

 

Test

Test a checkpoint and compute accuracy on test set:

python test.py --config configs/example.yaml

 

Classify

To predict the category for a specific sentence:

First edit the following items in classify.py:

checkpoint_path = 'str: path_to_your_checkpoint'

# pad limits
# only makes sense when model_name == 'han'
sentence_limit_per_doc = 15
word_limit_per_sentence = 20
# only makes sense when model_name != 'han'
word_limit = 200

Then, run:

python classify.py

 

Performance

Here I report the test accuracy (%) and training time per epoch (on RTX 2080 Ti) of each model on various datasets. Model parameters are not carefully tuned, so better performance can be achieved by some parameter tuning.

Model AG News DBpedia Yahoo Answers
Hierarchical Attention Network 92.7 (45s) 98.2 (70s) 74.5 (2.7m)
fastText 91.6 (8s) 97.9 (25s) 66.7 (41s)
Bi-LSTM + Attention 92.0 (50s) 99.0 (105s) 73.5 (3.4m)
TextCNN 92.2 (24s) 98.5 (100s) 72.8 (4m)
Transformer 92.2 (60s) 98.6 (8.2m) 72.5 (14.5m)

 

Notes

  • The load_embeddings method (in utils/embedding.py) would try to create a cache for loaded embeddings under folder dataset_output_path. This dramatically speeds up the loading time the next time.
  • Only the encoder part of Transformer is used.

 

License

MIT

 

Acknowledgement

This project is based on sgrvinod/a-PyTorch-Tutorial-to-Text-Classification.