Authors: Youqian Gao, Ben Dai
Word-level maximum mean discrepancy (WMMD) is a novel regularization framework that recognizes and accounts for the "word-level distribution discrepancy"–a common phenomenon in a range of NLP datasets where word distributions are noticeably disparate under different labels. The regularization serves a specific purpose: to enhance/preserve the distribution discrepancies within word embedding numerical vectors and thus prevent overfitting.
We visualize the embedding of a CNN model trained on the CE-T1 dataset, with different regularization. The WMMD regularization separates the word vectors under different labels and enforces the word-level discrepancy between the two groups.
Furthermore, WMMD regularization offers a flexible framework to incorporate prior information and high-order distribution discrepancies of words. For more information, kindly refer to the paper.
The following word clouds show significant differences in the most common words between the "sports" and "business" categories in the BBC News dataset. Given the word-level distribution discrepancy in the dataset, a good word embedding should result in numerical word representations that maintain the word-level distribution discrepancy.
pip install -r requirements.txt
python main.py dataset=<dataset> model=<model> regularizer=<regularizer>
For more details, refer to python main.py --help
. Currently, the learning framework supports the following configurations:
- dataset: bbc-news, ce-t1, sim1, sim2, sim3 (check the paper for more details on three simulation examples)
- model: cnn, bilstm, gru, mlp, logistic
- regularizer: wmmd, swmmd (structured wmmd), biwmmd (bigram wmmd), dropout, l1, none