-
Notifications
You must be signed in to change notification settings - Fork 13
/
hugnlp_trainer.py
658 lines (560 loc) · 28.8 KB
/
hugnlp_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
# -*- coding: utf-8 -*-
# @Time : 2022/1/7 3:07 下午
# @Author : JianingWang
# @File : HugTrainer
"""
This file is the runner of HugNLP.
Use HugTrainer to perform task training and evaluating.
"""
import os
import torch
from torch import nn
import numpy as np
from tqdm import tqdm
from packaging import version
import datasets
from datasets import Dataset
from processors.dataset import DatasetK
from torch.utils.data import RandomSampler, DistributedSampler
from typing import Dict, Union, Any, Optional, Callable, List, Tuple, Iterator, OrderedDict
from transformers import PreTrainedModel, DataCollator, PreTrainedTokenizerBase, EvalPrediction, TrainerCallback
from transformers.trainer_pt_utils import DistributedSamplerWithLoop, get_length_grouped_indices
from transformers.trainer_pt_utils import DistributedLengthGroupedSampler as DistributedLengthGroupedSamplerOri
from transformers.trainer_pt_utils import LengthGroupedSampler as LengthGroupedSamplerOri
# from transformers.trainer_utils import has_length
from transformers.training_args import ParallelMode
from transformers.trainer import Trainer
from transformers.trainer import *
from transformers.trainer_utils import denumpify_detensorize, TrainOutput
from config import TrainingArguments
from transformers.file_utils import is_datasets_available
from models.adversarial import FGM
from tools.processing_utils.sampler import random_sampling
from tools.model_utils.uncertainty import sample_by_bald_class_easiness
from tools.runner_utils.log_util import logging
logger = logging.getLogger(__name__)
if version.parse(torch.__version__) >= version.parse("1.6"):
_is_torch_generator_available = True
WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
class LengthGroupedSampler(LengthGroupedSamplerOri):
def __iter__(self):
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator, mega_batch_mult=256)
return iter(indices)
class DistributedLengthGroupedSampler(DistributedLengthGroupedSamplerOri):
def __iter__(self) -> Iterator:
# Deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g, mega_batch_mult=400)
if not self.drop_last:
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank: self.total_size: self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
"""
Trainer for running HugNLP
"""
class HugTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
):
super(HugTrainer, self).__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers)
self.metric_for_best_model = self.args.metric_for_best_model
if self.args.do_adv:
self.fgm = FGM(self.model)
for callback in callbacks:
callback.trainer = self
self.best_metrics = OrderedDict({
"best_epoch": 0,
f"best_eval_{self.metric_for_best_model}": 0,
})
self.global_step_ = 0
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model"s documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
self.global_step_ += 1
model.train()
inputs = self._prepare_inputs(inputs)
with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1 or len(loss.shape) > 0:
# 如果是多GPU,或者当前的loss是一个tensor列表
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps
if self.global_step_ % 10 == 0:
print("[step={}, loss={}]".format(self.global_step_, loss))
if self.do_grad_scaling:
self.scaler.scale(loss).backward()
elif self.deepspeed:
# loss gets scaled under gradient_accumulation_steps in deepspeed
loss = self.deepspeed.backward(loss)
else:
loss.backward()
# 对抗训练
if self.args.do_adv:
self.fgm.attack()
with self.autocast_smart_context_manager():
loss_adv = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss_adv = loss_adv.mean()
if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
loss_adv = loss_adv / self.args.gradient_accumulation_steps
if self.do_grad_scaling:
self.scaler.scale(loss_adv).backward()
else:
loss_adv.backward()
self.fgm.restore() # 恢复embedding参数
return loss.detach()
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
# if not has_length(self.train_dataset):
# return None
generator = None
if self.args.world_size <= 1 and _is_torch_generator_available:
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
# Build the sampler.
if self.args.group_by_length:
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
lengths = (
self.train_dataset[self.args.length_column_name]
if self.args.length_column_name in self.train_dataset.column_names
else None
)
else:
lengths = None
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if self.args.world_size <= 1:
return LengthGroupedSampler(
self.args.train_batch_size * self.args.gradient_accumulation_steps,
dataset=self.train_dataset,
lengths=lengths,
model_input_name=model_input_name,
generator=generator,
)
else:
return DistributedLengthGroupedSampler(
self.args.train_batch_size * self.args.gradient_accumulation_steps,
dataset=self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
lengths=lengths,
model_input_name=model_input_name,
seed=self.args.seed,
)
else:
if self.args.world_size <= 1:
if _is_torch_generator_available:
return RandomSampler(self.train_dataset, generator=generator)
return RandomSampler(self.train_dataset)
elif (
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
and not self.args.dataloader_drop_last
):
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
return DistributedSamplerWithLoop(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
else:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
"""
User for calculate mc
"""
def mc_evaluate(
self,
unlabeled_dataset: Optional[Dataset] = None,
unlabeled_data_num: int = -1,
description: str = "Evaluate on Unlabeled Data via MC Dropout Uncertainty Estimation",
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
T: int = 30,
num_classes: int = 0
):
"""
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
args = self.args
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
is_sample = True
if unlabeled_data_num == -1 or unlabeled_data_num >= len(unlabeled_dataset):
unlabeled_data_num = len(unlabeled_dataset)
is_sample = False
if is_sample:
recalled_examples_idx_list = random_sampling(
raw_datasets=unlabeled_dataset,
num_examples_per_label=unlabeled_data_num // num_classes
)
unlabeled_dataset = unlabeled_dataset.select(recalled_examples_idx_list)
unlabeled_data_num = len(unlabeled_dataset)
unlabeled_dataloader = self.get_eval_dataloader(unlabeled_dataset)
model = self._wrap_model(self.model, training=True, dataloader=unlabeled_dataloader) # reset training to True
batch_size = unlabeled_dataloader.batch_size
# unlabeled_data_num = self.num_examples(unlabeled_dataloader)
logger.info(f"***** Running {description} *****")
logger.info(f" Num examples = {unlabeled_data_num}")
logger.info(f" Batch size = {batch_size}")
model.train() # 开启train模式,允许模型进行Dropout
if args.past_index >= 0:
self._past = None
self.callback_handler.eval_dataloader = unlabeled_dataloader
# y_T = np.zeros((T, unlabeled_data_num, num_classes))
y_T = list()
for i in tqdm(range(T)):
y_pred = []
for step, inputs in enumerate(unlabeled_dataloader):
_, logits, __ = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
y_pred.extend(logits.detach().cpu().numpy().tolist())
# print("y_pred.shape=", torch.Tensor(y_pred).shape) # [n, num_class]
predict_proba = torch.softmax(torch.Tensor(y_pred).to(logits.device), -1)
# print("predict_proba.shape=", predict_proba.shape) # [n, num_class]
# y_T[i] = predict_proba.detach().cpu().numpy().tolist()
y_T.append(predict_proba.detach().cpu().numpy().tolist())
y_T = np.array(y_T)
#compute mean
y_mean = np.mean(y_T, axis=0)
# print("y_mean.shape=", y_mean.shape) # e.g., (4095, 3) [n, class_num]
# print("(unlabeled_data_num, num_classes)=", (unlabeled_data_num, num_classes))
assert y_mean.shape == (unlabeled_data_num, num_classes)
#compute majority prediction
y_pred = np.array([np.argmax(np.bincount(row)) for row in np.transpose(np.argmax(y_T, axis=-1))])
assert y_pred.shape == (unlabeled_data_num,)
#compute variance
y_var = np.var(y_T, axis=0)
assert y_var.shape == (unlabeled_data_num, num_classes)
return unlabeled_dataset, y_mean, y_var, y_pred, y_T
"""
Self-trainer for self-training HugNLP
"""
class HugSelfTrainer(object):
def __init__(
self,
teacher_base_model: torch.nn.Module,
student_base_model: torch.nn.Module,
training_args,
semi_training_args,
train_dataset: Optional[Dataset]=None,
unlabeled_dataset: Optional[Dataset]=None,
eval_dataset=None,
compute_metrics=None,
tokenizer=None,
teacher_data_collator=None,
student_data_collator=None,
num_classes=0,
) -> None:
logger.info("This is a Self-trainer.")
self.teacher_base_model = teacher_base_model
self.student_base_model = student_base_model
self.training_args = training_args
self.metric_for_best_model = self.training_args.metric_for_best_model
self.semi_training_args = semi_training_args
self.train_dataset = train_dataset
self.unlabeled_dataset = unlabeled_dataset
self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics
self.tokenizer = tokenizer
self.teacher_data_collator = teacher_data_collator
self.student_data_collator = student_data_collator
self.num_classes = num_classes
# self.set_teacher_trainer()
# self.set_student_trainer()
self.training_args.per_device_train_batch_size = self.semi_training_args.unlabeled_data_batch_size
self.teacher_training_epoch = self.semi_training_args.teacher_training_epoch # 最初teacher模型在labeled data上训练的epoch数
self.teacher_tuning_epoch = self.semi_training_args.teacher_tuning_epoch # 每一轮Self-training时,teacher模型继续在labeled data上tune的epoch数
self.student_training_epoch = self.semi_training_args.student_training_epoch # 每一轮Self-training时,student模型在pseudo-labeled data上训练的epoch数
self.self_training_epoch = self.semi_training_args.self_training_epoch # Self-training迭代数
self.unlabeled_data_num = self.semi_training_args.unlabeled_data_num # self-training每轮迭代时,首先挑选一部分用于计算MC dropout uncertainty。-1表示全部计算uncertainty
self.pseudo_sample_num_or_ratio = self.semi_training_args.pseudo_sample_num_or_ratio # MC dropout后,从所有计算过uncertainty的unlabeled data上采样的样本比例/数量
self.student_learning_rate = self.semi_training_args.student_learning_rate
self.output_dir = self.training_args.output_dir
def get_teacher_trainer(
self,
base_model: torch.nn.Module,
num_train_epochs: int,
output_dir: str = None,
):
training_args = self.training_args
training_args.num_train_epochs = num_train_epochs
if output_dir is not None:
training_args.output_dir = output_dir
# 初始化Teacher训练器
teacher_trainer = HugTrainer(
model=base_model,
args=training_args,
train_dataset=self.train_dataset if self.training_args.do_train else None,
eval_dataset=self.eval_dataset if self.training_args.do_eval else None,
compute_metrics=self.compute_metrics,
tokenizer=self.tokenizer,
data_collator=self.teacher_data_collator,
)
return teacher_trainer
def get_student_trainer(
self,
base_model: torch.nn.Module,
num_train_epochs: int,
student_learning_rate: float,
pseudo_labeled_dataset: Optional[Dataset] = None,
output_dir: str = None,
):
training_args = self.training_args
training_args.num_train_epochs = num_train_epochs
training_args.learning_rate = student_learning_rate
if output_dir is not None:
training_args.output_dir = output_dir
# 初始化Student训练器
student_trainer = HugTrainer(
model=base_model,
args=training_args,
train_dataset=pseudo_labeled_dataset,
eval_dataset=self.eval_dataset,
compute_metrics=self.compute_metrics,
tokenizer=self.tokenizer,
data_collator=self.student_data_collator,
)
return student_trainer
def train(self, resume_from_checkpoint=None):
if not os.path.exists(os.path.join(self.output_dir, "iteration")):
os.makedirs(os.path.join(self.output_dir, "iteration"))
teacher_model = self.teacher_base_model
teacher_trainer: HugTrainer = self.get_teacher_trainer(base_model=teacher_model, num_train_epochs=self.teacher_training_epoch)
if resume_from_checkpoint is not None and (os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) or os.path.isfile(
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME))
):
logger.info("*"*80)
logger.info("* Directly loading the trained teacher model from {} *".format(resume_from_checkpoint))
logger.info("*"*80)
print("*"*80)
logger.info("* Directly loading the trained teacher model from {} *".format(resume_from_checkpoint))
print("*"*80)
# 已有teacher模型,直接加载
teacher_trainer._load_from_checkpoint(resume_from_checkpoint)
else:
# 首先对Teacher模型在labeled data上进行full parameter fine-tuning
logger.info("*"*66)
logger.info("* Training teacher model over labeled data before self-training. *")
logger.info("*"*66)
print("*"*66)
print("* Training teacher model over labeled data before self-training. *")
print("*"*66)
teacher_trainer.train()
teacher_model.load_state_dict(torch.load(os.path.join(teacher_trainer.state.best_model_checkpoint, "pytorch_model.bin")))
teacher_trainer.model = teacher_model
# 原始的训练结果
metrics = teacher_trainer.evaluate()
convention_result = metrics["eval_{}".format(self.metric_for_best_model)]
logger.info("*"*50)
logger.info("* Conventional fine-tuning metric: {}. *".format(convention_result))
logger.info("*"*50)
print("*"*50)
print("* Conventional fine-tuning metric: {}. *".format(convention_result))
print("*"*50)
logger.info("*"*30)
logger.info("* Starting Self-training ... *")
logger.info("*"*30)
print("*"*30)
print("* Starting Self-training ... *")
print("*"*30)
best_test_metric = None
best_self_training_iteration = None
best_teacher_model = None
# 多轮Teacher-Student迭代训练
for iter in range(self.self_training_epoch):
logger.info("*"*34)
logger.info("* Self-training {}-th iteration *".format(iter))
logger.info("*"*34)
print("*"*34)
print("* Self-training {}-th iteration *".format(iter))
print("*"*34)
# 获得Teacher模型在测试集上的效果
if iter > 0:
teacher_trainer.model = teacher_model
metrics = teacher_trainer.evaluate()
# print("metrics=", metrics)
'''
e.g., {'eval_loss': 0.6926815509796143, 'eval_accuracy': 0.5234657039711191, 'eval_runtime': 0.7267, 'eval_samples_per_second': 381.161, 'eval_steps_per_second': 48.161, 'epoch': 1.0}
'''
logger.info("*"*60)
logger.info("* The testing result of teacher model is {} result: {} *".format(self.metric_for_best_model, metrics["eval_{}".format(self.metric_for_best_model)]))
logger.info("*"*60)
print("*"*60)
print("* The testing result of teacher model is {} result: {} *".format(self.metric_for_best_model, metrics["eval_{}".format(self.metric_for_best_model)]))
print("*"*60)
if best_test_metric is None or best_test_metric < metrics["eval_{}".format(self.metric_for_best_model)]:
best_test_metric = metrics["eval_{}".format(self.metric_for_best_model)]
best_self_training_iteration = iter
best_teacher_model = teacher_model
logger.info("The best teacher model at {}-th self-training iteration.".format(best_self_training_iteration))
logger.info("The best teacher model testing result is {}.".format(best_test_metric))
print("The best teacher model at {}-th self-training iteration.".format(best_self_training_iteration))
print("The best teacher model testing result is {}.".format(best_test_metric))
if iter == self.self_training_epoch - 1:
break
# Teacher模型在unlabeled data上获取pseudo-labeled data,并根据uncertainty estimation进行采样
logger.info("*"*72)
logger.info("Obtaining pseudo-labeled data and uncertainty estimation via MC dropout.")
logger.info("*"*72)
print("*"*72)
print("Obtaining pseudo-labeled data and uncertainty estimation via MC dropout.")
print("*"*72)
unlabeled_dataset, y_mean, y_var, y_pred, y_T = teacher_trainer.mc_evaluate(
unlabeled_dataset=self.unlabeled_dataset,
unlabeled_data_num=self.unlabeled_data_num,
T=20,
num_classes=self.num_classes
)
logger.info("*"*42)
logger.info("* Sampling reliable pseudo-labeled data. *")
logger.info("*"*42)
print("*"*42)
print("* Sampling reliable pseudo-labeled data. *")
print("*"*42)
X_batch, y_batch, _ = sample_by_bald_class_easiness(
tokenizer=self.tokenizer,
X=unlabeled_dataset,
y_mean=y_mean,
y_var=y_var,
y=y_pred,
num_samples=int(y_pred.shape[0] * self.pseudo_sample_num_or_ratio) if self.pseudo_sample_num_or_ratio <= 1.0 else int(self.pseudo_sample_num_or_ratio),
num_classes=self.num_classes,
y_T=y_T)
pseudo_labeled_examples = X_batch
pseudo_labeled_examples["label"] = y_batch
# 生成pseudo-labeled dataset,并与labeled data混合
# pseudo_labeled_dataset = DatasetDict()
pseudo_labeled_dataset = DatasetK.from_dict(pseudo_labeled_examples)
for i in range(len(self.train_dataset)):
pseudo_labeled_dataset = pseudo_labeled_dataset.add_item(self.train_dataset[i])
# 初始化一个新的Student模型,并让Student模型在pseudo-labeled data上进行鲁棒学习
logger.info("*"*56)
logger.info("* Training a new student model on pseudo-labeled data. *")
logger.info("*"*56)
print("*"*56)
print("* Training a new student model on pseudo-labeled data. *")
print("*"*56)
student_model = self.student_base_model
student_trainer: HugTrainer = self.get_student_trainer(
base_model=student_model,
num_train_epochs=self.student_training_epoch,
student_learning_rate=self.student_learning_rate,
pseudo_labeled_dataset=pseudo_labeled_dataset,
output_dir=os.path.join(self.output_dir, "iteration", "student_iter_{}".format(iter))
)
student_trainer.train()
student_model.load_state_dict(torch.load(os.path.join(student_trainer.state.best_model_checkpoint, "pytorch_model.bin")))
# 将Student模型参数赋给Teacher,作为下一轮训练的Teacher初始化
logger.info("*"*64)
logger.info("* Initializing a new teacher model from trained student model. *")
logger.info("*"*64)
print("*"*64)
print("* Initializing a new teacher model from trained student model. *")
print("*"*64)
teacher_model = student_model
# teacher_trainer = student_trainer
teacher_trainer: HugTrainer = self.get_teacher_trainer(
base_model=student_model,
num_train_epochs=self.teacher_tuning_epoch,
output_dir=os.path.join(self.output_dir, "iteration", "teacher_iter_{}".format(iter))
)
logger.info("********** Finishing Self-training **********")
logger.info("The best teacher model at {}-th self-training iteration.".format(best_self_training_iteration))
logger.info("The best teacher model testing result is {}.".format(best_test_metric))
print("********** Finishing Self-training **********")
print("The best teacher model at {}-th self-training iteration.".format(best_self_training_iteration))
print("The best teacher model testing result is {}.".format(best_test_metric))
# 根据当前最好的Teacher模型,在全部的unlabeled data上打伪标签,并进行mc dropout(样本数量最多不超过50000)
if self.semi_training_args.post_student_train:
logger.info("********** Post training **********")
print("********** Post training **********")
teacher_trainer: HugTrainer = self.get_teacher_trainer(
base_model=best_teacher_model,
num_train_epochs=self.teacher_tuning_epoch,
output_dir=os.path.join(self.output_dir, "teacher_iter_post")
)
unlabeled_dataset, y_mean, y_var, y_pred, y_T = teacher_trainer.mc_evaluate(
unlabeled_dataset=self.unlabeled_dataset,
unlabeled_data_num=20480,
T=5,
num_classes=self.num_classes
)
post_sample_num = int(y_pred.shape[0] * 0.5)
X_batch, y_batch, _ = sample_by_bald_class_easiness(
tokenizer=self.tokenizer,
X=unlabeled_dataset,
y_mean=y_mean,
y_var=y_var,
y=y_pred,
num_samples=post_sample_num,
num_classes=self.num_classes,
y_T=y_T)
pseudo_labeled_examples = X_batch
pseudo_labeled_examples["label"] = y_batch
# 生成pseudo-labeled dataset
# pseudo_labeled_dataset = DatasetDict()
pseudo_labeled_dataset = DatasetK.from_dict(pseudo_labeled_examples)
# 初始化一个新的Student模型,并让Student模型在pseudo-labeled data上进行鲁棒学习
logger.info("*"*56)
logger.info("* Training a new student model on pseudo-labeled data. *")
logger.info("*"*56)
print("*"*56)
print("* Training a new student model on pseudo-labeled data. *")
print("*"*56)
student_model = self.student_base_model
student_trainer: HugTrainer = self.get_student_trainer(
base_model=student_model,
num_train_epochs=self.student_training_epoch if len(pseudo_labeled_dataset) <= 4096 else int(self.student_training_epoch / 2),
student_learning_rate=self.student_learning_rate,
pseudo_labeled_dataset=pseudo_labeled_dataset,
output_dir=os.path.join(self.output_dir, "student_iter_{}".format(iter))
)
student_trainer.train()
student_model.load_state_dict(torch.load(os.path.join(student_trainer.state.best_model_checkpoint, "pytorch_model.bin")))
metrics = student_trainer.evaluate()
post_metric = metrics["eval_{}".format(self.metric_for_best_model)]
self.student_trainer = student_trainer
print("*"*68)
print("Finishing all the processes, the results are shown in the following:")
print("Conventional fine-tuning {} metric: {}".format(self.metric_for_best_model, convention_result))
print("Best self-training {} metric: {}".format(self.metric_for_best_model, best_test_metric))
if self.semi_training_args.post_student_train:
print("Post training {} metric: {}".format(self.metric_for_best_model, post_metric))
print("*"*68)
return TrainOutput(teacher_trainer.state.global_step, 0.0, metrics)