-
Notifications
You must be signed in to change notification settings - Fork 1
/
search_fine_tune.py
134 lines (111 loc) · 4.23 KB
/
search_fine_tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from datasets import load_dataset, load_metric
from transformers import (
Trainer,
TrainingArguments,
AutoTokenizer,
AutoModelForSequenceClassification
)
import numpy as np
import torch
import torch.nn as nn
import transformers
import logging
logging.basicConfig(level=logging.INFO)
tokenizer = AutoTokenizer.from_pretrained('roberta-base')
dataset = load_dataset('glue', 'mnli')
metric = load_metric('glue', 'mnli')
dataset_dict = {
"cb": load_dataset('super_glue', name="cb"),
"boolq": load_dataset('super_glue', name="boolq"),
"copa": load_dataset('super_glue', name="copa"),
"multirc": load_dataset('super_glue', name="multirc"),
"record": load_dataset('super_glue', name="record"),
"rte": load_dataset('super_glue', name="rte"),
"wsc": load_dataset('super_glue', name="wsc"),
"wic": load_dataset('super_glue', name="wic")
}
for task_name, dataset in dataset_dict.items():
print(task_name)
print(dataset_dict[task_name]["train"][0])
print()
class MultitaskModel(transformers.PreTrainedModel):
def __init__(self, encoder, taskmodels_dict):
"""
Setting MultitaskModel up as a PretrainedModel allows us
to take better advantage of Trainer features
"""
super().__init__(transformers.PretrainedConfig())
self.encoder = encoder
self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)
@classmethod
def create(cls, model_name, model_type_dict, model_config_dict):
"""
This creates a MultitaskModel using the model class and config objects
from single-task models.
We do this by creating each single-task model, and having them share
the same encoder transformer.
"""
shared_encoder = None
taskmodels_dict = {}
for task_name, model_type in model_type_dict.items():
model = model_type.from_pretrained(
model_name,
config=model_config_dict[task_name],
)
if shared_encoder is None:
shared_encoder = getattr(
model, cls.get_encoder_attr_name(model))
else:
setattr(model, cls.get_encoder_attr_name(
model), shared_encoder)
taskmodels_dict[task_name] = model
return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict)
@classmethod
def get_encoder_attr_name(cls, model):
"""
The encoder transformer is named differently in each model "architecture".
This method lets us get the name of the encoder attribute
"""
model_class_name = model.__class__.__name__
if model_class_name.startswith("Bert"):
return "bert"
elif model_class_name.startswith("Roberta"):
return "roberta"
elif model_class_name.startswith("Deberta"):
return "deberta"
else:
raise KeyError(f"Add support for new model {model_class_name}")
def forward(self, task_name, **kwargs):
return self.taskmodels_dict[task_name](**kwargs)
def encode(examples):
outputs = tokenizer(
examples['sentence1'], examples['sentence2'], truncation=True)
return outputs
encoded_dataset = dataset.map(encode, batched=True)
def model_init():
return AutoModelForSequenceClassification.from_pretrained(
'distilbert-base-uncased', return_dict=True)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = predictions.argmax(axis=-1)
return metric.compute(predictions=predictions, references=labels)
# Evaluate during training and a bit more often
# than the default to be able to prune bad trials early.
# Disabling tqdm is a matter of preference.
training_args = TrainingArguments(
"test", evaluation_strategy="steps", eval_steps=500, disable_tqdm=True)
trainer = Trainer(
args=training_args,
tokenizer=tokenizer,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset["validation"],
model_init=model_init,
compute_metrics=compute_metrics,
)
# Default objective is the sum of all metrics
# when metrics are provided, so we have to maximize it.
trainer.hyperparameter_search(
direction="maximize",
backend="ray",
n_trials=10 # number of trials
)