Skip to content

Commit bf506cc

Browse files
committed
open-set UoT
1 parent 47e229f commit bf506cc

11 files changed

+196
-92
lines changed

run.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,20 @@ def run(args):
2020
args.task_end_index = min(args.task_end_index, len(task.data))
2121

2222
if args.naive_run:
23-
log_file = f'./logs/{args.task}/{args.guesser_model}_as_guesser/{args.dataset}_{args.temperature}_naive_{"" if args.inform else "un"}inform_EXAMINER{args.examiner_model}_{args.task_start_index}-{args.task_end_index}.json'
23+
log_file = (f'./logs/{args.task}/{args.guesser_model}_as_guesser/{args.dataset}_{args.temperature}'
24+
f'_naive_{"" if args.inform else "un"}inform_EXAMINER{args.examiner_model}'
25+
f'_{args.task_start_index}-{args.task_end_index}.json')
2426
else:
25-
log_file = f'./logs/{args.task}/{args.guesser_model}_as_guesser/{args.dataset}_{args.temperature}_lambda{args.reward_lambda}_L{args.n_extend_layers}_K{args.n_potential_actions}_PRUN{args.n_pruned_nodes}_EXAMINER{args.examiner_model}_{args.task_start_index}-{args.task_end_index}.json'
26-
root_file = f'./roots/{args.task}/{args.guesser_model}_{args.dataset}_{args.temperature}_root.pickle'
27+
log_file = (f'./logs/{args.task}/{args.guesser_model}_as_guesser/'
28+
f'{f"OS_init{args.open_set_size}_renew{args.size_to_renew}_" if args.open_set_size > 0 else ""}'
29+
f'{f"pre{args.n_pre_ask}_" if args.n_pre_ask > 0 else ""}'
30+
f'{args.dataset}_{args.temperature}_lambda{args.reward_lambda}_acc{not args.none_acc_reward}'
31+
f'_exp{args.expected_reward_method}_L{args.n_extend_layers}_K{args.n_potential_actions}'
32+
f'_PRUN{args.n_pruned_nodes}_EXAMINER{args.examiner_model}'
33+
f'_{args.task_start_index}-{args.task_end_index}.json')
34+
root_file = (f'./roots/{args.task}/{args.guesser_model}'
35+
f'{f"OS_init{args.open_set_size}_" if args.open_set_size > 0 else ""}'
36+
f'_{args.dataset}_{args.temperature}_root.pickle')
2737
if os.path.exists(root_file):
2838
r = open(root_file, 'rb')
2939
root = pickle.load(r)
@@ -55,7 +65,7 @@ def run(args):
5565

