-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
160 lines (138 loc) · 6.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import logging
import os
import sys
from transformers import AutoTokenizer
from transformers import (
HfArgumentParser,
set_seed,
is_torch_xla_available,
EarlyStoppingCallback,
EvalPrediction
)
from tevatron.retriever.arguments import ModelArguments
from arguments import PromptRepsDataArguments as DataArguments, \
PromptRepsTrainingArguments as TrainingArguments
from dataset import PromptRepsTrainDataset as TrainDataset, \
PromptRepsTrainCollator as TrainCollator
from modeling import PromptRepsLLM, EncoderModel
from tevatron.retriever.gc_trainer import GradCacheTrainer as GCTrainer
from trainer import EarlyCheckpointCallback, PromptRepsTrainer
import torch
logger = logging.getLogger(__name__)
# def make_compute_metrics(model, training_args): # for eval dense and sparse loss
# def compute_metrics(eval_preds):
# predictions = eval_preds.predictions
# labels = eval_preds.label_ids # not used for now
# q_reps, p_reps, q_logits, p_logits = predictions
# scores_dense = model.compute_similarity(torch.tensor(q_reps, requires_grad=False),
# torch.tensor(p_reps, requires_grad=False))
# scores_dense = scores_dense.view(q_reps.shape[0], -1)
# target = torch.arange(scores_dense.shape[0], device=scores_dense.device, dtype=torch.long)
# target = target * (p_reps.shape[0] // q_reps.shape[0])
# loss_dense = model.compute_loss(scores_dense / model.temperature, target).item()
#
# if training_args.hybrid_training:
# # sparse loss
# scores_sparse = model.compute_similarity(torch.tensor(q_logits, requires_grad=False),
# torch.tensor(p_logits, requires_grad=False))
# scores_sparse = scores_sparse.view(q_logits.shape[0], -1)
# loss_sparse = model.compute_loss(scores_sparse, target).item()
# else:
# loss_sparse = 0.0
# return {
# 'loss_dense': loss_dense,
# 'loss_sparse': loss_sparse,
# }
# return compute_metrics
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("MODEL parameters %s", model_args)
set_seed(training_args.seed)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.unk_token_id if tokenizer.unk_token_id else tokenizer.eos_token_id
tokenizer.padding_side = 'right'
model = PromptRepsLLM.build(
model_args,
training_args,
cache_dir=model_args.cache_dir,
)
train_dataset = TrainDataset(data_args)
if training_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), training_args.max_train_samples)
train_dataset = train_dataset.train_data.select(range(max_train_samples))
train_dataset = TrainDataset(data_args, dataset=train_dataset)
eval_dataset = None
if training_args.do_eval:
datasets = train_dataset.train_data.train_test_split(training_args.eval_data_percentage)
train_dataset = datasets['train']
eval_dataset = datasets['test']
if training_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), training_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
train_dataset = TrainDataset(data_args, dataset=train_dataset)
eval_dataset = TrainDataset(data_args, dataset=eval_dataset)
collator = TrainCollator(data_args, tokenizer)
# TODO: eval for GCTrainer is not implemented
if training_args.grad_cache:
eval_dataset = None
training_args.do_eval = False
callbacks = []
if training_args.early_stopping_patience is not None:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=training_args.early_stopping_patience))
if training_args.save_early_checkpoints:
callbacks.append(EarlyCheckpointCallback())
trainer_cls = GCTrainer if training_args.grad_cache else PromptRepsTrainer
trainer = trainer_cls(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=collator,
callbacks=callbacks,
# compute_metrics=make_compute_metrics(model, training_args)
# if training_args.do_eval and not is_torch_xla_available() else None,
)
train_dataset.trainer = trainer
if eval_dataset:
eval_dataset.trainer = trainer
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
main()