-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbert2bert.py
160 lines (119 loc) · 5.79 KB
/
bert2bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# -*- coding: utf-8 -*-
"""BERT2BERT for CNN/Dailymail
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1Ekd5pUeCX7VOrMx94_czTkwNtLN32Uyu
## **Warm-starting BERT2BERT for CNN/Dailymail**
***Note***: This notebook only uses a few training, validation, and test data samples for demonstration purposes. To fine-tune an encoder-decoder model on the full training data, the user should change the training and data preprocessing parameters accordingly as highlighted by the comments.
### **Data Preprocessing**
"""
import datasets
import transformers
import rouge_score
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token
train_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train", ignore_verifications=True)
val_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation",ignore_verifications=True)
batch_size=16 # change to 16 for full training
encoder_max_length=512
decoder_max_length=128
def process_data_to_model_inputs(batch):
# tokenize the inputs and labels
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=encoder_max_length)
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=decoder_max_length)
batch["input_ids"] = inputs.input_ids
batch["attention_mask"] = inputs.attention_mask
batch["decoder_input_ids"] = outputs.input_ids
batch["decoder_attention_mask"] = outputs.attention_mask
batch["labels"] = outputs.input_ids.copy()
# because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`.
# We have to make sure that the PAD token is ignored
batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]
return batch
# only use 32 training examples for notebook - DELETE LINE FOR FULL TRAINING
train_data = train_data.map(
process_data_to_model_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights", "id"]
)
train_data.set_format(
type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)
# only use 16 training examples for notebook - DELETE LINE FOR FULL TRAINING
val_data = val_data.map(
process_data_to_model_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights", "id"]
)
val_data.set_format(
type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)
"""### **Warm-starting the Encoder-Decoder Model**"""
from transformers import EncoderDecoderModel
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
# set special tokens
bert2bert.config.decoder_start_token_id = tokenizer.bos_token_id
bert2bert.config.eos_token_id = tokenizer.eos_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id
# sensible parameters for beam search
bert2bert.config.vocab_size = bert2bert.config.decoder.vocab_size
bert2bert.config.max_length = 142
bert2bert.config.min_length = 56
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4
"""### **Fine-Tuning Warm-Started Encoder-Decoder Models**
For the `EncoderDecoderModel` framework, we will use the `Seq2SeqTrainingArguments` and the `Seq2SeqTrainer`. Let's import them.
"""
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
"""Also, we need to define a function to correctly compute the ROUGE score during validation. ROUGE is a much better metric to track during training than only language modeling loss."""
# load rouge for validation
rouge = datasets.load_metric("rouge")
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
# all unnecessary tokens are removed
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
return {
"rouge2_precision": round(rouge_output.precision, 4),
"rouge2_recall": round(rouge_output.recall, 4),
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
}
"""Cool! Finally, we start training."""
# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
output_dir="./",
evaluation_strategy="steps",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
predict_with_generate=True,
logging_steps=1000, # set to 1000 for full training
save_steps=500, # set to 500 for full training
eval_steps=8000, # set to 8000 for full training
warmup_steps=2000, # set to 2000 for full training
# delete for full training
overwrite_output_dir=True,
save_total_limit=3,
fp16=True,
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=bert2bert,
tokenizer=tokenizer,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_data,
eval_dataset=val_data,
)
trainer.train()
"""### **Evaluation**
Awesome, we finished training our dummy model. Let's now evaluated the model on the test data. We make use of the dataset's handy `.map()` function to generate a summary of each sample of the test data.
"""