Skip to content

Commit

Permalink
feat: arguments added for LORA (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanxinran authored and rohanmodi2810 committed Sep 26, 2024
1 parent 5766624 commit 2e4782b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
Binary file added dist/llments-0.0.0a1-py3-none-any.whl
Binary file not shown.
Binary file added dist/llments-0.0.0a1.tar.gz
Binary file not shown.
28 changes: 28 additions & 0 deletions llments/lm/base/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ def fit(
prediction_loss_only: bool = False,
optim: str = "adamw_torch",
logging_steps: int = 500,
lora_r: int | None = None,
lora_alpha: int | None = None,
) -> LanguageModel:
"""Fit the language model to a target language model's distribution.
Expand All @@ -239,6 +241,8 @@ def fit(
prediction_loss_only: When performing evaluation and generating predictions, only returns the loss.
optim: The optimizer to use. Can only choose from a list of names.
logging_steps: Number of update steps between two logs if logging_strategy="steps".
lora_r: Lora attention dimension (the “rank”).
lora_alpha: The alpha parameter for Lora scaling.
Returns:
The fitted language model.
Expand Down Expand Up @@ -285,6 +289,30 @@ def fit(
)
eval_dataset = Dataset.from_dict(eval_inputs)

# wrap the base model with peft
if lora_r and lora_alpha:
try:
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
)
except ImportError:
raise ImportError(
"You need to install 'peft' package to use this LORA functionality."
)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
# trainable layers: all linear layers between multihead attention
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = prepare_model_for_kbit_training(base.model)
base.model = get_peft_model(model, lora_config)

training_args = TrainingArguments(
output_dir=output_dir,
do_train=do_train,
Expand Down

0 comments on commit 2e4782b

Please sign in to comment.