-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Best practice to train on multiple datasets with different prompts #2945
Comments
Hello! Good question - this isn't clearly mentioned in the documentation anywhere. The Best of luck!
|
Thank you for the quick response! After tracing the code, I notice that the entire encode function is not called in the training pipeline and the forward function of nn.Sequential is actually being called in the loss function, thus I will take your suggestion and implement the prompt logic while loading the dataset. |
A quick follow-up question on this: How do I exclude prompts in computing the mean embedding in the above scenario? |
Hmm, I hadn't considered that yet. Via For training, I think the easiest will be to write a custom Pooling module, e.g.: from __future__ import annotations
import json
import os
from typing import Any
import torch
from torch import Tensor, nn
class PoolingExcludingPrompts(nn.Module):
"""
A pooling layer that computes the mean sentence embedding from a sequence of token embeddings,
excluding the prompt tokens.
"""
def __init__(self, word_embedding_dimension: int) -> None:
super().__init__()
self.word_embedding_dimension = word_embedding_dimension
def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
token_embeddings = features["token_embeddings"]
attention_mask = (
features["attention_mask"]
if "attention_mask" in features
else torch.ones(token_embeddings.shape[:-1], device=token_embeddings.device, dtype=torch.int64)
)
# Detect your model's prompt(s) and remove them from the attention_mask
...
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
)
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
# If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
if "token_weights_sum" in features:
sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size())
else:
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
features["sentence_embedding"] = sum_embeddings / sum_mask
return features
def get_sentence_embedding_dimension(self) -> int:
return self.word_embedding_dimension
def get_config_dict(self) -> dict[str, Any]:
return {"word_embedding_dimension": self.word_embedding_dimension}
def save(self, output_path) -> None:
with open(os.path.join(output_path, "config.json"), "w") as fOut:
json.dump(self.get_config_dict(), fOut, indent=2)
@staticmethod
def load(input_path) -> "PoolingExcludingPrompts":
with open(os.path.join(input_path, "config.json")) as fIn:
config = json.load(fIn)
return PoolingExcludingPrompts(**config) And then after the model is trained, you should be able to use the "normal" Pooling with Otherwise, you can also keep your custom Pooling in the final trained model, but then your users will have to use
|
Thank you! I will try out the customized pooling method you provided.
Nvm, I figure it out. To whoever is curious about the solution. Evaluators are calling model.encode, thus by setting default prompt in model will automatically load the instruction. |
Apologies, I missed your last question! Yes indeed, and some evaluators don't yet support a prompt/prompt_name argument. #2951 should improve that.
|
@ShengYun-Peng @tomaarsen, I just created #2964 that adds prompts to the trainer and masking accordingly. Let me know what you think! |
Thanks for the great work on facilitating the text embedding community!
I plan to train the instructor and other llm-based encoder models on multiple datasets. Since all of these models rely on different prompts on different embedding tasks. I'm curious what is the best way to prepend the prompt to the training dataset.
The text was updated successfully, but these errors were encountered: