From f0fb9516d8ebc4fed1bacaa11b789be023f003f8 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 29 Nov 2023 12:37:39 +0100 Subject: [PATCH 01/12] ENH: Different initialization methods for LoRA (#1189) This PR adds the possibility to use different initialization methods for LoRA, as is a requirement for a completely backwards compatible adoption of PEFT in diffusers. The default is still the same as always, namely the one from the reference implementation by Microsoft. On top of that, it is now possible to pass `init_lora_weights='gaussian'` to initialize the LoRA weights in the same way as is default for diffusers, namely with a normal distribution which is scaled by 1/r. The init method currently applies to LoRA linear and conv layers, but not embedding layers, which are always initialized from a normal distribution (and are probably irrelevant for diffusers). In the future, similar extensions could be added for other adapter methods. --- setup.py | 4 +- src/peft/tuners/lora/config.py | 12 +- src/peft/tuners/lora/layer.py | 22 +++- tests/test_initialization.py | 232 +++++++++++++++++++++++++++++++++ 4 files changed, 258 insertions(+), 12 deletions(-) create mode 100644 tests/test_initialization.py diff --git a/setup.py b/setup.py index 8e3d60ec7c4..7f5e55524f8 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,9 @@ extras["quality"] = ["black ~= 22.0", "ruff>=0.0.241", "urllib3<=2.0.0"] extras["docs_specific"] = ["hf-doc-builder"] extras["dev"] = extras["quality"] + extras["docs_specific"] -extras["test"] = extras["dev"] + ["pytest", "pytest-cov", "pytest-xdist", "parameterized", "datasets", "diffusers<0.21.0"] +extras["test"] = extras["dev"] + [ + "pytest", "pytest-cov", "pytest-xdist", "parameterized", "datasets", "diffusers<0.21.0", "scipy" +] setup( name="peft", diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 2412b61a1a8..b1e31d81987 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from dataclasses import dataclass, field -from typing import List, Optional, Union +from typing import List, Literal, Optional, Union from peft.config import PeftConfig from peft.utils import PeftType @@ -76,12 +78,14 @@ class LoraConfig(PeftConfig): "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." }, ) - init_lora_weights: bool = field( + init_lora_weights: bool | Literal["gaussian"] = field( default=True, metadata={ "help": ( - "Whether to initialize the weights of the Lora layers with their default initialization. Don't change " - "this setting, except if you know exactly what you're doing." + "How to initialize the weights of the LoRA layers. Passing True (default) results in the default " + "initialization from the reference implementation from Microsoft. Passing 'gaussian' results " + "in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization " + "to False leads to completely random initialization and is discouraged." ), }, ) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index c2630531836..5ea726d2ffb 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -84,7 +84,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False) self.scaling[adapter_name] = lora_alpha / r if init_lora_weights: - self.reset_lora_parameters(adapter_name) + self.reset_lora_parameters(adapter_name, init_lora_weights) weight = getattr(self.get_base_layer(), "weight", None) if weight is not None: @@ -116,7 +116,7 @@ def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lo self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) self.scaling[adapter_name] = lora_alpha / r if init_lora_weights: - self.reset_lora_parameters(adapter_name) + self.reset_lora_parameters(adapter_name, init_lora_weights) weight = getattr(base_layer, "weight", None) if weight is not None: @@ -142,8 +142,7 @@ def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A) self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B) self.scaling[adapter_name] = lora_alpha / r - if init_lora_weights: - self.reset_lora_parameters(adapter_name) + self.reset_lora_parameters(adapter_name, init_lora_weights) base_layer = self.get_base_layer() weight = getattr(base_layer, "weight", None) @@ -152,10 +151,19 @@ def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init self.to(base_layer.weight.device, dtype=weight.dtype) self.set_adapter(self.active_adapters) - def reset_lora_parameters(self, adapter_name): + def reset_lora_parameters(self, adapter_name, init_lora_weights): + if init_lora_weights is False: + return + if adapter_name in self.lora_A.keys(): - # initialize A the same way as the default for nn.Linear and B to zero - nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) + if init_lora_weights is True: + # initialize A the same way as the default for nn.Linear and B to zero + # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124 + nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) + elif init_lora_weights.lower() == "gaussian": + nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name]) + else: + raise ValueError(f"Unknown initialization {init_lora_weights=}") nn.init.zeros_(self.lora_B[adapter_name].weight) if adapter_name in self.lora_embedding_A.keys(): # initialize a the same way as the default for nn.linear and b to zero diff --git a/tests/test_initialization.py b/tests/test_initialization.py new file mode 100644 index 00000000000..3770b4a74f3 --- /dev/null +++ b/tests/test_initialization.py @@ -0,0 +1,232 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from scipy import stats +from torch import nn + +from peft import LoraConfig, get_peft_model +from peft.utils import infer_device + + +class InitializationTest(unittest.TestCase): + """Test class to check the initialization of adapters.""" + + torch_device = infer_device() + + def get_uniform(self, amin, amax, size=(10000,)): + unif = torch.distributions.uniform.Uniform(amin, amax) + samples = unif.sample(size) + return samples + + def get_normal(self, mean, std, size=(10000,)): + normal = torch.distributions.normal.Normal(mean, std) + samples = normal.sample(size) + return samples + + def get_model(self): + class MyModule(nn.Module): + def __init__(self): + super().__init__() + # choose a large weight so that averages are close to expected values + self.linear = nn.Linear(1000, 1000) + self.embed = nn.Embedding(1000, 1000) + self.conv2d = nn.Conv2d(100, 100, 3) + + def forward(self, x): + return self.linear(x) + + return MyModule().eval().to(self.torch_device) + + def test_lora_linear_init_default(self): + # default is True + torch.manual_seed(0) + + model = self.get_model() + config = LoraConfig(target_modules=["linear"]) + model = get_peft_model(model, config) + weight_A = model.linear.lora_A["default"].weight + weight_B = model.linear.lora_B["default"].weight + + # use statistical test to check if weight A is from a uniform distribution + unif = self.get_uniform(weight_A.min().item(), weight_A.max().item()) + _, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy()) + self.assertGreater(p_value, 0.5) + + # check that weight A is *not* from a normal distribution + normal = self.get_normal(weight_A.mean().item(), weight_A.std().item()) + _, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy()) + self.assertLess(p_value, 0.05) + + # check that weight B is zero + self.assertTrue((weight_B == 0.0).all()) + + def test_lora_linear_init_gaussian(self): + # use gaussian init + torch.manual_seed(0) + + model = self.get_model() + config = LoraConfig(target_modules=["linear"], init_lora_weights="gaussian") + model = get_peft_model(model, config) + weight_A = model.linear.lora_A["default"].weight + weight_B = model.linear.lora_B["default"].weight + + # use statistical test to check if weight A is from a normal distribution + normal = self.get_normal(0.0, 1 / config.r) + _, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy()) + + # import matplotlib.pyplot as plt + # x = weight_A.detach().flatten().cpu().numpy() + # breakpoint() + + self.assertGreater(p_value, 0.5) + + # check that weight A is *not* from a uniform distribution + unif = self.get_uniform(weight_A.min().item(), weight_A.max().item()) + _, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy()) + self.assertLess(p_value, 0.05) + + # check that weight B is zero + self.assertTrue((weight_B == 0.0).all()) + + def test_lora_linear_false(self): + torch.manual_seed(0) + + model = self.get_model() + config = LoraConfig(target_modules=["linear"], init_lora_weights=False) + model = get_peft_model(model, config) + weight_B = model.linear.lora_B["default"].weight + + # with init_lora_weights=False, weight B should *not* be zero. We don't care so much about the actual values + # as long as they are not zero, in order to avoid identity transformation. + self.assertFalse(torch.allclose(weight_B, torch.zeros_like(weight_B))) + + def test_lora_embedding_default(self): + # embedding is initialized as a normal distribution, not kaiming uniform + torch.manual_seed(0) + + model = self.get_model() + config = LoraConfig(target_modules=["embed"]) + model = get_peft_model(model, config) + weight_A = model.embed.lora_embedding_A["default"] + weight_B = model.embed.lora_embedding_B["default"] + + # use statistical test to check if weight B is from a normal distribution + normal = self.get_normal(0.0, 1.0) + _, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy()) + self.assertGreater(p_value, 0.5) + + # check that weight B is *not* from a uniform distribution + unif = self.get_uniform(weight_B.min().item(), weight_B.max().item()) + _, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy()) + self.assertLess(p_value, 0.05) + + # check that weight A is zero + self.assertTrue((weight_A == 0.0).all()) + + def test_lora_embedding_gaussian(self): + # embedding does not change with init_lora_weights="gaussian" vs True + torch.manual_seed(0) + + model = self.get_model() + config = LoraConfig(target_modules=["embed"], init_lora_weights="gaussian") + model = get_peft_model(model, config) + weight_A = model.embed.lora_embedding_A["default"] + weight_B = model.embed.lora_embedding_B["default"] + + # use statistical test to check if weight B is from a normal distribution + normal = self.get_normal(0.0, 1.0) + _, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy()) + self.assertGreater(p_value, 0.5) + + # check that weight B is *not* from a uniform distribution + unif = self.get_uniform(weight_B.min().item(), weight_B.max().item()) + _, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy()) + self.assertLess(p_value, 0.05) + + # check that weight A is zero + self.assertTrue((weight_A == 0.0).all()) + + def test_lora_embedding_false(self): + torch.manual_seed(0) + + model = self.get_model() + config = LoraConfig(target_modules=["embed"], init_lora_weights=False) + model = get_peft_model(model, config) + weight_A = model.embed.lora_embedding_B["default"] + + # with init_lora_weights=False, weight A should *not* be zero. We don't care so much about the actual values + # as long as they are not zero, in order to avoid identity transformation. + self.assertFalse(torch.allclose(weight_A, torch.zeros_like(weight_A))) + + def test_lora_conv2d_default(self): + # default is True + torch.manual_seed(0) + + model = self.get_model() + config = LoraConfig(target_modules=["conv2d"]) + model = get_peft_model(model, config) + weight_A = model.conv2d.lora_A["default"].weight + weight_B = model.conv2d.lora_B["default"].weight + + # use statistical test to check if weight A is from a uniform distribution + unif = self.get_uniform(weight_A.min().item(), weight_A.max().item()) + _, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy()) + self.assertGreater(p_value, 0.5) + + # check that weight A is *not* from a normal distribution + normal = self.get_normal(weight_A.mean().item(), weight_A.std().item()) + _, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy()) + self.assertLess(p_value, 0.05) + + # check that weight B is zero + self.assertTrue((weight_B == 0.0).all()) + + def test_lora_conv2d_init_gaussian(self): + # use gaussian init + torch.manual_seed(0) + + model = self.get_model() + config = LoraConfig(target_modules=["conv2d"], init_lora_weights="gaussian") + model = get_peft_model(model, config) + weight_A = model.conv2d.lora_A["default"].weight + weight_B = model.conv2d.lora_B["default"].weight + + # use statistical test to check if weight A is from a normal distribution + normal = self.get_normal(0.0, 1 / config.r) + _, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy()) + self.assertGreater(p_value, 0.5) + + # check that weight A is *not* from a uniform distribution + unif = self.get_uniform(weight_A.min().item(), weight_A.max().item()) + _, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy()) + self.assertLess(p_value, 0.05) + + # check that weight B is zero + self.assertTrue((weight_B == 0.0).all()) + + def test_lora_conv2d_false(self): + torch.manual_seed(0) + + model = self.get_model() + config = LoraConfig(target_modules=["conv2d"], init_lora_weights=False) + model = get_peft_model(model, config) + weight_B = model.conv2d.lora_B["default"].weight + + # with init_lora_weights=False, weight B should *not* be zero. We don't care so much about the actual values + # as long as they are not zero, in order to avoid identity transformation. + self.assertFalse(torch.allclose(weight_B, torch.zeros_like(weight_B))) From 8298f1a3668604ac9bc3f6e28b24e8eb554891a1 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 29 Nov 2023 19:28:41 +0530 Subject: [PATCH 02/12] Training PEFT models with new tokens being added to the embedding layers and tokenizer (#1147) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add support for saving base layers weights along with adapter weights * Update save_and_load.py * Add an example showing the usage of the added feature * refactor the functionality * fix * refactoring code 1. Add `is_embedding_layer_resized` parameter to `save_pretrained` 2. Fix the deduplication in README when adding PEFT details. 3. `save_pretrained` should only save the model when `is_main_process=True` which is one of the parameters of `save_pretrained`. * update example * fix the model card * fix model card * πŸ˜… * fix model card * automate setting `is_embedding_layer_resized` * nits * Update peft_lora_clm_with_additional_tokens.ipynb * add test * fix tests * maybe fixes the issue? * address comments Co-Authored-By: Benjamin Bossan * Apply suggestions from code review Co-authored-by: Benjamin Bossan --------- Co-authored-by: Benjamin Bossan --- ...peft_lora_clm_with_additional_tokens.ipynb | 1012 +++++++++++++++++ src/peft/peft_model.py | 52 +- src/peft/utils/other.py | 1 + src/peft/utils/save_and_load.py | 45 +- tests/test_custom_models.py | 76 ++ 5 files changed, 1167 insertions(+), 19 deletions(-) create mode 100644 examples/causal_language_modeling/peft_lora_clm_with_additional_tokens.ipynb diff --git a/examples/causal_language_modeling/peft_lora_clm_with_additional_tokens.ipynb b/examples/causal_language_modeling/peft_lora_clm_with_additional_tokens.ipynb new file mode 100644 index 00000000000..81762de08c6 --- /dev/null +++ b/examples/causal_language_modeling/peft_lora_clm_with_additional_tokens.ipynb @@ -0,0 +1,1012 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5f239612-620e-4430-8685-9fdc6b179b41", + "metadata": {}, + "source": [ + "# Training PEFT models with new tokens being added to the embedding layers and tokenizer\n", + "\n", + "In this example, we will learn how to train a LoRA model when adding new tokens to the tokenizer and model. \n", + "This is a common usecase when doing the following:\n", + "1. Instruction finetuning with new tokens beind added such as `<|user|>`, `<|assistant|>`, `<|system|>`, ``, `` to properly format the conversations\n", + "2. Finetuning on a specific language wherein language spoecific tokens are added, e.g., korean tokens being added to vocabulary for finetuning LLM on Korean datasets.\n", + "3. Instruction finetuning to return outputs in certain format to enable agent behaviour new tokens such as `<|FUNCTIONS|>`, `<|BROWSE|>`, `<|TEXT2IMAGE|>`, `<|ASR|>`, `<|TTS|>`, `<|GENERATECODE|>`, `<|RAG|>`.\n", + "\n", + "In such cases, you add the Embedding modules to the LORA `target_modules`. PEFT will take care of saving the embedding layers with the new added tokens along with the adapter weights that were trained on the specific initialization of the embeddings weights of the added tokens." + ] + }, + { + "cell_type": "markdown", + "id": "b27c55e8-edaa-4059-90bc-d6096d596902", + "metadata": {}, + "source": [ + "Let's import the necessary libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "6f864c90", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n", + "os.environ[\"WANDB_PROJECT\"] = \"PeftExamples\"\n", + "import transformers\n", + "from peft import (\n", + " LoraConfig,\n", + " PeftConfig,\n", + " PeftModel,\n", + " get_peft_model,\n", + " prepare_model_for_int8_training,\n", + ")\n", + "from transformers import (\n", + " AutoModelForCausalLM,\n", + " AutoTokenizer,\n", + " HfArgumentParser,\n", + " TrainingArguments,\n", + " Trainer,\n", + " default_data_collator,\n", + ")\n", + "import torch\n", + "from dataclasses import dataclass, field\n", + "from typing import Optional\n", + "from dataclass_csv import DataclassReader\n", + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "from enum import Enum" + ] + }, + { + "cell_type": "markdown", + "id": "74950a3f-bb63-4ce5-9e2b-1b83f92b13a2", + "metadata": {}, + "source": [ + "## Prepare Model and Tokenizer" + ] + }, + { + "cell_type": "markdown", + "id": "76763f5e-64b2-409b-8845-ae5589f8a4e0", + "metadata": {}, + "source": [ + "Now, we will be adding 27 new tokens as well as replace the existing pad, bos and eos tokens of the model." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fd0498ea-547e-418d-bf13-c9abafdd5476", + "metadata": {}, + "outputs": [], + "source": [ + "class SpecialTokens(str, Enum):\n", + " begin_target = \"<|begintarget|>\"\n", + " end_target = \"<|endtarget|>\"\n", + " begin_context = \"<|begincontext|>\"\n", + " end_context = \"<|endcontext|>\"\n", + " system = \"<|system|>\"\n", + " user = \"<|user|>\"\n", + " begin_last_user_utterance = \"<|beginlastuserutterance|>\"\n", + " end_last_user_utterance = \"<|endlastuserutterance|>\"\n", + " begin_dsts = \"<|begindsts|>\"\n", + " end_dsts = \"<|enddsts|>\"\n", + " begin_dst = \"<|begindst|>\"\n", + " end_dst = \"<|enddst|>\"\n", + " begin_belief = \"<|beginbelief|>\"\n", + " end_belief = \"<|endbelief|>\"\n", + " begin_response = \"<|beginresponse|>\"\n", + " end_response = \"<|endresponse|>\"\n", + " begin_action = \"<|beginaction|>\"\n", + " end_action = \"<|endaction|>\"\n", + " begin_user_action = \"<|beginuseraction|>\"\n", + " end_user_action = \"<|enduseraction|>\"\n", + " sys_actions = \"<|sysactions|>\"\n", + " begin_intent = \"<|beginintent|>\"\n", + " end_intent = \"<|endintent|>\"\n", + " begin_requested_slots = \"<|beginrequestedslots|>\"\n", + " end_requested_slots = \"<|endrequestedslots|>\"\n", + " pad_token = \"<|pad|>\"\n", + " bos_token = \"<|startoftext|>\"\n", + "\n", + " @classmethod\n", + " def list(cls):\n", + " return [c.value for c in cls]" + ] + }, + { + "cell_type": "markdown", + "id": "ae4a4255-5f13-4eef-a024-4f1de0f2173b", + "metadata": {}, + "source": [ + "We will be finetuning Mistral-7B model. Let's load the tokenizer and add the special tokens followed by loading the base model and resizzing the embedding layers to accomodate the newly added tokens." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f0eedef9", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "91c67b6377fc4dd7977bf544de784d51", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|><|begincontext|><|user|> Can you find me place to eat?<|system|> What kind of food would you like to have and where would you like me to search in?<|user|> Food kind of California will be perfect in SF.<|system|> There are 10 restaurants, Al's Place is one of the good restaurant in San Francisco.<|user|> Can you look for any other restaurant?<|system|> Alta Msp is one of the good restaurant in San Francisco.<|beginlastuserutterance|> Can you find me the address?<|endlastuserutterance|><|endcontext|><|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurants<|endintent|><|beginrequestedslots|> Restaurants^street_address<|endrequestedslots|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->California<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^street_address~1275 Minnesota Street<|endaction|><|beginresponse|> The street address of the restaurant is 1275 Minnesota Street.<|endresponse|><|endtarget|><|endtarget|>\"" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode(train_dataset[0][\"input_ids\"])" + ] + }, + { + "cell_type": "markdown", + "id": "239d1c83-196d-471e-9bf7-5f36dafa9894", + "metadata": {}, + "source": [ + "# Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ec80d6ee", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n", + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33msmangrul\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.0" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /raid/sourab/temp/wandb/run-20231128_230934-edod21gq" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run ethereal-eon-1 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/smangrul/PeftExamples" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/smangrul/PeftExamples/runs/edod21gq" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [246/246 05:51, Epoch 2/2]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
105.189800
203.745500
302.371500
401.630200
501.302600
600.999400
700.704100
800.527800
900.509700
1000.382300
1100.318200
1200.323500
1300.263400
1400.290900
1500.277400
1600.232800
1700.223600
1800.229600
1900.233100
2000.210200
2100.245800
2200.197300
2300.210100
2400.209800

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "TrainOutput(global_step=246, training_loss=0.8516577879587809, metrics={'train_runtime': 354.9013, 'train_samples_per_second': 5.556, 'train_steps_per_second': 0.693, 'total_flos': 4.318233532091597e+16, 'train_loss': 0.8516577879587809, 'epoch': 2.0})" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir=\"mistral_lora_clm_with_added_tokens\",\n", + " num_train_epochs=2,\n", + " save_total_limit=5,\n", + " per_device_train_batch_size=8,\n", + " warmup_steps=10,\n", + " weight_decay=0.0001,\n", + " dataloader_drop_last=True,\n", + " bf16=True,\n", + " logging_steps=10,\n", + " learning_rate=1e-5,\n", + " gradient_checkpointing=True,\n", + " gradient_checkpointing_kwargs={\"use_reentrant\": False},\n", + " remove_unused_columns=False,\n", + " hub_model_id=\"smangrul/mistral_lora_clm_with_added_tokens\",\n", + " push_to_hub=True,\n", + " hub_private_repo=True,\n", + ")\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " data_collator=default_data_collator,\n", + ")\n", + "# model.config.use_cache = False\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "id": "7bc1cbed-4eb9-4aaa-ab5f-5b91bf432307", + "metadata": {}, + "source": [ + "# Check the model output on a sample from evaluation dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "71851793", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "context=\"<|begincontext|><|user|>Can you find me a place to eat please?<|system|>Where at? And what kind of cuisine are you craving?<|user|>Somewhere in SF, and I am really craving Thai food at the moment!<|system|>I found a bunch of restaurants, there's actually 10 that you might like in San Francisco, one of them being Baan Thai House & Wine Bar<|user|>How can I reach them? And what's their address?<|system|>You can reach them by phone at 415-379-4505 and visit them at 534 Irving Street<|beginlastuserutterance|>Great, that restaurant sounds good<|endlastuserutterance|><|endcontext|>\" \n", + "\n", + " target_predicted='<|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurants<|endintent|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^phone_number~|REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^phone_number~415-379-4505|INFORM->Restaurants^street_address~534 Irving Street<|endaction|><|beginresponse|> Great, the phone number is 415-379-4505 and the address is 534 Irving Street<|endresponse|><|endtarget|>' \n", + "\n", + " target='<|begintarget|><|begindsts|><|begindst|><|beginintent|>FindRestaurants<|endintent|><|beginbelief|>Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|>SELECT->Restaurants^~<|enduseraction|><|beginaction|>OFFER_INTENT->Restaurants^intent~ReserveRestaurant<|endaction|><|beginresponse|>Want me to book a table?<|endresponse|><|endtarget|>'\n" + ] + } + ], + "source": [ + "import random\n", + "\n", + "i = random.randint(0, len(dataset[\"test\"]))\n", + "context = dataset[\"test\"][i][\"context\"]\n", + "\n", + "batch = tokenizer(context, return_tensors=\"pt\")\n", + "batch = {k: v.to(\"cuda\") for k, v in batch.items()}\n", + "model.eval()\n", + "output_tokens = model.generate(\n", + " **batch,\n", + " max_new_tokens=256,\n", + " do_sample=True,\n", + " temperature=0.2,\n", + " top_p=0.95,\n", + " top_k=50,\n", + " eos_token_id=tokenizer.eos_token_id,\n", + " pad_token_id=tokenizer.pad_token_id,\n", + ")\n", + "target_predicted = tokenizer.decode(output_tokens[0], skip_special_tokens=False).split(\"<|endcontext|>\")[1]\n", + "target = dataset[\"test\"][i][\"target\"]\n", + "print(f\"{context=} \\n\\n {target_predicted=} \\n\\n {target=}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f940a660-2f7c-4a3a-b412-3f037aedb890", + "metadata": {}, + "source": [ + "# Save the Adapter model " + ] + }, + { + "cell_type": "markdown", + "id": "7ebe05e9-9b93-42f6-bba8-46b8cc3d100f", + "metadata": {}, + "source": [ + "When the lora layers are applied to embedding layers, the corresponding base model embedding layers are also saved. " + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "3d7459ba-caa8-4f10-aa70-89be4541cbdf", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/raid/sourab/peft/src/peft/utils/save_and_load.py:128: UserWarning: Setting `is_embedding_layer_resized` to `True` as embedding layers found in `target_modules`\n", + " warnings.warn(\"Setting `is_embedding_layer_resized` to `True` as embedding layers found in `target_modules`\")\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8d23186832014f209939ab83e79da011", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Upload 3 LFS files: 0%| | 0/3 [00:00<|user|>Can you find me a place to eat please?<|system|>Where at? And what kind of cuisine are you craving?<|user|>Somewhere in SF, and I am really craving Thai food at the moment!<|system|>I found a bunch of restaurants, there's actually 10 that you might like in San Francisco, one of them being Baan Thai House & Wine Bar<|user|>How can I reach them? And what's their address?<|system|>You can reach them by phone at 415-379-4505 and visit them at 534 Irving Street<|beginlastuserutterance|>Great, that restaurant sounds good<|endlastuserutterance|><|endcontext|>\" \n", + "\n", + " target_predicted='<|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurant<|endintent|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^phone_number~|REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^phone_number~415-379-4505|INFORM->Restaurants^street_address~534 Irving Street<|endaction|><|beginresponse|> The phone number is 415-379-4505 and the address is 534 Irving Street<|endresponse|><|endtarget|>' \n", + "\n", + " target='<|begintarget|><|begindsts|><|begindst|><|beginintent|>FindRestaurants<|endintent|><|beginbelief|>Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|>SELECT->Restaurants^~<|enduseraction|><|beginaction|>OFFER_INTENT->Restaurants^intent~ReserveRestaurant<|endaction|><|beginresponse|>Want me to book a table?<|endresponse|><|endtarget|>'\n" + ] + } + ], + "source": [ + "from peft import PeftModel\n", + "\n", + "inference_model = AutoModelForCausalLM.from_pretrained(\n", + " model_name,\n", + " low_cpu_mem_usage=True,\n", + " # use_flash_attention_2=True,\n", + ")\n", + "inference_model.resize_token_embeddings(len(tokenizer))\n", + "\n", + "inference_model = PeftModel.from_pretrained(inference_model, \"smangrul/mistral_lora_clm_with_added_tokens\")\n", + "inference_model.to(\"cuda\")\n", + "inference_model.eval()\n", + "\n", + "output_tokens = inference_model.generate(\n", + " **batch,\n", + " max_new_tokens=256,\n", + " do_sample=True,\n", + " temperature=0.2,\n", + " top_p=0.95,\n", + " top_k=50,\n", + " eos_token_id=tokenizer.eos_token_id,\n", + " pad_token_id=tokenizer.pad_token_id,\n", + ")\n", + "\n", + "target_predicted = tokenizer.decode(output_tokens[0], skip_special_tokens=False).split(\"<|endcontext|>\")[1]\n", + "print(f\"{context=} \\n\\n {target_predicted=} \\n\\n {target=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd57f6e8-761f-4e0b-941c-f6973e13b186", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index c5c7825baaf..24ef48c22e2 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -159,6 +159,8 @@ def save_pretrained( save_directory: str, safe_serialization: bool = True, selected_adapters: Optional[List[str]] = None, + save_embedding_layers: Union[str, bool] = "auto", + is_main_process: bool = True, **kwargs: Any, ): r""" @@ -172,6 +174,14 @@ def save_pretrained( exist). safe_serialization (`bool`, *optional*): Whether to save the adapter files in safetensors format. + selected_adapters (`list(str)`, *optional*): + A list of adapters to be saved. If `None`, will default to all adapters. + save_embedding_layers (`Union[bool, str]`, , *optional*, defaults to `auto`): + If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common + embedding layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available. + Based on it sets the boolean flag. This only works for πŸ€— transformers models. + is_main_process (`bool`, *optional*): + Whether the process calling this is the main process or not. Will default to `True`. kwargs (additional keyword arguments, *optional*): Additional keyword arguments passed along to the `push_to_hub` method. """ @@ -190,19 +200,23 @@ def save_pretrained( f" {list(self.peft_config.keys())} - got {selected_adapters}." ) - os.makedirs(save_directory, exist_ok=True) - self.create_or_update_model_card(save_directory) + if is_main_process: + os.makedirs(save_directory, exist_ok=True) + self.create_or_update_model_card(save_directory) for adapter_name in selected_adapters: peft_config = self.peft_config[adapter_name] # save only the trainable weights output_state_dict = get_peft_model_state_dict( - self, state_dict=kwargs.get("state_dict", None), adapter_name=adapter_name + self, + state_dict=kwargs.get("state_dict", None), + adapter_name=adapter_name, + save_embedding_layers=save_embedding_layers, ) output_dir = os.path.join(save_directory, adapter_name) if adapter_name != "default" else save_directory os.makedirs(output_dir, exist_ok=True) - if safe_serialization: + if is_main_process and safe_serialization: # Section copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2111-L2134 # Safetensors does not allow tensor aliasing. # We're going to remove aliases before saving @@ -230,7 +244,7 @@ def save_pretrained( os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME), metadata={"format": "pt"}, ) - else: + elif is_main_process: torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) # save the config and change the inference mode to `True` @@ -257,7 +271,8 @@ def save_pretrained( else: auto_mapping_dict = None - peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict) + if is_main_process: + peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict) peft_config.inference_mode = inference_mode @classmethod @@ -721,24 +736,27 @@ def create_or_update_model_card(self, output_dir: str): if hasattr(self.config, "quantization_config"): quantization_config = self.config.quantization_config.to_dict() training_config_text = "" + quantization_prefix = "The following `bitsandbytes` quantization config was used during training:" # Adds quantization information if it was used if quantization_config is not None: - training_config_text += "\nThe following `bitsandbytes` quantization config was used during training:\n" + training_config_text += f"\n{quantization_prefix}\n" training_config_text += "\n".join([f"- {name}: {value}" for name, value in quantization_config.items()]) training_config_text += "\n" - training_procedure_heading = "## Training procedure\n" - if training_procedure_heading in lines: - lines.insert(lines.index(training_procedure_heading) + 2, training_config_text) - else: - lines.append(f"{training_procedure_heading}\n{training_config_text}") + training_procedure_heading = "## Training procedure" + if quantization_prefix not in lines and bool(training_config_text): + if training_procedure_heading in lines: + lines.insert(lines.index(training_procedure_heading) + 2, training_config_text) + else: + lines.append(f"{training_procedure_heading}\n{training_config_text}") # Adds peft version - framework_block_heading = "### Framework versions\n" - if framework_block_heading in lines: - lines.insert(lines.index(framework_block_heading) + 2, f"- PEFT {__version__}\n") - else: - lines.append(f"{framework_block_heading}\n\n- PEFT {__version__}\n") + framework_block_heading = "### Framework versions" + if f"- PEFT {__version__}" not in lines: + if framework_block_heading in lines: + lines.insert(lines.index(framework_block_heading) + 2, f"- PEFT {__version__}") + else: + lines.append(f"{framework_block_heading}\n\n- PEFT {__version__}") card.text = "\n".join(lines) card.save(filename) diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index e811bee5bad..1c347017392 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -583,3 +583,4 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" CONFIG_NAME = "adapter_config.json" +EMBEDDING_LAYER_NAMES = ["embed_tokens", "lm_head"] diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 07e653bef16..97bde0d6fe5 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings from typing import Optional import torch @@ -20,11 +21,26 @@ from huggingface_hub.utils import EntryNotFoundError from safetensors.torch import load_file as safe_load_file -from .other import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, infer_device +from .other import EMBEDDING_LAYER_NAMES, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, infer_device from .peft_types import PeftType -def get_peft_model_state_dict(model, state_dict=None, adapter_name="default", unwrap_compiled=False): +def has_valid_embedding_base_layer(layer): + """Check if the layer has an embedding base layer""" + return hasattr(layer, "base_layer") and isinstance(layer.base_layer, (torch.nn.Linear, torch.nn.Embedding)) + + +def get_embedding_layer_name(model, layer, is_prompt_learning): + """Get the name of the embedding module for a given layer.""" + for name, module in model.named_modules(): + if (is_prompt_learning and module == layer) or module == layer.base_layer: + return name + return None + + +def get_peft_model_state_dict( + model, state_dict=None, adapter_name="default", unwrap_compiled=False, save_embedding_layers="auto" +): """ Get the state dict of the Peft model. @@ -37,6 +53,10 @@ def get_peft_model_state_dict(model, state_dict=None, adapter_name="default", un The name of the adapter whose state dict should be returned. unwrap_compiled (`bool`, *optional*, defaults to `False`): Whether to unwrap the model if torch.compile was used. + save_embedding_layers (`Union[bool, str]`, , *optional*, defaults to `auto`): + If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common embedding + layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available. Based on it + sets the boolean flag. This only works for πŸ€— transformers models. """ if unwrap_compiled: model = getattr(model, "_orig_mod", model) @@ -100,6 +120,27 @@ def get_peft_model_state_dict(model, state_dict=None, adapter_name="default", un if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save): to_return[key.replace("modules_to_save.", "")] = value + # check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary + if ( + save_embedding_layers == "auto" + and hasattr(config, "target_modules") + and any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES) + ): + warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.") + save_embedding_layers = True + elif save_embedding_layers == "auto": + save_embedding_layers = False + + if save_embedding_layers and hasattr(model, "get_input_embeddings"): + for layer in [model.get_input_embeddings(), model.get_output_embeddings()]: + if config.is_prompt_learning or has_valid_embedding_base_layer(layer): + # support from version >= 0.6.2 + embedding_module_name = get_embedding_layer_name(model, layer, config.is_prompt_learning) + if embedding_module_name: + to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k}) + elif save_embedding_layers: + warnings.warn("Could not identify embedding layer(s) because the model is not a πŸ€— transformers model.") + to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()} return to_return diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 347df218b2c..b298388a844 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -333,6 +333,33 @@ def forward(self, X): return X +class ModelEmbWithEmbeddingUtils(nn.Module): + # Adds `get_input_embeddings` and `get_output_embeddings` methods to mimic πŸ€— transformers models + def __init__(self): + super().__init__() + self.embed_tokens = nn.Embedding(100, 5) + self.conv1d = Conv1D(1, 5) + self.relu = nn.ReLU() + self.flat = nn.Flatten() + self.lin0 = nn.Linear(10, 2) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = self.embed_tokens(X) + X = self.conv1d(X) + X = self.relu(X) + X = self.flat(X) + X = self.lin0(X) + X = self.sm(X) + return X + + def get_input_embeddings(self): + return self.embed_tokens + + def get_output_embeddings(self): + return None + + class ModelConv2D(nn.Module): def __init__(self): super().__init__() @@ -750,6 +777,55 @@ def test_non_existing_model_card(self): # rough check that the model card is pre-filled self.assertGreater(len(model_card), 1000) + @parameterized.expand(["auto", True, False]) + def test_targeting_lora_to_embedding_layer(self, save_embedding_layers): + model = ModelEmbWithEmbeddingUtils() + config = LoraConfig(target_modules=["embed_tokens", "lin0"], init_lora_weights=False) + model = get_peft_model(model, config) + + with tempfile.TemporaryDirectory() as tmp_dirname: + if save_embedding_layers == "auto": + # assert warning + msg_start = "Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`." + with self.assertWarns(UserWarning, msg=msg_start): + model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers) + else: + model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers) + from safetensors.torch import load_file as safe_load_file + + state_dict = safe_load_file(os.path.join(tmp_dirname, "adapter_model.safetensors")) + if save_embedding_layers in ["auto", True]: + self.assertTrue("base_model.model.embed_tokens.base_layer.weight" in state_dict) + self.assertTrue( + torch.allclose( + model.base_model.model.embed_tokens.base_layer.weight, + state_dict["base_model.model.embed_tokens.base_layer.weight"], + ) + ) + else: + self.assertFalse("base_model.model.embed_tokens.base_layer.weight" in state_dict) + del state_dict + + @parameterized.expand(["auto", True, False]) + def test_targeting_lora_to_embedding_layer_non_transformers(self, save_embedding_layers): + model = ModelEmbConv1D() + config = LoraConfig(target_modules=["emb", "lin0"], init_lora_weights=False) + model = get_peft_model(model, config) + + with tempfile.TemporaryDirectory() as tmp_dirname: + if save_embedding_layers is True: + # assert warning + msg_start = "Could not identify embedding layer(s) because the model is not a πŸ€— transformers model." + with self.assertWarns(UserWarning, msg=msg_start): + model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers) + else: + model.save_pretrained(tmp_dirname, save_embedding_layers=save_embedding_layers) + from safetensors.torch import load_file as safe_load_file + + state_dict = safe_load_file(os.path.join(tmp_dirname, "adapter_model.safetensors")) + self.assertFalse("base_model.model.emb.base_layer.weight" in state_dict) + del state_dict + @parameterized.expand( [ LoraConfig(target_modules=["lin0"], init_lora_weights=False), From 2b901ee57230559aaf39867c7698f6aca3617162 Mon Sep 17 00:00:00 2001 From: yxli2123 <69247082+yxli2123@users.noreply.github.com> Date: Wed, 29 Nov 2023 11:08:17 -0500 Subject: [PATCH 03/12] Add LoftQ initialization method for LoRA (#1150) --------- Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Co-authored-by: Benjamin Bossan --- README.md | 1 + examples/loftq_finetuning/README.md | 69 ++ .../loftq_finetuning/quantize_save_load.py | 244 +++++ .../loftq_finetuning/train_gsm8k_llama.py | 866 ++++++++++++++++++ requirements.txt | 15 + src/peft/__init__.py | 1 + src/peft/tuners/__init__.py | 2 +- src/peft/tuners/lora/__init__.py | 4 +- src/peft/tuners/lora/config.py | 45 +- src/peft/tuners/lora/layer.py | 48 +- src/peft/tuners/lora/model.py | 4 + src/peft/utils/loftq_utils.py | 227 +++++ 12 files changed, 1514 insertions(+), 12 deletions(-) create mode 100644 examples/loftq_finetuning/README.md create mode 100644 examples/loftq_finetuning/quantize_save_load.py create mode 100644 examples/loftq_finetuning/train_gsm8k_llama.py create mode 100644 requirements.txt create mode 100644 src/peft/utils/loftq_utils.py diff --git a/README.md b/README.md index 445cb265394..79259f98ee9 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ Supported methods: 7. MultiTask Prompt Tuning: [Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning](https://arxiv.org/abs/2303.02861) 8. LoHa: [FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning](https://arxiv.org/abs/2108.06098) 9. LoKr: [KronA: Parameter Efficient Tuning with Kronecker Adapter](https://arxiv.org/abs/2212.10650) based on [Navigating Text-To-Image Customization:From LyCORIS Fine-Tuning to Model Evaluation](https://arxiv.org/abs/2309.14859) implementation +10. LoftQ: [LoftQ: LoRA-Fine-Tuning-aware Quantization for Large Language Models](https://arxiv.org/abs/2310.08659) ## Getting started diff --git a/examples/loftq_finetuning/README.md b/examples/loftq_finetuning/README.md new file mode 100644 index 00000000000..726f544e854 --- /dev/null +++ b/examples/loftq_finetuning/README.md @@ -0,0 +1,69 @@ +# LoftQ: LoRA-fine-tuning-aware Quantization + +## Introduction + +LoftQ provides better initialization for LoRA adapters A and B, +and the Quantization of pre-trained weights W. + +## Quantization +We recommend to save the quantized backbone model as fp16/fp32 +and load it as [NormalFloat4](https://arxiv.org/abs/2305.14314). + +We provide a simple example to show how to quantize llama-2-7b model and save/load it. + +```sh +python quantize_save_load.py \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --token HF_TOKEN \ + --bits 4 --iter 5 --rank 16 \ + --save_dir model_zoo/loftq/ +``` + +- `HF_TOKEN` is the token used to access to [LLAMA models](https://huggingface.co/meta-llama). +- `quantize_and_save()` function will quantize the backbone and initialize LoRA adapters. +It creates 2 folders under `$save_dir`. The quantized backbone is at `Llama-2-7b-hf-4bit-16rank`, +and the LoRA adapters are at the sub-folder `Llama-2-7b-hf-4bit-16rank/loftq_init`. + +## Fine-tuning + +Here is an example to load the quantized backbone and LoRA adapters: + +```python +import os + +from transformers import AutoModelForCausalLM +from peft import PeftModel + + +base_model = AutoModelForCausalLM.from_pretrained( + os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank"), + load_in_4bit=True, +) +peft_model = PeftModel.from_pretrained( + base_model, + os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank", "loftq_init"), + is_trainable=True, +) +``` + +We also provide an example to fine-tune LoftQ on GSM8K. +We load the quantized backbone and LoRA adapters from the [LoftQ Huggingface hub](https://huggingface.co/LoftQ). + +```sh +python train_gsm8k_llama.py \ + --model_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank \ + --output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \ + --learning_rate 3e-4 \ + --seed 202 \ + --dataset_name gsm8k \ + --dataset_config main \ + --pad_to_max_length \ + --max_source_length 128 \ + --max_target_length 256 \ + --num_train_epochs 5 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --with_tracking \ + --report_to tensorboard +``` diff --git a/examples/loftq_finetuning/quantize_save_load.py b/examples/loftq_finetuning/quantize_save_load.py new file mode 100644 index 00000000000..3c47fa45cdd --- /dev/null +++ b/examples/loftq_finetuning/quantize_save_load.py @@ -0,0 +1,244 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch +import torch.nn as nn +from transformers import ( + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoTokenizer, + BitsAndBytesConfig, +) + +from peft import LoftQConfig, LoraConfig, PeftModel, TaskType, get_peft_model + + +class Shell(nn.Module): + def __init__(self, weight, bias=None): + super().__init__() + self.weight = nn.Parameter(weight, requires_grad=False) + if bias is not None: + self.bias = nn.Parameter(bias, requires_grad=False) + + +def unwarap_model(model, sub_module_name=".base_layer"): + sub_module_name_list = [k.split(sub_module_name)[0] for k in model.state_dict().keys() if sub_module_name in k] + sub_module_name_set = set(sub_module_name_list) + for name in sub_module_name_set: + # get the parent of the submodule + name_parent = ".".join(name.split(".")[:-1]) + name_child = name.split(".")[-1] + sub_module = model.get_submodule(name_parent) + print(sub_module) + + # replace with shell + child = getattr(sub_module, name_child) + weight = getattr(child.base_layer, "weight", None) + bias = getattr(child.base_layer, "bias", None) + shell = Shell(weight, bias) + + setattr(sub_module, name_child, shell) + + print("You have unwrapped the model. Use it on your own risk.") + + +def print_model(model, name): + print("=" * 10 + name + "=" * 10) + print(model) + for name, param in model.named_parameters(): + if torch.is_tensor(param): + if param.dtype in [torch.float32, torch.float16]: + print( + name, + param.shape, + param.device, + param.dtype, + param.requires_grad, + param.mean().item(), + param.max().item(), + ) + else: + print(name, param.shape, param.device, param.dtype, param.requires_grad) + + +def arg_parse(): + parser = argparse.ArgumentParser(description="Quantize a model with LoftQ.") + parser.add_argument( + "--model_name_or_path", + type=str, + default=None, + required=True, + help="The name or path of the fp32/16 model.", + ) + parser.add_argument( + "--token", + type=str, + default=None, + help="The access token to download model from HuggingFace Hub.", + ) + parser.add_argument( + "--bits", + type=int, + default=4, + help="The quantized bits", + ) + parser.add_argument( + "--iter", + type=int, + default=1, + help="The alternating steps in LoftQ", + ) + parser.add_argument( + "--rank", + type=int, + default=16, + help="The rank of the LoRA adapter", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./model_zoo/loftq/", + help="The rank of the LoRA adapter", + ) + args = parser.parse_args() + return args + + +def quantize_and_save(): + args = arg_parse() + + # Download weights and configure LoRA + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, token=args.token, trust_remote_code=True) + if any(name in args.model_name_or_path.lower() for name in ["llama", "mistral", "falcon"]): + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, token=args.token, trust_remote_code=True, device_map="auto" + ) + task_type = TaskType.CAUSAL_LM + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"] + + elif any(name in args.model_name_or_path.lower() for name in ["bart", "t5"]): + model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, token=args.token, device_map="auto") + task_type = TaskType.SEQ_2_SEQ_LM + target_modules = ["q_proj", "k_proj", "v_proj", "fc1", "fc2", "out_proj"] + + elif any(name in args.model_name_or_path.lower() for name in ["deberta", "roberta", "bert"]): + model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, token=args.token) + model = model.cuda() + task_type = TaskType.SEQ_CLS + target_modules = ["query_proj", "key_proj", "value_proj", "dense"] # embeddings not supported by peft + else: + raise NotImplementedError("Other models not supported yet.") + + # Config of LoftQ + loftq_config = LoftQConfig(loftq_bits=args.bits, loftq_iter=args.iter) + + lora_config = LoraConfig( + task_type=task_type, + inference_mode=True, + r=args.rank, + lora_alpha=16 if task_type is TaskType.CAUSAL_LM else args.rank, + lora_dropout=0.1, + target_modules=target_modules, + init_lora_weights="loftq", + loftq_config=loftq_config, + ) + + # Obtain LoftQ model + lora_model = get_peft_model(model, lora_config) + base_model = lora_model.get_base_model() + + # Save LoftQ model + model_name = args.model_name_or_path.split("/")[-1] + f"-{args.bits}bit" + f"-{args.rank}rank" + base_model_dir = os.path.join(args.save_dir, model_name) + lora_model_dir = os.path.join(args.save_dir, model_name, "loft_init") + + # save lora adapters first + lora_model.base_model.peft_config[ + "default" + ].base_model_name_or_path = base_model_dir # This can be a local path or Hub model id + lora_model.base_model.peft_config["default"].init_lora_weights = True # Don't apply LoftQ when loading again + + lora_model.save_pretrained(lora_model_dir) + print_model(lora_model, "lora_model") + + # remove lora adapters and save the backbone + unwarap_model(base_model) + base_model.save_pretrained(base_model_dir) + tokenizer.save_pretrained(base_model_dir) + + print_model(base_model, "base_model") + + return base_model_dir, lora_model_dir + + +def load_loftq(base_model_path, lora_adapter_path): + if any(name in base_model_path.lower() for name in ["llama", "mistral", "falcon"]): + model = AutoModelForCausalLM.from_pretrained( + base_model_path, + device_map="auto", + low_cpu_mem_usage=True, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="nf4", + ), + ) + elif any(name in base_model_path.lower() for name in ["bart", "t5"]): + model = AutoModelForSeq2SeqLM.from_pretrained( + base_model_path, + device_map="auto", + low_cpu_mem_usage=True, + load_in_4bit=True, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="nf4", + ), + ) + elif any(name in base_model_path.lower() for name in ["deberta", "roberta", "bert"]): + model = AutoModelForSequenceClassification.from_pretrained( + base_model_path, + low_cpu_mem_usage=True, + load_in_4bit=True, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="nf4", + ), + ) + else: + raise NotImplementedError("Other models not supported yet.") + + lora_model = PeftModel.from_pretrained(model, lora_adapter_path, is_trainable=True) + + # Do training or inference below + print_model(lora_model, "lora_model") + print_model(model, "base_model") + + +if __name__ == "__main__": + base_dir, lora_dir = quantize_and_save() + load_loftq(base_dir, lora_dir) + +# example command: +# python quantize_save_load.py \ +# --model_name_or_path meta-llama/Llama-2-7b-hf \ +# --token XXX \ +# --bits 4 --iter 5 --rank 16 \ +# --save_dir ./model_zoo/loftq/ diff --git a/examples/loftq_finetuning/train_gsm8k_llama.py b/examples/loftq_finetuning/train_gsm8k_llama.py new file mode 100644 index 00000000000..e8c3580d2e1 --- /dev/null +++ b/examples/loftq_finetuning/train_gsm8k_llama.py @@ -0,0 +1,866 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import copy +import logging +import math +import os +import random +import re +from pathlib import Path + +import datasets +import torch +import transformers +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from datasets import load_dataset +from huggingface_hub import Repository, create_repo +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils import send_example_telemetry +from transformers.utils.versions import require_version + +from peft import PeftModel + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +# check_min_version("4.32.0.dev0") + +logger = get_logger(__name__) + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) +HF_TOKEN = "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" + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--train_file", type=str, default=None, help="A csv, txt or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv, txt or a json file containing the validation data." + ) + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the πŸ€— Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--ignore_pad_token_for_loss", + type=bool, + default=True, + help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.", + ) + parser.add_argument( + "--max_source_length", + type=int, + default=128, + help=( + "The maximum total input sequence length after " + "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded." + ), + ) + parser.add_argument( + "--max_target_length", + type=int, + default=128, + help=( + "The maximum total sequence length for target text after " + "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded." + "during ``evaluate`` and ``predict``." + ), + ) + parser.add_argument( + "--pad_to_max_length", + action="store_true", + help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files." + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--trust_remote_code", + type=bool, + default=False, + help=( + "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" + "should only be set to `True` for repositories you trust and in which you have read the code, as it will" + "execute code present on the Hub on your local machine." + ), + ) + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), + ) + parser.add_argument( + "--low_cpu_mem_usage", + action="store_true", + help=( + "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." + "If passed, LLM loading time and RAM consumption will be benefited." + ), + ) + ########################## + # Generation Config # + ########################## + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="temperature of 1.0 has no effect, lower tend toward greedy sampling", + ) + parser.add_argument("--k", type=int, default=40, help="Choose k candidate words") + parser.add_argument("--p", type=float, default=0.95, help="The sum of probability of candidate words is 0.9 ") + + ########################## + # Exp Args # + ########################## + parser.add_argument( + "--adapter_name_or_path", + type=str, + default=None, + help=( + "The LoRA adapter checkpoint. Set None if you want to fine-tune from LoftQ." + "Specify a path if you want to evaluate." + ), + ) + + args = parser.parse_args() + + # Sanity checks + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def main(): + args = parse_args() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_clm_no_trainer", args) + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers + # in the environment + accelerator_log_kwargs = {} + + if args.with_tracking: + accelerator_log_kwargs["log_with"] = args.report_to + accelerator_log_kwargs["project_dir"] = args.output_dir + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + # Retrieve of infer repo_name + repo_name = args.hub_model_id + if repo_name is None: + repo_name = Path(args.output_dir).absolute().name + # Create repo and retrieve repo_id + repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id + # Clone repo locally + repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, + ) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained( + args.config_name, + trust_remote_code=args.trust_remote_code, + ) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained( + args.model_name_or_path, + trust_remote_code=args.trust_remote_code, + ) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code + ) + elif args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.model_name_or_path, + use_fast=not args.use_slow_tokenizer, + trust_remote_code=args.trust_remote_code, + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + ########################## + # Tokenizer # + ########################## + tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token + tokenizer.padding_side = "left" # Allow batched inference + tokenizer.truncation_side = "left" + + if args.model_name_or_path: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + low_cpu_mem_usage=True, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=config.torch_dtype, + ), + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForCausalLM.from_config(config, trust_remote_code=args.trust_remote_code) + + ########################## + # Peft Model # + ########################## + if args.adapter_name_or_path is None: + model = PeftModel.from_pretrained(model, args.model_name_or_path, subfolder="loftq_init", is_trainable=True) + else: + model = PeftModel.from_pretrained(model, args.adapter_name_or_path, is_trainable=True) + model.print_trainable_parameters() + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch + # on a small vocab and want a smaller embedding size, remove this test. + embedding_size = model.get_input_embeddings().weight.shape[0] + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer)) + + # Preprocessing the datasets. + # First we tokenize all the texts. + ########################## + # GSM8K dataset # + ########################## + + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + + # Get the column names for source/target. + source_column, target_column = "question", "answer" + + # Temporarily set max_target_length for training. + padding = "max_length" if args.pad_to_max_length else False + task_prompt = "\nAnswer the above question. First think step by step and then answer the final number.\n" + + def prompt_process(sent_1, sent_2, prompt_1="", prompt_2="", prompt_3=""): + sent_2 = sent_2.replace("####", "The final answer is") + return prompt_1 + sent_1 + prompt_2 + sent_2 + prompt_3 + + def preprocess_function_train(examples): + sources = examples[source_column] + targets = examples[target_column] + + inputs = [prompt_process(source, target, prompt_2=task_prompt) for (source, target) in zip(sources, targets)] + + model_inputs = tokenizer( + inputs, + max_length=args.max_source_length + args.max_target_length, + padding=padding, + truncation=True, + return_tensors="pt", + ) + + labels = copy.deepcopy(model_inputs) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and args.ignore_pad_token_for_loss: + # get the length of the target tokens. -1 to kick out the token + target_tokens = tokenizer(targets, padding=False) + target_len = [len(label) - 1 for label in target_tokens["input_ids"]] + + # don't calculate the loss from source and padding (left padding) + for i in range(len(labels["input_ids"])): + labels["input_ids"][i, : -target_len[i]] = -100 + + model_inputs["labels"] = labels["input_ids"] + return model_inputs + + def preprocess_function_test(examples): + sources = examples[source_column] + labels = examples[target_column] + + inputs = [source + task_prompt for source in sources] + + model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True) + labels = tokenizer(labels, max_length=args.max_target_length, padding=padding, truncation=True) + + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + with accelerator.main_process_first(): + train_dataset = raw_datasets["train"].map( + preprocess_function_train, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on training dataset", + ) + + eval_dataset = raw_datasets["test"].map( + preprocess_function_test, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on test dataset", + ) + + # Log a few random samples from the set: + for index in random.sample(range(len(train_dataset)), 2): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + for index in random.sample(range(len(eval_dataset)), 2): + logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = DataLoader( + train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size + ) + eval_dataloader = DataLoader( + eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size + ) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "layer_norm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and "lora" in n], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. + if accelerator.distributed_type == DistributedType.TPU: + model.tie_weights() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Figure out how many steps we should save the Accelerator states + checkpointing_steps = args.checkpointing_steps + if checkpointing_steps is not None and checkpointing_steps.isdigit(): + checkpointing_steps = int(checkpointing_steps) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if args.with_tracking: + experiment_config = vars(args) + # TensorBoard cannot log Enums, need the raw value + experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value + accelerator.init_trackers("clm_no_trainer", experiment_config) + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + starting_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": + checkpoint_path = args.resume_from_checkpoint + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] + dirs.sort(key=os.path.getctime) + path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last + checkpoint_path = path + path = os.path.basename(checkpoint_path) + + accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") + accelerator.load_state(path) + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + 1 + resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps + starting_epoch = resume_step // len(train_dataloader) + resume_step -= starting_epoch * len(train_dataloader) + completed_steps = resume_step // args.gradient_accumulation_steps + + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) + + for epoch in range(starting_epoch, args.num_train_epochs): + model.train() + if args.with_tracking: + total_loss = 0 + if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) + else: + active_dataloader = train_dataloader + for step, batch in enumerate(active_dataloader): + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + # We keep track of the loss at each epoch + if args.with_tracking: + total_loss += loss.detach().float() + accelerator.backward(loss) + accelerator.print(f"Epoch: {epoch} | Step: {step} | Loss: {loss}") + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + completed_steps += 1 + + if isinstance(checkpointing_steps, int): + if completed_steps % checkpointing_steps == 0: + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + if completed_steps >= args.max_train_steps: + break + + model.eval() + gen_kwargs = { + "max_new_tokens": args.max_target_length, + "temperature": args.temperature, + "top_k": args.k, + "top_p": args.p, + "do_sample": True, + } + ans_pred_list = [] + ans_gold_list = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + gen_kwargs["input_ids"] = batch["input_ids"] + gen_kwargs["attention_mask"] = batch["attention_mask"] + generated_tokens = accelerator.unwrap_model(model).generate(**gen_kwargs) + + pred_tokens = generated_tokens[:, args.max_source_length :] + pred_tokens = accelerator.pad_across_processes(pred_tokens, dim=1, pad_index=tokenizer.pad_token_id) + gold_tokens = batch["labels"] + + if not args.pad_to_max_length: + # If we did not pad to max length, we need to pad the labels too + gold_tokens = accelerator.pad_across_processes( + batch["labels"], dim=1, pad_index=tokenizer.pad_token_id + ) + + pred_tokens, gold_tokens = accelerator.gather_for_metrics((pred_tokens, gold_tokens)) + pred_tokens, gold_tokens = pred_tokens.cpu().numpy(), gold_tokens.cpu().numpy() + + if isinstance(pred_tokens, tuple): + pred_tokens = pred_tokens[0] + decoded_pred = tokenizer.batch_decode(pred_tokens, skip_special_tokens=True) + decoded_gold = tokenizer.batch_decode(gold_tokens, skip_special_tokens=True) + + # Extract the numbers in sentences + accelerator.print(decoded_pred) + ans_pred_list += [extract_answer_number(sentence_pred) for sentence_pred in decoded_pred] + ans_gold_list += [extract_answer_number(sentence_gold) for sentence_gold in decoded_gold] + + accelerator.print(ans_pred_list) + accelerator.print(ans_gold_list) + accuracy = compute_accuracy(ans_gold_list, ans_pred_list) + + logger.info(f"epoch {epoch}: accuracy: {accuracy}") + + if args.with_tracking: + accelerator.log( + { + "accuracy": accuracy, + "train_loss": total_loss.item() / len(train_dataloader), + "epoch": epoch, + "step": completed_steps, + }, + step=completed_steps, + ) + + if args.push_to_hub and epoch < args.num_train_epochs - 1: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + repo.push_to_hub( + commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True + ) + + if args.checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + + if args.with_tracking: + accelerator.end_training() + + if args.output_dir is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + + +PATTERN_NUMBER = re.compile(r"-?\d+\.?\d*") + + +def extract_answer_number(sentence: str) -> float: + sentence = sentence.replace(",", "") + pred = PATTERN_NUMBER.findall(sentence) + if not pred: + return float("inf") + segment = sentence.split("The final answer is ") + if len(segment) > 1: + pred_answer = segment[1] + pred_answer = PATTERN_NUMBER.findall(pred_answer) + if len(pred_answer) > 0: + pred_answer = pred_answer[0] + else: + pred_answer = float(pred[-1]) + else: + pred_answer = float(pred[-1]) + + if isinstance(pred_answer, str): + try: + pred_answer = float(pred_answer) + except ValueError: + pred_answer = float("inf") + return pred_answer + + +def compute_accuracy(pred: list, gold: list): + acc = 0.0 + for p, g in zip(pred, gold): + if p == g: + acc += 1 + + return acc / len(pred) + + +if __name__ == "__main__": + main() + +# example command + +# python train_gsm8k_llama.py \ +# --model_name_or_path LoftQ/Llama-2-7b-hf-bit4-rank64-backbone \ +# --adapter_name_or_path LoftQ/Llama-2-7b-hf-bit4-rank64-adapters \ +# --output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \ +# --learning_rate 1e-4 \ +# --seed 202 \ +# --dataset_name gsm8k \ +# --dataset_config main \ +# --pad_to_max_length \ +# --max_source_length 128 \ +# --max_target_length 256 \ +# --num_train_epochs 5 \ +# --per_device_train_batch_size 4 \ +# --per_device_eval_batch_size 4 \ +# --gradient_accumulation_steps 4 \ +# --with_tracking \ +# --report_to tensorboard diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000000..dca857de324 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +accelerate +torch +safetensors +bitsandbytes +scipy +peft +transformers +tqdm +packaging +pytest +numpy +pyyaml +datasets +psutil +setuptools \ No newline at end of file diff --git a/src/peft/__init__.py b/src/peft/__init__.py index a3ce332f247..4d9380e6978 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -48,6 +48,7 @@ AdaptionPromptConfig, AdaptionPromptModel, LoraConfig, + LoftQConfig, LoraModel, LoHaConfig, LoHaModel, diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index b357d47dc18..666e29d9973 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -18,7 +18,7 @@ # limitations under the License. from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel -from .lora import LoraConfig, LoraModel +from .lora import LoraConfig, LoraModel, LoftQConfig from .loha import LoHaConfig, LoHaModel from .lokr import LoKrConfig, LoKrModel from .ia3 import IA3Config, IA3Model diff --git a/src/peft/tuners/lora/__init__.py b/src/peft/tuners/lora/__init__.py index d02bf2d9481..ddc81d53cdd 100644 --- a/src/peft/tuners/lora/__init__.py +++ b/src/peft/tuners/lora/__init__.py @@ -15,13 +15,13 @@ from peft.import_utils import is_bnb_4bit_available, is_bnb_available -from .config import LoraConfig +from .config import LoftQConfig, LoraConfig from .gptq import QuantLinear from .layer import Conv2d, Embedding, Linear, LoraLayer from .model import LoraModel -__all__ = ["LoraConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"] +__all__ = ["LoraConfig", "LoftQConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"] if is_bnb_available(): diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index b1e31d81987..0dcca5c1e66 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -22,6 +22,25 @@ from peft.utils import PeftType +@dataclass +class LoftQConfig: + """ + This is the sub-configuration class to store the configuration of a [`LoraModel`]. + + Args: + bits_pattern (`dict`): The mapping from layer names or regexp expression to bits which are different from the + default bits specified by `bits`. For example, `{model.decoder.layers.0.encoder_attn.k_proj: 2`}. + bits (`int`): Quantization bits for LoftQ. + iter (`int`): Alternating iterations for LoftQ. + fake (`bool`): True: use fp16/fp32; used for first time to save weights. False: use bitsandbytes 4bit linear + models. weights can't be saved. Recommend to set to True, save the weights and load the saved weights in 4 + bits. + """ + + loftq_bits: int = field(default=4, metadata={"help": "Quantization bits for LoftQ"}) + loftq_iter: int = field(default=1, metadata={"help": "Alternating iterations for LoftQ"}) + + @dataclass class LoraConfig(PeftConfig): """ @@ -78,7 +97,7 @@ class LoraConfig(PeftConfig): "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." }, ) - init_lora_weights: bool | Literal["gaussian"] = field( + init_lora_weights: bool | Literal["gaussian", "loftq"] = field( default=True, metadata={ "help": ( @@ -86,6 +105,7 @@ class LoraConfig(PeftConfig): "initialization from the reference implementation from Microsoft. Passing 'gaussian' results " "in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization " "to False leads to completely random initialization and is discouraged." + "Pass `'loftq'` to use LoftQ initialization" ), }, ) @@ -121,6 +141,16 @@ class LoraConfig(PeftConfig): ) }, ) + # dict type is used when loading config.json + loftq_config: Union[LoftQConfig, dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone " + "weights and initialize Lora layers." + ) + }, + ) def __post_init__(self): self.peft_type = PeftType.LORA @@ -134,3 +164,16 @@ def __post_init__(self): # if target_modules is a regex expression, then layers_pattern should be None if isinstance(self.target_modules, str) and self.layers_pattern is not None: raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") + + # handle init_lora_weights and loftq_config + if self.init_lora_weights == "loftq": + import importlib + + if not importlib.util.find_spec("scipy"): + raise ImportError("The required package 'scipy' is not installed. Please install it to continue.") + if self.loftq_config is None: + raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.") + + # convert loftq_config to dict + if self.loftq_config is not None and not isinstance(self.loftq_config, dict): + self.loftq_config = vars(self.loftq_config) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 5ea726d2ffb..cf97108c871 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -15,7 +15,7 @@ import math import warnings -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import torch import torch.nn as nn @@ -46,6 +46,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] + self.kwargs = kwargs base_layer = self.get_base_layer() if isinstance(base_layer, nn.Linear): @@ -83,7 +84,10 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False) self.scaling[adapter_name] = lora_alpha / r - if init_lora_weights: + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: self.reset_lora_parameters(adapter_name, init_lora_weights) weight = getattr(self.get_base_layer(), "weight", None) @@ -115,7 +119,10 @@ def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lo self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False) self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) self.scaling[adapter_name] = lora_alpha / r - if init_lora_weights: + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: self.reset_lora_parameters(adapter_name, init_lora_weights) weight = getattr(base_layer, "weight", None) @@ -142,7 +149,11 @@ def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A) self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B) self.scaling[adapter_name] = lora_alpha / r - self.reset_lora_parameters(adapter_name, init_lora_weights) + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) base_layer = self.get_base_layer() weight = getattr(base_layer, "weight", None) @@ -170,6 +181,27 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights): nn.init.zeros_(self.lora_embedding_A[adapter_name]) nn.init.normal_(self.lora_embedding_B[adapter_name]) + def loftq_init(self, adapter_name): + from peft.utils.loftq_utils import loftq_init + + weight = self.get_base_layer().weight + kwargs = { + "num_bits": self.kwargs.get("loftq_bits", 4), + "reduced_rank": self.r[adapter_name], + "num_iter": self.kwargs.get("loftq_iter", 1), + } + + qweight, lora_A, lora_B = loftq_init(weight, **kwargs) + if adapter_name in self.lora_A.keys(): + # initialize A the same way as the default for nn.Linear and B to zero + self.lora_A[adapter_name].weight.data = lora_A + self.lora_B[adapter_name].weight.data = lora_B + if adapter_name in self.lora_embedding_A.keys(): + # initialize a the same way as the default for nn.linear and b to zero + self.lora_embedding_A[adapter_name].weight.data = lora_A + self.lora_embedding_B[adapter_name].weight.data = lora_B + self.get_base_layer().weight.data = qweight + def set_scale(self, adapter, scale): if adapter not in self.scaling: # Ignore the case where the adapter is not in the layer @@ -218,11 +250,11 @@ def __init__( lora_dropout: float = 0.0, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) is_target_conv_1d_layer: bool = False, - init_lora_weights: bool = True, + init_lora_weights: Union[bool, str] = True, **kwargs, ) -> None: super().__init__() - LoraLayer.__init__(self, base_layer) + LoraLayer.__init__(self, base_layer, **kwargs) self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name @@ -351,7 +383,7 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, - init_lora_weights: bool = True, + init_lora_weights: Union[bool, str] = True, **kwargs, ) -> None: super().__init__() @@ -491,7 +523,7 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, - init_lora_weights: bool = True, + init_lora_weights: Union[bool, str] = True, **kwargs, ) -> None: super().__init__() diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 653a684276b..6e0a64187af 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -286,8 +286,10 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): elif isinstance(target_base_layer, torch.nn.Embedding): embedding_kwargs = kwargs.copy() embedding_kwargs.pop("fan_in_fan_out", None) + embedding_kwargs.update(lora_config.loftq_config) new_module = Embedding(target, adapter_name, **embedding_kwargs) elif isinstance(target_base_layer, torch.nn.Conv2d): + kwargs.update(lora_config.loftq_config) new_module = Conv2d(target, adapter_name, **kwargs) elif isinstance(target_base_layer, torch.nn.Linear): if kwargs["fan_in_fan_out"]: @@ -296,6 +298,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): "Setting fan_in_fan_out to False." ) kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False + kwargs.update(lora_config.loftq_config) new_module = Linear(target, adapter_name, **kwargs) elif isinstance(target_base_layer, Conv1D): if not kwargs["fan_in_fan_out"]: @@ -304,6 +307,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): "Setting fan_in_fan_out to True." ) kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True + kwargs.update(lora_config.loftq_config) new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs) else: raise ValueError( diff --git a/src/peft/utils/loftq_utils.py b/src/peft/utils/loftq_utils.py new file mode 100644 index 00000000000..81ff1e2c34d --- /dev/null +++ b/src/peft/utils/loftq_utils.py @@ -0,0 +1,227 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Reference code: https://github.com/yxli2123/LoftQ/blob/main/utils.py +# Reference paper: https://arxiv.org/abs/2310.08659 + +import logging +from typing import Union + +import torch + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available + + +if is_bnb_available(): + import bitsandbytes as bnb + + +class NFQuantizer: + def __init__(self, num_bits=2, device="cuda", method="normal", block_size=64, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_bits = num_bits + self.device = device + self.method = method + self.block_size = block_size + if self.method == "normal": + self.norm_lookup_table = self.create_normal_map(num_bits=self.num_bits) + self.norm_lookup_table = self.norm_lookup_table.to(device) + elif self.method == "uniform": + self.norm_lookup_table = self.create_uniform_map(num_bits=self.num_bits) + self.norm_lookup_table = self.norm_lookup_table.to(device) + else: + raise NotImplementedError("Other quantization methods not supported yet.") + + @staticmethod + def create_uniform_map(symmetric=False, num_bits=4): + if symmetric: + # print("symmetric uniform quantization") + negative = torch.linspace(-1, 0, 2 ** (num_bits - 1)) + positive = torch.linspace(0, 1, 2 ** (num_bits - 1)) + table = torch.cat([negative, positive[1:]]) + else: + # print("asymmetric uniform quantization") + table = torch.linspace(-1, 1, 2**num_bits) + return table + + @staticmethod + def create_normal_map(offset=0.9677083, symmetric=False, num_bits=2): + try: + from scipy.stats import norm + except ImportError: + raise ImportError("The required package 'scipy' is not installed. Please install it to continue.") + + variations = 2**num_bits + if symmetric: + v = norm.ppf(torch.linspace(1 - offset, offset, variations + 1)).tolist() + values = [] + for index in range(len(v) - 1): + values.append(0.5 * v[index] + 0.5 * v[index + 1]) + v = values + else: + # one more positive value, this is an asymmetric type + v1 = norm.ppf(torch.linspace(offset, 0.5, variations // 2 + 1)[:-1]).tolist() + v2 = [0] + v3 = (-norm.ppf(torch.linspace(offset, 0.5, variations // 2)[:-1])).tolist() + v = v1 + v2 + v3 + + values = torch.Tensor(v) + values = values.sort().values + values /= values.max() + return values + + def quantize_tensor(self, weight): + max_abs = torch.abs(weight).max() + weight_normed = weight / max_abs + + weight_normed_expanded = weight_normed.unsqueeze(-1) + + # Reshape L to have the same number of dimensions as X_expanded + L_reshaped = torch.tensor(self.norm_lookup_table).reshape(1, -1) + + # Calculate the absolute difference between X_expanded and L_reshaped + abs_diff = torch.abs(weight_normed_expanded - L_reshaped) + + # Find the index of the minimum absolute difference for each element + qweight = torch.argmin(abs_diff, dim=-1) + return qweight, max_abs + + def dequantize_tensor(self, qweight, max_abs): + qweight_flatten = qweight.flatten() + + weight_normed = self.norm_lookup_table[qweight_flatten] + weight = weight_normed * max_abs + + weight = weight.reshape(qweight.shape) + + return weight + + def quantize_block(self, weight): + if len(weight.shape) != 2: + raise ValueError(f"Only support 2D matrix, but your input has {len(weight.shape)} dimensions.") + if weight.shape[0] * weight.shape[1] % self.block_size != 0: + raise ValueError( + f"Weight with shape ({weight.shape[0]} x {weight.shape[1]}) " + f"is not dividable by block size {self.block_size}." + ) + + M, N = weight.shape + device = weight.device + + # Quantization + weight_flatten = weight.flatten() # (M*N, ) + weight_block = weight_flatten.reshape(-1, self.block_size) # (L, B), L = M * N / B + if self.method == "normal": + weight_max = weight_block.abs().max(dim=-1)[0] # (L, 1) + elif self.method == "uniform": + weight_max = weight_block.mean(dim=-1) + 2.5 * weight_block.std(dim=-1) + else: + raise NotImplementedError("Method not supported yet.") + weight_max = weight_max.unsqueeze(-1) + weight_divabs = weight_block / weight_max # (L, B) + weight_divabs = weight_divabs.unsqueeze(-1) # (L, B, 1) + L_reshaped = self.norm_lookup_table.reshape(1, -1) # (1, 2**K) + + abs_diff = torch.abs(weight_divabs - L_reshaped) # (L, B, 2**K) + qweight = torch.argmin(abs_diff, dim=-1) # (L, B) + + # Pack multiple k-bit into uint8 + qweight = qweight.reshape(-1, 8 // self.num_bits) + qweight_pack = torch.zeros((M * N // 8 * self.num_bits, 1), dtype=torch.uint8, device=device) + + # data format example: + # [1, 0, 3, 2] or [01, 00, 11, 10] -> [10110001], LIFO + for i in range(8 // self.num_bits): + qweight[:, i] = qweight[:, i] << i * self.num_bits + qweight_pack[:, 0] |= qweight[:, i] + + return qweight_pack, weight_max, weight.shape + + def dequantize_block(self, qweight, weight_max, weight_shape): + # unpack weight + device = qweight.device + weight = torch.zeros((qweight.shape[0], 8 // self.num_bits), dtype=torch.float32, device=device) + for i in range(8 // self.num_bits): + lookup_table_idx = qweight.to(torch.long) % 2**self.num_bits # get the most right 2 bits + lookup_table_idx = lookup_table_idx.to(torch.int) + weight[:, i] = self.norm_lookup_table[lookup_table_idx].squeeze() + qweight = qweight >> self.num_bits # right shift 2 bits of the original data + + weight_block = weight.reshape(-1, self.block_size) + weight = weight_block * weight_max + weight = weight.reshape(weight_shape) + + return weight + + +def _low_rank_decomposition(weight, reduced_rank=32): + """ + :param weight: The matrix to decompose, of shape (H, W) :param reduced_rank: the final rank :return: + """ + matrix_dimension = len(weight.size()) + if matrix_dimension != 2: + raise ValueError(f"Only support 2D matrix, but your input has {matrix_dimension} dimensions.") + + # Use SVD to decompose a matrix, default full_matrices is False to save parameters + U, S, Vh = torch.linalg.svd(weight, full_matrices=False) + + L = U @ (torch.sqrt(torch.diag(S)[:, 0:reduced_rank])) + R = torch.sqrt(torch.diag(S)[0:reduced_rank, :]) @ Vh + + return {"L": L, "R": R, "U": U, "S": S, "Vh": Vh, "reduced_rank": reduced_rank} + + +@torch.no_grad() +def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, reduced_rank: int, num_iter=1): + if num_bits not in [2, 4, 8]: + raise ValueError("Only support 2, 4, 8 bits quantization") + if num_iter <= 0: + raise ValueError("Number of iterations must be greater than 0") + + out_feature, in_feature = weight.size() + device = weight.device + dtype = weight.dtype + + logging.info( + f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} " + f"| Num Iter: {num_iter} | Num Bits: {num_bits}" + ) + if not is_bnb_4bit_available(): + quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64) + + weight = weight.to(torch.float32) + res = weight.clone() + for i in range(num_iter): + torch.cuda.empty_cache() + # Quantization + if num_bits == 4 and is_bnb_4bit_available(): + qweight = bnb.nn.Params4bit( + res.to("cpu"), requires_grad=False, compress_statistics=False, quant_type="nf4" + ).to(device) + dequantized_weight = bnb.functional.dequantize_4bit(qweight.data, qweight.quant_state) + else: + quantized_weight, max_abs, shape = quantizer.quantize_block(res) + dequantized_weight = quantizer.dequantize_block(quantized_weight, max_abs, shape) + + res = weight - dequantized_weight + + # Decompose the residual by SVD + output = _low_rank_decomposition(res, reduced_rank=reduced_rank) + L, R, reduced_rank = output["L"], output["R"], output["reduced_rank"] + res = weight - torch.mm(L, R) + + lora_A, lora_B = R, L + + return dequantized_weight.to(dtype), lora_A, lora_B From 2674f5ea66b43e07f08dabe2634aa9542d979211 Mon Sep 17 00:00:00 2001 From: zhangshengdong29 <435878393@qq.com> Date: Thu, 30 Nov 2023 23:24:58 +0800 Subject: [PATCH 04/12] Megatron distributed parallel linear LoRA (#1092) Adds option to use Megatron's ColumnParallelLinear and RowParallelLinear for LoRA linear layers, leading to improved performance when using LoRA with Megatron. --- src/peft/tuners/lora/config.py | 26 +++++ src/peft/tuners/lora/layer.py | 3 + src/peft/tuners/lora/model.py | 27 +++++ src/peft/tuners/lora/tp_layer.py | 158 +++++++++++++++++++++++++++++ tests/test_lora_megatron.py | 167 +++++++++++++++++++++++++++++++ 5 files changed, 381 insertions(+) create mode 100644 src/peft/tuners/lora/tp_layer.py create mode 100644 tests/test_lora_megatron.py diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 0dcca5c1e66..53269ebb8db 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -141,6 +141,32 @@ class LoraConfig(PeftConfig): ) }, ) + megatron_config: Optional[dict] = field( + default=None, + metadata={ + "help": ( + "The TransformerConfig from Megatron, it is used to create LoRA's parallel linear layer." + "You can get it like this, `core_transformer_config_from_args(get_args())`, " + "this two functions are from Megatron." + "You need to specify this parameter when you want to loraize the ColumnParallelLinear and " + "RowParallelLinear layers of megatron." + "It should be noted that we may not be able to use the `save_pretrained` and `from_pretrained` " + "functions, because TransformerConfig may not necessarily be serialized." + "But when using megatron, we can use `get_peft_model_state_dict` function and " + "megatron's framework, they can also save and load models and configurations." + ) + }, + ) + megatron_core: Optional[str] = field( + default="megatron.core", + metadata={ + "help": ( + "The core module from Megatron, it is used to judge and create LoRA's parallel linear layer. " + "It only needs to be passed in when you need to use your own modified megatron core module. " + "Otherwise, it will use the default value `megatron.core`. " + ) + }, + ) # dict type is used when loading config.json loftq_config: Union[LoftQConfig, dict] = field( default_factory=dict, diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index cf97108c871..3219ca1e47b 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -62,6 +62,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): # QuantLinear in_features, out_features = base_layer.infeatures, base_layer.outfeatures + elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): + # Megatron ColumnParallelLinear,RowParallelLinear + in_features, out_features = base_layer.input_size, base_layer.output_size else: raise ValueError(f"Unsupported layer type {type(base_layer)}") diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 6e0a64187af..4f6538e9122 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import math import operator import re @@ -259,6 +260,10 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): else: target_base_layer = target + megatron_core = None + if lora_config.megatron_config: + megatron_core = importlib.import_module(lora_config.megatron_core) + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): eightbit_kwargs = kwargs.copy() eightbit_kwargs.update( @@ -300,6 +305,28 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False kwargs.update(lora_config.loftq_config) new_module = Linear(target, adapter_name, **kwargs) + elif megatron_core and isinstance( + target_base_layer, + (megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear), + ): + from .tp_layer import LoraParallelLinear + + megatron_kwargs = kwargs.copy() + megatron_config = lora_config.megatron_config + if isinstance(megatron_config, dict): + transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig + megatron_config = transformer_config_class(**lora_config.megatron_config) + megatron_kwargs["megatron_config"] = megatron_config + if megatron_kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `ColumnParallelLinear` " + "or `RowParallelLinear`. " + "Setting fan_in_fan_out to False." + ) + megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False + new_module = LoraParallelLinear( + base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs + ) elif isinstance(target_base_layer, Conv1D): if not kwargs["fan_in_fan_out"]: warnings.warn( diff --git a/src/peft/tuners/lora/tp_layer.py b/src/peft/tuners/lora/tp_layer.py new file mode 100644 index 00000000000..676430cf38c --- /dev/null +++ b/src/peft/tuners/lora/tp_layer.py @@ -0,0 +1,158 @@ +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.init as init + +from .layer import LoraLayer + + +class LoraParallelLinear(nn.Module, LoraLayer): + """ + When the target layer parallel_linear is RowParallelLinear, in order to keep the input and output shapes + consistent, we need to split the lora matrix A into rows, and the lora_B at this time should be a complete linear + layer; In the same way, when the target layer is ColumnParallelLinear, we perform column segmentation on lora_B, + while lora_A is still a complete linear layer. + """ + + def __init__( + self, + base_layer, + adapter_name: str, + backend, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, + init_lora_weights: bool = True, + **kwargs, + ): + super().__init__() + LoraLayer.__init__(self, base_layer=base_layer) + + self.backend = backend + self.is_paralle_a = isinstance(base_layer, backend.RowParallelLinear) + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + + megatron_config = kwargs["megatron_config"] + parallel_linear_kwargs = {"megatron_config": megatron_config} + init_method = init.xavier_normal_ + if hasattr(megatron_config, "init_method"): + init_method = megatron_config.init_method + input_is_parallel = True + gather_output = False + if isinstance(base_layer, self.backend.RowParallelLinear): + input_is_parallel = base_layer.input_is_parallel + else: + gather_output = base_layer.gather_output + self.update_layer( + adapter_name, + r, + lora_alpha, + lora_dropout, + init_lora_weights, + init_method, + input_is_parallel, + gather_output, + **parallel_linear_kwargs, + ) + + self.is_target_conv_1d_layer = False + + def update_layer( + self, + adapter_name, + r, + lora_alpha, + lora_dropout, + init_lora_weights, + init_method=init.xavier_normal_, + input_is_parallel=True, + gather_output=False, + **parallel_linear_kwargs, + ): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + + megatron_config = parallel_linear_kwargs["megatron_config"] + # lora needs to be forced to upgrade to 32-bit precision, otherwise it will overflow + megatron_config.params_dtype = torch.float32 + if self.is_paralle_a: + lora_a = self.backend.RowParallelLinear( + input_size=self.in_features, + output_size=r, + bias=False, + input_is_parallel=input_is_parallel, + skip_bias_add=True, + init_method=init_method, + config=megatron_config, + ) + lora_b = nn.Linear(in_features=r, out_features=self.out_features, bias=False, dtype=torch.float32) + else: + lora_a = nn.Linear(in_features=self.in_features, out_features=r, bias=False, dtype=torch.float32) + lora_b = self.backend.ColumnParallelLinear( + input_size=r, + output_size=self.out_features, + bias=False, + gather_output=gather_output, + init_method=init_method, + config=megatron_config, + ) + self.lora_A[adapter_name] = lora_a + self.lora_B[adapter_name] = lora_b + self.scaling[adapter_name] = lora_alpha / r + if init_lora_weights: + self.reset_lora_parameters(adapter_name) + + weight = getattr(self.get_base_layer(), "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + self.set_adapter(self.active_adapters) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): + previous_dtype = x.dtype + # If weight is used for matrix multiplication here, the final aggregation operation of the original + # parallel_linear layer will be missing, so we need to directly call its forward function to obtain the + # output of the original parallel_linear layer. + if self.disable_adapters: + if self.merged: + self.unmerge() + result, bias = self.base_layer(x, *args, **kwargs) + elif self.merged: + result, bias = self.base_layer(x, *args, **kwargs) + else: + result, bias = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + + lora_result = lora_A(dropout(x)) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = lora_B(lora_result) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = lora_result * scaling + + result = result + lora_result + + result = result.to(previous_dtype) + return result, bias diff --git a/tests/test_lora_megatron.py b/tests/test_lora_megatron.py new file mode 100644 index 00000000000..80d0f43010e --- /dev/null +++ b/tests/test_lora_megatron.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 + +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import importlib +import os +import unittest + +import torch +import torch.nn.init as init + +from peft import LoraConfig, PeftModel, get_peft_model, get_peft_model_state_dict + + +def is_megatron_available() -> bool: + return importlib.util.find_spec("megatron") is not None + + +if is_megatron_available(): + from megatron.core import parallel_state, tensor_parallel + from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + from megatron.core.transformer.module import MegatronModule + from megatron.core.transformer.transformer_config import TransformerConfig + + world_size = 1 + rank = 0 + + def initialize_distributed(): + print(f"Initializing torch.distributed with rank: {rank}, world_size: {world_size}") + torch.cuda.set_device(0) + init_method = "tcp://" + master_ip = os.getenv("MASTER_ADDR", "localhost") + master_port = os.getenv("MASTER_PORT", "6001") + init_method += master_ip + ":" + master_port + torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank, init_method=init_method) + + def destroy_model_parallel(): + parallel_state.destroy_model_parallel() + torch.distributed.barrier() + + def initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + pipeline_model_parallel_split_rank=None, + ): + parallel_state.destroy_model_parallel() + if not torch.distributed.is_initialized(): + initialize_distributed() + parallel_state.initialize_model_parallel( + tensor_model_parallel_size, + pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size, + pipeline_model_parallel_split_rank, + ) + + class DummyModule(MegatronModule): + def __init__(self, config: TransformerConfig): + super().__init__(config) + self.linear = tensor_parallel.ColumnParallelLinear( + input_size=10, + output_size=10, + config=config, + init_method=init.xavier_normal_, + bias=False, + gather_output=False, + ) + self.lm_head = tensor_parallel.RowParallelLinear( + input_size=10, + output_size=10, + config=config, + init_method=init.xavier_normal_, + bias=False, + input_is_parallel=True, + ) + + def forward(self, input): + x = self.linear(input)[0] + x = self.lm_head(x)[0] + return x + + class TestMegatronLora(unittest.TestCase): + def setUp(self): + initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + transformer_config = { + "num_layers": 2, + "hidden_size": 12, + "num_attention_heads": 4, + "use_cpu_initialization": True, + } + config = TransformerConfig(**transformer_config) + self.megatron_module = DummyModule(config=config).cuda() + self.dummy_module = copy.deepcopy(self.megatron_module).cuda() + + lora_config = LoraConfig( + lora_alpha=16, + lora_dropout=0.1, + r=64, + bias="none", + target_modules=["linear", "lm_head"], + megatron_config=config, + megatron_core="megatron.core", + ) + self.megatron_module = get_peft_model(self.megatron_module, lora_config) + + def tearDown(self): + destroy_model_parallel() + + def test_megatron_lora_module(self): + megatron_module = self.megatron_module + self.assertTrue(isinstance(megatron_module, PeftModel)) + + for name, module in megatron_module.named_modules(): + if name.endswith("linear"): + self.assertTrue(hasattr(module, "lora_A")) + self.assertTrue(hasattr(module, "lora_B")) + if name.endswith("linear.lora_A.default"): + self.assertTrue(isinstance(module, torch.nn.Linear)) + if name.endswith("linear.lora_B.default"): + self.assertTrue(isinstance(module, tensor_parallel.ColumnParallelLinear)) + + if name.endswith("lm_head.lora_A.default"): + self.assertTrue(isinstance(module, tensor_parallel.RowParallelLinear)) + if name.endswith("lm_head.lora_B.default"): + self.assertTrue(isinstance(module, torch.nn.Linear)) + + def test_forward(self): + x = torch.ones((2, 4, 10)).cuda() + megatron_module_result = self.megatron_module(x) + dummt_module_result = self.dummy_module(x) + + # Because lora_B is initialized with 0, the forward results of two models should be equal before backward. + self.assertTrue(megatron_module_result.equal(dummt_module_result)) + + def test_backward(self): + optimizer = torch.optim.AdamW(self.megatron_module.parameters()) + loss_fn = torch.nn.CrossEntropyLoss() + + x = torch.randn(2, 4, 10, requires_grad=True).cuda() + label = torch.randint(10, (2 * 4,)).cuda() + + output = self.megatron_module(x) + output = output.reshape(2 * 4, 10) + loss = loss_fn(output, label) + + loss.backward() + optimizer.step() + + def test_get_peft_model_state_dict(self): + peft_state_dict = get_peft_model_state_dict(self.megatron_module) + + for key in peft_state_dict.keys(): + self.assertTrue("lora" in key) From da17ac0f484b28a8471004b47bddfc408969ae04 Mon Sep 17 00:00:00 2001 From: takuoko Date: Fri, 1 Dec 2023 00:58:42 +0900 Subject: [PATCH 05/12] [Feature] Support OFT (#1160) * Support OFT * add test * Update README * fix code quality * fix test * Skip 1 test * fix eps rule and add more test * feat: added examples to new OFT method * fix: removed wrong arguments from model example * fix: changed name of inference file * fix: changed prompt variable * fix docs * fix: dreambooth inference revision based on feedback * fix: review from BenjaminBossan * apply safe merge * del partially * refactor oft * refactor oft * del unused line * del unused line * fix skip in windows * skip test * Add comments about bias added place * rename orig_weights to new_weights * use inverse instead of linalg.inv * delete alpha and scaling --------- Co-authored-by: Lukas Kuhn Co-authored-by: Lukas Kuhn --- README.md | 7 +- .../oft_dreambooth_inference.ipynb | 89 ++ examples/oft_dreambooth/train_dreambooth.py | 1112 +++++++++++++++++ src/peft/__init__.py | 2 + src/peft/mapping.py | 4 + src/peft/peft_model.py | 2 + src/peft/tuners/__init__.py | 1 + src/peft/tuners/oft/__init__.py | 21 + src/peft/tuners/oft/config.py | 109 ++ src/peft/tuners/oft/layer.py | 375 ++++++ src/peft/tuners/oft/model.py | 108 ++ src/peft/utils/peft_types.py | 1 + src/peft/utils/save_and_load.py | 5 +- tests/test_config.py | 4 +- tests/test_custom_models.py | 103 +- tests/test_stablediffusion.py | 23 +- tests/testing_common.py | 6 +- 17 files changed, 1959 insertions(+), 13 deletions(-) create mode 100644 examples/oft_dreambooth/oft_dreambooth_inference.ipynb create mode 100644 examples/oft_dreambooth/train_dreambooth.py create mode 100644 src/peft/tuners/oft/__init__.py create mode 100644 src/peft/tuners/oft/config.py create mode 100644 src/peft/tuners/oft/layer.py create mode 100644 src/peft/tuners/oft/model.py diff --git a/README.md b/README.md index 79259f98ee9..09846dc61cd 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ Supported methods: 8. LoHa: [FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning](https://arxiv.org/abs/2108.06098) 9. LoKr: [KronA: Parameter Efficient Tuning with Kronecker Adapter](https://arxiv.org/abs/2212.10650) based on [Navigating Text-To-Image Customization:From LyCORIS Fine-Tuning to Model Evaluation](https://arxiv.org/abs/2309.14859) implementation 10. LoftQ: [LoftQ: LoRA-Fine-Tuning-aware Quantization for Large Language Models](https://arxiv.org/abs/2310.08659) +11. OFT: [Controlling Text-to-Image Diffusion by Orthogonal Finetuning](https://arxiv.org/abs/2306.07280) ## Getting started @@ -278,9 +279,9 @@ Find models that are supported out of the box below. Note that PEFT works with a ### Text-to-Image Generation -| Model | LoRA | LoHa | LoKr | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 | -| --------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | -| Stable Diffusion | βœ… | βœ… | βœ… | | | | +| Model | LoRA | LoHa | LoKr | OFT | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 | +| --------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | +| Stable Diffusion | βœ… | βœ… | βœ… | βœ… | | | | ### Image Classification diff --git a/examples/oft_dreambooth/oft_dreambooth_inference.ipynb b/examples/oft_dreambooth/oft_dreambooth_inference.ipynb new file mode 100644 index 00000000000..4a28c4040e4 --- /dev/null +++ b/examples/oft_dreambooth/oft_dreambooth_inference.ipynb @@ -0,0 +1,89 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "acd7b15e", + "metadata": {}, + "source": [ + "# Dreambooth with OFT\n", + "This Notebook assumes that you already ran the train_dreambooth.py script to create your own adapter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acab479f", + "metadata": {}, + "outputs": [], + "source": [ + "from diffusers import DiffusionPipeline\n", + "from diffusers.utils import check_min_version, get_logger\n", + "from peft import PeftModel\n", + "\n", + "# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\n", + "check_min_version(\"0.10.0.dev0\")\n", + "\n", + "logger = get_logger(__name__)\n", + "\n", + "BASE_MODEL_NAME = \"stabilityai/stable-diffusion-2-1-base\"\n", + "ADAPTER_MODEL_PATH = \"INSERT MODEL PATH HERE\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe = DiffusionPipeline.from_pretrained(\n", + " BASE_MODEL_NAME,\n", + ")\n", + "pipe.to('cuda')\n", + "pipe.unet = PeftModel.from_pretrained(pipe.unet, ADAPTER_MODEL_PATH + \"/unet\", adapter_name=\"default\")\n", + "pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, ADAPTER_MODEL_PATH + \"/text_encoder\", adapter_name=\"default\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"A photo of a sks dog\"\n", + "image = pipe(\n", + " prompt,\n", + " num_inference_steps=50,\n", + " height=512,\n", + " width=512,\n", + ").images[0]\n", + "image" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + }, + "vscode": { + "interpreter": { + "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/oft_dreambooth/train_dreambooth.py b/examples/oft_dreambooth/train_dreambooth.py new file mode 100644 index 00000000000..cacce706474 --- /dev/null +++ b/examples/oft_dreambooth/train_dreambooth.py @@ -0,0 +1,1112 @@ +import argparse +import gc +import hashlib +import itertools +import logging +import math +import os +import threading +import warnings +from contextlib import nullcontext +from pathlib import Path +from typing import Optional + +import datasets +import diffusers +import numpy as np +import psutil +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +from peft import get_peft_model +from peft.tuners.oft.config import OFTConfig + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__) + +UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] # , "ff.net.0.proj"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + + # oft args + parser.add_argument("--use_oft", action="store_true", help="Whether to use OFT for parameter efficient tuning") + parser.add_argument("--oft_r", type=int, default=8, help="OFT rank, only used if use_oft is True") + parser.add_argument("--oft_alpha", type=int, default=32, help="OFT alpha, only used if use_oft is True") + parser.add_argument("--oft_dropout", type=float, default=0.0, help="OFT dropout, only used if use_oft is True") + parser.add_argument( + "--oft_use_coft", action="store_true", help="Using constrained OFT, only used if use_oft is True" + ) + parser.add_argument( + "--oft_eps", + type=float, + default=0.0, + help="The control strength of COFT. Only has an effect if `oft_use_coft` is set to True.", + ) + + parser.add_argument( + "--oft_text_encoder_r", + type=int, + default=8, + help="OFT rank for text encoder, only used if `use_oft` and `train_text_encoder` are True", + ) + parser.add_argument( + "--oft_text_encoder_alpha", + type=int, + default=32, + help="OFT alpha for text encoder, only used if `use_oft` and `train_text_encoder` are True", + ) + parser.add_argument( + "--oft_text_encoder_dropout", + type=float, + default=0.0, + help="OFT dropout for text encoder, only used if `use_oft` and `train_text_encoder` are True", + ) + parser.add_argument( + "--oft_text_encoder_use_coft", + action="store_true", + help="Using constrained OFT on the text encoder, only used if use_oft is True", + ) + parser.add_argument( + "--oft_text_encoder_eps", + type=float, + default=0.0, + help="The control strength of COFT on the text encoder. Only has an effect if `oft_text_encoder_use_coft` is set to True.", + ) + + parser.add_argument( + "--num_dataloader_workers", type=int, default=1, help="Num of workers for the training dataloader." + ) + + parser.add_argument( + "--no_tracemalloc", + default=False, + action="store_true", + help="Flag to stop memory allocation tracing during training. This could speed up training on Windows.", + ) + + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--wandb_key", + type=str, + default=None, + help=("If report to option is set to wandb, api-key for wandb used for login to wandb "), + ) + parser.add_argument( + "--wandb_project_name", + type=str, + default=None, + help=("If report to option is set to wandb, project name in wandb for log tracking "), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +# Converting Bytes to Megabytes +def b2mb(x): + return int(x / 2**20) + + +# This context manager is used to track the peak memory usage of the process +class TorchTracemalloc: + def __enter__(self): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.cuda.memory_allocated() + self.process = psutil.Process() + + self.cpu_begin = self.cpu_mem_used() + self.peak_monitoring = True + peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) + peak_monitor_thread.daemon = True + peak_monitor_thread.start() + return self + + def cpu_mem_used(self): + """get resident set size memory for the current process""" + return self.process.memory_info().rss + + def peak_monitor_func(self): + self.cpu_peak = -1 + + while True: + self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) + + # can't sleep or will not catch the peak right (this comment is here on purpose) + # time.sleep(0.001) # 1msec + + if not self.peak_monitoring: + break + + def __exit__(self, *exc): + self.peak_monitoring = False + + gc.collect() + torch.cuda.empty_cache() + self.end = torch.cuda.memory_allocated() + self.peak = torch.cuda.max_memory_allocated() + self.used = b2mb(self.end - self.begin) + self.peaked = b2mb(self.peak - self.begin) + + self.cpu_end = self.cpu_mem_used() + self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) + self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) + # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return example + + +def collate_fn(examples, with_prior_preservation=False): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_dir=logging_dir, + ) + if args.report_to == "wandb": + import wandb + + wandb.login(key=args.wandb_key) + wandb.init(project=args.wandb_project_name) + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) # noqa: F841 + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + ) # DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + if args.use_oft: + config = OFTConfig( + r=args.oft_r, + alpha=args.oft_alpha, + target_modules=UNET_TARGET_MODULES, + module_dropout=args.oft_dropout, + init_weights=True, + coft=args.oft_use_coft, + eps=args.oft_eps, + ) + unet = get_peft_model(unet, config) + unet.print_trainable_parameters() + print(unet) + + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + elif args.train_text_encoder and args.use_oft: + config = OFTConfig( + r=args.oft_text_encoder_r, + alpha=args.oft_text_encoder_alpha, + target_modules=TEXT_ENCODER_TARGET_MODULES, + module_dropout=args.oft_text_encoder_dropout, + init_weights=True, + coft=args.oft_text_encoder_use_coft, + eps=args.oft_text_encoder_eps, + ) + text_encoder = get_peft_model(text_encoder, config) + text_encoder.print_trainable_parameters() + print(text_encoder) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + # below fails when using oft so commenting it out + if args.train_text_encoder and not args.use_oft: + text_encoder.gradient_checkpointing_enable() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.num_dataloader_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae and text_encoder to device and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + with TorchTracemalloc() if not args.no_tracemalloc else nullcontext() as tracemalloc: + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + continue + + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if ( + args.validation_prompt is not None + and (step + num_update_steps_per_epoch * epoch) % args.validation_steps == 0 + ): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + safety_checker=None, + revision=args.revision, + ) + # set `keep_fp32_wrapper` to True because we do not want to remove + # mixed precision hooks while we are still training + pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) + pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + if args.seed is not None: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + else: + generator = None + images = [] + for _ in range(args.num_validation_images): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + import wandb + + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + if global_step >= args.max_train_steps: + break + # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage + + if not args.no_tracemalloc: + accelerator.print("GPU Memory before entering the train : {}".format(b2mb(tracemalloc.begin))) + accelerator.print("GPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used)) + accelerator.print("GPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked)) + accelerator.print( + "GPU Total Peak Memory consumed during the train (max): {}".format( + tracemalloc.peaked + b2mb(tracemalloc.begin) + ) + ) + + accelerator.print("CPU Memory before entering the train : {}".format(b2mb(tracemalloc.cpu_begin))) + accelerator.print( + "CPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.cpu_used) + ) + accelerator.print( + "CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked) + ) + accelerator.print( + "CPU Total Peak Memory consumed during the train (max): {}".format( + tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin) + ) + ) + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + if args.use_oft: + unwarpped_unet = accelerator.unwrap_model(unet) + unwarpped_unet.save_pretrained( + os.path.join(args.output_dir, "unet"), state_dict=accelerator.get_state_dict(unet) + ) + if args.train_text_encoder: + unwarpped_text_encoder = accelerator.unwrap_model(text_encoder) + unwarpped_text_encoder.save_pretrained( + os.path.join(args.output_dir, "text_encoder"), + state_dict=accelerator.get_state_dict(text_encoder), + ) + else: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 4d9380e6978..75ddda498cb 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -68,6 +68,8 @@ PromptTuningInit, MultitaskPromptTuningConfig, MultitaskPromptTuningInit, + OFTConfig, + OFTModel, ) from .utils import ( TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, diff --git a/src/peft/mapping.py b/src/peft/mapping.py index f69e89ec3e5..60503fa985b 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -42,6 +42,8 @@ LoraConfig, LoraModel, MultitaskPromptTuningConfig, + OFTConfig, + OFTModel, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig, @@ -73,6 +75,7 @@ "ADALORA": AdaLoraConfig, "IA3": IA3Config, "MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig, + "OFT": OFTConfig, } PEFT_TYPE_TO_TUNER_MAPPING = { @@ -81,6 +84,7 @@ "LOKR": LoKrModel, "ADALORA": AdaLoraModel, "IA3": IA3Model, + "OFT": OFTModel, } diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 24ef48c22e2..79bf8e46102 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -44,6 +44,7 @@ LoKrModel, LoraModel, MultitaskPromptEmbedding, + OFTModel, PrefixEncoder, PromptEmbedding, PromptEncoder, @@ -77,6 +78,7 @@ PeftType.ADALORA: AdaLoraModel, PeftType.ADAPTION_PROMPT: AdaptionPromptModel, PeftType.IA3: IA3Model, + PeftType.OFT: OFTModel, } diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index 666e29d9973..f5f665dd994 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -27,3 +27,4 @@ from .prefix_tuning import PrefixEncoder, PrefixTuningConfig from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit +from .oft import OFTConfig, OFTModel diff --git a/src/peft/tuners/oft/__init__.py b/src/peft/tuners/oft/__init__.py new file mode 100644 index 00000000000..456c46ee076 --- /dev/null +++ b/src/peft/tuners/oft/__init__.py @@ -0,0 +1,21 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import OFTConfig +from .layer import Conv2d, Linear, OFTLayer +from .model import OFTModel + + +__all__ = ["OFTConfig", "OFTModel", "Conv2d", "Linear", "OFTLayer"] diff --git a/src/peft/tuners/oft/config.py b/src/peft/tuners/oft/config.py new file mode 100644 index 00000000000..6b43255d1d4 --- /dev/null +++ b/src/peft/tuners/oft/config.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.tuners.lycoris_utils import LycorisConfig +from peft.utils import PeftType + + +@dataclass +class OFTConfig(LycorisConfig): + """ + This is the configuration class to store the configuration of a [`OFTModel`]. + + Args: + r (`int`): OFT rank. + module_dropout (`int`): The dropout probability for disabling OFT modules during training. + target_modules (`Union[List[str],str]`): The names of the modules to apply OFT to. + init_weights (`bool`): Whether to perform initialization of OFT weights. + layers_to_transform (`Union[List[int],int]`): + The layer indexes to transform, if this argument is specified, it will apply the OFT transformations on the + layer indexes that are specified in this list. If a single integer is passed, it will apply the OFT + transformations on the layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer + pattern is not in the common layers pattern. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. + modules_to_save (`List[str]`): The names of modules to be set as trainable except OFT parameters. + coft (`bool`): Whether to use the constrainted variant of OFT or not. + eps (`float`): + The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True. + block_share (`bool`): Whether to share the OFT parameters between blocks or not. + """ + + r: int = field(default=8, metadata={"help": "OFT rank"}) + module_dropout: float = field( + default=0.0, metadata={"help": "The dropout probability for disabling OFT modules during training"} + ) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with OFT." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the OFT layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + layers_to_transform: Optional[Union[List[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from OFT layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + coft: bool = field( + default=False, + metadata={"help": "Whether to use the constrainted variant of OFT or not."}, + ) + eps: float = field( + default=6e-5, + metadata={ + "help": "The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True." + }, + ) + block_share: bool = field( + default=False, + metadata={"help": "Whether to share the OFT parameters between blocks or not."}, + ) + + def __post_init__(self): + self.peft_type = PeftType.OFT + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py new file mode 100644 index 00000000000..b9e0d011b3c --- /dev/null +++ b/src/peft/tuners/oft/layer.py @@ -0,0 +1,375 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import Any, List, Optional, Set, Tuple + +import torch +import torch.nn as nn + +from peft.tuners.lycoris_utils import LycorisLayer + + +class OFTLayer(nn.Module, LycorisLayer): + # All names of layers that may contain adapter weights + adapter_layer_names = ("oft_r",) + # other_param_names is defined on parent class + + def __init__(self, base_layer: nn.Module): + super().__init__() + LycorisLayer.__init__(self, base_layer) + + # OFT info + self.oft_r = nn.ParameterDict({}) + self.coft = {} + self.eps = {} + self.block_share = {} + + @property + def _available_adapters(self) -> Set[str]: + return {*self.oft_r} + + def create_adapter_parameters(self, adapter_name: str, r: int, shape: Tuple[int, ...], block_share: bool): + if block_share: + self.oft_r[adapter_name] = nn.Parameter(torch.empty(1, math.ceil(shape[0] / r), math.ceil(shape[0] / r))) + else: + self.oft_r[adapter_name] = nn.Parameter(torch.empty(r, math.ceil(shape[0] / r), math.ceil(shape[0] / r))) + + def reset_adapter_parameters(self, adapter_name: str): + nn.init.zeros_(self.oft_r[adapter_name]) + + def reset_adapter_parameters_random(self, adapter_name: str): + nn.init.kaiming_uniform_(self.oft_r[adapter_name], a=math.sqrt(5)) + + def update_layer( + self, + adapter_name: str, + r: int, + module_dropout: float, + init_weights: bool, + coft: bool = False, + eps: float = 6e-5, + block_share: bool = False, + **kwargs, + ) -> None: + """Internal function to create oft adapter + + Args: + adapter_name (`str`): Name for the adapter to add. + r (`int`): Rank for the added adapter. + module_dropout (`float`): The dropout probability for disabling adapter during training. + init_weights (`bool`): Whether to initialize weights. + coft (`bool`): Whether to use the constrainted variant of OFT or not. + eps (`float`): + The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True. + block_share (`bool`): Whether to share the OFT parameters between blocks or not. + """ + + self.r[adapter_name] = r + self.module_dropout[adapter_name] = module_dropout + self.coft[adapter_name] = coft + self.block_share[adapter_name] = block_share + + # Determine shape of OFT weights + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + shape = tuple(base_layer.weight.shape) + elif isinstance(base_layer, nn.Conv2d): + shape = ( + base_layer.out_channels, + base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], + ) + else: + raise TypeError(f"OFT is not implemented for base layers of type {type(base_layer).__name__}") + + self.eps[adapter_name] = eps * math.ceil(shape[0] / r) * math.ceil(shape[0] / r) + + # Create weights with provided shape + self.create_adapter_parameters(adapter_name, r, shape, block_share) + + # Initialize weights + if init_weights: + self.reset_adapter_parameters(adapter_name) + else: + self.reset_adapter_parameters_random(adapter_name) + + # Move new weights to device + weight = getattr(self.get_base_layer(), "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + self.set_adapter(self.active_adapters) + + def unscale_layer(self, scale=None) -> None: + # scale is not used + pass + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + if self.merged: + warnings.warn( + f"Already following adapters were merged {','.join(self.merged_adapters)}. " + f"You are now additionally merging {','.join(self.active_adapters)}." + ) + if adapter_names is None: + adapter_names = self.active_adapters + + for active_adapter in adapter_names: + if active_adapter in self._available_adapters: + base_layer = self.get_base_layer() + + orig_weights = base_layer.weight.data + if isinstance(base_layer, nn.Linear): + orig_weights = torch.transpose(orig_weights, 0, 1) + elif isinstance(base_layer, nn.Conv2d): + orig_weights = orig_weights.view( + [ + base_layer.out_channels, + base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], + ] + ) + orig_weights = torch.transpose(orig_weights, 0, 1) + delta_weight = self.get_delta_weight(active_adapter) + if orig_weights.shape[1] != delta_weight.shape[1]: + # when in channels is not divisible by r + delta_weight = delta_weight[: orig_weights.shape[1], : orig_weights.shape[1]] + new_weights = torch.mm(orig_weights, delta_weight) + if isinstance(base_layer, nn.Linear): + new_weights = torch.transpose(new_weights, 0, 1) + elif isinstance(base_layer, nn.Conv2d): + new_weights = torch.transpose(new_weights, 0, 1) + new_weights = new_weights.view( + [ + base_layer.out_channels, + base_layer.in_channels, + base_layer.kernel_size[0], + base_layer.kernel_size[1], + ] + ) + + if safe_merge and not torch.isfinite(new_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = new_weights + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self._available_adapters: + base_layer = self.get_base_layer() + new_weights = base_layer.weight.data + if isinstance(base_layer, nn.Linear): + new_weights = torch.transpose(new_weights, 0, 1) + elif isinstance(base_layer, nn.Conv2d): + new_weights = new_weights.view( + [ + base_layer.out_channels, + base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], + ] + ) + new_weights = torch.transpose(new_weights, 0, 1) + delta_weight = self.get_delta_weight(active_adapter) + if new_weights.shape[1] != delta_weight.shape[1]: + # when in channels is not divisible by r + delta_weight = delta_weight[: new_weights.shape[1], : new_weights.shape[1]] + delta_inv = torch.inverse(delta_weight) + orig_weights = torch.mm(new_weights, delta_inv) + + if isinstance(base_layer, nn.Linear): + orig_weights = torch.transpose(orig_weights, 0, 1) + elif isinstance(base_layer, nn.Conv2d): + orig_weights = torch.transpose(orig_weights, 0, 1) + orig_weights = orig_weights.reshape( + [ + base_layer.out_channels, + base_layer.in_channels, + base_layer.kernel_size[0], + base_layer.kernel_size[1], + ] + ) + base_layer.weight.data = orig_weights + + def get_delta_weight(self, adapter_name: str) -> torch.Tensor: + rank = self.r[adapter_name] + coft = self.coft[adapter_name] + eps = self.eps[adapter_name] + opt_r = self.oft_r[adapter_name] + + if coft: + with torch.no_grad(): + opt_r.copy_(self._project_batch(opt_r, eps=eps)) + + orth_rotate = self._cayley_batch(opt_r) + weight = self._block_diagonal(orth_rotate, rank) + + return weight + + # Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L144 + def _cayley_batch(self, data: torch.Tensor) -> torch.Tensor: + b, r, c = data.shape + # Ensure the input matrix is skew-symmetric + skew = 0.5 * (data - data.transpose(1, 2)) + I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) + + # Perform the Cayley parametrization + Q = torch.bmm(I - skew, torch.inverse(I + skew)) + + return Q + + # Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L155 + def _block_diagonal(self, oft_r: torch.Tensor, rank: int) -> torch.Tensor: + if oft_r.shape[0] == 1: + # block share + blocks = [oft_r[0, ...] for i in range(rank)] + else: + blocks = [oft_r[i, ...] for i in range(rank)] + + # Use torch.block_diag to create the block diagonal matrix + A = torch.block_diag(*blocks) + + return A + + # Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L52 + def _project_batch(self, oft_r, eps=1e-5): + # scaling factor for each of the smaller block matrix + eps = eps * 1 / torch.sqrt(torch.tensor(oft_r.shape[0])) + I = ( + torch.zeros((oft_r.size(1), oft_r.size(1)), device=oft_r.device, dtype=oft_r.dtype) + .unsqueeze(0) + .expand_as(oft_r) + ) + diff = oft_r - I + norm_diff = torch.norm(oft_r - I, dim=(1, 2), keepdim=True) + mask = (norm_diff <= eps).bool() + out = torch.where(mask, oft_r, I + eps * (diff / norm_diff)) + return out + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + if len(result.shape) == 4: + result = result.permute(0, 2, 3, 1) + + base_layer = self.get_base_layer() + base_bias = base_layer.bias + if base_bias is not None: + # Bias should be added after OFT forward + result = result - base_bias.data + + # Execute all the adapters + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + module_dropout = self.module_dropout[active_adapter] + + # Modify current execution weights + if (not self.training) or (self.training and torch.rand(1) > module_dropout): + result = self._get_delta_activations(active_adapter, result, *args, **kwargs) + + if base_bias is not None: + result = result + base_bias.data + if len(result.shape) == 4: + result = result.permute(0, 3, 1, 2) + + result = result.to(previous_dtype) + return result + + +class Linear(OFTLayer): + """OFT implemented in Linear layer""" + + def __init__( + self, + base_layer: nn.Module, + adapter_name: str = "default", + r: int = 0, + module_dropout: float = 0.0, + init_weights: bool = True, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + + base_layer = self.get_base_layer() + base_weight = base_layer.weight.data + delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]] + + # don't add bias here, because the bias will be added after OFT forward + return torch.matmul(input, delta_weight) + + def __repr__(self) -> str: + rep = super().__repr__() + return "oft." + rep + + +class Conv2d(OFTLayer): + """OFT implemented in Conv2d layer""" + + def __init__( + self, + base_layer: nn.Module, + adapter_name: str = "default", + r: int = 0, + module_dropout: float = 0.0, + init_weights: bool = True, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + + base_layer = self.get_base_layer() + base_weight = base_layer.weight.data + delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]] + + # don't add bias here, because the bias will be added after OFT forward + return torch.matmul(input, delta_weight) + + def __repr__(self) -> str: + rep = super().__repr__() + return "oft." + rep diff --git a/src/peft/tuners/oft/model.py b/src/peft/tuners/oft/model.py new file mode 100644 index 00000000000..4b7953daa92 --- /dev/null +++ b/src/peft/tuners/oft/model.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Dict, Type, Union + +import torch +from torch import nn + +from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner + +from .layer import Conv2d, Linear, OFTLayer + + +class OFTModel(LycorisTuner): + """ + Creates Orthogonal Finetuning model from a pretrained model. The method is described in + https://arxiv.org/abs/2306.07280 + + Args: + model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached. + config ([`OFTConfig`]): The configuration of the OFT model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The OFT model. + + Example: + ```py + >>> from diffusers import StableDiffusionPipeline + >>> from peft import OFTModel, OFTConfig + + >>> config_te = OFTConfig( + ... r=8, + ... target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"], + ... module_dropout=0.0, + ... init_weights=True, + ... ) + >>> config_unet = OFTConfig( + ... r=8, + ... target_modules=[ + ... "proj_in", + ... "proj_out", + ... "to_k", + ... "to_q", + ... "to_v", + ... "to_out.0", + ... "ff.net.0.proj", + ... "ff.net.2", + ... ], + ... module_dropout=0.0, + ... init_weights=True, + ... ) + + >>> model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> model.text_encoder = OFTModel(model.text_encoder, config_te, "default") + >>> model.unet = OFTModel(model.unet, config_unet, "default") + ``` + + **Attributes**: + - **model** ([`~torch.nn.Module`]) -- The model to be adapted. + - **peft_config** ([`OFTConfig`]): The configuration of the OFT model. + """ + + prefix: str = "oft_" + layers_mapping: Dict[Type[torch.nn.Module], Type[OFTLayer]] = { + torch.nn.Conv2d: Conv2d, + torch.nn.Linear: Linear, + } + + def _create_and_replace( + self, + config: LycorisConfig, + adapter_name: str, + target: Union[OFTLayer, nn.Module], + target_name: str, + parent: nn.Module, + current_key: str, + **optional_kwargs, + ) -> None: + """ + A private method to create and replace the target module with the adapter module. + """ + + # Regexp matching - Find key which matches current target_name in patterns provided + pattern_keys = list(config.rank_pattern.keys()) + target_name_key = next(filter(lambda key: re.match(f"(.*\.)?{key}$", current_key), pattern_keys), target_name) + + kwargs = config.to_dict() + kwargs["r"] = config.rank_pattern.get(target_name_key, config.r) + + if isinstance(target, OFTLayer): + target.update_layer(adapter_name, **kwargs) + else: + new_module = self._create_new_module(config, adapter_name, target, **kwargs) + self._replace_module(parent, target_name, new_module, target) diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 29c764a08f4..93b892d9e59 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -30,6 +30,7 @@ class PeftType(str, enum.Enum): IA3 = "IA3" LOHA = "LOHA" LOKR = "LOKR" + OFT = "OFT" class TaskType(str, enum.Enum): diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 97bde0d6fe5..c5da274085c 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -113,6 +113,8 @@ def get_peft_model_state_dict( to_return["prompt_embeddings"] = prompt_embeddings elif config.peft_type == PeftType.IA3: to_return = {k: state_dict[k] for k in state_dict if "ia3_" in k} + elif config.peft_type == PeftType.OFT: + to_return = {k: state_dict[k] for k in state_dict if "oft_" in k} else: raise NotImplementedError if getattr(model, "modules_to_save", None) is not None: @@ -166,7 +168,7 @@ def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="defaul else: state_dict = peft_model_state_dict - if config.peft_type in (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA, PeftType.IA3): + if config.peft_type in (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA, PeftType.IA3, PeftType.OFT): peft_model_state_dict = {} parameter_prefix = { PeftType.IA3: "ia3_", @@ -174,6 +176,7 @@ def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="defaul PeftType.ADALORA: "lora_", PeftType.LOHA: "hada_", PeftType.LOKR: "lokr_", + PeftType.OFT: "oft_", }[config.peft_type] for k, v in state_dict.items(): if parameter_prefix in k: diff --git a/tests/test_config.py b/tests/test_config.py index 34f04232a9c..06e72dae8ed 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -30,6 +30,7 @@ LoHaConfig, LoraConfig, MultitaskPromptTuningConfig, + OFTConfig, PeftConfig, PrefixTuningConfig, PromptEncoder, @@ -51,6 +52,7 @@ PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig, + OFTConfig, ) @@ -189,7 +191,7 @@ def test_prompt_encoder_warning_num_layers(self): expected_msg = "for MLP, the argument `encoder_num_layers` is ignored. Exactly 2 MLP layers are used." assert str(record.list[0].message) == expected_msg - @parameterized.expand([LoHaConfig, LoraConfig, IA3Config]) + @parameterized.expand([LoHaConfig, LoraConfig, IA3Config, OFTConfig]) def test_save_pretrained_with_target_modules(self, config_class): # See #1041, #1045 config = config_class(target_modules=["a", "list"]) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index b298388a844..4785526b26a 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -24,7 +24,7 @@ from torch import nn from transformers.pytorch_utils import Conv1D -from peft import AdaLoraConfig, IA3Config, LoHaConfig, LoKrConfig, LoraConfig, PeftModel, get_peft_model +from peft import AdaLoraConfig, IA3Config, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, PeftModel, get_peft_model from peft.tuners.tuners_utils import BaseTunerLayer from .testing_common import PeftCommonTester @@ -191,6 +191,28 @@ "decompose_factor": 4, }, ), + ######## + # OFT # + ######## + ("Vanilla MLP 1 OFT", "MLP", OFTConfig, {"target_modules": "lin0"}), + ("Vanilla MLP 2 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"]}), + ("Vanilla MLP 5 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}), + ( + "Vanilla MLP 6 OFT", + "MLP", + OFTConfig, + { + "target_modules": ["lin0"], + "module_dropout": 0.1, + }, + ), + ("Vanilla MLP 7 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"], "coft": True}), + ("Vanilla MLP 8 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"], "block_share": True}), + ("Vanilla MLP 9 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"], "coft": True, "block_share": True}), + ("Conv2d 1 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"]}), + ("Conv2d 3 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "coft": True}), + ("Conv2d 4 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "block_share": True}), + ("Conv2d 5 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "coft": True, "block_share": True}), ] MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES = [ @@ -258,6 +280,7 @@ LoraConfig: "lora_", LoHaConfig: "hada_", LoKrConfig: "lokr_", + OFTConfig: "oft_", } @@ -833,6 +856,7 @@ def test_targeting_lora_to_embedding_layer_non_transformers(self, save_embedding LoHaConfig(target_modules=["lin0"], init_weights=False), AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), IA3Config(target_modules=["lin0"], feedforward_modules=["lin0"], init_ia3_weights=False), + OFTConfig(target_modules=["lin0"], init_weights=False), ] ) def test_adapter_name_makes_no_difference(self, config0): @@ -1852,3 +1876,80 @@ def test_requires_grad_lokr_same_targets(self): "base_model.model.lin0.lokr_w1.adapter1", "base_model.model.lin0.lokr_w2.adapter1", ) + + def test_requires_grad_oft_different_targets(self): + # test two different OFT adapters that target different modules + config0 = OFTConfig(target_modules=["lin0"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = OFTConfig(target_modules=["lin1"], inference_mode=True) + peft_model.add_adapter("adapter1", config1) + + # active pter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin0.oft_r.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.oft_r.default", + ) + + # change activate pter to pter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.oft_r.adapter1", + ) + + # disable all pters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + self.check_requires_grad( + peft_model, + "base_model.model.lin1.oft_r.adapter1", + ) + + def test_requires_grad_oft_same_targets(self): + # same as previous test, except that OFT adapters target the same layer + config0 = OFTConfig(target_modules=["lin0"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = OFTConfig(target_modules=["lin0"], inference_mode=True) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin0.oft_r.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.oft_r.default", + ) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.oft_r.adapter1", + ) + + # disable all adapters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.oft_r.adapter1", + ) diff --git a/tests/test_stablediffusion.py b/tests/test_stablediffusion.py index 830614a7aba..660c17caea8 100644 --- a/tests/test_stablediffusion.py +++ b/tests/test_stablediffusion.py @@ -20,7 +20,7 @@ from diffusers import StableDiffusionPipeline from parameterized import parameterized -from peft import LoHaConfig, LoraConfig, get_peft_model +from peft import LoHaConfig, LoraConfig, OFTConfig, get_peft_model from .testing_common import ClassInstantier, PeftCommonTester from .testing_utils import temp_seed @@ -60,11 +60,24 @@ "module_dropout": 0.0, }, }, + { + "text_encoder": { + "r": 8, + "target_modules": ["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"], + "module_dropout": 0.0, + }, + "unet": { + "r": 8, + "target_modules": ["proj_in", "proj_out", "to_k", "to_q", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"], + "module_dropout": 0.0, + }, + }, ) CLASSES_MAPPING = { "lora": (LoraConfig, CONFIG_TESTING_KWARGS[0]), "loha": (LoHaConfig, CONFIG_TESTING_KWARGS[1]), "lokr": (LoHaConfig, CONFIG_TESTING_KWARGS[1]), + "oft": (OFTConfig, CONFIG_TESTING_KWARGS[2]), } @@ -115,13 +128,14 @@ def prepare_inputs_for_testing(self): "model_ids": PEFT_DIFFUSERS_SD_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, "loha_kwargs": {"init_weights": [False]}, + "oft_kwargs": {"init_weights": [False]}, }, ) ) def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): - if config_cls == LoHaConfig: + if config_cls in [LoHaConfig, OFTConfig]: # TODO: This test is flaky with PyTorch 2.1 on Windows, we need to figure out what is going on - self.skipTest("LoHaConfig test is flaky") + self.skipTest("LoHaConfig and OFTConfig test is flaky") # Instantiate model & adapters model = self.instantiate_sd_peft(model_id, config_cls, config_kwargs) @@ -148,7 +162,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): "model_ids": PEFT_DIFFUSERS_SD_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, }, - filter_params_func=lambda tests: [x for x in tests if all(s not in x[0] for s in ["loha", "lokr"])], + filter_params_func=lambda tests: [x for x in tests if all(s not in x[0] for s in ["loha", "lokr", "oft"])], ) ) def test_add_weighted_adapter_base_unchanged(self, test_name, model_id, config_cls, config_kwargs): @@ -178,6 +192,7 @@ def test_add_weighted_adapter_base_unchanged(self, test_name, model_id, config_c "lora_kwargs": {"init_lora_weights": [False]}, "loha_kwargs": {"init_weights": [False]}, "lokr_kwargs": {"init_weights": [False]}, + "oft_kwargs": {"init_weights": [False]}, }, ) ) diff --git a/tests/testing_common.py b/tests/testing_common.py index 00809c2bc12..0c081cde2c6 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -574,7 +574,7 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): self.assertTrue(torch.allclose(logits_merged, logits_merged_from_pretrained, atol=atol, rtol=rtol)) def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs): - supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3] + supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT] if ("gpt2" in model_id.lower()) and (config_cls == IA3Config): self.skipTest("Merging GPT2 adapters not supported for IAΒ³ (yet)") @@ -886,7 +886,7 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar self.assertIsNotNone(param.grad) def _test_delete_adapter(self, model_id, config_cls, config_kwargs): - supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3] + supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT] # IA3 does not support deleting adapters yet, but it just needs to be added # AdaLora does not support multiple adapters config = config_cls( @@ -924,7 +924,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs): def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs): # same as test_delete_adapter, but this time an inactive adapter is deleted - supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3] + supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT] # IA3 does not support deleting adapters yet, but it just needs to be added # AdaLora does not support multiple adapters config = config_cls( From 6a57472665b2b712a84e2bedd98945038283f7cc Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 30 Nov 2023 21:58:16 +0100 Subject: [PATCH 06/12] Mixed adapter models (#1163) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Description This PR allows to add adapters of different types, e.g. LoRA and LoHa: base_model = ... config0 = LoraConfig(...) peft_model = get_peft_model(base_model, config0, mixed=True) config1 = LoHaConfig(...) peft_model.add_adapter(config1, "other") peft_model.set_adapter(["default", "other"]) peft_model(x) At this point, both adapters are active at the same time. Existing code should not be affected by this change, since users need to opt into this behavior by setting mixed=True, and a completely different class is being used (PeftMixedModel). Also interesting is that this method can be used for a single adapter type but with very different configs. Right now, we have limited support for that (e.g. for LoRA, different r values by using rank_pattern), but with this, we don't need to special case the differing arguments anymore. Not implemented - [ ] I'm not yet sure if the same logic can be applied to IAΒ³ or if it may fail because IAΒ³ can apply its scaling to the input, not the output. - [ ] OFT is not supported yet but should work. - [ ] It is currently not possible to represent a mixed adapter model as a single config. I think we can come up with a solution but I don't think it is necessary for a first version of this. - [ ] Saving and loading is not yet implemented for mixed models. Those could potentially be added in a future PR. --------- Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> --- README.md | 31 + docs/source/_toctree.yml | 2 + docs/source/developer_guides/mixed_models.md | 39 + src/peft/__init__.py | 1 + src/peft/mapping.py | 19 +- src/peft/mixed_model.py | 394 +++++++++ src/peft/tuners/__init__.py | 1 + src/peft/tuners/mixed/__init__.py | 19 + src/peft/tuners/mixed/model.py | 323 ++++++++ tests/test_mixed.py | 794 +++++++++++++++++++ 10 files changed, 1620 insertions(+), 3 deletions(-) create mode 100644 docs/source/developer_guides/mixed_models.md create mode 100644 src/peft/mixed_model.py create mode 100644 src/peft/tuners/mixed/__init__.py create mode 100644 src/peft/tuners/mixed/model.py create mode 100644 tests/test_mixed.py diff --git a/README.md b/README.md index 09846dc61cd..06a757ed905 100644 --- a/README.md +++ b/README.md @@ -367,6 +367,8 @@ any GPU memory savings. Please refer issue [[FSDP] FSDP with CPU offload consume ## πŸ€— PEFT as a utility library +### Injecting adapters directly into the model + Inject trainable adapters on any `torch` model using `inject_adapter_in_model` method. Note the method will make no further change to the model. ```python @@ -403,6 +405,35 @@ dummy_outputs = model(dummy_inputs) Learn more about the [low level API in the docs](https://huggingface.co/docs/peft/developer_guides/low_level_api). +### Mixing different adapter types + +Ususally, it is not possible to combine different adapter types in the same model, e.g. combining LoRA with AdaLoRA, LoHa, or LoKr. Using a mixed model, this can, however, be achieved: + +```python +from peft import PeftMixedModel + +model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM").eval() +peft_model = PeftMixedModel.from_pretrained(model, , "adapter0") +peft_model.load_adapter(, "adapter1") +peft_model.set_adapter(["adapter0", "adapter1"]) +result = peft_model(**inputs) +``` + +The main intent is to load already trained adapters and use this only for inference. However, it is also possible to create a PEFT model for training by passing `mixed=True` to `get_peft_model`: + +```python +from peft import get_peft_model, LoraConfig, LoKrConfig + +base_model = ... +config0 = LoraConfig(...) +config1 = LoKrConfig(...) +peft_model = get_peft_model(base_model, config0, "adapter0", mixed=True) +peft_model.add_adapter(config1, "adapter1") +peft_model.set_adapter(["adapter0", "adapter1"]) +for batch in dataloader: + ... +``` + ## Contributing If you would like to contribute to PEFT, please check out our [contributing guide](https://huggingface.co/docs/peft/developer_guides/contributing). diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 88bedf31d7f..25992b3966e 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -34,6 +34,8 @@ title: Working with custom models - local: developer_guides/low_level_api title: PEFT low level API + - local: developer_guides/mixed_models + title: Mixing different adapter types - local: developer_guides/contributing title: Contributing to PEFT - local: developer_guides/troubleshooting diff --git a/docs/source/developer_guides/mixed_models.md b/docs/source/developer_guides/mixed_models.md new file mode 100644 index 00000000000..93414eee045 --- /dev/null +++ b/docs/source/developer_guides/mixed_models.md @@ -0,0 +1,39 @@ + + +# Working with mixed adapter types + +Normally, it is not possible to mix different adapter types in πŸ€— PEFT. For example, even though it is possible to create a PEFT model that has two different LoRA adapters (that can have different config options), it is not possible to combine a LoRA adapter with a LoHa adapter. However, by using a mixed model, this works as long as the adapter types are compatible. + +## Loading different adapter types into a PEFT model + +To load different adapter types into a PEFT model, proceed the same as if you were loading two adapters of the same type, but use `PeftMixedModel` instead of `PeftModel`: + +```py +from peft import PeftMixedModel + +base_model = ... # load the base model, e.g. from transformers +# load first adapter, which will be called "default" +peft_model = PeftMixedModel.from_pretrained(base_model, ) +peft_model.load_adapter(, adapter_name="other") +peft_model.set_adapter(["default", "other"]) +``` + +The last line is necessary if you want to activate both adapters, otherwise, only the first adapter would be active. Of course, you can add more different adapters by calling `add_adapter` repeatedly. + +Currently, the main purpose of mixed adapter types is to combine trained adapters for inference. Although it is technically also possible to train a mixed adapter model, this has not been tested and is not recommended. + +## Tips + +- Not all adapter types can be combined. See `peft.tuners.mixed.COMPATIBLE_TUNER_TYPES` for a list of compatible types. An error will be raised if you are trying to combine incompatible adapter types. +- It is possible to mix multiple adapters of the same type. This can be useful to combine adapters with very different configs. +- If you want to combine a lot of different adapters, it is most performant to add the same types of adapters consecutively. E.g., add LoRA1, LoRA2, LoHa1, LoHa2 in this order, instead of LoRA1, LoHa1, LoRA2, LoHa2. The order will make a difference for the outcome in most cases, but since no order is better a priori, it is best to choose the order that is most performant. diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 75ddda498cb..2b1883ebd73 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -35,6 +35,7 @@ get_peft_model, inject_adapter_in_model, ) +from .mixed_model import PeftMixedModel from .peft_model import ( PeftModel, PeftModelForCausalLM, diff --git a/src/peft/mapping.py b/src/peft/mapping.py index 60503fa985b..f34bdb51c53 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -20,6 +20,7 @@ import torch from .config import PeftConfig +from .mixed_model import PeftMixedModel from .peft_model import ( PeftModel, PeftModelForCausalLM, @@ -99,13 +100,21 @@ def get_peft_config(config_dict: Dict[str, Any]) -> PeftConfig: return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) -def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel: +def get_peft_model( + model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default", mixed: bool = False +) -> PeftModel | PeftMixedModel: """ Returns a Peft model object from a model and a config. Args: - model ([`transformers.PreTrainedModel`]): Model to be wrapped. - peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model. + model ([`transformers.PreTrainedModel`]): + Model to be wrapped. + peft_config ([`PeftConfig`]): + Configuration object containing the parameters of the Peft model. + adapter_name (`str`, `optional`, defaults to `"default"`): + The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). + mixed (`bool`, `optional`, defaults to `False`): + Whether to allow mixing different (compatible) adapter types. """ model_config = getattr(model, "config", {"model_type": "custom"}) if hasattr(model_config, "to_dict"): @@ -113,8 +122,12 @@ def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) + if mixed: + return PeftMixedModel(model, peft_config, adapter_name=adapter_name) + if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning: return PeftModel(model, peft_config, adapter_name=adapter_name) + if peft_config.is_prompt_learning: peft_config = _prepare_prompt_learning_config(peft_config, model_config) return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name) diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py new file mode 100644 index 00000000000..55892851e9e --- /dev/null +++ b/src/peft/mixed_model.py @@ -0,0 +1,394 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from contextlib import contextmanager +from typing import Any, Optional, Union + +import torch +from accelerate.hooks import remove_hook_from_submodules +from torch import nn +from transformers.utils import PushToHubMixin + +from peft.tuners.mixed import COMPATIBLE_TUNER_TYPES + +from .config import PeftConfig +from .peft_model import PeftModel +from .tuners import ( + AdaLoraModel, + IA3Model, + LoHaModel, + LoKrModel, + LoraModel, + MixedModel, +) +from .utils import PeftType, _set_adapter, _set_trainable + + +PEFT_TYPE_TO_MODEL_MAPPING = { + PeftType.LORA: LoraModel, + PeftType.LOHA: LoHaModel, + PeftType.LOKR: LoKrModel, + PeftType.ADALORA: AdaLoraModel, + PeftType.IA3: IA3Model, +} + + +def _prepare_model_for_gradient_checkpointing(model: nn.Module) -> None: + r""" + Prepares the model for gradient checkpointing if necessary + """ + # Note: same as PeftModel._prepare_model_for_gradient_checkpointing + if not getattr(model, "is_gradient_checkpointing", True): + return model + + if not ( + getattr(model, "is_loaded_in_8bit", False) + or getattr(model, "is_loaded_in_4bit", False) + or getattr(model, "is_quantized", False) + ): + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + elif hasattr(model, "get_input_embeddings"): + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + +def _check_config_compatible(peft_config: PeftConfig) -> None: + if peft_config.peft_type not in COMPATIBLE_TUNER_TYPES: + raise ValueError( + f"The provided `peft_type` '{peft_config.peft_type.value}' is not compatible with the `PeftMixedModel`. " + f"Compatible types are: {COMPATIBLE_TUNER_TYPES}" + ) + + +class PeftMixedModel(PushToHubMixin, torch.nn.Module): + """ + Peft model for mixing different types of adapters. + + This class currently does not support saving and loading. Instead, it is assumed that the adapters are already + trained and loading the model requires a script to be run each time. + + Currently, the main purpose of mixed adapter types is to combine trained adapters for inference. Although it is + technically possible to train a mixed adapter model, this has not been tested and is not recommended. + + Note: This class should usually not be initialized directly. Instead, use `get_peft_model` with the argument + `mixed=True`. + + Below is an example that shows how to load a mixed model with two different types of adapters. + + ```py + >>> from peft import get_peft_model + + >>> base_model = ... # load the base model, e.g. from transformers + >>> peft_model = PeftMixedModel.from_pretrained(base_model, path_to_adapter1, "adapter1").eval() + >>> peft_model.load_adapter(path_to_adapter2, "adapter2") + >>> peft_model.set_adapter(["adapter1", "adapter2"]) # activate both adapters + >>> peft_model(data) # forward pass using both adapters + ``` + + Tips: + + - Not all adapter types can be combined. See `peft.tuners.mixed.COMPATIBLE_TUNER_TYPES` for a list of compatible + types. An error will be raised if you are trying to combine incompatible adapter types. + - It is possible to mix multiple adapters of the same type. This can be useful to combine adapters with very + different configs. + - If you want to combine a lot of different adapters, it is most performant to add the same types of adapters + consecutively. E.g., add LoRA1, LoRA2, LoHa1, LoHa2 in this order, instead of LoRA1, LoHa1, LoRA2, LoHa2. As long + as the adapters are commutative, the order does not matter for the final result. + + Args: + model (`torch.nn.Module`): + The model to be tuned. + config (`PeftConfig`): + The config of the model to be tuned. The adapter type must be compatible. + adapter_name (`str`, `optional`, defaults to `"default"`): + The name of the first adapter. + """ + + def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: + super().__init__() + _check_config_compatible(peft_config) + _prepare_model_for_gradient_checkpointing(model) + self.modules_to_save = None + self.base_model = MixedModel(model, {adapter_name: peft_config}, adapter_name) + self.set_modules_to_save(peft_config, adapter_name) + + self.config = getattr(model, "config", {"model_type": "custom"}) + + # the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid + # numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected + # behavior we disable that in this line. + if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"): + self.base_model.config.pretraining_tp = 1 + + @property + def peft_config(self) -> dict[str, PeftConfig]: + return self.base_model.peft_config + + @property + def active_adapter(self) -> str: + return self.base_model.active_adapter + + @property + def active_adapters(self) -> list[str]: + return self.base_model.active_adapters + + def get_nb_trainable_parameters(self): + r""" + Returns the number of trainable parameters and number of all parameters in the model. + """ + # note: same as PeftModel.get_nb_trainable_parameters + trainable_params = 0 + all_param = 0 + for _, param in self.named_parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + # Due to the design of 4bit linear layers from bitsandbytes + # one needs to multiply the number of parameters by 2 to get + # the correct number of parameters + if param.__class__.__name__ == "Params4bit": + num_params = num_params * 2 + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + return trainable_params, all_param + + def print_trainable_parameters(self): + """ + Prints the number of trainable parameters in the model. + """ + # note: same as PeftModel.print_trainable_parameters + trainable_params, all_param = self.get_nb_trainable_parameters() + + print( + f"trainable params: {trainable_params:,d} || " + f"all params: {all_param:,d} || " + f"trainable%: {100 * trainable_params / all_param:.4f}" + ) + + def forward(self, *args: Any, **kwargs: Any): + """ + Forward pass of the model. + """ + return self.base_model(*args, **kwargs) + + def generate(self, *args: Any, **kwargs: Any): + """ + Generate output. + """ + return self.base_model.generate(*args, **kwargs) + + @contextmanager + def disable_adapter(self): + """ + Disables the adapter module. + """ + try: + self.base_model.disable_adapter_layers() + yield + finally: + self.base_model.enable_adapter_layers() + + def add_adapter(self, adapter_name: str, peft_config: PeftConfig): + _check_config_compatible(peft_config) + + try: + self.peft_config[adapter_name] = peft_config + self.base_model.inject_adapter(self, adapter_name) + except Exception: # somthing went wrong, roll back + if adapter_name in self.peft_config: + del self.peft_config[adapter_name] + raise + + self.set_modules_to_save(peft_config, adapter_name) + + def set_modules_to_save(self, peft_config: PeftConfig, adapter_name: str) -> None: + if (modules_to_save := getattr(peft_config, "modules_to_save", None)) is None: + return + + if self.modules_to_save is None: + self.modules_to_save = set(modules_to_save) + else: + self.modules_to_save.update(modules_to_save) + _set_trainable(self, adapter_name) + + def set_adapter(self, adapter_name: Union[str, list[str]]) -> None: + """ + Sets the active adapter(s) for the model. + + Note that the order in which the adapters are applied during the forward pass may not be the same as the order + in which they are passed to this function. Instead, the order during the forward pass is determined by the + order in which the adapters were loaded into the model. The active adapters only determine which adapters are + active during the forward pass, but not the order in which they are applied. + + Args: + adapter_name (`str` or `List[str]`): + The name of the adapter(s) to be activated. + """ + if isinstance(adapter_name, str): + adapter_name = [adapter_name] + + mismatched = set(adapter_name) - set(self.peft_config.keys()) + if mismatched: + raise ValueError( + f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}" + ) + + self.base_model.set_adapter(adapter_name) + _set_adapter(self, adapter_name) + + def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None: + if isinstance(adapter_name, str): + adapter_name = [adapter_name] + + mismatched = set(adapter_name) - set(self.peft_config.keys()) + if mismatched: + raise ValueError( + f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}" + ) + + self.base_model.delete_adapter(adapter_name) + + def merge_and_unload(self, *args: Any, **kwargs: Any): + r""" + This method merges the adapter layers into the base model. This is needed if someone wants to use the base + model as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + return self.base_model.merge_and_unload(*args, **kwargs) + + def unload(self, *args: Any, **kwargs: Any): + """ + Gets back the base model by removing all the adapter modules without merging. This gives back the original base + model. + """ + return self.base_model.unload(*args, **kwargs) + + @classmethod + def _split_kwargs(cls, kwargs: dict[str, Any]): + return PeftModel._split_kwargs(kwargs) + + def load_adapter(self, model_id: str, adapter_name: str, *args: Any, **kwargs: Any): + output = PeftModel.load_adapter(self, model_id, adapter_name, *args, **kwargs) + # TODO: not quite clear why this is necessary but tests fail without it + self.set_adapter(self.active_adapters) + return output + + def create_or_update_model_card(self, output_dir: str): + raise NotImplementedError(f"Model card creation is not supported for {self.__class__.__name__} (yet).") + + def save_pretrained( + self, + save_directory: str, + safe_serialization: bool = False, + selected_adapters: Optional[list[str]] = None, + **kwargs: Any, + ): + raise NotImplementedError(f"Saving is not supported for {self.__class__.__name__} (yet).") + + @classmethod + def from_pretrained( + cls, + model: nn.Module, + model_id: str | os.PathLike, + adapter_name: str = "default", + is_trainable: bool = False, + config: Optional[PeftConfig] = None, + **kwargs: Any, + ): + r""" + Instantiate a PEFT mixed model from a pretrained model and loaded PEFT weights. + + Note that the passed `model` may be modified inplace. + + Args: + model (`nn.Module`): + The model to be adapted. + model_id (`str` or `os.PathLike`): + The name of the PEFT configuration to use. Can be either: + - A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face + Hub. + - A path to a directory containing a PEFT configuration file saved using the `save_pretrained` + method (`./my_peft_config_directory/`). + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter to be loaded. This is useful for loading multiple adapters. + is_trainable (`bool`, *optional*, defaults to `False`): + Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and use for + inference + config ([`~peft.PeftConfig`], *optional*): + The configuration object to use instead of an automatically loaded configuation. This configuration + object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already + loaded before calling `from_pretrained`. + kwargs: (`optional`): + Additional keyword arguments passed along to the specific PEFT configuration class. + """ + # note: adapted from PeftModel.from_pretrained + from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING + + # load the config + if config is None: + config = PEFT_TYPE_TO_CONFIG_MAPPING[ + PeftConfig._get_peft_type( + model_id, + subfolder=kwargs.get("subfolder", None), + revision=kwargs.get("revision", None), + cache_dir=kwargs.get("cache_dir", None), + use_auth_token=kwargs.get("use_auth_token", None), + ) + ].from_pretrained(model_id, **kwargs) + elif isinstance(config, PeftConfig): + config.inference_mode = not is_trainable + else: + raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}") + + # note: this is different from PeftModel.from_pretrained + if config.peft_type not in PEFT_TYPE_TO_MODEL_MAPPING: + raise ValueError(f"Adapter of type {config.peft_type} is not supported for mixed models.") + + if (getattr(model, "hf_device_map", None) is not None) and len( + set(model.hf_device_map.values()).intersection({"cpu", "disk"}) + ) > 0: + remove_hook_from_submodules(model) + + if config.is_prompt_learning and is_trainable: + # note: should not be possible to reach, but just in case + raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") + else: + config.inference_mode = not is_trainable + + # note: this is different from PeftModel.from_pretrained, we always return a PeftMixedModel + model = cls(model, config, adapter_name) + model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs) + return model diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index f5f665dd994..9211cfb4f83 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -28,3 +28,4 @@ from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit from .oft import OFTConfig, OFTModel +from .mixed import MixedModel diff --git a/src/peft/tuners/mixed/__init__.py b/src/peft/tuners/mixed/__init__.py new file mode 100644 index 00000000000..f21cff3b293 --- /dev/null +++ b/src/peft/tuners/mixed/__init__.py @@ -0,0 +1,19 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .model import COMPATIBLE_TUNER_TYPES, MixedModel + + +__all__ = ["COMPATIBLE_TUNER_TYPES", "MixedModel"] diff --git a/src/peft/tuners/mixed/model.py b/src/peft/tuners/mixed/model.py new file mode 100644 index 00000000000..5e7acf1cfe7 --- /dev/null +++ b/src/peft/tuners/mixed/model.py @@ -0,0 +1,323 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from typing import Any, Optional, Union + +from torch import nn +from tqdm import tqdm + +from peft.tuners import adalora, loha, lokr, lora +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + PeftType, + _get_submodules, + get_auto_gptq_quant_linear, +) + + +# Collection of constants used for all tuners +COMPATIBLE_TUNER_TYPES = (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA) +PREFIXES = [lora.LoraModel.prefix, lokr.LoKrModel.prefix, loha.LoHaModel.prefix] +Configs = Union[lora.LoraConfig, loha.LoHaConfig, lokr.LoKrConfig, adalora.AdaLoraConfig] +Layers = (lora.layer.LoraLayer, loha.layer.LoHaLayer, lokr.layer.LoKrLayer, adalora.layer.AdaLoraLayer) + + +class MixedModel(BaseTuner): + """ + A class that allows to mix different types of adapters in a single model. + + Note: This class should usually not be initialized directly. Instead, use `get_peft_model` with the argument + `mixed=True`. + + Args: + model (:obj:`nn.Module`): + The model to be tuned. + config (:obj:`PeftConfig`): + The config of the model to be tuned. The adapter type must be compatible. + adapter_name (:obj:`str`): + The name of the first adapter. + """ + + def __init__(self, model: nn.Module, config: Configs, adapter_name: str) -> None: + super().__init__(model, config, adapter_name) + + def _check_new_adapter_config(self, config: Configs) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + if not isinstance(config, Configs.__args__): + raise ValueError( + f"{self.__class__.__name__} only supports {COMPATIBLE_TUNER_TYPES} configs, but got {type(config)}." + ) + + biases = (getattr(config, "bias", None) for config in self.peft_config) + biases = [bias for bias in biases if bias not in (None, "none")] + if len(biases) > 1: + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + @staticmethod + def _check_target_module_exists(config: Configs, key: str): + return check_target_module_exists(config, key) + + def _create_and_replace( + self, + config: Configs, + *args: Any, + **kwargs: Any, + ) -> None: + if isinstance(config, adalora.AdaLoraConfig): + adalora.AdaLoraModel._create_and_replace(self, config, *args, **kwargs) + elif isinstance(config, lora.LoraConfig): + lora.LoraModel._create_and_replace(self, config, *args, **kwargs) + elif isinstance(config, loha.LoHaConfig): + loha.LoHaModel._create_and_replace(self, config, *args, **kwargs) + elif isinstance(config, lokr.LoKrConfig): + lokr.LoKrModel._create_and_replace(self, config, *args, **kwargs) + else: + raise ValueError(f"Unsupported config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.") + + def _replace_module(self, parent, child_name, new_module, child) -> None: + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.get_base_layer() + elif hasattr(child, "quant_linear_module"): + # TODO maybe not necessary to have special treatment? + child = child.quant_linear_module + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if any(prefix in name for prefix in PREFIXES): + module.to(child.weight.device) + if "ranknum" in name: + module.to(child.weight.device) + + def _mark_only_adapters_as_trainable(self) -> None: + for n, p in self.model.named_parameters(): + if not any(prefix in n for prefix in PREFIXES): + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = getattr(self.peft_config[active_adapter], "bias", "none") + if bias == "none": + continue + + if bias == "all": + for n, p in self.model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "lora_only": + # TODO: check if this is needed for other supported types + for m in self.model.modules(): + if isinstance(m, Layers) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise ValueError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(config, adapter_name, target, **kwargs): + gptq_quantization_config = kwargs.get("gptq_quantization_config", None) + AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) + if (gptq_quantization_config is not None) or (AutoGPTQQuantLinear is not None): + raise ValueError(f"GPTQ quantization not supported for {config.peft_type.value} (yet).") + + loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) + loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) + if loaded_in_8bit or loaded_in_4bit: + raise ValueError(f"8bit and 4bit quantization not supported for {config.peft_type.value} (yet).") + + if isinstance(config, adalora.AdaLoraConfig): + new_module = adalora.AdaLoraModel._create_new_module(config, adapter_name, target, **kwargs) + elif isinstance(config, lora.LoraConfig): + new_module = lora.LoraModel._create_new_module(config, adapter_name, target, **kwargs) + elif isinstance(config, loha.LoHaConfig): + new_module = loha.LoHaModel._create_new_module(config, adapter_name, target, **kwargs) + elif isinstance(config, lokr.LoKrConfig): + new_module = lokr.LoKrModel._create_new_module(config, adapter_name, target, **kwargs) + else: + raise ValueError(f"Unknown config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.") + return new_module + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + for active_adapter in self.active_adapters: + val = getattr(self.peft_config[active_adapter], "bias", "none") + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name: Union[str, list[str]]) -> None: + for module in self.model.modules(): + if isinstance(module, Layers): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + if merge: + if getattr(self.model, "quantization_method", None) == "gptq": + raise ValueError("Cannot merge layers when the model is gptq quantized") + + def merge_recursively(module): + # helper function to recursively merge the base_layer of the target + path = [] + layer = module + while hasattr(layer, "base_layer"): + path.append(layer) + layer = layer.base_layer + for layer_before, layer_after in zip(path[:-1], path[1:]): + layer_after.merge(safe_merge=safe_merge, adapter_names=adapter_names) + layer_before.base_layer = layer_after.base_layer + module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + + key_list = [key for key, _ in self.model.named_modules() if not any(prefix in key for prefix in PREFIXES)] + desc = "Unloading " + ("and merging " if merge else "") + "model" + + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + merge_recursively(target) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + setattr(parent, target_name, target.modules_to_save[target.active_adapter]) + + return self.model + + def add_weighted_adapter(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError(f"Weighted adapters are not supported for {self.__class__.__name__} (yet).") + + def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (Union[str, list[str]]): Name of the adapter(s) to delete. + """ + if isinstance(adapter_name, str): + adapter_names = [adapter_name] + else: + adapter_names = adapter_name + + mismatched = set(adapter_names) - set(self.peft_config.keys()) + if mismatched: + raise ValueError( + f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}" + ) + + for adapter_name in adapter_names: + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if not any(prefix in key for prefix in PREFIXES)] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, BaseTunerLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> nn.Module: + r""" + This method merges the layers into the base model. This is needed if someone wants to use the base model as a + standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> nn.Module: + """ + Gets back the base model by removing all the lora modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) + + def generate(self, *args: Any, **kwargs: Any): + return self.model.generate(*args, **kwargs) diff --git a/tests/test_mixed.py b/tests/test_mixed.py new file mode 100644 index 00000000000..bd8f455e990 --- /dev/null +++ b/tests/test_mixed.py @@ -0,0 +1,794 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import os +import re +import tempfile +import unittest + +import torch +from parameterized import parameterized +from torch import nn +from transformers import AutoModelForCausalLM + +from peft import AdaLoraConfig, LoHaConfig, LoKrConfig, LoraConfig, PeftMixedModel, PrefixTuningConfig, get_peft_model +from peft.tuners.tuners_utils import BaseTunerLayer +from peft.utils import infer_device + + +class SimpleNet(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.relu = nn.ReLU() + self.lin1 = nn.Linear(20, 2, bias=bias) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + return X + + +def _param_name_func(testcase_func, param_num, params): + # for parameterized tests in TextMixedAdapterTypes + config0, config1 = params[0] + name0 = config0.__class__.__name__ + name1 = config1.__class__.__name__ + if name0 != name1: + return f"{testcase_func.__name__}_{param_num}_{name0}_{name1}" + return f"{testcase_func.__name__}_{param_num}_{name0}_x2" + + +class TestMixedAdapterTypes(unittest.TestCase): + torch_device = infer_device() + + def _get_model(self, model_cls, peft_config=None, adapter_name=None, seed=0, mixed=True): + torch.manual_seed(0) # always use seed 0 for base model, seed for adapters may differ + base_model = model_cls().eval().to(self.torch_device) + if peft_config is None: + return base_model + + torch.manual_seed(seed) + assert adapter_name is not None + peft_model = get_peft_model(base_model, peft_config, adapter_name=adapter_name, mixed=mixed) + return peft_model.eval().to(self.torch_device) + + def _check_mixed_outputs(self, model_cls, config0, config1, input, *, is_commutative): + # This test checks different combinations of adapter0, adapter1, or combinations of the two, and whether + # outputs are the same/different, depending on context. If we pass is_commutative=True, it means that the order + # of adapters does not matter, and we expect the same output regardless of the order in which adapters are + # applied. + # We have to very careful with resetting the random seed each time it is used, otherwise the adapters may be + # initialized with different values, and the test will fail. + + atol = 1e-5 + rtol = 1e-5 + seed0 = 0 + seed1 = 1 + + # base model + base_model = self._get_model(model_cls) + output_base = base_model(input) + self.assertTrue(torch.isfinite(output_base).all()) + + # adapter 0 + peft_model_0 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + output_config0 = peft_model_0(input) + + self.assertTrue(torch.isfinite(output_config0).all()) + self.assertFalse(torch.allclose(output_base, output_config0, atol=atol, rtol=rtol)) + + # adapter 1 + peft_model_1 = self._get_model(model_cls, config1, "adapter1", seed=seed1) + output_config1 = peft_model_1(input) + + self.assertTrue(torch.isfinite(output_config1).all()) + self.assertFalse(torch.allclose(output_base, output_config1, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_config0, output_config1, atol=atol, rtol=rtol)) + + # adapter 0 + 1 + peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + torch.manual_seed(seed1) + peft_model_01.add_adapter("adapter1", config1) + peft_model_01.set_adapter(["adapter0", "adapter1"]) + output_mixed_01 = peft_model_01(input) + + # check the number of tuner layer types + tuner_layers = [mod for mod in peft_model_01.modules() if isinstance(mod, BaseTunerLayer)] + tuner_types = {type(tuner_layer) for tuner_layer in tuner_layers} + if type(config0) == type(config1): + self.assertEqual(len(tuner_types), 1) + else: + self.assertEqual(len(tuner_types), 2) + + self.assertEqual(peft_model_01.active_adapters, ["adapter0", "adapter1"]) + self.assertTrue(torch.isfinite(output_mixed_01).all()) + self.assertFalse(torch.allclose(output_config0, output_mixed_01, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_config1, output_mixed_01, atol=atol, rtol=rtol)) + if is_commutative: + delta0 = output_config0 - output_base + delta1 = output_config1 - output_base + delta_mixed_01 = output_mixed_01 - output_base + self.assertTrue(torch.allclose(delta0 + delta1, delta_mixed_01, atol=atol, rtol=rtol)) + + # adapter 1 + 0 + peft_model_10 = self._get_model(model_cls, config1, "adapter1", seed=seed1) + torch.manual_seed(seed0) + peft_model_10.add_adapter("adapter0", config0) + peft_model_10.set_adapter(["adapter1", "adapter0"]) + output_mixed_10 = peft_model_10(input) + + # check the number of tuner layer types + tuner_layers = [mod for mod in peft_model_10.modules() if isinstance(mod, BaseTunerLayer)] + tuner_types = {type(tuner_layer) for tuner_layer in tuner_layers} + if type(config0) == type(config1): + self.assertEqual(len(tuner_types), 1) + else: + self.assertEqual(len(tuner_types), 2) + + self.assertEqual(peft_model_10.active_adapters, ["adapter1", "adapter0"]) + self.assertTrue(torch.isfinite(output_mixed_10).all()) + self.assertFalse(torch.allclose(output_config0, output_mixed_10, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_config1, output_mixed_10, atol=atol, rtol=rtol)) + if is_commutative: + self.assertTrue(torch.allclose(output_mixed_01, output_mixed_10, atol=atol, rtol=rtol)) + + # turn around the order of the adapters of the 0 + 1 mixed model, should behave like the 0 + 1 mixed model + peft_model_10.set_adapter(["adapter0", "adapter1"]) + output_mixed_reversed = peft_model_10(input) + + # check the number of tuner layer types + tuner_layers = [mod for mod in peft_model_10.modules() if isinstance(mod, BaseTunerLayer)] + tuner_types = {type(tuner_layer) for tuner_layer in tuner_layers} + if type(config0) == type(config1): + self.assertEqual(len(tuner_types), 1) + else: + self.assertEqual(len(tuner_types), 2) + + self.assertEqual(peft_model_10.active_adapters, ["adapter0", "adapter1"]) + self.assertTrue(torch.isfinite(output_mixed_reversed).all()) + self.assertTrue(torch.allclose(output_mixed_reversed, output_mixed_01, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_mixed_reversed, output_config0, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_mixed_reversed, output_config1, atol=atol, rtol=rtol)) + if is_commutative: + self.assertTrue(torch.allclose(output_mixed_reversed, output_mixed_10, atol=atol, rtol=rtol)) + + def _check_merging(self, model_cls, config0, config1, input): + # Ensure that when merging mixed adapters, the result is the same as when applying the adapters separately. + atol = 1e-5 + rtol = 1e-5 + seed0 = 0 + seed1 = 1 + + # adapter 0 + 1 + peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + torch.manual_seed(seed1) + peft_model_01.add_adapter("adapter1", config1) + peft_model_01.set_adapter(["adapter0", "adapter1"]) + output_mixed_01 = peft_model_01(input) + + model_merged_01 = peft_model_01.merge_and_unload() + output_merged_01 = model_merged_01(input) + self.assertTrue(torch.allclose(output_mixed_01, output_merged_01, atol=atol, rtol=rtol)) + + # adapter 1 + 0 + peft_model_10 = self._get_model(model_cls, config1, "adapter1", seed=seed1) + torch.manual_seed(seed0) + peft_model_10.add_adapter("adapter0", config0) + peft_model_10.set_adapter(["adapter1", "adapter0"]) + output_mixed_10 = peft_model_10(input) + + model_merged_10 = peft_model_10.merge_and_unload() + output_merged_10 = model_merged_10(input) + self.assertTrue(torch.allclose(output_mixed_10, output_merged_10, atol=atol, rtol=rtol)) + + def _check_unload(self, model_cls, config0, config1, input): + # Ensure that we can unload the base model without merging + atol = 1e-5 + rtol = 1e-5 + seed0 = 0 + seed1 = 1 + + base_model = self._get_model(model_cls) + output_base = base_model(input) + + # adapter 0 + 1 + peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + torch.manual_seed(seed1) + peft_model_01.add_adapter("adapter1", config1) + peft_model_01.set_adapter(["adapter0", "adapter1"]) + output_mixed = peft_model_01(input) + + # unload + model_unloaded = peft_model_01.unload() + output_unloaded = model_unloaded(input) + + self.assertFalse(torch.allclose(output_mixed, output_unloaded, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(output_base, output_unloaded, atol=atol, rtol=rtol)) + + def _check_disable(self, model_cls, config0, config1, input): + # Ensure that we can disable adapters + atol = 1e-5 + rtol = 1e-5 + seed0 = 0 + seed1 = 1 + + # base model + base_model = self._get_model(model_cls) + output_base = base_model(input) + + # adapter 0 + peft_model_0 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + output_config0 = peft_model_0(input) + with peft_model_0.disable_adapter(): + output_disabled0 = peft_model_0(input) + + self.assertFalse(torch.allclose(output_base, output_config0, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(output_base, output_disabled0, atol=atol, rtol=rtol)) + + # adapter 1 + peft_model_1 = self._get_model(model_cls, config1, "adapter1", seed=seed1) + output_config1 = peft_model_1(input) + with peft_model_1.disable_adapter(): + output_disabled1 = peft_model_1(input) + + self.assertFalse(torch.allclose(output_base, output_config1, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(output_base, output_disabled1, atol=atol, rtol=rtol)) + + # adapter 0 + 1 + peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + torch.manual_seed(seed1) + peft_model_01.add_adapter("adapter1", config1) + peft_model_01.set_adapter(["adapter0", "adapter1"]) + output_mixed_01 = peft_model_01(input) + with peft_model_01.disable_adapter(): + output_disabled01 = peft_model_01(input) + + self.assertFalse(torch.allclose(output_base, output_mixed_01, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(output_base, output_disabled01, atol=atol, rtol=rtol)) + + # adapter 1 + 0 + peft_model_10 = self._get_model(model_cls, config1, "adapter1", seed=seed1) + torch.manual_seed(seed0) + peft_model_10.add_adapter("adapter0", config0) + peft_model_10.set_adapter(["adapter1", "adapter0"]) + output_mixed_10 = peft_model_10(input) + with peft_model_10.disable_adapter(): + output_disabled10 = peft_model_10(input) + + self.assertFalse(torch.allclose(output_base, output_mixed_10, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(output_base, output_disabled10, atol=atol, rtol=rtol)) + + def _check_loading(self, model_cls, config0, config1, input): + # Check that we can load two adapters into the same model + # Note that we save the adapters using a normal PeftModel because PeftMixModel doesn't support saving yet + atol = 1e-5 + rtol = 1e-5 + seed0 = 0 + seed1 = 1 + + with tempfile.TemporaryDirectory() as tmp_dirname: + # SAVING + # adapter 0: note that we set mixed=False because mixed models don't support saving (yet) + peft_model_0 = self._get_model(model_cls, config0, "adapter0", seed=seed0, mixed=False) + output_config0 = peft_model_0(input) + peft_model_0.save_pretrained(os.path.join(tmp_dirname, "adapter0")) + + # adapter 1: note that we set mixed=False because mixed models don't support saving (yet) + peft_model_1 = self._get_model(model_cls, config1, "adapter1", seed=seed1, mixed=False) + output_config1 = peft_model_1(input) + peft_model_1.save_pretrained(os.path.join(tmp_dirname, "adapter1")) + + # adapter 0 + 1 + peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + torch.manual_seed(seed1) + peft_model_01.add_adapter("adapter1", config1) + peft_model_01.set_adapter(["adapter0", "adapter1"]) + output_mixed_01 = peft_model_01(input) + + # LOADING + # adapter 0 + base_model = self._get_model(model_cls) + # Notes: + # Path is tmp_dirname/adapter0/adapter0 because non-default adapters are saved in a subfolder. + # As a sanity check, we should set a completely different seed here. That way, we ensure that the the + # weights are not just randomly initialized exactly to the same values as before. + torch.manual_seed(123456) + peft_model_loaded0 = PeftMixedModel.from_pretrained( + base_model, os.path.join(tmp_dirname, "adapter0", "adapter0"), "adapter0" + ) + output_loaded0 = peft_model_loaded0(input) + self.assertTrue(torch.allclose(output_config0, output_loaded0, atol=atol, rtol=rtol)) + + # adapter 1 + base_model = self._get_model(model_cls) + torch.manual_seed(654321) # setting a completely different seed here should not affect the result + peft_model_loaded1 = PeftMixedModel.from_pretrained( + base_model, os.path.join(tmp_dirname, "adapter1", "adapter1"), "adapter1" + ) + output_loaded1 = peft_model_loaded1(input) + self.assertTrue(torch.allclose(output_config1, output_loaded1, atol=atol, rtol=rtol)) + + # adapter 0 + 1 + base_model = self._get_model(model_cls) + torch.manual_seed(97531) # setting a completely different seed here should not affect the result + peft_model_loaded_01 = PeftMixedModel.from_pretrained( + base_model, os.path.join(tmp_dirname, "adapter0", "adapter0"), "adapter0" + ) + peft_model_loaded_01.load_adapter(os.path.join(tmp_dirname, "adapter1", "adapter1"), "adapter1") + # at this point, "config0" should still be active + self.assertEqual(peft_model_loaded_01.active_adapters, ["adapter0"]) + output_loaded01_0 = peft_model_loaded_01(input) + self.assertTrue(torch.allclose(output_config0, output_loaded01_0, atol=atol, rtol=rtol)) + # activate adapter1 + peft_model_loaded_01.set_adapter(["adapter1"]) + self.assertEqual(peft_model_loaded_01.active_adapters, ["adapter1"]) + output_loaded01_1 = peft_model_loaded_01(input) + self.assertTrue(torch.allclose(output_config1, output_loaded01_1, atol=atol, rtol=rtol)) + # activate both adapters + peft_model_loaded_01.set_adapter(["adapter0", "adapter1"]) + output_loaded01 = peft_model_loaded_01(input) + self.assertTrue(torch.allclose(output_mixed_01, output_loaded01, atol=atol, rtol=rtol)) + + # adapter 1 + 0 + base_model = self._get_model(model_cls) + torch.manual_seed(445566) # setting a completely different seed here should not affect the result + peft_model_loaded_10 = PeftMixedModel.from_pretrained( + base_model, os.path.join(tmp_dirname, "adapter1", "adapter1"), "adapter1" + ) + peft_model_loaded_10.load_adapter(os.path.join(tmp_dirname, "adapter0", "adapter0"), "adapter0") + # at this point, "config0" should still be active + self.assertEqual(peft_model_loaded_10.active_adapters, ["adapter1"]) + output_loaded10_1 = peft_model_loaded_10(input) + self.assertTrue(torch.allclose(output_config1, output_loaded10_1, atol=atol, rtol=rtol)) + # activate adapter1 + peft_model_loaded_10.set_adapter(["adapter0"]) + self.assertEqual(peft_model_loaded_10.active_adapters, ["adapter0"]) + output_loaded10_0 = peft_model_loaded_10(input) + self.assertTrue(torch.allclose(output_config0, output_loaded10_0, atol=atol, rtol=rtol)) + # activate both adapters + peft_model_loaded_10.set_adapter(["adapter1", "adapter0"]) + output_loaded10 = peft_model_loaded_10(input) + self.assertTrue(torch.allclose(output_mixed_01, output_loaded10, atol=atol, rtol=rtol)) + + @parameterized.expand( + itertools.combinations( + [ + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoHaConfig(target_modules=["lin0"], init_weights=False), + LoKrConfig(target_modules=["lin0"], init_weights=False), + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + ], + r=2, + ), + name_func=_param_name_func, + ) + def test_target_first_layer(self, config0, config1): + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=False) + self._check_merging(SimpleNet, config0, config1, input) + self._check_unload(SimpleNet, config0, config1, input) + self._check_disable(SimpleNet, config1, config0, input) + self._check_loading(SimpleNet, config0, config1, input) + + @parameterized.expand( + itertools.combinations( + [ + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + LoHaConfig(target_modules=["lin1"], init_weights=False), + LoKrConfig(target_modules=["lin1"], init_weights=False), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + ], + r=2, + ), + name_func=_param_name_func, + ) + def test_target_last_layer(self, config0, config1): + # We are targeting the last layer of the SimpleNet. Therefore, since the adapters only add their activations + # to the output, the results should be commutative. This would *not* work if the adapters do something more + # complex or if we target an earlier layer, because of the non-linearity would destroy the commutativity. + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=True) + self._check_merging(SimpleNet, config0, config1, input) + self._check_unload(SimpleNet, config0, config1, input) + self._check_disable(SimpleNet, config1, config0, input) + self._check_loading(SimpleNet, config0, config1, input) + + @parameterized.expand( + [ + ( + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoHaConfig(target_modules=["lin1"], init_weights=False), + ), + ( + LoHaConfig(target_modules=["lin0"], init_weights=False), + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoKrConfig(target_modules=["lin1"], init_weights=False), + ), + ( + LoKrConfig(target_modules=["lin0"], init_weights=False), + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + LoHaConfig(target_modules=["lin0"], init_weights=False), + LoKrConfig(target_modules=["lin1"], init_weights=False), + ), + ( + LoKrConfig(target_modules=["lin0"], init_weights=False), + LoHaConfig(target_modules=["lin1"], init_weights=False), + ), + ( + LoHaConfig(target_modules=["lin0"], init_weights=False), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoHaConfig(target_modules=["lin1"], init_weights=False), + ), + ( + LoKrConfig(target_modules=["lin0"], init_weights=False), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoKrConfig(target_modules=["lin1"], init_weights=False), + ), + ], + name_func=_param_name_func, + ) + def test_target_different_layers(self, config0, config1): + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=False) + self._check_merging(SimpleNet, config0, config1, input) + self._check_unload(SimpleNet, config0, config1, input) + self._check_disable(SimpleNet, config1, config0, input) + self._check_loading(SimpleNet, config0, config1, input) + + @parameterized.expand( + [ + ( + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + LoHaConfig(target_modules=["lin1"], init_weights=False), + LoHaConfig(target_modules=["lin1"], init_weights=False), + ), + ( + LoKrConfig(target_modules=["lin1"], init_weights=False), + LoKrConfig(target_modules=["lin1"], init_weights=False), + ), + ( + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ], + name_func=_param_name_func, + ) + def test_target_last_layer_same_type(self, config0, config1): + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=True) + self._check_merging(SimpleNet, config0, config1, input) + self._check_unload(SimpleNet, config0, config1, input) + self._check_disable(SimpleNet, config1, config0, input) + + @parameterized.expand( + [ + ( + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + ), + ( + LoHaConfig(target_modules=["lin0"], init_weights=False), + LoHaConfig(target_modules=["lin0"], init_weights=False), + ), + ( + LoKrConfig(target_modules=["lin0"], init_weights=False), + LoKrConfig(target_modules=["lin0"], init_weights=False), + ), + ( + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + ), + ], + name_func=_param_name_func, + ) + def test_target_first_layer_same_type(self, config0, config1): + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=False) + self._check_merging(SimpleNet, config0, config1, input) + self._check_unload(SimpleNet, config0, config1, input) + self._check_disable(SimpleNet, config1, config0, input) + self._check_loading(SimpleNet, config0, config1, input) + + def test_deeply_nested(self): + # a somewhat absurdly nested model using different adapter types + atol = 1e-5 + rtol = 1e-5 + torch.manual_seed(0) + + model = SimpleNet().eval().to(self.torch_device) + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + output_base = model(input) + + config0 = LoraConfig(r=4, lora_alpha=4, target_modules=["lin0", "lin1"], init_lora_weights=False) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + + config1 = LoHaConfig(r=4, alpha=4, target_modules=["lin0"], init_weights=False) + peft_model.add_adapter("adapter1", config1) + + config2 = AdaLoraConfig(r=4, lora_alpha=4, target_modules=["lin1"], init_lora_weights=False) + peft_model.add_adapter("adapter2", config2) + + config3 = LoKrConfig(r=4, alpha=4, target_modules=["lin0", "lin1"], init_weights=False) + peft_model.add_adapter("adapter3", config3) + + config4 = LoraConfig(r=4, lora_alpha=4, target_modules=["lin0", "lin1"], init_lora_weights=False) + peft_model.add_adapter("adapter4", config4) + + peft_model.set_adapter(["adapter0", "adapter1", "adapter2", "adapter3", "adapter4"]) + output_mixed = peft_model(input) + self.assertTrue(torch.isfinite(output_base).all()) + self.assertFalse(torch.allclose(output_base, output_mixed, atol=atol, rtol=rtol)) + + # test disabling all adapters + with peft_model.disable_adapter(): + output_disabled = peft_model(input) + self.assertTrue(torch.isfinite(output_disabled).all()) + self.assertTrue(torch.allclose(output_base, output_disabled, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_mixed, output_disabled, atol=atol, rtol=rtol)) + + # merge and unload all adapters + model_copy = copy.deepcopy(peft_model) + model = model_copy.merge_and_unload() + output_merged = model(input) + self.assertTrue(torch.isfinite(output_merged).all()) + self.assertTrue(torch.allclose(output_mixed, output_merged, atol=atol, rtol=rtol)) + + # merge and unload only adapter1 and adapter3 + model_copy = copy.deepcopy(peft_model) + model_copy.set_adapter(["adapter1", "adapter3"]) + output_13 = model_copy(input) + self.assertTrue(torch.isfinite(output_13).all()) + self.assertFalse(torch.allclose(output_mixed, output_13, atol=atol, rtol=rtol)) + + model_copy.set_adapter(["adapter0", "adapter1", "adapter2", "adapter3", "adapter4"]) + model_merged_unloaded = model_copy.merge_and_unload(adapter_names=["adapter1", "adapter3"]) + output_merged_13 = model_merged_unloaded(input) + self.assertTrue(torch.isfinite(output_merged_13).all()) + self.assertTrue(torch.allclose(output_13, output_merged_13, atol=atol, rtol=rtol)) + + # test unloading + model_copy = copy.deepcopy(peft_model) + model_unloaded = model_copy.unload() + output_unloaded = model_unloaded(input) + self.assertTrue(torch.isfinite(output_unloaded).all()) + self.assertTrue(torch.allclose(output_base, output_unloaded, atol=atol, rtol=rtol)) + + def test_delete_adapter(self): + atol = 1e-5 + rtol = 1e-5 + torch.manual_seed(0) + + model = SimpleNet().eval().to(self.torch_device) + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + output_base = model(input) + + # create adapter0 + torch.manual_seed(0) + config0 = LoraConfig(r=4, lora_alpha=4, target_modules=["lin0", "lin1"], init_lora_weights=False) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + output_0 = peft_model(input) + self.assertFalse(torch.allclose(output_base, output_0, atol=atol, rtol=rtol)) + + # add adapter1 + torch.manual_seed(1) + config1 = LoHaConfig(r=4, alpha=4, target_modules=["lin0"], init_weights=False) + peft_model.add_adapter("adapter1", config1) + peft_model.set_adapter(["adapter0", "adapter1"]) + output_01 = peft_model(input) + self.assertFalse(torch.allclose(output_base, output_01, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_0, output_01, atol=atol, rtol=rtol)) + + # delete adapter1 + peft_model.delete_adapter("adapter1") + self.assertEqual(peft_model.active_adapters, ["adapter0"]) + output_deleted_1 = peft_model(input) + self.assertTrue(torch.allclose(output_0, output_deleted_1, atol=atol, rtol=rtol)) + + msg = re.escape("Adapter(s) ['adapter1'] not found, available adapters: ['adapter0']") + with self.assertRaisesRegex(ValueError, expected_regex=msg): + peft_model.set_adapter(["adapter0", "adapter1"]) + + # re-add adapter1 + torch.manual_seed(1) + peft_model.add_adapter("adapter1", config1) + peft_model.set_adapter(["adapter0", "adapter1"]) + output_01_readded = peft_model(input) + self.assertFalse(torch.allclose(output_base, output_01_readded, atol=atol, rtol=rtol)) + + # same as above, but this time delete adapter0 first + torch.manual_seed(0) + model = SimpleNet().eval().to(self.torch_device) + torch.manual_seed(0) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + torch.manual_seed(1) + peft_model.add_adapter("adapter1", config1) + peft_model.delete_adapter("adapter0") + self.assertEqual(peft_model.active_adapters, ["adapter1"]) + output_deleted_0 = peft_model(input) + self.assertFalse(torch.allclose(output_deleted_0, output_base, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_deleted_0, output_01, atol=atol, rtol=rtol)) + + msg = re.escape("Adapter(s) ['adapter0'] not found, available adapters: ['adapter1']") + with self.assertRaisesRegex(ValueError, expected_regex=msg): + peft_model.set_adapter(["adapter0", "adapter1"]) + + peft_model.delete_adapter("adapter1") + self.assertEqual(peft_model.active_adapters, []) + output_deleted_01 = peft_model(input) + self.assertTrue(torch.allclose(output_deleted_01, output_base, atol=atol, rtol=rtol)) + + def test_modules_to_save(self): + model = SimpleNet().eval().to(self.torch_device) + config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + + # adding a second adapter with same modules_to_save is not allowed + # TODO: theoretically, we could allow this if it's the same target layer + config1 = LoHaConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + peft_model.add_adapter("adapter1", config1) + msg = "Only one adapter can be set at a time for modules_to_save" + with self.assertRaisesRegex(ValueError, expected_regex=msg): + peft_model.set_adapter(["adapter0", "adapter1"]) + + def test_get_nb_trainable_parameters(self): + model = SimpleNet().eval().to(self.torch_device) + config0 = LoraConfig(target_modules=["lin0"]) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + trainable_params0, all_param0 = peft_model.get_nb_trainable_parameters() + + params_base = 262 + params_lora = sum(p.numel() for n, p in model.named_parameters() if "adapter0" in n) + self.assertEqual(trainable_params0, params_lora) + self.assertEqual(all_param0, params_base + params_lora) + + config1 = LoHaConfig(target_modules=["lin1"]) + peft_model.add_adapter("adapter1", config1) + peft_model.set_adapter(["adapter0", "adapter1"]) + params_loha = sum(p.numel() for n, p in model.named_parameters() if "adapter1" in n) + trainable_params1, all_param1 = peft_model.get_nb_trainable_parameters() + self.assertEqual(trainable_params1, params_lora + params_loha) + self.assertEqual(all_param1, params_base + params_lora + params_loha) + + config2 = AdaLoraConfig(target_modules=["lin0", "lin1"]) + peft_model.add_adapter("adapter2", config2) + peft_model.set_adapter(["adapter0", "adapter1", "adapter2"]) + params_adalora = sum(p.numel() for n, p in model.named_parameters() if "adapter2" in n) + trainable_params2, all_param2 = peft_model.get_nb_trainable_parameters() + # remove 2 params because we need to exclude "ranknum" for AdaLora trainable params + self.assertEqual(trainable_params2, params_lora + params_loha + params_adalora - 2) + self.assertEqual(all_param2, params_base + params_lora + params_loha + params_adalora) + + def test_incompatible_config_raises(self): + model = SimpleNet().eval().to(self.torch_device) + config0 = LoraConfig(target_modules=["lin0"]) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + + config1 = PrefixTuningConfig() + msg = "The provided `peft_type` 'PREFIX_TUNING' is not compatible with the `PeftMixedModel`." + with self.assertRaisesRegex(ValueError, expected_regex=msg): + peft_model.add_adapter("adapter1", config1) + + def test_decoder_model(self): + # test a somewhat realistic model instead of a toy model + torch.manual_seed(0) + + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id).eval().to(self.torch_device) + input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device) + attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + input_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + output_base = model.generate(**input_dict) + + torch.manual_seed(0) + config0 = LoraConfig(task_type="CAUSAL_LM", init_lora_weights=False) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + output0 = peft_model.generate(**input_dict) + self.assertTrue(torch.isfinite(output0).all()) + self.assertFalse(torch.allclose(output_base, output0)) + + torch.manual_seed(1) + config1 = LoHaConfig(task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"], init_weights=False) + peft_model.add_adapter("adapter1", config1) + peft_model.set_adapter(["adapter0", "adapter1"]) + output1 = peft_model.generate(**input_dict) + self.assertTrue(torch.isfinite(output1).all()) + self.assertFalse(torch.allclose(output0, output1)) + + torch.manual_seed(2) + config2 = AdaLoraConfig(task_type="CAUSAL_LM", init_lora_weights=False) + peft_model.add_adapter("adapter2", config2) + peft_model.set_adapter(["adapter0", "adapter1", "adapter2"]) + output2 = peft_model.generate(**input_dict) + self.assertTrue(torch.isfinite(output2).all()) + self.assertFalse(torch.allclose(output1, output2)) + + torch.manual_seed(3) + config3 = LoKrConfig(task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"], init_weights=False) + peft_model.add_adapter("adapter3", config3) + peft_model.set_adapter(["adapter0", "adapter1", "adapter2", "adapter3"]) + output3 = peft_model.generate(**input_dict) + self.assertTrue(torch.isfinite(output3).all()) + self.assertFalse(torch.allclose(output2, output3)) + + with peft_model.disable_adapter(): + output_disabled = peft_model.generate(**input_dict) + self.assertTrue(torch.isfinite(output_disabled).all()) + self.assertTrue(torch.allclose(output_base, output_disabled)) + + model_unloaded = peft_model.merge_and_unload() + output_unloaded = model_unloaded.generate(**input_dict) + self.assertTrue(torch.isfinite(output_unloaded).all()) + self.assertTrue(torch.allclose(output3, output_unloaded)) + + with tempfile.TemporaryDirectory() as tmp_dir: + # save adapter0 (use normal PeftModel, because PeftMixedModel does not support saving) + torch.manual_seed(0) + model = AutoModelForCausalLM.from_pretrained(model_id).eval().to(self.torch_device) + torch.manual_seed(0) + peft_model = get_peft_model(model, config0, "adapter0") + output0_save = peft_model(**input_dict).logits + self.assertTrue(torch.isfinite(output0_save).all()) + peft_model.save_pretrained(tmp_dir) + + # save adapter1 + torch.manual_seed(0) + model = AutoModelForCausalLM.from_pretrained(model_id).eval().to(self.torch_device) + torch.manual_seed(1) + peft_model = get_peft_model(model, config1, "adapter1") + output1_save = peft_model(**input_dict).logits + self.assertTrue(torch.isfinite(output1_save).all()) + peft_model.save_pretrained(tmp_dir) + + # load adapter0 and adapter1 + model = AutoModelForCausalLM.from_pretrained(model_id).eval().to(self.torch_device) + peft_model = PeftMixedModel.from_pretrained(model, os.path.join(tmp_dir, "adapter0"), "adapter0") + peft_model.load_adapter(os.path.join(tmp_dir, "adapter1"), "adapter1") + peft_model.set_adapter(["adapter0", "adapter1"]) + output01_loaded = peft_model(**input_dict).logits + + atol, rtol = 1e-3, 1e-3 + self.assertTrue(torch.isfinite(output01_loaded).all()) + self.assertFalse(torch.allclose(output0_save, output01_loaded, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output1_save, output01_loaded, atol=atol, rtol=rtol)) From 5bad88ba042d2c45404d309a859759dd73a3eac5 Mon Sep 17 00:00:00 2001 From: Akash Kundu <112017800+Akash190104@users.noreply.github.com> Date: Mon, 4 Dec 2023 16:23:40 +0530 Subject: [PATCH 07/12] [DOCS] README.md (#1054) minor fixes --- docs/README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/README.md b/docs/README.md index 5955736ee6d..0b76173a66d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -33,7 +33,7 @@ pip install git+https://github.com/huggingface/doc-builder **NOTE** You only need to generate the documentation to inspect it locally (if you're planning changes and want to -check how they look before committing for instance). You don't have to commit the built documentation. +check how they look before committing for instance). You don't have to commit to the built documentation. --- @@ -46,7 +46,7 @@ typing the following command: doc-builder build peft docs/source/ --build_dir ~/tmp/test-build ``` -You can adapt the `--build_dir` to set any temporary folder that you prefer. This command will create it and generate +You can adapt the `--build_dir` to set any temporary folder you prefer. This command will create it and generate the MDX files that will be rendered as the documentation on the main website. You can inspect them in your favorite Markdown editor. @@ -124,7 +124,7 @@ Adding a new tutorial or section is done in two steps: - Link that file in `./source/_toctree.yml` on the correct toc-tree. Make sure to put your new file under the proper section. It's unlikely to go in the first section (*Get Started*), so -depending on the intended targets (beginners, more advanced users, or researchers) it should go in sections two, three, or +depending on the intended targets (beginners, more advanced users, or researchers) it should go into sections two, three, or four. ### Writing source documentation @@ -188,7 +188,7 @@ then its documentation should look like this: ``` Note that we always omit the "defaults to \`None\`" when None is the default for any argument. Also note that even -if the first line describing your argument type and its default gets long, you can't break it on several lines. You can +if the first line describing your argument type and its default gets long, you can't break it into several lines. You can however write as many lines as you want in the indented description (see the example above with `input_ids`). #### Writing a multi-line code block @@ -234,13 +234,13 @@ We have an automatic script running with the `make style` comment that will make - the docstrings fully take advantage of the line width - all code examples are formatted using black, like the code of the Transformers library -This script may have some weird failures if you made a syntax mistake or if you uncover a bug. Therefore, it's +This script may have some weird failures if you make a syntax mistake or if you uncover a bug. Therefore, it's recommended to commit your changes before running `make style`, so you can revert the changes done by that script easily. ## Writing documentation examples -The syntax for Example docstrings can look as follows: +The syntax, for example, docstrings can look as follows: ``` Example: @@ -264,4 +264,4 @@ is to be used in inference and also include the expected (ideally sensible) output. Often, readers will try out the example before even going through the function or class definitions. Therefore, it is of utmost importance that the example -works as expected. \ No newline at end of file +works as expected. From 5ed46e4f043b421401da1919f1037277b7d9356a Mon Sep 17 00:00:00 2001 From: zhangshengdong29 <435878393@qq.com> Date: Mon, 4 Dec 2023 19:16:58 +0800 Subject: [PATCH 08/12] FIX Issue with megatron parallel linear lora (#1202) --- src/peft/tuners/lora/tp_layer.py | 2 +- tests/test_lora_megatron.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/tp_layer.py b/src/peft/tuners/lora/tp_layer.py index 676430cf38c..d55731dce5a 100644 --- a/src/peft/tuners/lora/tp_layer.py +++ b/src/peft/tuners/lora/tp_layer.py @@ -111,7 +111,7 @@ def update_layer( self.lora_B[adapter_name] = lora_b self.scaling[adapter_name] = lora_alpha / r if init_lora_weights: - self.reset_lora_parameters(adapter_name) + self.reset_lora_parameters(adapter_name, init_lora_weights) weight = getattr(self.get_base_layer(), "weight", None) if weight is not None: diff --git a/tests/test_lora_megatron.py b/tests/test_lora_megatron.py index 80d0f43010e..4244dd9735a 100644 --- a/tests/test_lora_megatron.py +++ b/tests/test_lora_megatron.py @@ -85,6 +85,7 @@ def __init__(self, config: TransformerConfig): init_method=init.xavier_normal_, bias=False, input_is_parallel=True, + skip_bias_add=True, ) def forward(self, input): From e05b2670c52959a6b4f0aecb8ec4d3417fb0f0e8 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 4 Dec 2023 12:18:49 +0100 Subject: [PATCH 09/12] ENH: Enable OFT adapter for mixed adapter models (#1204) This PR makes it possible to use the newly added OFT adapter in mixed adapter type models, similar to LoRA, LoHa, etc. Notes Adding the integration was pretty straightforward, which is a good sign. The difficult part was actually about the tests. This stems from the fact that OFT is (if my understanding is correct) never commutative. What I mean is that even if the adapters are applied to the last layer of a model, it makes a difference whether we apply, say, first LoRA, then OFT vs first OFT, then LoRA. This is different for the other adapters that were added so far for mixed models, as they basically do: - Xa = X + dXa - Xab = Xa + dXb = X + dXa + dXb = X + dXb + dXa = Xb + dXa = Xba This is not true for OFT, so when OFT is used, I had to ensure that no test was applied that (implicitly) assumes commutativity. Furthermore, I had to increase the model size, see this comment: https://github.com/huggingface/peft/pull/1160#issuecomment-1836107235 --- src/peft/mixed_model.py | 2 + src/peft/tuners/mixed/model.py | 14 +-- tests/test_mixed.py | 159 ++++++++++++++++++--------------- 3 files changed, 100 insertions(+), 75 deletions(-) diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py index 55892851e9e..b22c1ad1859 100644 --- a/src/peft/mixed_model.py +++ b/src/peft/mixed_model.py @@ -35,6 +35,7 @@ LoKrModel, LoraModel, MixedModel, + OFTModel, ) from .utils import PeftType, _set_adapter, _set_trainable @@ -45,6 +46,7 @@ PeftType.LOKR: LoKrModel, PeftType.ADALORA: AdaLoraModel, PeftType.IA3: IA3Model, + PeftType.OFT: OFTModel, } diff --git a/src/peft/tuners/mixed/model.py b/src/peft/tuners/mixed/model.py index 5e7acf1cfe7..e8c98bc4f76 100644 --- a/src/peft/tuners/mixed/model.py +++ b/src/peft/tuners/mixed/model.py @@ -20,7 +20,7 @@ from torch import nn from tqdm import tqdm -from peft.tuners import adalora, loha, lokr, lora +from peft.tuners import adalora, loha, lokr, lora, oft from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists from peft.utils import ( TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, @@ -32,10 +32,10 @@ # Collection of constants used for all tuners -COMPATIBLE_TUNER_TYPES = (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA) -PREFIXES = [lora.LoraModel.prefix, lokr.LoKrModel.prefix, loha.LoHaModel.prefix] -Configs = Union[lora.LoraConfig, loha.LoHaConfig, lokr.LoKrConfig, adalora.AdaLoraConfig] -Layers = (lora.layer.LoraLayer, loha.layer.LoHaLayer, lokr.layer.LoKrLayer, adalora.layer.AdaLoraLayer) +COMPATIBLE_TUNER_TYPES = (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA, PeftType.OFT) +PREFIXES = [lora.LoraModel.prefix, lokr.LoKrModel.prefix, loha.LoHaModel.prefix, oft.OFTModel.prefix] +Configs = Union[lora.LoraConfig, loha.LoHaConfig, lokr.LoKrConfig, adalora.AdaLoraConfig, oft.OFTConfig] +Layers = (lora.layer.LoraLayer, loha.layer.LoHaLayer, lokr.layer.LoKrLayer, adalora.layer.AdaLoraLayer, oft.OFTLayer) class MixedModel(BaseTuner): @@ -95,6 +95,8 @@ def _create_and_replace( loha.LoHaModel._create_and_replace(self, config, *args, **kwargs) elif isinstance(config, lokr.LoKrConfig): lokr.LoKrModel._create_and_replace(self, config, *args, **kwargs) + elif isinstance(config, oft.OFTConfig): + oft.OFTModel._create_and_replace(self, config, *args, **kwargs) else: raise ValueError(f"Unsupported config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.") @@ -171,6 +173,8 @@ def _create_new_module(config, adapter_name, target, **kwargs): new_module = loha.LoHaModel._create_new_module(config, adapter_name, target, **kwargs) elif isinstance(config, lokr.LoKrConfig): new_module = lokr.LoKrModel._create_new_module(config, adapter_name, target, **kwargs) + elif isinstance(config, oft.OFTConfig): + new_module = oft.OFTModel._create_new_module(config, adapter_name, target, **kwargs) else: raise ValueError(f"Unknown config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.") return new_module diff --git a/tests/test_mixed.py b/tests/test_mixed.py index bd8f455e990..ea35df391f7 100644 --- a/tests/test_mixed.py +++ b/tests/test_mixed.py @@ -25,7 +25,16 @@ from torch import nn from transformers import AutoModelForCausalLM -from peft import AdaLoraConfig, LoHaConfig, LoKrConfig, LoraConfig, PeftMixedModel, PrefixTuningConfig, get_peft_model +from peft import ( + AdaLoraConfig, + LoHaConfig, + LoKrConfig, + LoraConfig, + OFTConfig, + PeftMixedModel, + PrefixTuningConfig, + get_peft_model, +) from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils import infer_device @@ -33,9 +42,10 @@ class SimpleNet(nn.Module): def __init__(self, bias=True): super().__init__() + # note: out_features must be > rank or else OFT will be an identity transform self.lin0 = nn.Linear(10, 20, bias=bias) self.relu = nn.ReLU() - self.lin1 = nn.Linear(20, 2, bias=bias) + self.lin1 = nn.Linear(20, 16, bias=bias) def forward(self, X): X = X.float() @@ -48,8 +58,8 @@ def forward(self, X): def _param_name_func(testcase_func, param_num, params): # for parameterized tests in TextMixedAdapterTypes config0, config1 = params[0] - name0 = config0.__class__.__name__ - name1 = config1.__class__.__name__ + name0 = config0.__class__.__name__[: -len("Config")] + name1 = config1.__class__.__name__[: -len("Config")] if name0 != name1: return f"{testcase_func.__name__}_{param_num}_{name0}_{name1}" return f"{testcase_func.__name__}_{param_num}_{name0}_x2" @@ -163,16 +173,17 @@ def _check_mixed_outputs(self, model_cls, config0, config1, input, *, is_commuta self.assertEqual(peft_model_10.active_adapters, ["adapter0", "adapter1"]) self.assertTrue(torch.isfinite(output_mixed_reversed).all()) - self.assertTrue(torch.allclose(output_mixed_reversed, output_mixed_01, atol=atol, rtol=rtol)) self.assertFalse(torch.allclose(output_mixed_reversed, output_config0, atol=atol, rtol=rtol)) self.assertFalse(torch.allclose(output_mixed_reversed, output_config1, atol=atol, rtol=rtol)) if is_commutative: + self.assertTrue(torch.allclose(output_mixed_reversed, output_mixed_01, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(output_mixed_reversed, output_mixed_10, atol=atol, rtol=rtol)) def _check_merging(self, model_cls, config0, config1, input): # Ensure that when merging mixed adapters, the result is the same as when applying the adapters separately. - atol = 1e-5 - rtol = 1e-5 + # Merging requires a bit higher tolerance for some adapters, which can also vary depending on CPU vs GPU. + atol = 1e-4 + rtol = 1e-4 seed0 = 0 seed1 = 1 @@ -275,7 +286,7 @@ def _check_disable(self, model_cls, config0, config1, input): self.assertFalse(torch.allclose(output_base, output_mixed_10, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(output_base, output_disabled10, atol=atol, rtol=rtol)) - def _check_loading(self, model_cls, config0, config1, input): + def _check_loading(self, model_cls, config0, config1, input, *, is_commutative): # Check that we can load two adapters into the same model # Note that we save the adapters using a normal PeftModel because PeftMixModel doesn't support saving yet atol = 1e-5 @@ -302,6 +313,13 @@ def _check_loading(self, model_cls, config0, config1, input): peft_model_01.set_adapter(["adapter0", "adapter1"]) output_mixed_01 = peft_model_01(input) + # adapter 1 + 0 + peft_model_10 = self._get_model(model_cls, config1, "adapter1", seed=seed1) + torch.manual_seed(seed0) + peft_model_10.add_adapter("adapter0", config0) + peft_model_10.set_adapter(["adapter1", "adapter0"]) + output_mixed_10 = peft_model_10(input) + # LOADING # adapter 0 base_model = self._get_model(model_cls) @@ -332,7 +350,7 @@ def _check_loading(self, model_cls, config0, config1, input): base_model, os.path.join(tmp_dirname, "adapter0", "adapter0"), "adapter0" ) peft_model_loaded_01.load_adapter(os.path.join(tmp_dirname, "adapter1", "adapter1"), "adapter1") - # at this point, "config0" should still be active + # at this point, "adapter0" should still be active self.assertEqual(peft_model_loaded_01.active_adapters, ["adapter0"]) output_loaded01_0 = peft_model_loaded_01(input) self.assertTrue(torch.allclose(output_config0, output_loaded01_0, atol=atol, rtol=rtol)) @@ -353,7 +371,7 @@ def _check_loading(self, model_cls, config0, config1, input): base_model, os.path.join(tmp_dirname, "adapter1", "adapter1"), "adapter1" ) peft_model_loaded_10.load_adapter(os.path.join(tmp_dirname, "adapter0", "adapter0"), "adapter0") - # at this point, "config0" should still be active + # at this point, "adapter1" should still be active self.assertEqual(peft_model_loaded_10.active_adapters, ["adapter1"]) output_loaded10_1 = peft_model_loaded_10(input) self.assertTrue(torch.allclose(output_config1, output_loaded10_1, atol=atol, rtol=rtol)) @@ -365,7 +383,11 @@ def _check_loading(self, model_cls, config0, config1, input): # activate both adapters peft_model_loaded_10.set_adapter(["adapter1", "adapter0"]) output_loaded10 = peft_model_loaded_10(input) - self.assertTrue(torch.allclose(output_mixed_01, output_loaded10, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(output_mixed_10, output_loaded10, atol=atol, rtol=rtol)) + + if is_commutative: + self.assertTrue(torch.allclose(output_loaded01, output_loaded10, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(output_loaded10, output_mixed_01, atol=atol, rtol=rtol)) @parameterized.expand( itertools.combinations( @@ -374,6 +396,7 @@ def _check_loading(self, model_cls, config0, config1, input): LoHaConfig(target_modules=["lin0"], init_weights=False), LoKrConfig(target_modules=["lin0"], init_weights=False), AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + OFTConfig(target_modules=["lin0"], init_weights=False), ], r=2, ), @@ -385,7 +408,7 @@ def test_target_first_layer(self, config0, config1): self._check_merging(SimpleNet, config0, config1, input) self._check_unload(SimpleNet, config0, config1, input) self._check_disable(SimpleNet, config1, config0, input) - self._check_loading(SimpleNet, config0, config1, input) + self._check_loading(SimpleNet, config0, config1, input, is_commutative=False) @parameterized.expand( itertools.combinations( @@ -394,6 +417,7 @@ def test_target_first_layer(self, config0, config1): LoHaConfig(target_modules=["lin1"], init_weights=False), LoKrConfig(target_modules=["lin1"], init_weights=False), AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + OFTConfig(target_modules=["lin1"], init_weights=False), ], r=2, ), @@ -404,72 +428,47 @@ def test_target_last_layer(self, config0, config1): # to the output, the results should be commutative. This would *not* work if the adapters do something more # complex or if we target an earlier layer, because of the non-linearity would destroy the commutativity. input = torch.arange(90).reshape(9, 10).to(self.torch_device) - self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=True) + # OFT is not commutative, as it's not a linear operation on the inputs + is_commutative = not any(isinstance(config, OFTConfig) for config in [config0, config1]) + + self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=is_commutative) self._check_merging(SimpleNet, config0, config1, input) self._check_unload(SimpleNet, config0, config1, input) self._check_disable(SimpleNet, config1, config0, input) - self._check_loading(SimpleNet, config0, config1, input) + self._check_loading(SimpleNet, config0, config1, input, is_commutative=is_commutative) @parameterized.expand( - [ - ( - LoraConfig(target_modules=["lin0"], init_lora_weights=False), - LoHaConfig(target_modules=["lin1"], init_weights=False), - ), - ( - LoHaConfig(target_modules=["lin0"], init_weights=False), - LoraConfig(target_modules=["lin1"], init_lora_weights=False), - ), - ( - LoraConfig(target_modules=["lin0"], init_lora_weights=False), - LoKrConfig(target_modules=["lin1"], init_weights=False), - ), - ( - LoKrConfig(target_modules=["lin0"], init_weights=False), - LoraConfig(target_modules=["lin1"], init_lora_weights=False), - ), - ( - LoraConfig(target_modules=["lin0"], init_lora_weights=False), - AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), - ), - ( - AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), - LoraConfig(target_modules=["lin1"], init_lora_weights=False), - ), - ( - LoHaConfig(target_modules=["lin0"], init_weights=False), - LoKrConfig(target_modules=["lin1"], init_weights=False), - ), - ( - LoKrConfig(target_modules=["lin0"], init_weights=False), - LoHaConfig(target_modules=["lin1"], init_weights=False), - ), - ( - LoHaConfig(target_modules=["lin0"], init_weights=False), - AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), - ), - ( - AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), - LoHaConfig(target_modules=["lin1"], init_weights=False), - ), - ( - LoKrConfig(target_modules=["lin0"], init_weights=False), - AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), - ), - ( - AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), - LoKrConfig(target_modules=["lin1"], init_weights=False), - ), - ], + itertools.combinations( + [ + LoraConfig(init_lora_weights=False), + LoHaConfig(init_weights=False), + LoKrConfig(init_weights=False), + AdaLoraConfig(init_lora_weights=False), + OFTConfig(init_weights=False), + ], + r=2, + ), name_func=_param_name_func, ) def test_target_different_layers(self, config0, config1): input = torch.arange(90).reshape(9, 10).to(self.torch_device) + + config0.target_modules = ["lin0"] + config1.target_modules = ["lin1"] self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=False) self._check_merging(SimpleNet, config0, config1, input) self._check_unload(SimpleNet, config0, config1, input) + self._check_disable(SimpleNet, config0, config1, input) + self._check_loading(SimpleNet, config0, config1, input, is_commutative=False) + + # same, but switch target_modules around + config0.target_modules = ["lin1"] + config1.target_modules = ["lin0"] + self._check_mixed_outputs(SimpleNet, config1, config0, input, is_commutative=False) + self._check_merging(SimpleNet, config1, config0, input) + self._check_unload(SimpleNet, config1, config0, input) self._check_disable(SimpleNet, config1, config0, input) - self._check_loading(SimpleNet, config0, config1, input) + self._check_loading(SimpleNet, config1, config0, input, is_commutative=False) @parameterized.expand( [ @@ -489,12 +488,19 @@ def test_target_different_layers(self, config0, config1): AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), ), + ( + OFTConfig(target_modules=["lin1"], init_weights=False), + OFTConfig(target_modules=["lin1"], init_weights=False), + ), ], name_func=_param_name_func, ) def test_target_last_layer_same_type(self, config0, config1): input = torch.arange(90).reshape(9, 10).to(self.torch_device) - self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=True) + # OFT is not commutative, as it's not a linear operation on the inputs + is_commutative = not any(isinstance(config, OFTConfig) for config in [config0, config1]) + + self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=is_commutative) self._check_merging(SimpleNet, config0, config1, input) self._check_unload(SimpleNet, config0, config1, input) self._check_disable(SimpleNet, config1, config0, input) @@ -517,6 +523,10 @@ def test_target_last_layer_same_type(self, config0, config1): AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), ), + ( + OFTConfig(target_modules=["lin0"], init_weights=False), + OFTConfig(target_modules=["lin0"], init_weights=False), + ), ], name_func=_param_name_func, ) @@ -526,7 +536,7 @@ def test_target_first_layer_same_type(self, config0, config1): self._check_merging(SimpleNet, config0, config1, input) self._check_unload(SimpleNet, config0, config1, input) self._check_disable(SimpleNet, config1, config0, input) - self._check_loading(SimpleNet, config0, config1, input) + self._check_loading(SimpleNet, config0, config1, input, is_commutative=False) def test_deeply_nested(self): # a somewhat absurdly nested model using different adapter types @@ -550,7 +560,7 @@ def test_deeply_nested(self): config3 = LoKrConfig(r=4, alpha=4, target_modules=["lin0", "lin1"], init_weights=False) peft_model.add_adapter("adapter3", config3) - config4 = LoraConfig(r=4, lora_alpha=4, target_modules=["lin0", "lin1"], init_lora_weights=False) + config4 = OFTConfig(r=8, target_modules=["lin0", "lin1"], init_weights=False) peft_model.add_adapter("adapter4", config4) peft_model.set_adapter(["adapter0", "adapter1", "adapter2", "adapter3", "adapter4"]) @@ -671,11 +681,12 @@ def test_modules_to_save(self): def test_get_nb_trainable_parameters(self): model = SimpleNet().eval().to(self.torch_device) + params_base = sum(p.numel() for p in model.parameters()) + config0 = LoraConfig(target_modules=["lin0"]) peft_model = get_peft_model(model, config0, "adapter0", mixed=True) trainable_params0, all_param0 = peft_model.get_nb_trainable_parameters() - params_base = 262 params_lora = sum(p.numel() for n, p in model.named_parameters() if "adapter0" in n) self.assertEqual(trainable_params0, params_lora) self.assertEqual(all_param0, params_base + params_lora) @@ -752,6 +763,14 @@ def test_decoder_model(self): self.assertTrue(torch.isfinite(output3).all()) self.assertFalse(torch.allclose(output2, output3)) + torch.manual_seed(4) + config4 = OFTConfig(task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"], init_weights=False) + peft_model.add_adapter("adapter4", config4) + peft_model.set_adapter(["adapter0", "adapter1", "adapter2", "adapter3", "adapter4"]) + output4 = peft_model.generate(**input_dict) + self.assertTrue(torch.isfinite(output4).all()) + self.assertFalse(torch.allclose(output3, output4)) + with peft_model.disable_adapter(): output_disabled = peft_model.generate(**input_dict) self.assertTrue(torch.isfinite(output_disabled).all()) @@ -760,7 +779,7 @@ def test_decoder_model(self): model_unloaded = peft_model.merge_and_unload() output_unloaded = model_unloaded.generate(**input_dict) self.assertTrue(torch.isfinite(output_unloaded).all()) - self.assertTrue(torch.allclose(output3, output_unloaded)) + self.assertTrue(torch.allclose(output4, output_unloaded)) with tempfile.TemporaryDirectory() as tmp_dir: # save adapter0 (use normal PeftModel, because PeftMixedModel does not support saving) From c456d55216c8f95b5510f05632791e872b1ac7c1 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 4 Dec 2023 12:22:03 +0100 Subject: [PATCH 10/12] DOC: Update & improve docstrings and type annotations for common methods and classes (#1201) The docstrings of the most user-exposed methods and classes have been updated, or added if not already present. Furthermore, type annotations have been updated or added for those methods and classes. --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- src/peft/peft_model.py | 127 +++++++++++++++++++++---------- src/peft/tuners/adalora/layer.py | 3 + src/peft/tuners/ia3/layer.py | 6 ++ src/peft/tuners/ia3/model.py | 26 +++++-- src/peft/tuners/lora/bnb.py | 10 ++- src/peft/tuners/lora/layer.py | 9 +++ src/peft/tuners/lora/model.py | 37 ++++++--- src/peft/tuners/lycoris_utils.py | 41 ++++++++-- src/peft/tuners/oft/layer.py | 15 ++++ src/peft/tuners/tuners_utils.py | 35 ++++++--- src/peft/utils/peft_types.py | 4 + 11 files changed, 238 insertions(+), 75 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 79bf8e46102..176cec87a43 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -89,24 +89,24 @@ class PeftModel(PushToHubMixin, torch.nn.Module): Args: model ([`~transformers.PreTrainedModel`]): The base transformer model used for Peft. peft_config ([`PeftConfig`]): The configuration of the Peft model. - adapter_name (`str`): The name of the adapter, defaults to `"default"`. + adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`. **Attributes**: - - **base_model** ([`~transformers.PreTrainedModel`]) -- The base transformer model used for Peft. + - **base_model** ([`torch.nn.Module`]) -- The base transformer model used for Peft. - **peft_config** ([`PeftConfig`]) -- The configuration of the Peft model. - **modules_to_save** (`list` of `str`) -- The list of sub-module names to save when - saving the model. + saving the model. - **prompt_encoder** ([`PromptEncoder`]) -- The prompt encoder used for Peft if - using [`PromptLearningConfig`]. + using [`PromptLearningConfig`]. - **prompt_tokens** (`torch.Tensor`) -- The virtual prompt tokens used for Peft if - using [`PromptLearningConfig`]. + using [`PromptLearningConfig`]. - **transformer_backbone_name** (`str`) -- The name of the transformer - backbone in the base model if using [`PromptLearningConfig`]. + backbone in the base model if using [`PromptLearningConfig`]. - **word_embeddings** (`torch.nn.Embedding`) -- The word embeddings of the transformer backbone - in the base model if using [`PromptLearningConfig`]. + in the base model if using [`PromptLearningConfig`]. """ - def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default"): + def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> None: super().__init__() self.modules_to_save = None self.active_adapter = adapter_name @@ -140,7 +140,7 @@ def peft_config(self) -> Dict[str, PeftConfig]: return self.base_model.peft_config @property - def active_adapters(self): + def active_adapters(self) -> list[str]: try: adapters = self.base_model.active_adapters except AttributeError: @@ -164,7 +164,7 @@ def save_pretrained( save_embedding_layers: Union[str, bool] = "auto", is_main_process: bool = True, **kwargs: Any, - ): + ) -> None: r""" This function saves the adapter model and the adapter configuration files to a directory, so that it can be reloaded using the [`PeftModel.from_pretrained`] class method, and also used by the [`PeftModel.push_to_hub`] @@ -175,15 +175,16 @@ def save_pretrained( Directory where the adapter model and configuration files will be saved (will be created if it does not exist). safe_serialization (`bool`, *optional*): - Whether to save the adapter files in safetensors format. - selected_adapters (`list(str)`, *optional*): + Whether to save the adapter files in safetensors format, defaults to `True`. + selected_adapters (`List[str]`, *optional*): A list of adapters to be saved. If `None`, will default to all adapters. - save_embedding_layers (`Union[bool, str]`, , *optional*, defaults to `auto`): + save_embedding_layers (`Union[bool, str]`, *optional*, defaults to `"auto"`): If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common embedding layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available. - Based on it sets the boolean flag. This only works for πŸ€— transformers models. + and automatically sets the boolean flag. This only works for πŸ€— transformers models. is_main_process (`bool`, *optional*): - Whether the process calling this is the main process or not. Will default to `True`. + Whether the process calling this is the main process or not. Will default to `True`. Will not save the + checkpoint if not on the main process, which is important for multi device setups (e.g. DDP). kwargs (additional keyword arguments, *optional*): Additional keyword arguments passed along to the `push_to_hub` method. """ @@ -280,22 +281,22 @@ def save_pretrained( @classmethod def from_pretrained( cls, - model: PreTrainedModel, + model: torch.nn.Module, model_id: Union[str, os.PathLike], adapter_name: str = "default", is_trainable: bool = False, config: Optional[PeftConfig] = None, **kwargs: Any, - ): + ) -> "PeftModel": r""" Instantiate a PEFT model from a pretrained model and loaded PEFT weights. Note that the passed `model` may be modified inplace. Args: - model ([`~transformers.PreTrainedModel`]): - The model to be adapted. The model should be initialized with the - [`~transformers.PreTrainedModel.from_pretrained`] method from the πŸ€— Transformers library. + model ([`torch.nn.Module`]): + The model to be adapted. For πŸ€— Transformers models, the model should be initialized with the + [`~transformers.PreTrainedModel.from_pretrained`]. model_id (`str` or `os.PathLike`): The name of the PEFT configuration to use. Can be either: - A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face @@ -305,8 +306,8 @@ def from_pretrained( adapter_name (`str`, *optional*, defaults to `"default"`): The name of the adapter to be loaded. This is useful for loading multiple adapters. is_trainable (`bool`, *optional*, defaults to `False`): - Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and use for - inference + Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be + used for inference. config ([`~peft.PeftConfig`], *optional*): The configuration object to use instead of an automatically loaded configuation. This configuration object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already @@ -421,10 +422,10 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) return model - def get_prompt_embedding_to_save(self, adapter_name: str): + def get_prompt_embedding_to_save(self, adapter_name: str) -> torch.Tensor: """ - Returns the prompt embedding to save when saving the model. Only applicable when `peft_config.peft_type != - PeftType.LORA`. + Returns the prompt embedding to save when saving the model. Only applicable when using a prompt learning + method. """ prompt_encoder = self.prompt_encoder[adapter_name] prompt_tokens = ( @@ -440,9 +441,9 @@ def get_prompt_embedding_to_save(self, adapter_name: str): return prompt_embeddings[0].detach().cpu() - def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None): + def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Returns the virtual prompts to use for Peft. Only applicable when `peft_config.peft_type != PeftType.LORA`. + Returns the virtual prompts to use for Peft. Only applicable when using a prompt learning method. """ peft_config = self.active_peft_config prompt_encoder = self.prompt_encoder[self.active_adapter] @@ -486,9 +487,9 @@ def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None): prompts = prompt_encoder(prompt_tokens) return prompts - def get_nb_trainable_parameters(self): + def get_nb_trainable_parameters(self) -> tuple[int, int]: r""" - Returns the number of trainable parameters and number of all parameters in the model. + Returns the number of trainable parameters and the number of all parameters in the model. """ trainable_params = 0 all_param = 0 @@ -510,7 +511,7 @@ def get_nb_trainable_parameters(self): return trainable_params, all_param - def print_trainable_parameters(self): + def print_trainable_parameters(self) -> None: """ Prints the number of trainable parameters in the model. """ @@ -544,7 +545,14 @@ def _get_base_model_class(self, is_prompt_tuning=False): @contextmanager def disable_adapter(self): """ - Disables the adapter module. + Context manager that disables the adapter module. Use this to run inference on the base model. + + Example: + + ```py + >>> with model.disable_adapter(): + ... model(inputs) + ``` """ try: if self.peft_config[self.active_adapter].is_prompt_learning: @@ -564,13 +572,27 @@ def disable_adapter(self): else: self.base_model.enable_adapter_layers() - def get_base_model(self): + def get_base_model(self) -> torch.nn.Module: """ Returns the base model. """ return self.base_model if self.active_peft_config.is_prompt_learning else self.base_model.model - def add_adapter(self, adapter_name: str, peft_config: PeftConfig): + def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None: + """ + Add an adapter to the model based on the passed configuration. + + The name for the new adapter should be unique. + + The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active + adapter. + + Args: + adapter_name (`str`): + The name of the adapter to be added. + peft_config ([`PeftConfig`]): + The configuration of the adapter to be added. + """ if peft_config.peft_type != self.peft_type: raise ValueError( f"Cannot combine adapters with different peft types. " @@ -622,6 +644,25 @@ def _split_kwargs(cls, kwargs: Dict[str, Any]): return hf_hub_download_kwargs, other_kwargs def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = False, **kwargs: Any): + """ + Load a trained adapter into the model. + + The name for the new adapter should be unique. + + The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active + adapter. + + Args: + adapter_name (`str`): + The name of the adapter to be added. + peft_config ([`PeftConfig`]): + The configuration of the adapter to be added. + is_trainable (`bool`, *optional*, defaults to `False`): + Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be + used for inference. + kwargs: (`optional`): + Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub. + """ from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs) @@ -693,9 +734,15 @@ def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = Fa self.eval() return load_result - def set_adapter(self, adapter_name: str): + def set_adapter(self, adapter_name: str) -> None: """ Sets the active adapter. + + Only one adapter can be active at a time. + + Args: + adapter_name (`str`): + The name of the adapter to be set as active. The adapter must be loaded first. """ if adapter_name not in self.peft_config: raise ValueError(f"Adapter {adapter_name} not found.") @@ -804,7 +851,7 @@ class PeftModelForSequenceClassification(PeftModel): ``` """ - def __init__(self, model, peft_config: PeftConfig, adapter_name="default"): + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: super().__init__(model, peft_config, adapter_name) if self.modules_to_save is None: self.modules_to_save = {"classifier", "score"} @@ -990,7 +1037,7 @@ class PeftModelForCausalLM(PeftModel): ``` """ - def __init__(self, model, peft_config: PeftConfig, adapter_name="default"): + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: super().__init__(model, peft_config, adapter_name) self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation @@ -1158,7 +1205,7 @@ class PeftModelForSeq2SeqLM(PeftModel): ``` """ - def __init__(self, model, peft_config: PeftConfig, adapter_name="default"): + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: super().__init__(model, peft_config, adapter_name) self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation self.base_model_prepare_encoder_decoder_kwargs_for_generation = ( @@ -1407,7 +1454,7 @@ class PeftModelForTokenClassification(PeftModel): ``` """ - def __init__(self, model, peft_config: PeftConfig = None, adapter_name="default"): + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig = None, adapter_name: str = "default") -> None: super().__init__(model, peft_config, adapter_name) if self.modules_to_save is None: self.modules_to_save = {"classifier", "score"} @@ -1578,7 +1625,7 @@ class PeftModelForQuestionAnswering(PeftModel): ``` """ - def __init__(self, model, peft_config: PeftConfig = None, adapter_name="default"): + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: super().__init__(model, peft_config, adapter_name) if self.modules_to_save is None: self.modules_to_save = {"qa_outputs"} @@ -1767,7 +1814,7 @@ class PeftModelForFeatureExtraction(PeftModel): ``` """ - def __init__(self, model, peft_config: PeftConfig = None, adapter_name="default"): + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default"): super().__init__(model, peft_config, adapter_name) def forward( diff --git a/src/peft/tuners/adalora/layer.py b/src/peft/tuners/adalora/layer.py index b4a98de0396..f0e197bb264 100644 --- a/src/peft/tuners/adalora/layer.py +++ b/src/peft/tuners/adalora/layer.py @@ -138,6 +138,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N self.merged_adapters.append(active_adapter) def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return diff --git a/src/peft/tuners/ia3/layer.py b/src/peft/tuners/ia3/layer.py index e682c2bdd54..4ed0bf2664a 100644 --- a/src/peft/tuners/ia3/layer.py +++ b/src/peft/tuners/ia3/layer.py @@ -146,6 +146,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N self.merged_adapters.append(active_adapter) def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return @@ -271,6 +274,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N self.merged_adapters.append(active_adapter) def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 725faf084b9..d5ad48cdad4 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import re import warnings @@ -250,13 +251,26 @@ def _set_adapter_layers(self, enabled=True): if isinstance(module, (IA3Layer, ModulesToSaveWrapper)): module.enable_adapters(enabled) - def enable_adapter_layers(self): + def enable_adapter_layers(self) -> None: + """Enable all adapters. + + Call this if you have previously disabled all adapters and want to re-enable them. + """ self._set_adapter_layers(enabled=True) - def disable_adapter_layers(self): + def disable_adapter_layers(self) -> None: + """Disable all adapters. + + When disabling all adapters, the model output corresponds to the output of the base model. + """ self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name): + def set_adapter(self, adapter_name: str | list[str]) -> None: + """Set the active adapter(s). + + Args: + adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + """ for module in self.model.modules(): if isinstance(module, IA3Layer): if module.merged: @@ -316,7 +330,7 @@ def _unload_and_optionally_merge( return self.model - def merge_and_unload(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None): + def merge_and_unload(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> torch.nn.Module: r""" This method merges the IAΒ³ layers into the base model. This is needed if someone wants to use the base model as a standalone model. @@ -343,14 +357,14 @@ def merge_and_unload(self, safe_merge: bool = False, adapter_names: Optional[Lis """ return self._unload_and_optionally_merge(safe_merge=safe_merge, adapter_names=adapter_names) - def unload(self): + def unload(self) -> torch.nn.Module: """ Gets back the base model by removing all the IAΒ³ modules without merging. This gives back the original base model. """ return self._unload_and_optionally_merge(merge=False) - def delete_adapter(self, adapter_name: str): + def delete_adapter(self, adapter_name: str) -> None: """ Deletes an existing adapter. diff --git a/src/peft/tuners/lora/bnb.py b/src/peft/tuners/lora/bnb.py index dd672adfcc9..7781bee0931 100644 --- a/src/peft/tuners/lora/bnb.py +++ b/src/peft/tuners/lora/bnb.py @@ -101,7 +101,10 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N state.reset_grads() self.merged_adapters.append(active_adapter) - def unmerge(self): + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return @@ -242,7 +245,10 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N ) self.merged_adapters.append(active_adapter) - def unmerge(self): + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 3219ca1e47b..668f1144ffc 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -306,6 +306,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N self.merged_adapters.append(active_adapter) def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return @@ -437,6 +440,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N self.merged_adapters.append(active_adapter) def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return @@ -576,6 +582,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N self.merged_adapters.append(active_adapter) def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 4f6538e9122..47536529a44 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import importlib import math import operator @@ -54,10 +56,12 @@ class LoraModel(BaseTuner): """ - Creates Low Rank Adapter (Lora) model from a pretrained transformers model. + Creates Low Rank Adapter (LoRA) model from a pretrained transformers model. + + The method is described in detail in https://arxiv.org/abs/2106.09685. Args: - model ([`~transformers.PreTrainedModel`]): The model to be adapted. + model ([`torch.nn.Module`]): The model to be adapted. config ([`LoraConfig`]): The configuration of the Lora model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. @@ -360,15 +364,23 @@ def get_peft_config_as_dict(self, inference: bool = False): config_dict[key] = config return config - def _set_adapter_layers(self, enabled=True): + def _set_adapter_layers(self, enabled: bool = True) -> None: for module in self.model.modules(): if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): module.enable_adapters(enabled) - def enable_adapter_layers(self): + def enable_adapter_layers(self) -> None: + """Enable all adapters. + + Call this if you have previously disabled all adapters and want to re-enable them. + """ self._set_adapter_layers(enabled=True) - def disable_adapter_layers(self): + def disable_adapter_layers(self) -> None: + """Disable all adapters. + + When disabling all adapters, the model output corresponds to the output of the base model. + """ for active_adapter in self.active_adapters: val = self.peft_config[active_adapter].bias if val != "none": @@ -379,7 +391,12 @@ def disable_adapter_layers(self): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name): + def set_adapter(self, adapter_name: str | list[str]) -> None: + """Set the active adapter(s). + + Args: + adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + """ for module in self.model.modules(): if isinstance(module, LoraLayer): if module.merged: @@ -437,7 +454,7 @@ def add_weighted_adapter( svd_clamp=None, svd_full_matrices=True, svd_driver=None, - ): + ) -> None: """ This method adds a new adapter by merging the given adapters with the given weights. @@ -637,7 +654,7 @@ def _svd_weighted_adapter( Vh = Vh.reshape(target_lora_A.data.shape) return Vh, U - def delete_adapter(self, adapter_name: str): + def delete_adapter(self, adapter_name: str) -> None: """ Deletes an existing adapter. @@ -661,7 +678,7 @@ def delete_adapter(self, adapter_name: str): def merge_and_unload( self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[List[str]] = None - ): + ) -> torch.nn.Module: r""" This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model as a standalone model. @@ -691,7 +708,7 @@ def merge_and_unload( progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names ) - def unload(self): + def unload(self) -> torch.nn.Module: """ Gets back the base model by removing all the lora modules without merging. This gives back the original base model. diff --git a/src/peft/tuners/lycoris_utils.py b/src/peft/tuners/lycoris_utils.py index 8f68ce6bd60..432c7abfd92 100644 --- a/src/peft/tuners/lycoris_utils.py +++ b/src/peft/tuners/lycoris_utils.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import warnings from abc import abstractmethod @@ -106,6 +107,18 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor: ... def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If `None`, all active adapters will be merged. + Defaults to `None`. + """ if self.merged: warnings.warn( f"Already following adapters were merged {','.join(self.merged_adapters)}. " @@ -153,6 +166,9 @@ def scale_layer(self, scale: float) -> None: self.scaling[active_adapter] *= scale def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return @@ -321,15 +337,23 @@ def _unload_and_optionally_merge( return self.model - def enable_adapter_layers(self): + def enable_adapter_layers(self) -> None: + """Enable all adapters. + + Call this if you have previously disabled all adapters and want to re-enable them. + """ self._set_adapter_layers(enabled=True) - def disable_adapter_layers(self): + def disable_adapter_layers(self) -> None: + """Disable all adapters. + + When disabling all adapters, the model output corresponds to the output of the base model. + """ self._set_adapter_layers(enabled=False) def merge_and_unload( self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[List[str]] = None - ): + ) -> torch.nn.Module: r""" This method merges the adapter layers into the base model. This is needed if someone wants to use the base model as a standalone model. @@ -349,14 +373,19 @@ def merge_and_unload( progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names ) - def unload(self): + def unload(self) -> torch.nn.Module: """ Gets back the base model by removing all the lora modules without merging. This gives back the original base model. """ return self._unload_and_optionally_merge(merge=False) - def set_adapter(self, adapter_name): + def set_adapter(self, adapter_name: str | list[str]) -> None: + """Set the active adapter(s). + + Args: + adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + """ for module in self.model.modules(): if isinstance(module, LycorisLayer): if module.merged: @@ -364,7 +393,7 @@ def set_adapter(self, adapter_name): module.unmerge() module.set_adapter(adapter_name) - def delete_adapter(self, adapter_name: str): + def delete_adapter(self, adapter_name: str) -> None: """ Deletes an existing adapter. diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py index b9e0d011b3c..2e5f8763e2d 100644 --- a/src/peft/tuners/oft/layer.py +++ b/src/peft/tuners/oft/layer.py @@ -121,6 +121,18 @@ def unscale_layer(self, scale=None) -> None: pass def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If `None`, all active adapters will be merged. + Defaults to `None`. + """ if self.merged: warnings.warn( f"Already following adapters were merged {','.join(self.merged_adapters)}. " @@ -171,6 +183,9 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N self.merged_adapters.append(active_adapter) def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index d9616d29d6d..8fe9837cf29 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -18,7 +18,7 @@ import re import warnings from abc import ABC, abstractmethod -from typing import Any, Union +from typing import Any, Optional, Union import torch from torch import nn @@ -270,17 +270,30 @@ def inject_adapter(self, model: nn.Module, adapter_name: str): else: model.modules_to_save.update(set(peft_config.modules_to_save)) - def merge_adapter(self): + def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None: """ - This method merges the LoRa layers into the base model. + This method merges the adapter layers into the base model. + + Merging adapters can lead to a speed up of the forward pass. A copy of the adapter weights is still kept in + memory, which is required to unmerge the adapters. In order to merge the adapter weights without keeping them + in memory, please call `merge_and_unload`. + + Args: + safe_merge (`bool`, *optional*): + If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If `None`, all active adapters will be merged. + Defaults to `None`. """ for module in self.model.modules(): if isinstance(module, BaseTunerLayer): - module.merge() + module.merge(adapter_names=adapter_names) def unmerge_adapter(self): """ - This method unmerges the LoRa layers from the base model. + This method unmerges all merged adapter layers from the base model. """ for module in self.model.modules(): if isinstance(module, BaseTunerLayer): @@ -341,10 +354,10 @@ def weight(self) -> torch.Tensor: weight = base_layer.weight return weight - def merge(self, *args) -> None: + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: raise NotImplementedError - def unmerge(self, *args) -> None: + def unmerge(self) -> None: raise NotImplementedError @property @@ -368,7 +381,7 @@ def active_adapters(self): # is already a list of str return self.active_adapter - def enable_adapters(self, enabled: bool): + def enable_adapters(self, enabled: bool) -> None: """Toggle the enabling and disabling of adapters Takes care of setting the requires_grad flag for the adapter weights. @@ -386,11 +399,11 @@ def enable_adapters(self, enabled: bool): layer.requires_grad_(False) self._disable_adapters = True - def set_adapter(self, adapter_names: str | list[str]): - """Set the active adapter + def set_adapter(self, adapter_names: str | list[str]) -> None: + """Set the active adapter(s). Args: - adapter_name (str): The name of the adapter to set as active + adapter_name (`str` or `List[str]`): Name of the adapter(s) to be activated. """ if isinstance(adapter_names, str): adapter_names = [adapter_names] diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 93b892d9e59..aaec1908555 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -20,6 +20,8 @@ class PeftType(str, enum.Enum): + """Enum class for the different types of adapters in PEFT.""" + PROMPT_TUNING = "PROMPT_TUNING" MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING" P_TUNING = "P_TUNING" @@ -34,6 +36,8 @@ class PeftType(str, enum.Enum): class TaskType(str, enum.Enum): + """Enum class for the different types of tasks supported by PEFT.""" + SEQ_CLS = "SEQ_CLS" SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM" CAUSAL_LM = "CAUSAL_LM" From 1b1091c15835f5112d4c5460c80883f0136ea7af Mon Sep 17 00:00:00 2001 From: yxli2123 <69247082+yxli2123@users.noreply.github.com> Date: Mon, 4 Dec 2023 09:15:19 -0500 Subject: [PATCH 11/12] remove HF tokens (#1207) --- examples/loftq_finetuning/train_gsm8k_llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/loftq_finetuning/train_gsm8k_llama.py b/examples/loftq_finetuning/train_gsm8k_llama.py index e8c3580d2e1..1d9331d6018 100644 --- a/examples/loftq_finetuning/train_gsm8k_llama.py +++ b/examples/loftq_finetuning/train_gsm8k_llama.py @@ -58,7 +58,6 @@ MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) -HF_TOKEN = "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" def parse_args(): From f7cf460f7c99b7c19fcf1b1874a733941562fa2c Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Mon, 4 Dec 2023 11:00:29 -0800 Subject: [PATCH 12/12] [docs] Update index and quicktour (#1191) * first draft * fix toctree * lora subby section * feedback * iframe height * feedback --- docs/source/_toctree.yml | 28 ++-- docs/source/index.md | 109 +------------ docs/source/package_reference/auto_class.md | 48 ++++++ docs/source/quicktour.md | 164 +++++++++++--------- 4 files changed, 164 insertions(+), 185 deletions(-) create mode 100644 docs/source/package_reference/auto_class.md diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 25992b3966e..6b31749f0a8 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -9,24 +9,26 @@ - title: Task guides sections: - - local: task_guides/image_classification_lora - title: Image classification using LoRA - local: task_guides/seq2seq-prefix-tuning title: Prefix tuning for conditional generation - local: task_guides/clm-prompt-tuning title: Prompt tuning for causal language modeling - - local: task_guides/semantic_segmentation_lora - title: Semantic segmentation using LoRA - local: task_guides/ptuning-seq-classification title: P-tuning for sequence classification - - local: task_guides/dreambooth_lora - title: Dreambooth fine-tuning with LoRA - - local: task_guides/token-classification-lora - title: LoRA for token classification - - local: task_guides/int8-asr - title: int8 training for automatic speech recognition - - local: task_guides/semantic-similarity-lora - title: Semantic similarity with LoRA + - title: LoRA + sections: + - local: task_guides/image_classification_lora + title: Image classification + - local: task_guides/semantic_segmentation_lora + title: Semantic segmentation + - local: task_guides/token-classification-lora + title: Token classification + - local: task_guides/semantic-similarity-lora + title: Semantic similarity + - local: task_guides/int8-asr + title: int8 training for automatic speech recognition + - local: task_guides/dreambooth_lora + title: DreamBooth - title: Developer guides sections: @@ -59,6 +61,8 @@ - title: Reference sections: + - local: package_reference/auto_class + title: AutoPeftModel - local: package_reference/peft_model title: PEFT model - local: package_reference/config diff --git a/docs/source/index.md b/docs/source/index.md index 5faf706e50e..cfb57c0678d 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -16,11 +16,9 @@ rendered properly in your Markdown viewer. # PEFT -πŸ€— PEFT, or Parameter-Efficient Fine-Tuning (PEFT), is a library for efficiently adapting pre-trained language models (PLMs) to various downstream applications without fine-tuning all the model's parameters. -PEFT methods only fine-tune a small number of (extra) model parameters, significantly decreasing computational and storage costs because fine-tuning large-scale PLMs is prohibitively costly. -Recent state-of-the-art PEFT techniques achieve performance comparable to that of full fine-tuning. +πŸ€— PEFT (Parameter-Efficient Fine-Tuning) is a library for efficiently adapting large pretrained models to various downstream applications without fine-tuning all of a model's parameters because it is prohibitively costly. PEFT methods only fine-tune a small number of (extra) model parameters - significantly decreasing computational and storage costs - while yielding performance comparable to a fully fine-tuned model. This makes it more accessible to train and store large language models (LLMs) on consumer hardware. -PEFT is seamlessly integrated with πŸ€— Accelerate for large-scale models leveraging DeepSpeed and [Big Model Inference](https://huggingface.co/docs/accelerate/usage_guides/big_modeling). +PEFT is integrated with the Transformers, Diffusers, and Accelerate libraries to provide a faster and easier way to load, train, and use large models for inference.

