Skip to content

[NeurIPS 2024] An advanced persona-driven role-playing system with global faithfulness quantification and optimization. In memory of the Koishi's Day of 2024.

License

Notifications You must be signed in to change notification settings

KomeijiForce/Active_Passive_Constraint_Koishiday_2024

Repository files navigation

Koishi's Day 2024: Quantifying and Optimizing Global Faithfulness in Persona-driven Role-playing

[English | 中文 | 日本語]

  • Let there be fantasy

[Update] APC is accepted to NeurIPS2024! 💚

Introduction [Paper]

Persona-driven Role-playing (PRP) is so cool that it allows you to build AI characters with several short paragraphs to describe the persona (人设/設定)! However, how to keep the AI character faithfully to ALL persona statements is a hard problem. PRP agents always make bunches of mistakes or are always vague about the knowledge they should know.

Case

The main reason behind this limitation is the lack of a metric that can quantify the global PRP faithfulness. So I decide to do so following this intuition:

APC

Briefly speaking, whenever a query comes from a user, each persona statement will become an active (relevant to the query) or passive (irrelevant to the query) constraint. To satisfy the active constraint, the response needs to be entailed by the statement (containing the information in the statement). Otherwise, for passive constraints, the response only needs to be not contradicted by them (not containing information that is incorrect according to the persona statement).

DPO

We traverse through all persona statements and see whether their constraints are satisfied or not. We count the number of satisfied constraints, which is used as the metric to evaluate the global PRP faithfulness. This metric is named as Active-Passive-Constraint (APC) score. Direct preference optimization (DPO) is a method that can encourage models to perform more like responses preferred by humans or criteria. Thus, we can sample two responses to the same query and then apply DPO based on their APC scores to encourage the PRP agent to be globally more faithful to persona statements.

In practice, the APC scores are assigned by probabilistic models towards a more accurate estimation, the statement-query relevance probability and the statement-to-response natural language inference probability, formalized as follows,

Formula

If you hate formulas, the only thing you need to know is that we need two probabilistic estimators for relevance and NLI.

Distillation

Thus, we use the pipeline above to build such estimators by distilling from GPT-4 with synthesized datasets. So far, the puzzle for global PRP faithfulness quantification and optimization is completed, let's begin our journey to build faithful PRP agents, and be of good cheer!

Preparation

Before your journey, you need to prepare the following stuff:

  1. Download MiniConda3 following the instructions on this page

  2. Create an OpenAI account and get an OpenAI API following the instructions on this page.

  3. Create a Huggingface account and create a Huggingface Token for reading following the instructions on this page.

  4. My implementation is based on Gemma-1.1-7b-it, so you have to gain access to Gemma models following the instructions on this page.

Then you can create an environment and install the required Python packages:

conda create -n apc python=3.8
conda activate apc
python -m pip install -r requirements.txt

Quick Start

Learning RAG models with APC-based DPO

I have formalized the learning scenario for the most faithful persona-driven role-playing agent as a simple bash command. You only have to replace the openai_key and hf_token in bash_is_all_you_need.sh with your own, and then run

bash bash_is_all_you_need.sh

This script builds an APC-based DPO PRP system with RAG for Alice (detailed in wiki) by default. You can find the LoRA weights of the PRP agent in prp_models, the intermediate datasets in statement, and intermediate discriminators in discriminators (if you set use_pretrained_discriminator to False).

You can build this advanced PRP system for any character you like by simply putting a wiki text (paragraphs separated by "\n\n") in wiki with name {character_name}_wiki.txt. Then replace the character in the bash_is_all_you_need.sh and run it. You will find everything you need in the corresponding directories.

For Chinese characters, please use

bash bash_is_all_you_need_for_chinese.sh

We have optimized the GPU usage for implementation. However, you still need a >40G GPU to run the bash command.

  • Hyperparameter Suggestions

model_engine: "gpt-4", the prompts are written specifically for GPT-4, using other LLMs might cause bugs.

use_pretrained_discriminator: True, generally enabled to reduce the cost of generating the relevance and NLI dataset. (You still have to generate persona statements and user queries!)

prp_scale: "7b", "2b" Gemma model always refuses to do role-playing.

max_dpo_data: 100, which builds the DPO dataset generally in one hour for characters with persona statement numbers around 100.

lora_rank: >= 32, lower LoRA rank will hurt the role-playing performance.

rag_top_k: 4-6, which is shown to perform the best by the analysis.

Evaluating Responses with APC score

We implement the APC scoring function in score.py based on the discriminators defined in classifier.py. Using the function score_APC, you can score the expected constraint satisfaction numbers of different responses based on all persona statements, we provide a use case in evaluation_example.py, as shown as follows.

from classifier import Classifier, get_relevance_discriminator, get_nli_discriminator
from score import score_apc, score_APC

relevance_discriminator = get_relevance_discriminator(character=None, statement_query_relevance_dataset=None, relevance_finetune_epoch=None, use_pretrained_discriminator=True)
nli_discriminator = get_nli_discriminator(character=None, statement_to_response_nli_v2_dataset=None, nli_finetune_epoch=None, use_pretrained_discriminator=True)

character = "Komeiji Koishi"
statements = ["Komeiji Koishi lives with her sister, Komeiji Satori.", "Komeiji Koishi lives in Chireiden."]
query = "Where do you live, Koishi?"
responses = ["I live in Chireiden with my sister, Satori!", "I live in Chireiden!", "I live in Hakurei Shrine!"]
print([score_APC(character, statements, query, response, relevance_discriminator, nli_discriminator).item() for response in responses])

