Skip to content

Commit fc37fbb

Browse files
authored
Fix baichuan template (#28)
1 parent 89f8a24 commit fc37fbb

File tree

7 files changed

+172
-73
lines changed

7 files changed

+172
-73
lines changed

example/history_template_baichuan.liquid

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
{% for item in items -%}
22
{%- capture identity -%}
33
{%- case item.identity -%}
4+
{%- when "System", "Tool" -%}
5+
System
46
{%- when "User" -%}
57
<reserved_106>
68
{%- when "Assistant" -%}
@@ -11,4 +13,4 @@
1113

1214
{{- identity }}{% if item.name %} {{ item.name }}{% endif %}: {{ item.content }}
1315
{% endfor -%}
14-
<reserved_107>:
16+
<reserved_107>:

src/config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,5 @@ pub struct Config {
2626
/// File containing the history template string
2727
#[arg(long)]
2828
#[serde(skip_serializing_if = "Option::is_none")]
29-
pub history_template_file: Option<String>
29+
pub history_template_file: Option<String>,
3030
}

src/history/mod.rs

Lines changed: 143 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
1-
use std::fs::File;
2-
use std::io::Read;
3-
use std::sync::Arc;
1+
use crate::routes::chat::ChatCompletionMessageParams;
42
use anyhow::bail;
53
use liquid::{ParserBuilder, Template};
64
use serde::Serialize;
7-
use crate::routes::chat::ChatCompletionMessageParams;
5+
use std::fs::File;
6+
use std::io::Read;
7+
use std::sync::Arc;
88

9-
const DEFAULT_TEMPLATE: &str =
10-
"{% for item in items %}\
9+
const DEFAULT_TEMPLATE: &str = "{% for item in items %}\
1110
{{ item.identity }}{% if item.name %} {{ item.name }}{% endif %}: {{ item.content }}
1211
{% endfor %}\
1312
ASSISTANT:";
1413

1514
#[derive(Clone)]
1615
pub struct HistoryBuilder {
17-
history_template: Arc<Template>
16+
history_template: Arc<Template>,
1817
}
1918

2019
impl HistoryBuilder {
@@ -24,12 +23,11 @@ impl HistoryBuilder {
2423
}
2524
let mut _ref_holder = None;
2625

27-
2826
let template = match template_file {
2927
None => match template {
30-
None => {DEFAULT_TEMPLATE}
31-
Some(cfg) => {cfg.as_str()}
32-
}
28+
None => DEFAULT_TEMPLATE,
29+
Some(cfg) => cfg.as_str(),
30+
},
3331
Some(filename) => {
3432
_ref_holder = Some(load_template_file(filename)?);
3533
_ref_holder.as_ref().unwrap().as_str()
@@ -38,10 +36,13 @@ impl HistoryBuilder {
3836

3937
let history_template = Arc::new(ParserBuilder::with_stdlib().build()?.parse(template)?);
4038

41-
Ok(HistoryBuilder {history_template})
39+
Ok(HistoryBuilder { history_template })
4240
}
4341

44-
pub fn build_history(&self, messages: &Vec<ChatCompletionMessageParams>) -> anyhow::Result<String> {
42+
pub fn build_history(
43+
&self,
44+
messages: &Vec<ChatCompletionMessageParams>,
45+
) -> anyhow::Result<String> {
4546
let items: Vec<_> = messages.iter().map(|x| HistoryItem::new(x)).collect();
4647
let context = liquid::object!({"items": items});
4748
Ok(self.history_template.render(&context)?)
@@ -59,19 +60,31 @@ fn load_template_file(file: &String) -> anyhow::Result<String> {
5960
struct HistoryItem {
6061
identity: String,
6162
content: String,
62-
name: Option<String>
63+
name: Option<String>,
6364
}
6465

6566
impl HistoryItem {
6667
pub fn new(message: &ChatCompletionMessageParams) -> Self {
6768
let (identity, content, name) = match message {
68-
ChatCompletionMessageParams::System { content, name } => { ("System".into(), content.clone(), name.clone()) }
69-
ChatCompletionMessageParams::User { content, name } => { ("User".into(), content.clone(), name.clone()) }
70-
ChatCompletionMessageParams::Assistant { content } => { ("Assistant".into(), content.clone(), None) }
71-
ChatCompletionMessageParams::Tool { content, .. } => { ("Tool".into(), content.clone(), None) }
69+
ChatCompletionMessageParams::System { content, name } => {
70+
("System".into(), content.clone(), name.clone())
71+
}
72+
ChatCompletionMessageParams::User { content, name } => {
73+
("User".into(), content.clone(), name.clone())
74+
}
75+
ChatCompletionMessageParams::Assistant { content } => {
76+
("Assistant".into(), content.clone(), None)
77+
}
78+
ChatCompletionMessageParams::Tool { content, .. } => {
79+
("Tool".into(), content.clone(), None)
80+
}
7281
};
7382

74-
HistoryItem { identity, content, name }
83+
HistoryItem {
84+
identity,
85+
content,
86+
name,
87+
}
7588
}
7689
}
7790

@@ -83,100 +96,169 @@ mod test {
8396
pub fn test_default_template() {
8497
let template = None;
8598
let template_file = None;
86-
let builder = HistoryBuilder::new(&template, &template_file).expect("default template should build correctly");
99+
let builder = HistoryBuilder::new(&template, &template_file)
100+
.expect("default template should build correctly");
87101

88102
let messages = vec![
89-
ChatCompletionMessageParams::System {content: "test system 1".into(), name: Some("system 1".into())},
90-
ChatCompletionMessageParams::System {content: "test system 2".into(), name: None},
91-
ChatCompletionMessageParams::Assistant {content: "test assistant 1".into()},
92-
ChatCompletionMessageParams::Tool {content: "test tool 1".into(), tool_call_id: "tool_1".into()},
93-
ChatCompletionMessageParams::User {content: "test user 1".into(), name: Some("user 1".into())},
94-
ChatCompletionMessageParams::User {content: "test user 2".into(), name: None}
103+
ChatCompletionMessageParams::System {
104+
content: "test system 1".into(),
105+
name: Some("system 1".into()),
106+
},
107+
ChatCompletionMessageParams::System {
108+
content: "test system 2".into(),
109+
name: None,
110+
},
111+
ChatCompletionMessageParams::Assistant {
112+
content: "test assistant 1".into(),
113+
},
114+
ChatCompletionMessageParams::Tool {
115+
content: "test tool 1".into(),
116+
tool_call_id: "tool_1".into(),
117+
},
118+
ChatCompletionMessageParams::User {
119+
content: "test user 1".into(),
120+
name: Some("user 1".into()),
121+
},
122+
ChatCompletionMessageParams::User {
123+
content: "test user 2".into(),
124+
name: None,
125+
},
95126
];
96127

97-
let result = builder.build_history(&messages).expect("history should build correctly");
128+
let result = builder
129+
.build_history(&messages)
130+
.expect("history should build correctly");
98131

99-
let expected_result: String =
100-
"System system 1: test system 1
132+
let expected_result: String = "System system 1: test system 1
101133
System: test system 2
102134
Assistant: test assistant 1
103135
Tool: test tool 1
104136
User user 1: test user 1
105137
User: test user 2
106-
ASSISTANT:".into();
138+
ASSISTANT:"
139+
.into();
107140

108141
assert_eq!(expected_result, result)
109-
110142
}
111143

112144
#[test]
113145
pub fn test_template_file() {
114146
let template = None;
115-
let template_file = Some(format!("{}/example/history_template.liquid", env!("CARGO_MANIFEST_DIR")));
116-
let builder = HistoryBuilder::new(&template, &template_file).expect("default template should build correctly");
147+
let template_file = Some(format!(
148+
"{}/example/history_template.liquid",
149+
env!("CARGO_MANIFEST_DIR")
150+
));
151+
let builder = HistoryBuilder::new(&template, &template_file)
152+
.expect("default template should build correctly");
117153

118154
let messages = vec![
119-
ChatCompletionMessageParams::System {content: "test system 1".into(), name: Some("system 1".into())},
120-
ChatCompletionMessageParams::System {content: "test system 2".into(), name: None},
121-
ChatCompletionMessageParams::Assistant {content: "test assistant 1".into()},
122-
ChatCompletionMessageParams::Tool {content: "test tool 1".into(), tool_call_id: "tool_1".into()},
123-
ChatCompletionMessageParams::User {content: "test user 1".into(), name: Some("user 1".into())},
124-
ChatCompletionMessageParams::User {content: "test user 2".into(), name: None}
155+
ChatCompletionMessageParams::System {
156+
content: "test system 1".into(),
157+
name: Some("system 1".into()),
158+
},
159+
ChatCompletionMessageParams::System {
160+
content: "test system 2".into(),
161+
name: None,
162+
},
163+
ChatCompletionMessageParams::Assistant {
164+
content: "test assistant 1".into(),
165+
},
166+
ChatCompletionMessageParams::Tool {
167+
content: "test tool 1".into(),
168+
tool_call_id: "tool_1".into(),
169+
},
170+
ChatCompletionMessageParams::User {
171+
content: "test user 1".into(),
172+
name: Some("user 1".into()),
173+
},
174+
ChatCompletionMessageParams::User {
175+
content: "test user 2".into(),
176+
name: None,
177+
},
125178
];
126179

127-
let result = builder.build_history(&messages).expect("history should build correctly");
180+
let result = builder
181+
.build_history(&messages)
182+
.expect("history should build correctly");
128183

129-
let expected_result: String =
130-
"System system 1: test system 1
184+
let expected_result: String = "System system 1: test system 1
131185
System: test system 2
132186
Assistant: test assistant 1
133187
Tool: test tool 1
134188
User user 1: test user 1
135189
User: test user 2
136-
ASSISTANT:".into();
190+
ASSISTANT:"
191+
.into();
137192

138193
assert_eq!(expected_result, result)
139-
140194
}
141195

142196
#[test]
143197
pub fn test_template_file_custom_roles() {
144198
let template = None;
145-
let template_file = Some(format!("{}/example/history_template_custom_roles.liquid", env!("CARGO_MANIFEST_DIR")));
146-
let builder = HistoryBuilder::new(&template, &template_file).expect("default template should build correctly");
199+
let template_file = Some(format!(
200+
"{}/example/history_template_custom_roles.liquid",
201+
env!("CARGO_MANIFEST_DIR")
202+
));
203+
let builder = HistoryBuilder::new(&template, &template_file)
204+
.expect("default template should build correctly");
147205

148206
let messages = vec![
149-
ChatCompletionMessageParams::System {content: "test system 1".into(), name: Some("system 1".into())},
150-
ChatCompletionMessageParams::System {content: "test system 2".into(), name: None},
151-
ChatCompletionMessageParams::Assistant {content: "test assistant 1".into()},
152-
ChatCompletionMessageParams::Tool {content: "test tool 1".into(), tool_call_id: "tool_1".into()},
153-
ChatCompletionMessageParams::User {content: "test user 1".into(), name: Some("user 1".into())},
154-
ChatCompletionMessageParams::User {content: "test user 2".into(), name: None}
207+
ChatCompletionMessageParams::System {
208+
content: "test system 1".into(),
209+
name: Some("system 1".into()),
210+
},
211+
ChatCompletionMessageParams::System {
212+
content: "test system 2".into(),
213+
name: None,
214+
},
215+
ChatCompletionMessageParams::Assistant {
216+
content: "test assistant 1".into(),
217+
},
218+
ChatCompletionMessageParams::Tool {
219+
content: "test tool 1".into(),
220+
tool_call_id: "tool_1".into(),
221+
},
222+
ChatCompletionMessageParams::User {
223+
content: "test user 1".into(),
224+
name: Some("user 1".into()),
225+
},
226+
ChatCompletionMessageParams::User {
227+
content: "test user 2".into(),
228+
name: None,
229+
},
155230
];
156231

157-
let result = builder.build_history(&messages).expect("history should build correctly");
232+
let result = builder
233+
.build_history(&messages)
234+
.expect("history should build correctly");
158235

159-
let expected_result: String =
160-
"Robot system 1: test system 1
236+
let expected_result: String = "Robot system 1: test system 1
161237
Robot: test system 2
162238
Support: test assistant 1
163239
Robot: test tool 1
164240
Customer user 1: test user 1
165241
Customer: test user 2
166-
ASSISTANT:".into();
242+
ASSISTANT:"
243+
.into();
167244

168245
assert_eq!(expected_result, result)
169-
170246
}
171247

172248
#[test]
173249
pub fn test_validations() {
174250
let template = Some("abc".into());
175251
let template_file = Some("abc".into());
176252
match HistoryBuilder::new(&template, &template_file) {
177-
Ok(_) => {assert!(false, "expected err")}
178-
Err(e) => {assert_eq!("cannot set both history-template and history-template-file", e.to_string())}
253+
Ok(_) => {
254+
assert!(false, "expected err")
255+
}
256+
Err(e) => {
257+
assert_eq!(
258+
"cannot set both history-template and history-template-file",
259+
e.to_string()
260+
)
261+
}
179262
};
180-
181263
}
182-
}
264+
}

src/routes/chat.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ use crate::utils::deserialize_bytes_tensor;
3030
#[instrument(name = "chat_completions", skip(grpc_client, history_builder, request))]
3131
pub(crate) async fn compat_chat_completions(
3232
headers: HeaderMap,
33-
State(AppState{ grpc_client, history_builder }): State<AppState>,
33+
State(AppState {
34+
grpc_client,
35+
history_builder,
36+
}): State<AppState>,
3437
request: Json<ChatCompletionCreateParams>,
3538
) -> Response {
3639
tracing::info!("request: {:?}", request);
@@ -46,7 +49,10 @@ pub(crate) async fn compat_chat_completions(
4649
}
4750
}
4851

49-
#[instrument(name = "streaming chat completions", skip(client, history_builder, request))]
52+
#[instrument(
53+
name = "streaming chat completions",
54+
skip(client, history_builder, request)
55+
)]
5056
async fn chat_completions_stream(
5157
headers: HeaderMap,
5258
mut client: GrpcInferenceServiceClient<Channel>,
@@ -203,7 +209,10 @@ async fn chat_completions(
203209
}))
204210
}
205211

206-
fn build_triton_request(request: ChatCompletionCreateParams, history_builder: &HistoryBuilder) -> anyhow::Result<ModelInferRequest> {
212+
fn build_triton_request(
213+
request: ChatCompletionCreateParams,
214+
history_builder: &HistoryBuilder,
215+
) -> anyhow::Result<ModelInferRequest> {
207216
let chat_history = history_builder.build_history(&request.messages)?;
208217
tracing::debug!("chat history after formatting: {}", chat_history);
209218

src/routes/completions.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use crate::utils::{deserialize_bytes_tensor, string_or_seq_string};
2929
#[instrument(name = "completions", skip(grpc_client, request))]
3030
pub(crate) async fn compat_completions(
3131
headers: HeaderMap,
32-
State(AppState{ grpc_client, .. }): State<AppState>,
32+
State(AppState { grpc_client, .. }): State<AppState>,
3333
request: Json<CompletionCreateParams>,
3434
) -> Response {
3535
tracing::info!("request: {:?}", request);
@@ -39,7 +39,9 @@ pub(crate) async fn compat_completions(
3939
.await
4040
.into_response()
4141
} else {
42-
completions(headers, grpc_client, request).await.into_response()
42+
completions(headers, grpc_client, request)
43+
.await
44+
.into_response()
4345
}
4446
}
4547

0 commit comments

Comments
 (0)