-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain.py
131 lines (105 loc) · 4.35 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import logging
import pathlib
from typing import Tuple
import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
HfArgumentParser, PreTrainedModel,
PreTrainedTokenizer, Trainer)
from chatllms.configs import DataArguments, ModelArguments, TrainingArguments
from chatllms.data import make_supervised_data_module
from chatllms.utils.model_utils import (add_special_tokens_if_missing,
safe_save_model_for_hf_trainer)
def load_model_tokenizer(args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Load a pre-trained model and tokenizer for natural language processing tasks.
Args:
args: An object containing the input arguments.
Returns:
A tuple containing the loaded model and tokenizer.
"""
# Determine the torch data type based on the input arguments
torch_dtype = torch.float16 if args.fp16 else (
torch.bfloat16 if args.bf16 else torch.float32)
config_kwargs = {
'cache_dir': args.cache_dir,
'use_auth_token': args.use_auth_token,
'trust_remote_code': args.trust_remote_code,
}
# Load the pre-trained model
print(f'Loading Model from {args.model_name_or_path}...')
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=torch_dtype,
**config_kwargs,
)
# Enable model parallelism
setattr(model, 'model_parallel', True)
setattr(model, 'is_parallelizable', True)
if args.gradient_checkpointing:
logging.warning('Using gradient checkpointing...')
model.enable_input_require_grads()
model.config.use_cache = False # Turn off when gradient checkpointing is enabled
# Load the tokenizer
print(f'Loading tokenizer from {args.model_name_or_path}...')
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
padding_side='right',
model_max_length=args.model_max_length,
use_fast=False,
tokenizer_type='llama' if 'llama' in args.model_name_or_path else None,
**config_kwargs,
)
return model, tokenizer
def train() -> None:
"""
Trains a language model using Hugging Face's Transformers library.
Args:
model_args (ModelArguments): The arguments for the model configuration.
data_args (DataArguments): The arguments for the data configuration.
training_args (TrainingArguments): The arguments for the training configuration.
Returns:
None
"""
parser = HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
(model_args, data_args,
training_args) = parser.parse_args_into_dataclasses()
data_args.init_for_training()
args = argparse.Namespace(**vars(model_args), **vars(data_args),
**vars(training_args))
# load model and tokenizer
logging.warning('Loading model and tokenizer...')
model, tokenizer = load_model_tokenizer(args=args)
logging.warning('Successfully loaded model and tokenizer.')
if 'llama' in args.model_name_or_path or 'baichuan' in args.model_name_or_path:
logging.warning(
f'Adding special tokens for {args.model_name_or_path}.')
add_special_tokens_if_missing(tokenizer, model)
if 'baichuan' in args.model_name_or_path:
# Tie the weights
model.tie_weights()
# Create a supervised dataset and Trainer, then train the model
logging.warning('Creating a supervised dataset and DataCollator...')
data_module = make_supervised_data_module(tokenizer=tokenizer, args=args)
# Initialize the Trainer object and start training
logging.warning('Initializing Trainer object.')
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module,
)
logging.warning('Start Training...')
if list(pathlib.Path(training_args.output_dir).glob('checkpoint-*')):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
logging.warning(f'Saving Model to {training_args.output_dir}')
trainer.save_state()
# Save the trained model
safe_save_model_for_hf_trainer(trainer=trainer,
output_dir=training_args.output_dir)
logging.warning('Done.')
if __name__ == '__main__':
train()