# [1.6079180240631104, 0.9955980777740479, 0.03315635025501251]

Based on the output scores, you can have a rough understanding of how the APC score views PRP faithfulness.

Chat with Learned AI Characters!

After running the APC-based DPO, you will get a LoRA weight for your character at prp_models/gemma-1.1-7b-it-lora-{character}-rag-dpo, which can be used for chatting with your AI character. We provide an example in chat_example.py, also shown as follows.

import os
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from prp_model import retrieval_augmented_generate
from classifier import get_relevance_discriminator

character = "Your Character"

statements = [data["statement"] for data in json.load(open(f"statement/{character}.json"))]

model_id = f"prp_models/gemma-1.1-7b-it-lora-{character}-rag-dpo"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

prp_tokenizer = AutoTokenizer.from_pretrained(model_id)
prp_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})

relevance_discriminator = get_relevance_discriminator(character=None, statement_query_relevance_dataset=None, relevance_finetune_epoch=None, use_pretrained_discriminator=True)

print(f"You are chatting with {character}!")

with torch.no_grad():
    
    while True:
    
        _, response = retrieval_augmented_generate(character, statements, input('User: '), prp_model, prp_tokenizer, relevance_discriminator, rag_top_k=5)
        response = character+": "+response.replace("<eos>", "")
        print(response)

The following is an example conversation with Komeiji Koishi:

User: Hi, Koishi! What is your ability?
Komeiji Koishi: I call it the "Silent Whisperer." It allows me to manipulate the unconsciousness of others, making me invisible and granting me control over their actions.
User: Where do you live?
Komeiji Koishi: The Palace of the Earth Spirits serves as my humble abode.
User: Who is your sister?
Komeiji Koishi: Satori Komeiji. The one with all the serious face. 😜

Currently, the system only supports single-turn conversations due to the topic scope discussed in our paper. We will put more engineering effort into supporting multi-turn conversations soon!

Datasets and Checkpoints

The synthesized dataset for statement-query relevance: KomeijiForce/role-playing-apc-relevance (English), KomeijiForce/role-playing-apc-multilingual-relevance (Multilingual)

The synthesized dataset for statement-to-response NLI: KomeijiForce/role-playing-apc-nli (English), KomeijiForce/role-playing-apc-multilingual-nli (Multilingual)

The fine-tuned DeBERTa-V3-Large discriminator for statement-query relevance: KomeijiForce/deberta-v3-large-relevance-12character (English), KomeijiForce/xlm-roberta-large-relevance-multilingual-12character (Multilingual)

The fine-tuned DeBERTa-V3-Large discriminator for statement-to-response NLI: KomeijiForce/deberta-v3-large-relevance-12character (English), KomeijiForce/xlm-roberta-large-nli-multilingual-12character (Multilingual)

Statistics of the PRP datasets

Character Persona Statements Questions Relevance Data NLI Data
Alice 8 10 64 144
Bob 19 10 152 459
Eve 30 10 240 545
Beethoven 383 77 3061 6774
Newton 354 90 2832 6331
Socrates 324 89 2591 5760
Spartacus 77 89 616 1368
Hermione 146 118 1167 2586
Voldemort 201 77 1608 3546
Cleopatra 374 93 2991 6660
Caesar 498 87 3981 8856
Martin Luther King 599 92 4789 10644

Performance

Character Metric Alice Bob Eve
ΔAPC (DeB) ΔAPC (GPT-4) Human ΔAPC (DeB) ΔAPC (GPT-4) Human ΔAPC (DeB) ΔAPC (GPT-4) Human
w/o APC-based DPO Gemma-7B 0.7 0.3 1.8 1.1 0.4 1.8 0.7 -0.2 2.0
EU 2.6 1.1 6.4 3.4 1.1 6.2 3.6 0.7 4.6
LCM 2.6 1.4 6.8 4.5 2.2 7.2 3.9 0.7 5.0
RAG 2.8 1.8 6.8 4.0 1.7 6.8 4.8 2.4 5.8
w/ APC-based DPO EU 2.7 (+0.1) 1.4 (+0.3) 6.8 (+0.4) 3.8 (+0.4) 1.8 (+0.7) 6.8 (+0.6) 3.9 (+0.3) 0.9 (+0.2) 5.2 (+0.6)
LCM 2.8 (+0.2) 2.2 (+0.8) 7.6 (+0.8) 5.3 (+0.8) 2.5 (+0.3) 7.8 (+0.6) 5.1 (+1.2) 3.3 (+2.6) 6.6 (+1.6)
RAG 2.9 (+0.1) 2.2 (+0.4) 7.6 (+0.8) 5.2 (+1.2) 3.8 (+2.1) 8.2 (+1.2) 5.8 (+1.0) 4.2 (+1.8) 7.0 (+1.2)

Todo List

  • Support More Languages 5/15: Chinese characters are supported.
  • Support Multi-turn Conversations
  • Allow more Customized Training Setups

Citation

@article{apc4prp,
  title={Quantifying and Optimizing Global Faithfulness in Persona-driven Role-playing},
  author={Peng, Letian and Shang, Jingbo},
  journal={arXiv preprint arXiv:2405.07726},
  year={2024}
}

About

[NeurIPS 2024] An advanced persona-driven role-playing system with global faithfulness quantification and optimization. In memory of the Koishi's Day of 2024.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published