Skip to content

Commit 94aa897

Browse files
committed
fix(bpe): 强化 bpe 测试,改正 unk 解码
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 96db746 commit 94aa897

File tree

2 files changed

+60
-58
lines changed

2 files changed

+60
-58
lines changed

src/bpe/algorithm.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,11 @@ impl Iterator for IntoIter<'_> {
222222
fn next(&mut self) -> Option<Self::Item> {
223223
match &self.marks[self.i..] {
224224
&[Mark { token, .. }, ..] => {
225-
self.i += self.bpe.token(token).len();
225+
self.i += if token == self.bpe.unk {
226+
1
227+
} else {
228+
self.bpe.token(token).len()
229+
};
226230
Some(token)
227231
}
228232
[] => None,
@@ -236,7 +240,11 @@ impl Iterator for Iter<'_> {
236240
fn next(&mut self) -> Option<Self::Item> {
237241
match self.marks {
238242
&[Mark { token, .. }, ref tail @ ..] => {
239-
self.marks = &tail[self.bpe.token(token).len() - 1..];
243+
self.marks = if token == self.bpe.unk {
244+
tail
245+
} else {
246+
&tail[self.bpe.token(token).len() - 1..]
247+
};
240248
Some(token)
241249
}
242250
[] => None,

src/bpe/mod.rs

Lines changed: 50 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -265,73 +265,71 @@ mod bpe_tests {
265265
}
266266
}
267267

268+
fn test_bpe() -> Bpe {
269+
Bpe::new(
270+
[
271+
"<unk>", //
272+
"a", "b", "c", "d", //
273+
"ab", "ac", "ad", "bd", //
274+
"bcd",
275+
],
276+
[
277+
0., //
278+
1., 1., 1., 1., //
279+
1.1, 1.2, 1.3, 1.4, //
280+
10.,
281+
],
282+
[false; 10],
283+
0,
284+
)
285+
}
286+
268287
#[test]
269288
fn test_bpe_new() {
270-
let vocabs = vec!["a", "b", "c", "ab", "bc"];
271-
let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0];
272-
let is_byte = vec![false, false, false, false, false];
273-
let bpe = Bpe::new(vocabs, scores, is_byte, 0);
274-
assert_eq!(bpe.vocab_size(), 5);
289+
let bpe = test_bpe();
290+
assert_eq!(bpe.vocab_size(), 10);
275291
}
276292

277293
#[test]
278-
fn test_bpe_encode() {
279-
let vocabs = vec!["a", "b", "c", "ab", "bc"];
280-
let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0];
281-
let is_byte = vec![false, false, false, false, false];
282-
let bpe = Bpe::new(vocabs, scores, is_byte, 0);
294+
fn test_bpe_unk_token() {
295+
let bpe = test_bpe();
296+
assert_eq!(bpe.unk_token(), 0);
297+
}
283298

284-
let encoded: Vec<_> = bpe.encode("abc").into_iter().collect();
285-
assert_eq!(encoded.len(), 2); // Should merge "ab" and leave "c"
286-
assert_eq!(encoded[0], 3); // Assuming "ab" is assigned token ID 3
287-
assert_eq!(encoded[1], 2); // Assuming "c" is assigned token ID 2
299+
#[test]
300+
fn test_bpe_encode() {
301+
let bpe = test_bpe();
302+
let encoded: Vec<_> = bpe.encode("abd").into_iter().collect();
303+
assert_eq!(encoded, [1, 8]); // Should merge "bc" and leave "a"
288304
}
289305

290306
#[test]
291307
fn test_bpe_decode() {
292-
let vocabs = vec!["a", "b", "c", "ab", "bc"];
293-
let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0];
294-
let is_byte = vec![false, false, false, false, false];
295-
let bpe = Bpe::new(vocabs, scores, is_byte, 0);
296-
297-
assert_eq!(bpe.decode(3), b"ab");
298-
assert_eq!(bpe.decode(2), b"c");
308+
let bpe = test_bpe();
309+
assert_eq!(bpe.decode(3), b"c");
310+
assert_eq!(bpe.decode(6), b"ac");
311+
assert_eq!(bpe.decode(9), b"bcd");
312+
assert_eq!(bpe.decode(0), b"<unk>");
299313
}
300314

301315
#[test]
302316
fn test_bpe_encode_decode() {
303-
let vocabs = vec!["a", "b", "c", "ab", "bc"];
304-
let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0];
305-
let is_byte = vec![false, false, false, false, false];
306-
let bpe = Bpe::new(vocabs, scores, is_byte, 0);
317+
let bpe = test_bpe();
307318

308-
let text = "abcbc";
319+
let text = "abcdx";
309320
let encoded: Vec<_> = bpe.encode(text).into_iter().collect();
310-
let decoded: Vec<u8> = encoded
321+
assert_eq!(encoded, [5, 3, 4, 0]);
322+
323+
let decoded: Vec<_> = encoded
311324
.iter()
312325
.flat_map(|&t| bpe.decode(t).iter().copied())
313326
.collect();
314-
assert_eq!(String::from_utf8(decoded).unwrap(), text);
315-
}
316-
317-
#[test]
318-
fn test_bpe_unk_token() {
319-
let vocabs = vec!["a", "b", "c"];
320-
let scores = vec![1.0, 1.0, 1.0];
321-
let is_byte = vec![false, false, false];
322-
let unk_token = 100;
323-
let bpe = Bpe::new(vocabs, scores, is_byte, unk_token);
324-
325-
assert_eq!(bpe.unk_token(), unk_token);
327+
assert_eq!(std::str::from_utf8(&decoded), Ok("abcd<unk>"));
326328
}
327329

328330
#[test]
329331
fn test_bpe_inaccessible() {
330-
let vocabs = vec!["a", "b", "c", "ab", "bcd", "d"];
331-
let scores = vec![1.0, 1.0, 1.0, 2.0, 1.5, 1.0];
332-
let is_byte = vec![false, false, false, false, false, false];
333-
let bpe = Bpe::new(vocabs, scores, is_byte, 0);
334-
332+
let bpe = test_bpe();
335333
let inaccessible = bpe.inaccessible();
336334
println!("Inaccessible tokens: {:?}", inaccessible);
337335

@@ -342,8 +340,9 @@ mod bpe_tests {
342340
);
343341

344342
// 'bcd' cannot be formed by merging other tokens, so it should be inaccessible
345-
assert!(
346-
inaccessible.contains_key("bcd"),
343+
assert_eq!(
344+
inaccessible.get("bcd"),
345+
Some(&9),
347346
"Token 'bcd' should be inaccessible"
348347
);
349348

@@ -356,17 +355,12 @@ mod bpe_tests {
356355

357356
#[test]
358357
fn test_bpe_with_byte_tokens() {
359-
let vocabs = vec!["a", "b", "<0x41>", "<0x42>"];
360-
let scores = vec![1.0, 1.0, 1.0, 1.0];
361-
let is_byte = vec![false, false, true, true];
358+
let vocabs = ["a", "b", "<0x41>", "<0x42>"];
359+
let scores = [1.0, 1.0, 1.0, 1.0];
360+
let is_byte = [false, false, true, true];
362361
let bpe = Bpe::new(vocabs, scores, is_byte, 0);
363362

364-
let input = "aAB";
365-
let encoded: Vec<_> = bpe.encode(input).into_iter().collect();
366-
println!("Input: {:?}", input);
367-
println!("Encoded tokens: {:?}", encoded);
368-
println!("Vocabulary size: {}", bpe.vocab_size());
369-
370-
assert_eq!(encoded.len(), 3, "Expected 3 tokens for input 'aAB'");
363+
let encoded: Vec<_> = bpe.encode("aAB").into_iter().collect();
364+
assert_eq!(encoded, [0, 2, 3], "Expected 3 tokens for input 'aAB'");
371365
}
372366
}

0 commit comments

Comments
 (0)