Skip to content

Commit

Permalink
Updated classification model tokenization logic. Added deberta, mpnet…
Browse files Browse the repository at this point in the history
…, squeezenet for classification
  • Loading branch information
Thilina Rajapakse committed Feb 1, 2021
1 parent f56302d commit 6f189e0
Show file tree
Hide file tree
Showing 16 changed files with 329 additions and 59 deletions.
20 changes: 19 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.60.0] - 2021-02-02

# Added

- Added class weights support for Longformer classification
- Added new classification models:
- SqueezeBert
- DeBERTa
- MPNet

# Changed

- Updated ClassificationModel logic to make it easier to add new models

## [0.51.16] - 2021-01-29

## Fixed
Expand Down Expand Up @@ -1386,7 +1400,11 @@ Model checkpoint is now saved for all epochs again.

- This CHANGELOG file to hopefully serve as an evolving example of a standardized open source project CHANGELOG.

[0.51.15]: https://github.com/ThilinaRajapakse/simpletransformers/compare/2af55e9...HEAD
[0.60.0]: https://github.com/ThilinaRajapakse/simpletransformers/compare/5840749...HEAD

[0.51.16]: https://github.com/ThilinaRajapakse/simpletransformers/compare/b42898e...5840749

[0.51.15]: https://github.com/ThilinaRajapakse/simpletransformers/compare/2af55e9...b42898e

[0.51.14]: https://github.com/ThilinaRajapakse/simpletransformers/compare/278fca1...2af55e9

Expand Down
37 changes: 20 additions & 17 deletions docs/_docs/04-classification-specifics.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
title: Classification Specifics
permalink: /docs/classification-specifics/
excerpt: "Specific notes for text classification tasks."
last_modified_at: 2020/12/21 22:13:56
last_modified_at: 2021/02/02 02:03:09
toc: true
---

Expand Down Expand Up @@ -32,22 +32,25 @@ The process of performing text classification in Simple Transformers does not de

New model types are regularly added to the library. Text classification tasks currently supports the model types given below.

| Model | Model code for `ClassificationModel` |
| ----------- | ------------------------------------ |
| ALBERT | albert |
| BERT | bert |
| BERTweet | bertweet |
| CamemBERT | camembert |
| RoBERTa | roberta |
| DistilBERT | distilbert |
| ELECTRA | electra |
| FlauBERT | flaubert |
| *LayoutLM | layoutlm |
| Longformer | longformer |
| *MobileBERT | mobilebert |
| XLM | xlm |
| XLM-RoBERTa | xlmroberta |
| XLNet | xlnet |
| Model | Model code for `ClassificationModel` |
| ------------ | ------------------------------------ |
| ALBERT | albert |
| BERT | bert |
| BERTweet | bertweet |
| CamemBERT | camembert |
| *DeBERTa | deberta |
| DistilBERT | distilbert |
| ELECTRA | electra |
| FlauBERT | flaubert |
| LayoutLM | layoutlm |
| *Longformer | longformer |
| *MPNet | mpnet |
| MobileBERT | mobilebert |
| RoBERTa | roberta |
| *SqueezeBert | squeezebert |
| XLM | xlm |
| XLM-RoBERTa | xlmroberta |
| XLNet | xlnet |

\* *Not available with Multi-label classification*

Expand Down
2 changes: 1 addition & 1 deletion examples/t5/mt5_translation/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@
english_preds = model.predict(to_english)

sin_eng_bleu = sacrebleu.corpus_bleu(english_preds, english_truth)
print("Sinhalese to English: ", sin_eng_bleu.score)
print("Sinhalese to English: ", sin_eng_bleu.score)
2 changes: 1 addition & 1 deletion examples/t5/training_on_a_new_task/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
"wandb_project": "Question Generation with T5",
}

model = T5Model("t5","t5-large",args=model_args)
model = T5Model("t5", "t5-large", args=model_args)

