-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
47 lines (40 loc) · 1.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import transformers
import torch
import numpy as np
import os
from preprocessing.dataset import *
def train_model(dataset='files/unified/prompts/toked_train_set.pt', save_model='files/unified/models/unipred.pt', save_state='files/unified/models/unipred_state.pt'):
model, tokenizer = setup_model_and_tokenizer('gpt2')
data_module = torch.load(dataset)
model = model.cuda()
print('data loaded!')
training_args = TrainingArguments(
"files/checkpoints",
per_device_train_batch_size=4,
)
training_args = training_args.set_save(strategy="steps", steps=10000, total_limit=3)
# max_grad_norm = 1
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
print(trainer.args._n_gpu)
print(trainer.args.parallel_mode)
trainer.train()
print('training finished!')
torch.save(model, save_model)
torch.save(model.state_dict(), save_state)
print('model saved!')
return model
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model-type',
type=str,
required=True,
help='The input model type. Must be a value among "unipred", "light", and "ablation".')
args = parser.parse_args()
model_type = args.model_type
if model_type == 'unipred':
train_model(dataset='files/unified/prompts/toked_train_set.pt', save_model='files/unified/models/unipred.pt', save_state='files/unified/models/unipred_state.pt')
elif model_type == 'lignt':
train_model(dataset='files/unified/prompts/toked_light_train_set.pt', save_model='files/unified/models/light.pt', save_state='files/unified/models/light_state.pt')
elif model_type == 'ablation':
train_model(dataset='files/unified/prompts/toked_abl_aug_train_set.pt', save_model='files/unified/models/abl_aug.pt', save_state='files/unified/models/abl_aug_state.pt')