-
Notifications
You must be signed in to change notification settings - Fork 50
/
conversation.py
203 lines (180 loc) · 6.89 KB
/
conversation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
DOLLY = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
# Used for gradio server
skip_next: bool = False
conv_id: Any = None
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system
for role, message in self.messages:
if message:
ret += self.sep + " " + role + ": " + message
else:
ret += self.sep + " " + role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ":\n" + message + seps[i % 2]
if i % 2 == 1:
ret += "\n\n"
else:
ret += role + ":\n"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
conv_id=self.conv_id)
def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
"conv_id": self.conv_id,
}
conv_one_shot = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_gar = Conversation(
system="A chat between a curious human and a question-answering system. "
"The system gives helpful, detailed, and polite answers to the human's questions.",
roles=("Question", "Answer"),
messages=(),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="</s>",
)
conv_gar_test = Conversation(
system="A chat between a curious human and a question-answering system. "
"The system gives helpful, detailed, and polite answers to the human's questions.",
roles=("Question", "Good Response (rank 1/100)"),
messages=(),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="</s>",
)
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_koala_v1 = Conversation(
system="BEGINNING OF CONVERSATION:",
roles=("USER", "GPT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_dolly = Conversation(
system=
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
roles=('### Instruction', '### Response'),
messages=(),
offset=0,
sep_style=SeparatorStyle.DOLLY,
sep="\n\n",
sep2="### End",
)
conv_templates = {
"conv_one_shot": conv_one_shot,
"vicuna_v1.1": conv_vicuna_v1_1,
"koala_v1": conv_koala_v1,
"dolly": conv_dolly,
"gar": conv_gar,
"gar_test": conv_gar_test,
}
def get_default_conv_template(model_name):
model_name = model_name.lower()
if "vicuna" in model_name or "output" in model_name:
return conv_vicuna_v1_1
elif "koala" in model_name:
return conv_koala_v1
elif "dolly" in model_name:
return conv_dolly
elif "gar" in model_name:
return conv_gar
return conv_one_shot
if __name__ == "__main__":
print(default_conversation.get_prompt())