-
Notifications
You must be signed in to change notification settings - Fork 515
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5751d23
commit d475416
Showing
7 changed files
with
416 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
|
||
## 造训练数据 | ||
|
||
### 数据生成框架 | ||
本数据集使用OpenAI API接口生成,流程: | ||
|
||
- **种子特征集和基础设定**: | ||
- 手工编写的种子集包含基本角色特征。 | ||
- LLM从这个种子集生成角色的基础设定。 | ||
- **角色设定的进化**: | ||
- 第二个种子集包含指导角色设定进化的指令Prompt。 | ||
- 这些进化角色的指令Prompt被放到一个指令池中。基于这些进化Prompt,LLM对基础设定实施进化。 | ||
- **反馈循环**: | ||
- 由人类评估者和GPT-4组成的混合评价系统。此系统对进化后的设定给出反馈。 | ||
- 反馈用于迭代更新种子集。如此迭代,我们最终得到一个细致的角色设定数据集。 | ||
- **角色扮演和对话生成**: | ||
- 使用self-instruction框架基于角色设定生成角色的对话数据。 | ||
|
||
|
||
1. 生成角色设定,分别生成护士角色和患者角色 | ||
```bash | ||
cd role_play_data | ||
|
||
python role_generate.py | ||
``` | ||
|
||
|
||
2. 生成医患之间的多轮对话 | ||
LLM选择:分别用gpt-4o的api和豆包的doubao-character-pro-32k的api生成对话 | ||
```bash | ||
python roleplay_data_generate_gpt4.py | ||
|
||
python roleplay_data_generate_doubao.py | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import json | ||
import random | ||
|
||
from openai import OpenAI | ||
from tqdm import tqdm | ||
|
||
client = OpenAI() | ||
print(client) | ||
|
||
|
||
def generate(prompt): | ||
print(prompt) | ||
messages = [ | ||
{"role": "user", "content": prompt} | ||
] | ||
r = client.chat.completions.create( | ||
model='gpt-4o', | ||
temperature=1, | ||
messages=messages, ) | ||
response = r.choices[0].message.content | ||
print("回答:", response) | ||
return response | ||
|
||
|
||
def generate_role(input_file, save_file, total_lines): | ||
with open(input_file, "r", encoding="utf-8") as file: | ||
lines = file.readlines() | ||
with tqdm(total=total_lines, desc="指令进度") as pbar: | ||
while pbar.n < total_lines: | ||
random.shuffle(lines) | ||
i = 0 | ||
sum_str = "" | ||
for line in lines: | ||
i += 1 | ||
try: | ||
data = json.loads(line.strip()) | ||
except: | ||
print("error:", line.strip()) | ||
continue | ||
question = data["system_prompt"] | ||
|
||
sum_str += f"{i}.{question}\n\n" | ||
|
||
if i == 5: | ||
res = generate(f'请续写下面内容,不少于10条,增加些多样性。\n\n{sum_str}') | ||
res = res.split("\n\n") | ||
for result in res: | ||
result = result.strip() | ||
prefix_length = len(result.split(".", 1)[0]) + 1 # 获取前缀数字的长度,包括后面的点号 | ||
result = result[prefix_length:] | ||
if result == "": | ||
continue | ||
json_data = {'system_prompt': result} | ||
# 将数据写入文件 | ||
with open(save_file, 'a', encoding='utf-8') as f: | ||
f.write(json.dumps(json_data, ensure_ascii=False) + '\n') | ||
|
||
pbar.update(1) | ||
if pbar.n >= total_lines: | ||
break | ||
|
||
|
||
if __name__ == '__main__': | ||
total_lines = 50 | ||
input_file = "seed_nurse_role.jsonl" | ||
save_file = "seed_nurse_role_output.jsonl" | ||
generate_role(input_file, save_file, total_lines) | ||
|
||
total_lines = 50 | ||
input_file = "seed_patient_role.jsonl" | ||
save_file = "seed_patient_role_output.jsonl" | ||
generate_role(input_file, save_file, total_lines) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import json | ||
import random | ||
|
||
from openai import OpenAI | ||
from tqdm import tqdm | ||
|
||
client = OpenAI( | ||
api_key="xxx", | ||
base_url="https://ark.cn-beijing.volces.com/api/v3", | ||
) | ||
print(client) | ||
|
||
|
||
def generate(prompt, system_prompt=''): | ||
print('提示:', prompt) | ||
messages = [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": prompt} | ||
] | ||
completion = client.chat.completions.create( | ||
#pro-32k: ep-20240623141021-r77gl | ||
#lite-4k:ep-20240623140948-92n2g | ||
model="ep-20240623141021-r77gl", # your model endpoint ID | ||
messages=messages, | ||
max_tokens=3048, | ||
) | ||
response = completion.choices[0].message.content | ||
print("生成的对话:", response) | ||
return response | ||
|
||
|
||
file_role1 = "seed_nurse_role.jsonl" | ||
file_role2 = "seed_patient_role.jsonl" | ||
with open(file_role1, "r", encoding="utf-8") as file: | ||
role1s = file.readlines() | ||
with open(file_role2, "r", encoding="utf-8") as file: | ||
role2s = file.readlines() | ||
|
||
save_file = "roleplay_train_data_v2.jsonl" | ||
total_lines = 1000 # 10000 | ||
max_history_len = 10 | ||
|
||
with tqdm(total=total_lines, desc="生成对话") as pbar: | ||
while pbar.n < total_lines: | ||
role1 = random.choice(role1s) | ||
role2 = random.choice(role2s) | ||
data1 = json.loads(role1.strip())['system_prompt'] | ||
data2 = json.loads(role2.strip())['system_prompt'] | ||
p = "你是护士,跟患者对话。\n\n护士角色:" + data1 + '\n患者角色:' + data2 | ||
conversation = {"id": str(pbar.n), "system_prompt": p, "conversations": []} | ||
|
||
system_prompt = f"护士角色:{data1}\n患者角色:{data2}\n" | ||
print('------' * 10) | ||
print('system_prompt:', system_prompt) | ||
history = [] | ||
|
||
for i in range(6): | ||
patient_prompt = f"要求你扮演患者,并且根据角色的设定内容模仿 角色相应的对话口吻和风格。你说一句话,完成本轮对话即可。" | ||
for history_turn in history[-max_history_len:]: | ||
patient_prompt += history_turn + '\n' | ||
patient_prompt += "患者:" | ||
|
||
patient_response = generate(patient_prompt, system_prompt) | ||
conversation["conversations"].append({"from": "human", "value": patient_response.strip()}) | ||
history.append("患者:" + patient_response.strip()) | ||
|
||
nurse_prompt = f"要求你扮演护士,并且根据角色的设定内容模仿 角色相应的对话口吻和风格。你说一句话,完成本轮对话即可。\n" | ||
for history_turn in history[-max_history_len:]: | ||
nurse_prompt += history_turn + '\n' | ||
nurse_prompt += "护士:" | ||
|
||
nurse_response = generate(nurse_prompt, system_prompt) | ||
conversation["conversations"].append({"from": "gpt", "value": nurse_response.strip()}) | ||
history.append("护士: " + nurse_response.strip()) | ||
|
||
with open(save_file, 'a', encoding='utf-8') as f: | ||
f.write(json.dumps(conversation, ensure_ascii=False) + '\n') | ||
|
||
pbar.update(1) | ||
if pbar.n >= total_lines: | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import json | ||
import random | ||
|
||
from openai import OpenAI | ||
from tqdm import tqdm | ||
|
||
client = OpenAI() | ||
print(client) | ||
|
||
|
||
def generate(prompt): | ||
print('提示:', prompt) | ||
messages = [ | ||
{"role": "user", "content": prompt} | ||
] | ||
r = client.chat.completions.create( | ||
model='gpt-4o', | ||
messages=messages, | ||
temperature=1, | ||
max_tokens=3048, # 增加max_tokens以生成更长的对话 | ||
) | ||
response = r.choices[0].message.content | ||
print("生成的对话:", response) | ||
return response | ||
|
||
|
||
file_role1 = "seed_nurse_role.jsonl" | ||
file_role2 = "seed_patient_role.jsonl" | ||
with open(file_role1, "r", encoding="utf-8") as file: | ||
role1s = file.readlines() | ||
with open(file_role2, "r", encoding="utf-8") as file: | ||
role2s = file.readlines() | ||
|
||
save_file = "roleplay_train_data_v1.jsonl" | ||
total_lines = 1000 | ||
|
||
with tqdm(total=total_lines, desc="生成对话") as pbar: | ||
while pbar.n < total_lines: | ||
role1 = random.choice(role1s) | ||
role2 = random.choice(role2s) | ||
data1 = json.loads(role1.strip())['system_prompt'] | ||
data2 = json.loads(role2.strip())['system_prompt'] | ||
p = "你是护士,跟患者对话。\n\n护士角色:" + data1 + '\n患者角色:' + data2 | ||
conversation = {"id": str(pbar.n), "system_prompt": p, "conversations": []} | ||
|
||
combined_prompt = f"你扮演一个护士,以下对话是你和患者之间的对话。\n护士角色:{data1}\n患者角色:{data2}\n" | ||
combined_prompt += "进行多轮问答(6轮以上)。患者说话以`患者:`开头,护士说话以`护士:`开头。患者先提问。\n" | ||
|
||
prompt = combined_prompt + "\n对话开始:\n " | ||
response = generate(prompt) | ||
|
||
# 解析生成的多轮对话 | ||
lines = response.strip().split('\n') | ||
for line in lines: | ||
if line.startswith("患者"): | ||
conversation["conversations"].append({"from": "human", "value": line.split("患者")[1].strip()[1:]}) | ||
elif line.startswith("护士"): | ||
conversation["conversations"].append({"from": "gpt", "value": line.split("护士")[1].strip()[1:]}) | ||
|
||
with open(save_file, 'a', encoding='utf-8') as f: | ||
f.write(json.dumps(conversation, ensure_ascii=False) + '\n') | ||
|
||
pbar.update(1) | ||
if pbar.n >= total_lines: | ||
break |
Oops, something went wrong.