model.train_model(train_df, eval_data=eval_df)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="simpletransformers",
version="0.51.16",
version="0.60.0",
author="Thilina Rajapakse",
author_email="chaturangarajapakshe@gmail.com",
description="An easy-to-use wrapper library for the Transformers library.",
Expand Down
84 changes: 67 additions & 17 deletions simpletransformers/classification/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
mean_squared_error,
roc_curve,
auc,
average_precision_score
average_precision_score,
)
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
Expand All @@ -46,11 +46,17 @@
from transformers import (
AlbertConfig,
AlbertTokenizer,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
BertConfig,
BertTokenizer,
BertweetTokenizer,
CamembertConfig,
CamembertTokenizer,
DebertaConfig,
DebertaForSequenceClassification,
DebertaTokenizer,
DistilBertConfig,
DistilBertTokenizer,
ElectraConfig,
Expand All @@ -60,15 +66,17 @@
LayoutLMConfig,
LayoutLMTokenizer,
LongformerConfig,
LongformerForSequenceClassification,
LongformerTokenizer,
MPNetConfig,
MPNetForSequenceClassification,
MPNetTokenizer,
MobileBertConfig,
MobileBertForSequenceClassification,
MobileBertTokenizer,
ReformerConfig,
ReformerTokenizer,
RobertaConfig,
RobertaTokenizer,
SqueezeBertConfig,
SqueezeBertForSequenceClassification,
SqueezeBertTokenizer,
WEIGHTS_NAME,
XLMConfig,
XLMRobertaConfig,
Expand All @@ -91,6 +99,8 @@
from simpletransformers.classification.transformer_models.distilbert_model import DistilBertForSequenceClassification
from simpletransformers.classification.transformer_models.flaubert_model import FlaubertForSequenceClassification
from simpletransformers.classification.transformer_models.layoutlm_model import LayoutLMForSequenceClassification
from simpletransformers.classification.transformer_models.longformer_model import LongformerForSequenceClassification
from simpletransformers.classification.transformer_models.mobilebert_model import MobileBertForSequenceClassification
from simpletransformers.classification.transformer_models.roberta_model import RobertaForSequenceClassification
from simpletransformers.classification.transformer_models.xlm_model import XLMForSequenceClassification
from simpletransformers.classification.transformer_models.xlm_roberta_model import XLMRobertaForSequenceClassification
Expand All @@ -100,7 +110,6 @@
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.custom_models.models import ElectraForSequenceClassification

from transformers.models.reformer import ReformerForSequenceClassification

try:
import wandb
Expand All @@ -112,11 +121,22 @@
logger = logging.getLogger(__name__)


MODELS_WITHOUT_CLASS_WEIGHTS_SUPPORT = ["squeezebert", "deberta", "mpnet"]

MODELS_WITH_EXTRA_SEP_TOKEN = ["roberta", "camembert", "xlmroberta", "longformer", "mpnet"]

MODELS_WITH_ADD_PREFIX_SPACE = ["roberta", "camembert", "xlmroberta", "longformer", "mpnet"]

MODELS_WITHOUT_SLIDING_WINDOW_SUPPORT = ["squeezebert"]


class ClassificationModel:
def __init__(
self,
model_type,
model_name,
tokenizer_type=None,
tokenizer_name=None,
num_labels=None,
weight=None,
args=None,
Expand All @@ -132,6 +152,9 @@ def __init__(
Args:
model_type: The type of model (bert, xlnet, xlm, roberta, distilbert)
model_name: The exact architecture and trained weights to use. This may be a Hugging Face Transformers compatible pre-trained model, a community model, or the path to a directory containing model files.
tokenizer_type: The type of tokenizer (auto, bert, xlnet, xlm, roberta, distilbert, etc.) to use. If a string is passed, Simple Transformers will try to initialize a tokenizer class from the available MODEL_CLASSES.
Alternatively, a Tokenizer class (subclassed from PreTrainedTokenizer) can be passed.
tokenizer_name: The name/path to the tokenizer. If the tokenizer_type is not specified, the model_type will be used to determine the type of the tokenizer.
num_labels (optional): The number of labels or classes in the dataset.
weight (optional): A list of length num_labels containing the weights to assign to each label for loss calculation.
args (optional): Default args will be used if this parameter is not provided. If provided, it should be a dict containing the args that should be changed in the default args.
Expand All @@ -143,17 +166,20 @@ def __init__(

MODEL_CLASSES = {
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
"auto": (AutoConfig, AutoModelForSequenceClassification, AutoTokenizer),
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
"bertweet": (RobertaConfig, RobertaForSequenceClassification, BertweetTokenizer),
"camembert": (CamembertConfig, CamembertForSequenceClassification, CamembertTokenizer),
"deberta": (DebertaConfig, DebertaForSequenceClassification, DebertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
"electra": (ElectraConfig, ElectraForSequenceClassification, ElectraTokenizer),
"flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer),
"layoutlm": (LayoutLMConfig, LayoutLMForSequenceClassification, LayoutLMTokenizer),
"longformer": (LongformerConfig, LongformerForSequenceClassification, LongformerTokenizer),
"mobilebert": (MobileBertConfig, MobileBertForSequenceClassification, MobileBertTokenizer),
"reformer": (ReformerConfig, ReformerForSequenceClassification, ReformerTokenizer),
"mpnet": (MPNetConfig, MPNetForSequenceClassification, MPNetTokenizer),
"roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
"squeezebert": (SqueezeBertConfig, SqueezeBertForSequenceClassification, SqueezeBertTokenizer),
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
"xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
"xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
Expand All @@ -166,6 +192,9 @@ def __init__(
elif isinstance(args, ClassificationArgs):
self.args = args

if model_type in MODELS_WITHOUT_SLIDING_WINDOW_SUPPORT and self.args.sliding_window:
raise ValueError("{} does not currently support sliding window".format(model_type))

if self.args.thread_count:
torch.set_num_threads(self.args.thread_count)

Expand Down Expand Up @@ -200,13 +229,24 @@ def __init__(
self.args.labels_list = [i for i in range(len_labels_list)]

config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type]

if tokenizer_type is not None:
if isinstance(tokenizer_type, str):
_, _, tokenizer_class = MODEL_CLASSES[tokenizer_type]
else:
tokenizer_class = tokenizer_type

if num_labels:
self.config = config_class.from_pretrained(model_name, num_labels=num_labels, **self.args.config)
self.num_labels = num_labels
else:
self.config = config_class.from_pretrained(model_name, **self.args.config)
self.num_labels = self.config.num_labels
self.weight = weight

if model_type in MODELS_WITHOUT_CLASS_WEIGHTS_SUPPORT and weight is not None:
raise ValueError("{} does not currently support class weights".format(model_type))
else:
self.weight = weight

if use_cuda:
if torch.cuda.is_available():
Expand Down Expand Up @@ -275,17 +315,20 @@ def __init__(
except AttributeError:
raise AttributeError("fp16 requires Pytorch >= 1.6. Please update Pytorch or turn off fp16.")

if model_name in [
if tokenizer_name is None:
tokenizer_name = model_name

if tokenizer_name in [
"vinai/bertweet-base",
"vinai/bertweet-covid19-base-cased",
"vinai/bertweet-covid19-base-uncased",
]:
self.tokenizer = tokenizer_class.from_pretrained(
model_name, do_lower_case=self.args.do_lower_case, normalization=True, **kwargs
tokenizer_name, do_lower_case=self.args.do_lower_case, normalization=True, **kwargs
)
else:
self.tokenizer = tokenizer_class.from_pretrained(
model_name, do_lower_case=self.args.do_lower_case, **kwargs
tokenizer_name, do_lower_case=self.args.do_lower_case, **kwargs
)

if self.args.special_tokens_list:
Expand All @@ -294,6 +337,8 @@ def __init__(

self.args.model_name = model_name
self.args.model_type = model_type
self.args.tokenizer_name = tokenizer_name
self.args.tokenizer_type = tokenizer_type

if model_type in ["camembert", "xlmroberta"]:
warnings.warn(
Expand Down Expand Up @@ -1184,7 +1229,7 @@ def load_and_cache_examples(
sep_token=tokenizer.sep_token,
# RoBERTa uses an extra separator b/w pairs of sentences,
# cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
sep_token_extra=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]),
sep_token_extra=args.model_type in MODELS_WITH_EXTRA_SEP_TOKEN,
# PAD on the left for XLNet
pad_on_left=bool(args.model_type in ["xlnet"]),
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
Expand All @@ -1196,7 +1241,7 @@ def load_and_cache_examples(
sliding_window=args.sliding_window,
flatten=not evaluate,
stride=args.stride,
add_prefix_space=bool(args.model_type in ["roberta", "camembert", "xlmroberta", "longformer"]),
add_prefix_space=args.model_type in MODELS_WITH_ADD_PREFIX_SPACE,
# avoid padding in case of single example/online inferencing to decrease execution time
pad_to_max_length=bool(len(examples) > 1),
args=args,
Expand Down Expand Up @@ -1236,8 +1281,10 @@ def load_and_cache_examples(
else:
return dataset
else:
train_dataset = ClassificationDataset(examples, self.tokenizer, self.args, mode=mode, multi_label=multi_label, output_mode=output_mode)
return train_dataset
dataset = ClassificationDataset(
examples, self.tokenizer, self.args, mode=mode, multi_label=multi_label, output_mode=output_mode
)
return dataset

def compute_metrics(self, preds, model_outputs, labels, eval_examples=None, multi_label=False, **kwargs):
"""
Expand Down Expand Up @@ -1302,7 +1349,10 @@ def compute_metrics(self, preds, model_outputs, labels, eval_examples=None, mult
auroc = auc(fpr, tpr)
auprc = average_precision_score(labels, scores)
return (
{**{"mcc": mcc, "tp": tp, "tn": tn, "fp": fp, "fn": fn, "auroc": auroc, "auprc": auprc}, **extra_metrics},
{
**{"mcc": mcc, "tp": tp, "tn": tn, "fp": fp, "fn": fn, "auroc": auroc, "auprc": auprc},
**extra_metrics,
},
wrong,
)
else:
Expand Down Expand Up @@ -1575,7 +1625,7 @@ def _move_model_to_device(self):

def _get_inputs_dict(self, batch):
if isinstance(batch[0], dict):
inputs = {key: value.squeeze().to(self.device) for key, value in batch[0].items()}
inputs = {key: value.squeeze(1).to(self.device) for key, value in batch[0].items()}
inputs["labels"] = batch[1].to(self.device)
else:
batch = tuple(t.to(self.device) for t in batch)
Expand Down
6 changes: 3 additions & 3 deletions simpletransformers/classification/classification_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ def preprocess_data(data):
max_length=args.max_seq_length,
truncation=True,
padding="max_length",
return_tensors="pt"
return_tensors="pt",
)
else:
tokenized_example = tokenizer.encode_plus(
text=example.text_a,
max_length=args.max_seq_length,
truncation=True,
padding="max_length",
return_tensors="pt"
return_tensors="pt",
)

return {**tokenized_example, "label": example.label}
Expand Down Expand Up @@ -600,7 +600,7 @@ def __init__(
self.data = [
dict(
json.load(open(os.path.join(data_path, l + self.data_type_extension))),
**{"images": l + image_type_extension}
**{"images": l + image_type_extension},
)
for l in files_list
]
Expand Down
Loading

0 comments on commit 6f189e0

Please sign in to comment.