Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

only classifier head is trained in tweet sentiment classification LoRA finetuning blog #1824

Open
yao-matrix opened this issue Feb 19, 2024 · 0 comments

Comments

@yao-matrix
Copy link
Contributor

yao-matrix commented Feb 19, 2024

@mehdiir,

We tried to reproduce your work in our env and found one weird issue: by using your code, gradient_checkpointing=True runs much faster than gradient_checkpointing=False which betrayed our intuition(2 hr vs 6 hr in our CPU env). So we did some analysis, as below:

  1. In this case, while setting gradient_checkpointing=True (and with PyTorch use_reentrant=True implicitly), LoRA weights are wrapped by transformer block whose input and output's requires_grad are both False, so all the transformers blocks will not execute BP in this setting, so in this case, actually only classifier head is trained, LoRA weights will not be trained and keep as identity per initialization.

  2. We upgraded the transformers to 4.37.2 and add below 2 lines in get_lora_model to set use_reentrant to False, things will back to normal and LoRA weights will be trained.

def get_lora_model(model_checkpoints, num_labels=2, rank=4, alpha=16, lora_dropout=0.1, bias='none'):
    ...
    + gradient_checkpointing_kwargs = {"use_reentrant": False}
    + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs = gradient_checkpointing_kwargs)

    model = get_peft_model(model, peft_config)
    print(model.print_trainable_parameters())

    return model

FYI in case other people meet the similar issue too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant