1
1
#!/usr/bin/env python3
2
2
3
+ from unsloth import is_bfloat16_supported
4
+ from transformers import TrainingArguments , DataCollatorForSeq2Seq
3
5
from unsloth import FastLanguageModel
4
6
import torch
5
7
from trl import SFTTrainer
@@ -71,13 +73,14 @@ def formatting_prompts_func(examples):
71
73
else :
72
74
dataset = load_dataset (source , split = "train" )
73
75
74
- dataset = dataset .map (formatting_prompts_func , batched = True )
76
+ dataset = dataset .map (formatting_prompts_func , batched = True )
75
77
76
78
trainer = SFTTrainer (
77
79
model = model ,
78
80
train_dataset = dataset ,
79
81
dataset_text_field = "text" ,
80
82
max_seq_length = max_seq_length ,
83
+ data_collator = DataCollatorForSeq2Seq (tokenizer = tokenizer ),
81
84
tokenizer = tokenizer ,
82
85
dataset_num_proc = 2 ,
83
86
packing = cfg .get ('packing' ), # Can make training 5x faster for short sequences.
@@ -87,8 +90,8 @@ def formatting_prompts_func(examples):
87
90
warmup_steps = cfg .get ('warmupSteps' ),
88
91
max_steps = cfg .get ('maxSteps' ),
89
92
learning_rate = cfg .get ('learningRate' ),
90
- fp16 = not torch . cuda . is_bf16_supported (),
91
- bf16 = torch . cuda . is_bf16_supported (),
93
+ fp16 = not is_bfloat16_supported (),
94
+ bf16 = is_bfloat16_supported (),
92
95
logging_steps = cfg .get ('loggingSteps' ),
93
96
optim = cfg .get ('optimizer' ),
94
97
weight_decay = cfg .get ('weightDecay' ),
0 commit comments