From 1d48cf1ee7ff7c037a261cdf6d1a560975e704d9 Mon Sep 17 00:00:00 2001 From: jordiclive Date: Tue, 25 Jul 2023 18:40:54 +0100 Subject: [PATCH 1/2] patch fix for lora training --- model/model_training/models/peft_modeling.py | 1 - model/model_training/trainer_sft.py | 15 ++++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/model/model_training/models/peft_modeling.py b/model/model_training/models/peft_modeling.py index 2355f541ab..4c796287b0 100644 --- a/model/model_training/models/peft_modeling.py +++ b/model/model_training/models/peft_modeling.py @@ -57,7 +57,6 @@ def peft_model(model, training_config): "lora_dropout": 0.05, "bias": "none", "task_type": "CAUSAL_LM", - "modules_to_save": ["wte", "lm_head"], } kwargs = merge_dicts(default_args, peft_config) if kwargs.get("target_modules") == "all": diff --git a/model/model_training/trainer_sft.py b/model/model_training/trainer_sft.py index 4c4c820999..1ebe0a40cc 100755 --- a/model/model_training/trainer_sft.py +++ b/model/model_training/trainer_sft.py @@ -27,7 +27,7 @@ from torch import nn from torch.utils.data import DataLoader, Subset from tqdm import tqdm -from transformers import PreTrainedModel, Trainer, TrainingArguments +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, Trainer, TrainingArguments from transformers.trainer_pt_utils import IterableDatasetShard from transformers.trainer_utils import seed_worker from transformers.training_args import OptimizerNames @@ -327,7 +327,10 @@ def main(): init_rng(training_conf) - tokenizer = get_tokenizer(training_conf) + if training_conf.peft_model: + tokenizer = AutoTokenizer.from_pretrained(training_conf.model_name) + else: + tokenizer = get_tokenizer(training_conf) if not training_conf.deepspeed or training_conf.local_rank == 0: tokenizer_sanity_check(tokenizer) @@ -416,7 +419,13 @@ def main(): sampler = None metrics, preprocess_fns = get_metrics(training_conf, tokenizer) - model = get_model(training_conf, tokenizer) + if training_conf.peft_model: + logging.warning("PEFT, make sure this is a basemodel has been adapted to have special tokens!") + model = AutoModelForCausalLM.from_pretrained( + training_conf.model_name, torch_dtype=torch.bfloat16 if training_conf.dtype == "bf16" else torch.float16 + ) + else: + model = get_model(training_conf, tokenizer) superhot = RopePatch.from_config(training_conf) if training_conf.superhot else None if superhot: From d55ff360677255321bf6e194e2a89b48cf61da61 Mon Sep 17 00:00:00 2001 From: jordiclive Date: Tue, 25 Jul 2023 18:42:18 +0100 Subject: [PATCH 2/2] no message --- model/model_training/trainer_sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/model_training/trainer_sft.py b/model/model_training/trainer_sft.py index 1ebe0a40cc..56797a8e7f 100755 --- a/model/model_training/trainer_sft.py +++ b/model/model_training/trainer_sft.py @@ -420,7 +420,7 @@ def main(): metrics, preprocess_fns = get_metrics(training_conf, tokenizer) if training_conf.peft_model: - logging.warning("PEFT, make sure this is a basemodel has been adapted to have special tokens!") + logging.warning("PEFT model: make sure this is an adapted base model which has added special tokens!") model = AutoModelForCausalLM.from_pretrained( training_conf.model_name, torch_dtype=torch.bfloat16 if training_conf.dtype == "bf16" else torch.float16 )