@@ -43,100 +41,9 @@ PEFT is seamlessly integrated with πŸ€— Accelerate for large-scale models levera
-## Supported methods - -1. LoRA: [LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/pdf/2106.09685.pdf) -2. Prefix Tuning: [Prefix-Tuning: Optimizing Continuous Prompts for Generation](https://aclanthology.org/2021.acl-long.353/), [P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/pdf/2110.07602.pdf) -3. P-Tuning: [GPT Understands, Too](https://arxiv.org/pdf/2103.10385.pdf) -4. Prompt Tuning: [The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/pdf/2104.08691.pdf) -5. AdaLoRA: [Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning](https://arxiv.org/abs/2303.10512) -6. [LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention](https://github.com/ZrrSkywalker/LLaMA-Adapter) -7. IA3: [Infused Adapter by Inhibiting and Amplifying Inner Activations](https://arxiv.org/abs/2205.05638) - -## Supported models - -The tables provided below list the PEFT methods and models supported for each task. To apply a particular PEFT method for -a task, please refer to the corresponding Task guides. - -### Causal Language Modeling - -| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 | -|--------------| ---- | ---- | ---- | ---- | ---- | -| GPT-2 | βœ… | βœ… | βœ… | βœ… | βœ… | -| Bloom | βœ… | βœ… | βœ… | βœ… | βœ… | -| OPT | βœ… | βœ… | βœ… | βœ… | βœ… | -| GPT-Neo | βœ… | βœ… | βœ… | βœ… | βœ… | -| GPT-J | βœ… | βœ… | βœ… | βœ… | βœ… | -| GPT-NeoX-20B | βœ… | βœ… | βœ… | βœ… | βœ… | -| LLaMA | βœ… | βœ… | βœ… | βœ… | βœ… | -| ChatGLM | βœ… | βœ… | βœ… | βœ… | βœ… | - -### Conditional Generation - -| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 | -| --------- | ---- | ---- | ---- | ---- | ---- | -| T5 | βœ… | βœ… | βœ… | βœ… | βœ… | -| BART | βœ… | βœ… | βœ… | βœ… | βœ… | - -### Sequence Classification - -| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 | -| --------- | ---- | ---- | ---- | ---- | ---- | -| BERT | βœ… | βœ… | βœ… | βœ… | βœ… | -| RoBERTa | βœ… | βœ… | βœ… | βœ… | βœ… | -| GPT-2 | βœ… | βœ… | βœ… | βœ… | | -| Bloom | βœ… | βœ… | βœ… | βœ… | | -| OPT | βœ… | βœ… | βœ… | βœ… | | -| GPT-Neo | βœ… | βœ… | βœ… | βœ… | | -| GPT-J | βœ… | βœ… | βœ… | βœ… | | -| Deberta | βœ… | | βœ… | βœ… | | -| Deberta-v2 | βœ… | | βœ… | βœ… | | - -### Token Classification - -| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 | -| --------- | ---- | ---- | ---- | ---- | --- | -| BERT | βœ… | βœ… | | | | -| RoBERTa | βœ… | βœ… | | | | -| GPT-2 | βœ… | βœ… | | | | -| Bloom | βœ… | βœ… | | | | -| OPT | βœ… | βœ… | | | | -| GPT-Neo | βœ… | βœ… | | | | -| GPT-J | βœ… | βœ… | | | | -| Deberta | βœ… | | | | | -| Deberta-v2 | βœ… | | | | | - -### Text-to-Image Generation - -| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 | -| --------- | ---- | ---- | ---- | ---- | ---- | -| Stable Diffusion | βœ… | | | | | - - -### Image Classification - -| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 | -| --------- | ---- | ---- | ---- | ---- | ---- | ---- | -| ViT | βœ… | | | | | -| Swin | βœ… | | | | | - -### Image to text (Multi-modal models) - -We have tested LoRA for [ViT](https://huggingface.co/docs/transformers/model_doc/vit) and [Swin](https://huggingface.co/docs/transformers/model_doc/swin) for fine-tuning on image classification. -However, it should be possible to use LoRA for any [ViT-based model](https://huggingface.co/models?pipeline_tag=image-classification&sort=downloads&search=vit) from πŸ€— Transformers. -Check out the [Image classification](/task_guides/image_classification_lora) task guide to learn more. If you run into problems, please open an issue. - -| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 | -| --------- | ---- | ---- | ---- | ---- | ---- | -| Blip-2 | βœ… | | | | | - - -### Semantic Segmentation - -As with image-to-text models, you should be able to apply LoRA to any of the [segmentation models](https://huggingface.co/models?pipeline_tag=image-segmentation&sort=downloads). -It's worth noting that we haven't tested this with every architecture yet. Therefore, if you come across any issues, kindly create an issue report. - -| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | IA3 | -| --------- | ---- | ---- | ---- | ---- | ---- | -| SegFormer | βœ… | | | | | - + diff --git a/docs/source/package_reference/auto_class.md b/docs/source/package_reference/auto_class.md new file mode 100644 index 00000000000..c1b78a2c342 --- /dev/null +++ b/docs/source/package_reference/auto_class.md @@ -0,0 +1,48 @@ + + +# AutoPeftModels + +The `AutoPeftModel` classes loads the appropriate PEFT model for the task type by automatically inferring it from the configuration file. They are designed to quickly and easily load a PEFT model in a single line of code without having to worry about which exact model class you need or manually loading a [`PeftConfig`]. + +## AutoPeftModel + +[[autodoc]] auto.AutoPeftModel + - from_pretrained + +## AutoPeftModelForCausalLM + +[[autodoc]] auto.AutoPeftModelForCausalLM + +## AutoPeftModelForSeq2SeqLM + +[[autodoc]] auto.AutoPeftModelForSeq2SeqLM + +## AutoPeftModelForSequenceClassification + +[[autodoc]] auto.AutoPeftModelForSequenceClassification + +## AutoPeftModelForTokenClassification + +[[autodoc]] auto.AutoPeftModelForTokenClassification + +## AutoPeftModelForQuestionAnswering + +[[autodoc]] auto.AutoPeftModelForQuestionAnswering + +## AutoPeftModelForFeatureExtraction + +[[autodoc]] auto.AutoPeftModelForFeatureExtraction diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index a6678f59a88..d7dae7b7ad4 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -16,21 +16,19 @@ rendered properly in your Markdown viewer. # Quicktour -πŸ€— PEFT contains parameter-efficient finetuning methods for training large pretrained models. The traditional paradigm is to finetune all of a model's parameters for each downstream task, but this is becoming exceedingly costly and impractical because of the enormous number of parameters in models today. Instead, it is more efficient to train a smaller number of prompt parameters or use a reparametrization method like low-rank adaptation (LoRA) to reduce the number of trainable parameters. +PEFT offers parameter-efficient methods for finetuning large pretrained models. The traditional paradigm is to finetune all of a model's parameters for each downstream task, but this is becoming exceedingly costly and impractical because of the enormous number of parameters in models today. Instead, it is more efficient to train a smaller number of prompt parameters or use a reparametrization method like low-rank adaptation (LoRA) to reduce the number of trainable parameters. -This quicktour will show you πŸ€— PEFT's main features and help you train large pretrained models that would typically be inaccessible on consumer devices. You'll see how to train the 1.2B parameter [`bigscience/mt0-large`](https://huggingface.co/bigscience/mt0-large) model with LoRA to generate a classification label and use it for inference. +This quicktour will show you PEFT's main features and how you can train or run inference on large models that would typically be inaccessible on consumer devices. -## PeftConfig +## Train -Each πŸ€— PEFT method is defined by a [`PeftConfig`] class that stores all the important parameters for building a [`PeftModel`]. +Each PEFT method is defined by a [`PeftConfig`] class that stores all the important parameters for building a [`PeftModel`]. For example, to train with LoRA, load and create a [`LoraConfig`] class and specify the following parameters: -Because you're going to use LoRA, you'll need to load and create a [`LoraConfig`] class. Within `LoraConfig`, specify the following parameters: - -- the `task_type`, or sequence-to-sequence language modeling in this case -- `inference_mode`, whether you're using the model for inference or not -- `r`, the dimension of the low-rank matrices -- `lora_alpha`, the scaling factor for the low-rank matrices -- `lora_dropout`, the dropout probability of the LoRA layers +- `task_type`: the task to train for (sequence-to-sequence language modeling in this case) +- `inference_mode`: whether you're using the model for inference or not +- `r`: the dimension of the low-rank matrices +- `lora_alpha`: the scaling factor for the low-rank matrices +- `lora_dropout`: the dropout probability of the LoRA layers ```python from peft import LoraConfig, TaskType @@ -40,25 +38,21 @@ peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, -πŸ’‘ See the [`LoraConfig`] reference for more details about other parameters you can adjust. +See the [`LoraConfig`] reference for more details about other parameters you can adjust, such as the modules to target or the bias type. -## PeftModel - -A [`PeftModel`] is created by the [`get_peft_model`] function. It takes a base model - which you can load from the πŸ€— Transformers library - and the [`PeftConfig`] containing the instructions for how to configure a model for a specific πŸ€— PEFT method. +Once the [`LoraConfig`] is setup, create a [`PeftModel`] with the [`get_peft_model`] function. It takes a base model - which you can load from the Transformers library - and the [`LoraConfig`] containing the parameters for how to configure a model for training with LoRA. -Start by loading the base model you want to finetune. +Load the base model you want to finetune. ```python from transformers import AutoModelForSeq2SeqLM -model_name_or_path = "bigscience/mt0-large" -tokenizer_name_or_path = "bigscience/mt0-large" -model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) +model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/mt0-large") ``` -Wrap your base model and `peft_config` with the `get_peft_model` function to create a [`PeftModel`]. To get a sense of the number of trainable parameters in your model, use the [`print_trainable_parameters`] method. In this case, you're only training 0.19% of the model's parameters! 🀏 +Wrap the base model and `peft_config` with the [`get_peft_model`] function to create a [`PeftModel`]. To get a sense of the number of trainable parameters in your model, use the [`print_trainable_parameters`] method. ```python from peft import get_peft_model @@ -68,83 +62,109 @@ model.print_trainable_parameters() "output: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282" ``` -That is it πŸŽ‰! Now you can train the model using the πŸ€— Transformers [`~transformers.Trainer`], πŸ€— Accelerate, or any custom PyTorch training loop. +Out of [bigscience/mt0-large's](https://huggingface.co/bigscience/mt0-large) 1.2B parameters, you're only training 0.19% of them! -## Save and load a model +That is it πŸŽ‰! Now you can train the model with the Transformers [`~transformers.Trainer`], Accelerate, or any custom PyTorch training loop. -After your model is finished training, you can save your model to a directory using the [`~transformers.PreTrainedModel.save_pretrained`] function. You can also save your model to the Hub (make sure you log in to your Hugging Face account first) with the [`~transformers.PreTrainedModel.push_to_hub`] function. +For example, to train with the [`~transformers.Trainer`] class, setup a [`~transformers.TrainingArguments`] class with some training hyperparameters. -```python +```py +training_args = TrainingArguments( + output_dir="your-name/bigscience/mt0-large-lora", + learning_rate=1e-3, + per_device_train_batch_size=32, + per_device_eval_batch_size=32, + num_train_epochs=2, + weight_decay=0.01, + evaluation_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, +) +``` + +Pass the model, training arguments, dataset, tokenizer, and any other necessary component to the [`~transformers.Trainer`], and call [`~transformers.Trainer.train`] to start training. + +```py +trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_datasets["train"], + eval_dataset=tokenized_datasets["test"], + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics, +) + +trainer.train() +``` + +### Save model + +After your model is finished training, you can save your model to a directory using the [`~transformers.PreTrainedModel.save_pretrained`] function. + +```py model.save_pretrained("output_dir") +``` + +You can also save your model to the Hub (make sure you're logged in to your Hugging Face account first) with the [`~transformers.PreTrainedModel.push_to_hub`] function. -# if pushing to Hub +```python from huggingface_hub import notebook_login notebook_login() -model.push_to_hub("my_awesome_peft_model") +model.push_to_hub("your-name/bigscience/mt0-large-lora") ``` -This only saves the incremental πŸ€— PEFT weights that were trained, meaning it is super efficient to store, transfer, and load. For example, this [`bigscience/T0_3B`](https://huggingface.co/smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM) model trained with LoRA on the [`twitter_complaints`](https://huggingface.co/datasets/ought/raft/viewer/twitter_complaints/train) subset of the RAFT [dataset](https://huggingface.co/datasets/ought/raft) only contains two files: `adapter_config.json` and `adapter_model.bin`. The latter file is just 19MB! +Both methods only save the extra PEFT weights that were trained, meaning it is super efficient to store, transfer, and load. For example, this [facebook/opt-350m](https://huggingface.co/ybelkada/opt-350m-lora) model trained with LoRA only contains two files: `adapter_config.json` and `adapter_model.safetensors`. The `adapter_model.safetensors` file is just 6.3MB! -Easily load your model for inference using the [`~transformers.PreTrainedModel.from_pretrained`] function: +
+ +
The adapter weights for a opt-350m model stored on the Hub are only ~6MB compared to the full size of the model weights, which can be ~700MB.
+
-```diff - from transformers import AutoModelForCausalLM, AutoTokenizer -+ from peft import PeftModel, PeftConfig +## Inference -+ peft_model_id = "merve/Mistral-7B-Instruct-v0.2" -+ config = PeftConfig.from_pretrained(peft_model_id) - model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path) -+ model = PeftModel.from_pretrained(model, peft_model_id) - tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) + - model = model.to(device) - model.eval() - inputs = tokenizer("Tell me the recipe for chocolate chip cookie", return_tensors="pt") +Take a look at the [AutoPeftModel](package_reference/auto_class) API reference for a complete list of available `AutoPeftModel` classes. - with torch.no_grad(): - outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=10) - print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]) - 'Tell me the recipe for chocolate chip cookie dough. - 1. Preheat oven' -``` + -## Easy loading with Auto classes +Easily load any PEFT-trained model for inference with the [`AutoPeftModel`] class and the [`~transformers.PreTrainedModel.from_pretrained`] method: -If you have saved your adapter locally or on the Hub, you can leverage the `AutoPeftModelForxxx` classes and load any PEFT model with a single line of code: +```py +from peft import AutoPeftModelForCausalLM +from transformers import AutoTokenizer +import torch -```diff -- from peft import PeftConfig, PeftModel -- from transformers import AutoModelForCausalLM -+ from peft import AutoPeftModelForCausalLM +model = AutoPeftModelForCausalLM.from_pretrained("ybelkada/opt-350m-lora") +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") -- peft_config = PeftConfig.from_pretrained("ybelkada/opt-350m-lora") -- base_model_path = peft_config.base_model_name_or_path -- transformers_model = AutoModelForCausalLM.from_pretrained(base_model_path) -- peft_model = PeftModel.from_pretrained(transformers_model, peft_config) -+ peft_model = AutoPeftModelForCausalLM.from_pretrained("ybelkada/opt-350m-lora") -``` +model = model.to("cuda") +model.eval() +inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors="pt") + +outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=50) +print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]) -Currently, supported auto classes are: `AutoPeftModelForCausalLM`, `AutoPeftModelForSequenceClassification`, `AutoPeftModelForSeq2SeqLM`, `AutoPeftModelForTokenClassification`, `AutoPeftModelForQuestionAnswering` and `AutoPeftModelForFeatureExtraction`. For other tasks (e.g. Whisper, StableDiffusion), you can load the model with: +"Preheat the oven to 350 degrees and place the cookie dough in the center of the oven. In a large bowl, combine the flour, baking powder, baking soda, salt, and cinnamon. In a separate bowl, combine the egg yolks, sugar, and vanilla." +``` -```diff -- from peft import PeftModel, PeftConfig, AutoPeftModel -+ from peft import AutoPeftModel -- from transformers import WhisperForConditionalGeneration +For other tasks that aren't explicitly supported with an `AutoPeftModelFor` class - such as automatic speech recognition - you can still use the base [`AutoPeftModel`] class to load a model for the task. -- model_id = "smangrul/openai-whisper-large-v2-LORA-colab" +```py +from peft import AutoPeftModel -peft_model_id = "smangrul/openai-whisper-large-v2-LORA-colab" -- peft_config = PeftConfig.from_pretrained(peft_model_id) -- model = WhisperForConditionalGeneration.from_pretrained( -- peft_config.base_model_name_or_path, load_in_8bit=True, device_map="auto" -- ) -- model = PeftModel.from_pretrained(model, peft_model_id) -+ model = AutoPeftModel.from_pretrained(peft_model_id) +model = AutoPeftModel.from_pretrained("smangrul/openai-whisper-large-v2-LORA-colab") ``` ## Next steps -Now that you've seen how to train a model with one of the πŸ€— PEFT methods, we encourage you to try out some of the other methods like prompt tuning. The steps are very similar to the ones shown in this quickstart; prepare a [`PeftConfig`] for a πŸ€— PEFT method, and use the `get_peft_model` to create a [`PeftModel`] from the configuration and base model. Then you can train it however you like! +Now that you've seen how to train a model with one of the PEFT methods, we encourage you to try out some of the other methods like prompt tuning. The steps are very similar to the ones shown in the quicktour: + +1. prepare a [`PeftConfig`] for a PEFT method +2. use the [`get_peft_model`] method to create a [`PeftModel`] from the configuration and base model + +Then you can train it however you like! To load a PEFT model for inference, you can use the [`AutoPeftModel`] class. -Feel free to also take a look at the task guides if you're interested in training a model with a πŸ€— PEFT method for a specific task such as semantic segmentation, multilingual automatic speech recognition, DreamBooth, and token classification. +Feel free to also take a look at the task guides if you're interested in training a model with another PEFT method for a specific task such as semantic segmentation, multilingual automatic speech recognition, DreamBooth, token classification, and more.