diff --git a/dist/llments-0.0.0a1-py3-none-any.whl b/dist/llments-0.0.0a1-py3-none-any.whl new file mode 100644 index 0000000..a953152 Binary files /dev/null and b/dist/llments-0.0.0a1-py3-none-any.whl differ diff --git a/dist/llments-0.0.0a1.tar.gz b/dist/llments-0.0.0a1.tar.gz new file mode 100644 index 0000000..ce28cc6 Binary files /dev/null and b/dist/llments-0.0.0a1.tar.gz differ diff --git a/llments/lm/base/hugging_face.py b/llments/lm/base/hugging_face.py index a1d4cc3..a9d9737 100644 --- a/llments/lm/base/hugging_face.py +++ b/llments/lm/base/hugging_face.py @@ -216,6 +216,8 @@ def fit( prediction_loss_only: bool = False, optim: str = "adamw_torch", logging_steps: int = 500, + lora_r: int | None = None, + lora_alpha: int | None = None, ) -> LanguageModel: """Fit the language model to a target language model's distribution. @@ -239,6 +241,8 @@ def fit( prediction_loss_only: When performing evaluation and generating predictions, only returns the loss. optim: The optimizer to use. Can only choose from a list of names. logging_steps: Number of update steps between two logs if logging_strategy="steps". + lora_r: Lora attention dimension (the “rank”). + lora_alpha: The alpha parameter for Lora scaling. Returns: The fitted language model. @@ -285,6 +289,30 @@ def fit( ) eval_dataset = Dataset.from_dict(eval_inputs) + # wrap the base model with peft + if lora_r and lora_alpha: + try: + from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_kbit_training, + ) + except ImportError: + raise ImportError( + "You need to install 'peft' package to use this LORA functionality." + ) + lora_config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + # trainable layers: all linear layers between multihead attention + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + model = prepare_model_for_kbit_training(base.model) + base.model = get_peft_model(model, lora_config) + training_args = TrainingArguments( output_dir=output_dir, do_train=do_train,