-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathtraining.py
128 lines (101 loc) · 4.28 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import torch
import yaml
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from trl import SFTTrainer
# Function to load the configuration from YAML file
def load_config(config_file):
with open(config_file, 'r') as file:
return yaml.safe_load(file)
# Load model and tokenizer based on configuration
def load_model(config):
compute_dtype = getattr(torch, config["model"]["bnb_4bit_compute_dtype"])
bnb_config = BitsAndBytesConfig(
load_in_4bit=config["model"]["use_4bit"],
bnb_4bit_quant_type=config["model"]["bnb_4bit_quant_type"],
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=config["model"]["use_nested_quant"],
)
model_name = config["model"]["model_name"]
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=config["model"]["device_map"],
quantization_config=bnb_config
)
model.config.use_cache = False
model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
return model, tokenizer
# Configure LoRA settings
def configure_lora(config):
lora_params = config["lora"]
return LoraConfig(
lora_alpha=lora_params["lora_alpha"],
lora_dropout=lora_params["lora_dropout"],
r=lora_params["lora_r"],
bias="none",
task_type="CAUSAL_LM",
)
# Format data into the instruction template
def format_dolly(sample):
instruction = f"<s>[INST] {sample['instruction']}"
context = f"Here's some context: {sample['input']}" if len(sample["input"]) > 0 else None
response = f" [/INST] {sample['output']}"
return "".join([i for i in [instruction, context, response] if i is not None])
# Template dataset mapping
def template_dataset(sample, tokenizer):
sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}"
return sample
def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Fine-tune a model with LoRA and 4-bit precision.")
parser.add_argument("--config", type=str, default="config.yaml", help="Path to the YAML config file.")
args = parser.parse_args()
# Load configuration
config = load_config(args.config)
# Load model and tokenizer
model, tokenizer = load_model(config)
# Configure LoRA
peft_config = configure_lora(config)
# Load and process the dataset
dataset_name = config["model"]["dataset_name"]
dataset = load_dataset("json", data_files=dataset_name, split="train")
dataset = dataset.shuffle(seed=42)
dataset = dataset.select(range(50)) # Optional: select first 50 rows for demo
dataset = dataset.map(lambda sample: template_dataset(sample, tokenizer), remove_columns=list(dataset.features))
# Set up training arguments
training_arguments = TrainingArguments(
output_dir=config["model"]["output_dir"],
per_device_train_batch_size=config["training"]["per_device_train_batch_size"],
gradient_accumulation_steps=config["training"]["gradient_accumulation_steps"],
optim=config["training"]["optim"],
save_steps=config["training"]["save_steps"],
logging_steps=config["training"]["logging_steps"],
learning_rate=config["training"]["learning_rate"],
fp16=config["training"]["fp16"],
bf16=config["training"]["bf16"],
max_grad_norm=config["training"]["max_grad_norm"],
max_steps=config["training"]["max_steps"],
warmup_ratio=config["training"]["warmup_ratio"],
group_by_length=config["training"]["group_by_length"],
lr_scheduler_type=config["training"]["lr_scheduler_type"],
)
# Initialize and start training
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=config["training"]["max_seq_length"],
tokenizer=tokenizer,
args=training_arguments,
packing=config["training"]["packing"],
)
trainer.train()
trainer.model.save_pretrained(config["model"]["output_dir"])
if __name__ == "__main__":
main()