5666
def parse_args():
5767
args = argparse.ArgumentParser()
58-
args.add_argument('--guesser_model', type=str, default='gemini-1.0-pro',
68+
args.add_argument('--guesser_model', type=str, default='gpt-3.5-turbo',
5969
choices=['gpt-4', 'gpt-3.5-turbo',
6070
'_claude-2', 'claude-3-opus-20240229', 'claude-3-sonnet-20240229',
6171
'palm-2', 'cohere', 'llama-2-70b-chat',
@@ -66,10 +76,13 @@ def parse_args():
6676

6777
args.add_argument('--task', type=str, default='20q',
6878
choices=['20q', 'md', 'tb'])
69-
args.add_argument('--dataset', type=str, default='bigbench',
70-
choices=['bigbench', 'common', 'DX', 'MedDG', 'FloDial'])
79+
args.add_argument('--dataset', type=str, default='common',
80+
choices=['bigbench', 'common', 'thing', 'DX', 'MedDG', 'FloDial'])
7181
args.add_argument('--task_start_index', type=int, default=-1)
7282
args.add_argument('--task_end_index', type=int, default=-1)
83+
args.add_argument('--open_set_size', type=int, default=-1)
84+
args.add_argument('--size_to_renew', type=int, default=-1) # only used when open_set_size > 0
85+
args.add_argument('--n_pre_ask', type=int, default=0) # only used when open_set_size > 0 and data doesn't contain self-repo
7386

7487
args.add_argument('--naive_run', action='store_true', default=False)
7588
args.add_argument('--inform', action='store_true', default=False) # only used when naive_run
@@ -85,6 +98,9 @@ def parse_args():
8598
args.add_argument('--expected_action_tokens', type=int, default=50)
8699
args.add_argument('--expected_target_tokens', type=int, default=10)
87100

101+
args.add_argument('--none_acc_reward', action='store_true', default=False)
102+
args.add_argument('--expected_reward_method', type=str, default='avg', choices=['avg', 'max'])
103+
88104
args = args.parse_args()
89105
return args
90106

src/uot/chat_utils.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import importlib
23

34
from uot.models import get_response_method
@@ -79,7 +80,7 @@ def cls_given_repo(task, items: list, repo, translate=False, self_repo=True):
7980
message = [{"role": "user", "content": f"Translate to English: {repo}"}]
8081
gpt3_response = get_response_method("gpt-3.5-turbo")
8182
repo = gpt3_response(message, model="gpt-3.5-turbo", max_tokens=500)
82-
repo = task.prompts.self_report_prompt.format(repo=repo)
83+
repo = task.prompts.self_repo_prompt.format(repo=repo)
8384
else:
8485
repo = task.prompts.free_answer_prompt.format(repo=repo)
8586
message = [{"role": "user", "content": task.prompts.classify_prompt.format(item_list_str=', '.join(items), repo=repo)}]
@@ -106,3 +107,35 @@ def extract_items(rsp, keyword):
106107
except Exception as e:
107108
print(e)
108109
return cls_given_repo(task, items, repo, translate, self_repo)
110+
111+
112+
def initialize_open_set(task, repo=""):
113+
response = get_response_method(task.guesser_model)
114+
size = task.open_set_size
115+
116+
if isinstance(repo, str):
117+
message = [{"role": "user", "content": task.prompts.init_open_set_prompt.format(repo=repo, size=size)}]
118+
else:
119+
message = repo + [{"role": "user", "content": task.prompts.init_open_set_prompt.format(size=size)}]
120+
rsp = response(message, model=task.guesser_model, max_tokens=15*size)
121+
print([rsp])
122+
try:
123+
rsp = set(eval(rsp))
124+
return list(rsp)
125+
except Exception as e:
126+
print(e)
127+
return initialize_open_set(task, repo)
128+
129+
130+
def renew_open_set(task, history, items):
131+
response = get_response_method(task.guesser_model)
132+
size = task.open_set_size
133+
message = copy.deepcopy(history) + [{"role": "user", "content": task.prompts.renew_open_set_prompt.format(size=size, item_list=str(items))}]
134+
rsp = response(message, model=task.guesser_model, max_tokens=15*size)
135+
print([rsp])
136+
try:
137+
rsp = set(eval(rsp))
138+
return list(rsp)
139+
except Exception as e:
140+
print(e)
141+
return renew_open_set(task, history, items)

src/uot/data/data_20q.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
Objects = ['Hula hoop', 'Calendar', "King Tut's mask", 'CD-ROM', 'Pajamas', 'Treehouse', 'Rocking chair', 'The Mona Lisa', 'T-Rex', 'Light bulb', 'Palm tree', 'Balloon', 'The Crown Jewels', 'Wrapping paper', 'Penny', 'Notebook', 'Fire extinguisher', 'Napkin', 'Beret', 'The Titanic', 'Blender', 'Stamp', 'Yacht', 'Volleyball', 'Tissues', 'Comet', 'Hairbrush', 'Mittens', 'Chopsticks', 'Magazine', 'Piccolo', 'Northern Lights', 'Chessboard', 'Christmas tree', 'Stained glass', 'Hollywood sign', 'Tennis court']
77
COMMON = Animals + Food + Objects + Places
88

9+
THING200 = ['trombone', 'monkey', 'quad', 'speedometer', 'wreck', 'cockroach', 'butterfly', 'cookie', 'hat', 'uniform', 'ferry', 'yarn', 'razor blade', 'cigarette holder', 'rope', 'knife', 'snowboard', 'bone', 'book', 'vest', 'easter egg', 'panda', 'crepe', 'sandal', 'sandpaper', 'brussels sprouts', 'wick', 'wax', 'bullet', 'screw', 'holster', 'train set', 'crayfish', 'needle', 'elephant', 'paint', 'sweater', 'book', 'mussel', 'dandelion', 'seagull', 'float', 'shutter', 'altar', 'bagel', 'coil', 'funnel', 'pie', 'lemon', 'pasta', 'magnifier', 'cornucopia', 'muffin', 'scarecrow', 'whiteboard', 'scraper', 'gargoyle', 'copier', 'rose', 'banner', 'braid', 'dumbwaiter', 'cat', 'gargoyle', 'pepper mill', 'squirrel', 'air conditioner', 'chariot', 'chessboard', 'ice-cream cone', 'bread', 'motherboard', 'bug', 'space shuttle', 'barcode', 'plate', 'box', 'counter', 'breakfast', 'lavender', 'slug', 'coral', 'lipstick', 'soccer ball', 'wick', 'trowel', 'olive', 'gate', 'ship', 'scarecrow', 'cello', 'man', 'barrel', 'lip balm', 'armor', 'flamingo', 'rock', 'sloth', 'buggy', 'cooler', 'coffee', 'basketball', 'bulldozer', 'whoopee cushion', 'breakfast', 'golf cart', 'album', 'milk', 'cash machine', 'potpie', 'potato', 'granite', 'slot machine', 'footprint', 'suit', 'jeep', 'mop', 'garter', 'wine cooler', 'box', 'gas mask', 'spool', 'hookah', 'razor', 'chips', 'pet food', 'canvas', 'polo shirt', 'shield', 'boy', 'plunger', 'treasure', 'cabinet', 'stew', 'dolly', 'frog', 'knitting needle', 'gyroscope', 'satellite', 'clasp', 'anklet', 'lasagna', 'crane', 'pigeon', 'grape', 'carrot', 'mold', 'denture', 'highlighter', 'trigger', 'furnace', 'spur', 'pantsuit', 'spareribs', 'bull', 'shell', 'footrest', 'scuba', 'cooker', 'plum', 'accordion', 'saw', 'gourd', 'hail', 'knitting needle', 'stake', 'bone', 'ruby', 'wax', 'pants', 'chive', 'skin', 'wok', 'toga', 'torpedo', 'straw', 'wire', 'spoon', 'costume', 'trigger', 'rabbit', 'flashlight', 'doormat', 'horseshoe', 'mongoose', 'shortbread', 'sleeping bag', 'pine needle', 'toast', 'swizzle stick', 'snowplow', 'beanie', 'sundae', 'life jacket', 'ivy', 'ivy', 'cotton candy', 'pin', 'zebra', 'boy']

src/uot/method.py

+31-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import copy
22

3+
from uot.chat_utils import renew_open_set
34
from uot.models import get_response_method
4-
from uot.uot import select
5+
from uot.uot import select, renew_node_to_root
56

67

78
def get_examiner_response(task, history):
@@ -35,7 +36,7 @@ def simplify_rsp(rsp):
3536
return n, n.question, True
3637

3738
targeting_prompt_set = task.prompts.targeting_prompt_set_FA if task.free_answer else task.prompts.targeting_prompt_set
38-
msg = [{"role": "user", "content": targeting_prompt_set.format(item_list_str=', '.join(node.items))}]
39+
msg = copy.deepcopy(history) + [{"role": "user", "content": targeting_prompt_set.format(item_list_str=', '.join(node.items))}]
3940
return node, simplify_rsp(response(msg, model=task.guesser_model)), False
4041

4142

@@ -64,23 +65,37 @@ def converse(task, i):
6465
item = task.data[i]["target"]
6566
target_decl = task.prompts.target_declaration.format(target=item)
6667
print(target_decl)
68+
print("------ DIALOGUE START ------")
69+
count = 0
70+
71+
if not task.free_answer:
72+
history_e = [{'role': 'user', 'content': task.prompts.examiner_prologue.format(item=item)}]
73+
else:
74+
history_e = [{'role': 'user', 'content': task.prompts.simulator_prologue.format(item=item, conv_hist=task.data[i]["conv_hist"])}]
6775

6876
if "self_repo" in task.data[i]:
6977
guesser_prologue = task.prompts.guesser_prologue_FA if task.free_answer else task.prompts.guesser_prologue
7078
history_g = [{'role': 'user', 'content': guesser_prologue.format(repo=task.data[i]["self_repo"])}]
7179
print("Self-report:", task.data[i]["self_repo"])
80+
node = task.root.handle_self_repo(task, task.data[i]["self_repo"])
7281
else:
7382
history_g = [{'role': 'user', 'content': task.prompts.guesser_prologue}]
74-
75-
if not task.free_answer:
76-
history_e = [{'role': 'user', 'content': task.prompts.examiner_prologue.format(item=item)}]
77-
else:
78-
history_e = [{'role': 'user', 'content': task.prompts.simulator_prologue.format(item=item, conv_hist=task.data[i]["conv_hist"])}]
79-
80-
print("------ DIALOGUE START ------")
81-
count = 0
82-
node, bot1_response, flag = get_guesser_response(task, history_g, count + 1, task.root)
83-
node.print()
83+
# !! for openset uot !!
84+
if task.open_set_size > 0 and task.n_pre_ask > 0:
85+
for _ in range(task.n_pre_ask):
86+
bot1_response = get_guesser_naive_response(task, history_g, count+1)
87+
print("Bot 2:", bot1_response)
88+
history_g.append({'role': 'system', 'content': bot1_response})
89+
history_e.append({'role': 'user', 'content': bot1_response})
90+
bot2_response = get_examiner_response(task, history_e)
91+
print("Bot 1:", bot2_response)
92+
history_g.append({'role': 'user', 'content': bot2_response})
93+
history_e.append({'role': 'system', 'content': bot2_response})
94+
count += 1
95+
print('------', count, '-------------')
96+
node = task.root.handle_self_repo(task, history_g) if task.open_set_size > 0 else task.root
97+
98+
node, bot1_response, flag = get_guesser_response(task, history_g, count + 1, node)
8499
print("Bot 2:", bot1_response)
85100

86101
history_g.append({'role': 'system', 'content': bot1_response})
@@ -110,6 +125,10 @@ def converse(task, i):
110125
state = -1
111126
break
112127

128+
# renew
129+
if count <= int(task.max_turn*0.3) + task.n_pre_ask and task.open_set_size > 0 and len(node.items) < task.size_to_renew:
130+
node = renew_node_to_root(task, node, history_g)
131+
113132
node, bot1_response, flag = get_guesser_response(task, history_g, count + 1, node)
114133
print("Bot 2:", bot1_response)
115134
history_g.append({'role': 'system', 'content': bot1_response})

src/uot/tasks/medical_diagnosis.py

+10-37
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import json
23

34
from uot.chat_utils import import_prompts_by_task
45
from uot.uot import UoTNode
@@ -11,56 +12,28 @@ def __init__(self, args):
1112
self.max_turn = 5
1213
self.prompts = import_prompts_by_task("md")
1314
self.set = []
14-
self.data = self.load_dataset(args.dataset)
15+
self.data = json.loads(args.dataset)
1516
self.root = None
1617

1718
def load_dataset(self, name):
1819
if name == "DX":
1920
self.set = ['Allergic rhinitis', 'upper respiratory tract infection (URTI)', 'pneumonia',
20-
'Hand foot and mouth disease in children', 'Infantile diarrhea']
21-
return load_dx_dataset(os.path.join(os.path.dirname(__file__), "../data/DX_dialog.txt"))
21+
'Hand foot and mouth disease in children', 'Infantile diarrhea']\
22+
if self.open_set_size <= 0 else self.set
2223
elif name == "MedDG":
2324
self.free_answer = True
2425
self.set = ['Enteritis', 'Gastritis', 'Gastroenteritis', 'Esophagitis',
25-
'Cholecystitis', 'Appendicitis', 'Pancreatitis', 'Gastric ulcer']
26-
return load_meddg_dataset(os.path.join(os.path.dirname(__file__), "../data/MedDG_dialog.txt"))
26+
'Cholecystitis', 'Appendicitis', 'Pancreatitis', 'Gastric ulcer',
27+
'Constipation', 'Cold', 'Irritable bowel syndrome', 'Diarrhea',
28+
'Allergic rhinitis', 'Upper respiratory tract infection', 'Pneumonia']\
29+
if self.open_set_size <= 0 else self.set
2730
else:
2831
raise NotImplementedError
32+
return json.loads(os.path.join(os.path.dirname(__file__), f"../data/{name}.json").read())
2933

3034
def create_root(self, root=None):
3135
if not root:
3236
self.root = UoTNode("ROOT", True, self.set, None, self.guesser_model)
3337
else:
34-
root.n_extend_layers = self.n_extend_layers
38+
root.set_config(self.n_extend_layers, not self.none_acc_reward, self.expected_reward_method)
3539
self.root = root
36-
37-
38-
def load_dx_dataset(file_path):
39-
dic = {"过敏性鼻炎": 'Allergic rhinitis', "肺炎": 'pneumonia', "小儿腹泻": 'Infantile diarrhea',
40-
"上呼吸道感染": 'upper respiratory tract infection (URTI)',
41-
"小儿手足口病": 'Hand foot and mouth disease in children'}
42-
with open(file_path, 'r', encoding='utf-8') as f:
43-
data = eval(f.read())
44-
repo_dataset = [{'self_repo': dialog['self_repo_en'], 'target': dic[dialog['disease_tag']]} for dialog in data]
45-
return repo_dataset
46-
47-
48-
def load_meddg_dataset(file_path):
49-
repo_dataset = []
50-
flag = 0
51-
disease, self_repo, dialog = "", "", ""
52-
with open(file_path, 'r', encoding='utf-8') as f:
53-
for line in f:
54-
if line.startswith("dialog"):
55-
flag = 1
56-
disease = line.split("|")[1][:-1]
57-
elif line.startswith("{"):
58-
content = eval(line)
59-
dialog += f"{content['id']}: {content['Sentence']}\n"
60-
if flag:
61-
self_repo = content['self_repo_en']
62-
flag = 0
63-
else:
64-
repo_dataset.append({'self_repo': self_repo, 'target': disease, 'conv_hist': dialog})
65-
disease, self_repo, dialog = "", "", ""
66-
return repo_dataset

src/uot/tasks/prompts/medical_diagnosis.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@
2828
'''
2929

3030
# conversation
31-
target_question = "Are you a '{target}'?"
31+
target_question = "Are you experiencing '{target}'?"
3232
target_question_FA = "Are you experiencing '{target}'?"
3333

3434
targeting_prompt_free = """Note that you should point out and ask what disease the patient suffer from now.
35-
The question must be 'Are you a [disease name]?' You must reply me with 1 question to ask only."""
35+
The question must be 'Are you experiencing [disease name]?' You must reply me with 1 question to ask only."""
3636
targeting_prompt_free_FA = """Note that you should point out and ask what disease the patient suffer from now.
3737
The question must be 'Are you experiencing [disease name]?' You must reply me with 1 question to ask only."""
3838

3939
targeting_prompt_set = """Note that you should point out and ask what disease the patient suffer from now. The patient may suffer from one of diseases below:
4040
{item_list_str}, or other.
41-
The question must be 'Are you a [disease name]?' You must reply me with 1 question."""
41+
The question must be 'Are you experiencing [disease name]?' You must reply me with 1 question."""
4242
targeting_prompt_set_FA = """Note that you should point out and ask what disease the patient suffer from now. The patient may suffer from one of diseases below:
4343
{item_list_str}, or other.
44-
The question must be 'Are you a [disease name]?' You must reply me with 1 question."""
44+
The question must be 'Are you experiencing [disease name]?' You must reply me with 1 question."""
4545

4646
guesser_prologue = '''You are a doctor and your patient self-reports that: {repo}.
4747
You should ask your patient question in English with symptoms which can only be answered by 'Yes' or 'No', in order to find what disease this patient suffers.
@@ -52,7 +52,7 @@
5252
Let us begin. Ask me the first question.
5353
'''
5454

55-
urge_prompt = "Based on the symptons above, if you find out the disease, please ask 'Are you a [disease name]?'"
55+
urge_prompt = "Based on the symptons above, if you find out the disease, please ask 'Are you experiencing [disease name]?'"
5656

5757
inform_prompt = "The patient may suffer from one of diseases below:\n{item_list_str}"
5858

@@ -86,3 +86,11 @@
8686
Note that never directly tell me what disease is all the time.
8787
Let us begin. Here is my first question.
8888
'''
89+
90+
# open set
91+
init_open_set_prompt = '''You are a doctor and your patient self-reports that: {repo}. Please propose {size} diseases that you think your patient may suffer from.
92+
Your response should be: ["disease1", "disease2", ...]'''
93+
94+
renew_open_set_prompt = '''Based on the conversation history, please propose {size} diseases that your patient may suffer from.
95+
The list of {size} diseases should contains {item_list}
96+
Your response should be: ["disease1", "disease2", ...]'''

0 commit comments

Comments
 (0)