diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index fa66a095e..ff12a91cd 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -204,13 +204,15 @@ def mean(self, states: List[BottleneckState], weights: torch.Tensor) -> Bottlene def compose_single(self, adapter_setup: str, state: BottleneckState, lvl: int = 0) -> BottleneckState: adapter_layer = self.adapters[adapter_setup] context = ForwardContext.get_context() + output_gating = context.output_adapter_gating_scores if context is not None else False layer_output = adapter_layer( state.hidden_states, residual_input=state.adapter_residual, - output_gating=context.output_adapter_gating_scores, + output_gating=output_gating, ) hidden_states, up = layer_output[0], layer_output[2] - self._store_gating_score(adapter_setup, layer_output[-1]) + if output_gating: + self._store_gating_score(adapter_setup, layer_output[-1]) return state._replace(hidden_states=hidden_states, bottleneck_up=up, last=adapter_setup) @@ -246,14 +248,15 @@ def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0 up_list = torch.stack([state.bottleneck_up for state in children_states]) up_list = up_list.permute(1, 2, 0, 3) + output_fusion_attns = context.output_adapter_fusion_attentions if context is not None else False fusion_output = self.adapter_fusion_layer[adapter_setup.name]( query, up_list, up_list, state.adapter_residual, - output_attentions=context.output_adapter_fusion_attentions, + output_attentions=output_fusion_attns, ) - if context.output_adapter_fusion_attentions: + if output_fusion_attns: hidden_states = fusion_output[0] self._store_fusion_attentions(adapter_setup.name, fusion_output[-1]) else: diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 461cdde2b..e9924e696 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -398,10 +398,11 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -437,6 +438,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.attention_adapters(hidden_states, residual, None) diff --git a/src/adapters/trainer.py b/src/adapters/trainer.py index 6be5b3ee7..2896585bc 100644 --- a/src/adapters/trainer.py +++ b/src/adapters/trainer.py @@ -42,6 +42,9 @@ def __init__( optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, ): + if model is not None: + model_quantized = getattr(model, "is_quantized", False) + model.is_quantized = False super().__init__( model, args, @@ -55,6 +58,8 @@ def __init__( optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + if model is not None: + model.is_quantized = model_quantized if adapter_names is not None: self.model.set_active_adapters(adapter_names) diff --git a/tests/test_adapter_trainer.py b/tests/test_adapter_trainer.py index d313b656e..b86514950 100644 --- a/tests/test_adapter_trainer.py +++ b/tests/test_adapter_trainer.py @@ -3,22 +3,27 @@ from tempfile import TemporaryDirectory import torch +from datasets import Dataset import adapters from adapters import AutoAdapterModel from adapters.composition import Fuse, Stack from adapters.trainer import AdapterTrainer, logger +from parameterized import parameterized from transformers import ( + AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BertConfig, BertForSequenceClassification, + BitsAndBytesConfig, + DataCollatorForLanguageModeling, GlueDataset, GlueDataTrainingArguments, Trainer, TrainingArguments, ) -from transformers.testing_utils import require_ray, slow +from transformers.testing_utils import require_bitsandbytes, require_ray, slow, torch_device class TestAdapterTrainer(unittest.TestCase): @@ -536,6 +541,85 @@ def model_init(trail=None): trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, backend="ray", n_trials=2) + @parameterized.expand(["lora", "seq_bn"]) + @require_bitsandbytes + def test_quantized_training(self, config): + model_name = "HuggingFaceM4/tiny-random-LlamaForCausalLM" + + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + dataset = Dataset.from_dict({"text": ["Hello, I'm a single sentence!", "This is another sentence."]}) + + def tokenize(element): + return tokenizer( + element["text"], + truncation=True, + max_length=512, # can set to longer values such as 2048 + add_special_tokens=False, + ) + + dataset_tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"]) + + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="auto", + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ), + torch_dtype=torch.bfloat16, + ) + model.config.use_cache = False + + adapters.init(model) + model.add_adapter("task") + model.train_adapter("task") + + model.adapter_to("task", device=torch_device) + + for param in model.parameters(): + if param.ndim == 1: + # cast the small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + + class CastOutputToFloat(torch.nn.Sequential): + def forward(self, x): + return super().forward(x).to(torch.float32) + + model.lm_head = CastOutputToFloat(model.lm_head) + + self.assertEqual(Stack("task"), model.active_adapters) + with TemporaryDirectory() as tempdir: + training_args = TrainingArguments( + output_dir=tempdir, + per_device_train_batch_size=1, + per_device_eval_batch_size=1, + evaluation_strategy="steps", + logging_steps=10, + max_steps=5, + lr_scheduler_type="constant", + optim="paged_adamw_32bit", + learning_rate=0.0002, + group_by_length=True, + bf16=True, + max_grad_norm=0.3, + ) + trainer = AdapterTrainer( + model=model, + tokenizer=tokenizer, + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + train_dataset=dataset_tokenized, + args=training_args, + ) + + trainer.train() + if __name__ == "__main__": unittest.main()