Skip to content

[Findings of ACL'2023] Improving Contrastive Learning of Sentence Embeddings from AI Feedback

Notifications You must be signed in to change notification settings

xiami2019/CLAIF

Repository files navigation

Improving Contrastive Learning of Sentence Embeddings from AI Feedback

This repository contains the code, data and model checkpoints for our paper Improving Contrastive Learning of Sentence Embeddings from AI Feedback (CLAIF).
Accepted to Findings of ACL 2023.

Preparing the Environment

conda create -n claif python=3.8
conda activate claif
pip install -r requirements.txt

Data Generation

You can choose your own data as original sentences to construct datasets for sentence embeddings learning. Here we use a small set of sentences as an example to show the whole data generation process.
We use text-davinci-003 as the default engine.

Overview

Step 1: Sentence Pair Generation

python data_generation.py --generation_stage stage-1 --output_dir sentence_pairs --input_file demo_sentences.csv --input_file_type stsb --batch_size 2 --openai_api_key <your_openai_api_key>

After step1, you will get sentence pairs in 'sentence_pairs/generated-dataset.jsonl' with a jsonl format.

Step 2: Semantic Similarity Labeling

We use text-davinci-003 as default.

python data_generation.py --generation_stage stage-2\
--output_dir sentence_pairs_with_labels\
--input_file ./sentence_pairs/generated-dataset.jsonl\
--input_file_type jsonl\
--batch_size 5\
--openai_api_key <your_openai_api_key>

After step2, you will get sentence pairs with similarity scores and explainations from AI feedback in 'sentence_pairs_with_labels/generated-dataset.jsonl' with a jsonl format.

Post Processing

We refer to the post-processing pipeline in dino: https://github.com/timoschick/dino/blob/main/scripts/sts/postprocess_dataset.py

python postprocess_dataset.py --input_file ./sentence_pairs_with_labels/generated-dataset.jsonl\
--output_file demo_sentence_pairs_post.jsonl

After post processing, you will get the final data 'demo_sentence_pairs_post.jsonl', which can be used for sentence embeddings learning.

Generated Data

Here wo provide our generated data, which are used in our experiments: https://huggingface.co/datasets/fnlp/CLAIF-data
CLAIF: claif_data.jsonl
CLAIF_scaled: claif_scaled_data.jsonl
NLI_data_with_similarity_scores: nli_data_with_similarity_scores.csv

Model Training

Download Generated Data

cd generated_data
bash download_claif_data.sh

Prepare STS Data

cd SentEval/data/downstream/
bash download_dataset.sh

CLAIF

python run_training.py \
--input_file ./generated_data/claif_data.jsonl \
--output_dir result_model \
--model_name roberta-base \
--num_epochs 3 \
--lr 2e-5 \
--using_stsb_dev

CLHAIF

For the training of CLHAIF, you should use the same environment as the SimCSE, since the version variants of transformers and pytorch may cause some bugs.

CUDA_VISIBLE_DEVICES=0,1,2,3 bash run_clhaif_simcse.sh

Before evaluation the saved checkpoint, you need to convert it to the huggingface format (the same step as SimCSE):

python simcse_to_huggingface.py --path {PATH_TO_CHECKPOINT_FOLDER}

After that, you can evaluate it by our evaluation code.

Inference

Model List

Our released models are listed as following. You can import these models by using Sentence Transformers or using HuggingFace's Transformers.

Model Avg. STS
fnlp/claif-bert-base 79.63
fnlp/claif-roberta-base 79.90
fnlp/claif-scaled-bert-base 82.37
fnlp/claif-scaled-roberta-base 81.88
fnlp/clhaif-simcse-bert-base 82.08
fnlp/clhaif-simcse-roberta-base 82.85

Use CLAIF with Sentence-Transformers

from sentence_transformers import SentenceTransformer
sentences = ["This is an example sentence", "Each sentence is converted"]

model = SentenceTransformer('fnlp/claif-scaled-bert-base')
embeddings = model.encode(sentences)
print(embeddings)

Use CLAIF with HuggingFace Transformers

from transformers import AutoTokenizer, AutoModel
import torch


#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


# Sentences we want sentence embeddings for
sentences = ['This is an example sentence', 'Each sentence is converted']

# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('fnlp/claif-scaled-bert-base')
model = AutoModel.from_pretrained('fnlp/claif-scaled-bert-base')

# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

# Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input)

# Perform pooling. In this case, mean pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

print("Sentence embeddings:")
print(sentence_embeddings)

Evaluation of CLAIF

You can run the evaluation script for claif like:

python evaluation_sts.py --model_name_or_path 'fnlp/claif-roberta-base'\
--mode test\
--task_set sts

which is expected to output the results in a tubular format:

------ test ------
+-------+-------+-------+-------+-------+--------------+-----------------+-------+
| STS12 | STS13 | STS14 | STS15 | STS16 | STSBenchmark | SICKRelatedness |  Avg. |
+-------+-------+-------+-------+-------+--------------+-----------------+-------+
| 68.33 | 82.26 | 77.00 | 85.18 | 83.43 |    85.05     |      78.02      | 79.90 |
+-------+-------+-------+-------+-------+--------------+-----------------+-------+

Evaluation of CLHAIF

You can run the evaluation script for clhaif-bert-base like:

python evaluation_clhaif.py \
--model_name_or_path fnlp/clhaif-simcse-bert-base \
--pooler cls \
--task_set sts \
--mode test

which is expected to output the results in a tubular format:

------ test ------
+-------+-------+-------+-------+-------+--------------+-----------------+-------+
| STS12 | STS13 | STS14 | STS15 | STS16 | STSBenchmark | SICKRelatedness |  Avg. |
+-------+-------+-------+-------+-------+--------------+-----------------+-------+
| 74.86 | 85.09 | 81.24 | 85.96 | 81.33 |    84.69     |      81.36      | 82.08 |
+-------+-------+-------+-------+-------+--------------+-----------------+-------+

You can run the evaluation script for clhaif-roberta-base like:

python evaluation_clhaif.py \
--model_name_or_path fnlp/clhaif-simcse-roberta-base \
--pooler avg \
--task_set sts \
--mode test

which is expected to output the results in a tubular format:

------ test ------
+-------+-------+-------+-------+-------+--------------+-----------------+-------+
| STS12 | STS13 | STS14 | STS15 | STS16 | STSBenchmark | SICKRelatedness |  Avg. |
+-------+-------+-------+-------+-------+--------------+-----------------+-------+
| 76.23 | 85.46 | 81.48 | 86.47 | 83.40 |    85.93     |      80.95      | 82.85 |
+-------+-------+-------+-------+-------+--------------+-----------------+-------+

Citation

Our implementation is built on the source code from dino and SimCSE. Thanks for their work.

@inproceedings{DBLP:conf/acl/ChengYSLQ23,
  author       = {Qinyuan Cheng and
                  Xiaogui Yang and
                  Tianxiang Sun and
                  Linyang Li and
                  Xipeng Qiu},
  title        = {Improving Contrastive Learning of Sentence Embeddings from {AI} Feedback},
  booktitle    = {Findings of the Association for Computational Linguistics: {ACL} 2023,
                  Toronto, Canada, July 9-14, 2023},
  pages        = {11122--11138},
  publisher    = {Association for Computational Linguistics},
  year         = {2023},
}

About

[Findings of ACL'2023] Improving Contrastive Learning of Sentence Embeddings from AI Feedback

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published