Skip to content

Commit

Permalink
Add tests for BPE
Browse files Browse the repository at this point in the history
Add tests for BPE
  • Loading branch information
ZhangHanDong authored Aug 9, 2024
1 parent 751f2f8 commit 5322c28
Showing 1 changed file with 108 additions and 10 deletions.
118 changes: 108 additions & 10 deletions src/bpe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,114 @@ fn rank(scores: &[f32]) -> impl IntoIterator<Item = u32> + '_ {
scores.iter().map(move |&f| map[&FloatOrd(f)])
}

#[test]
fn test() {
if let Ok(buf) = std::fs::read("tokenizer.model") {
let bpe = Bpe::from_tokenizer_model(&buf);
#[cfg(test)]
mod bpe_tests {
use super::*;

#[test]
fn test() {
if let Ok(buf) = std::fs::read("tokenizer.model") {
let bpe = Bpe::from_tokenizer_model(&buf);
let inaccessible = bpe.inaccessible();
println!(
"bpe: detected {} tokens, compressed to {} bytes",
bpe.vocab_size(),
bpe._vocabs.len(),
);
println!("inaccessible: {inaccessible:#?}");
}
}

#[test]
fn test_bpe_new() {
let vocabs = vec!["a", "b", "c", "ab", "bc"];
let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0];
let is_byte = vec![false, false, false, false, false];
let bpe = Bpe::new(vocabs, scores, is_byte, 0);
assert_eq!(bpe.vocab_size(), 5);
}

#[test]
fn test_bpe_encode() {
let vocabs = vec!["a", "b", "c", "ab", "bc"];
let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0];
let is_byte = vec![false, false, false, false, false];
let bpe = Bpe::new(vocabs, scores, is_byte, 0);

let encoded: Vec<_> = bpe.encode("abc").into_iter().collect();
assert_eq!(encoded.len(), 2); // Should merge "ab" and leave "c"
assert_eq!(encoded[0], 3); // Assuming "ab" is assigned token ID 3
assert_eq!(encoded[1], 2); // Assuming "c" is assigned token ID 2
}

#[test]
fn test_bpe_decode() {
let vocabs = vec!["a", "b", "c", "ab", "bc"];
let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0];
let is_byte = vec![false, false, false, false, false];
let bpe = Bpe::new(vocabs, scores, is_byte, 0);

assert_eq!(bpe.decode(3), b"ab");
assert_eq!(bpe.decode(2), b"c");
}

#[test]
fn test_bpe_encode_decode() {
let vocabs = vec!["a", "b", "c", "ab", "bc"];
let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0];
let is_byte = vec![false, false, false, false, false];
let bpe = Bpe::new(vocabs, scores, is_byte, 0);

let text = "abcbc";
let encoded: Vec<_> = bpe.encode(text).into_iter().collect();
let decoded: Vec<u8> = encoded.iter().flat_map(|&t| bpe.decode(t).iter().copied()).collect();
assert_eq!(String::from_utf8(decoded).unwrap(), text);
}

#[test]
fn test_bpe_unk_token() {
let vocabs = vec!["a", "b", "c"];
let scores = vec![1.0, 1.0, 1.0];
let is_byte = vec![false, false, false];
let unk_token = 100;
let bpe = Bpe::new(vocabs, scores, is_byte, unk_token);

assert_eq!(bpe.unk_token(), unk_token);
}

#[test]
fn test_bpe_inaccessible() {
let vocabs = vec!["a", "b", "c", "ab", "bcd", "d"];
let scores = vec![1.0, 1.0, 1.0, 2.0, 1.5, 1.0];
let is_byte = vec![false, false, false, false, false, false];
let bpe = Bpe::new(vocabs, scores, is_byte, 0);

let inaccessible = bpe.inaccessible();
println!(
"bpe: detected {} tokens, compressed to {} bytes",
bpe.vocab_size(),
bpe._vocabs.len(),
);
println!("inaccessible: {inaccessible:#?}");
println!("Inaccessible tokens: {:?}", inaccessible);

// 'd' is a single character, so it should be accessible
assert!(!inaccessible.contains_key("d"), "Token 'd' should be accessible");

// 'bcd' cannot be formed by merging other tokens, so it should be inaccessible
assert!(inaccessible.contains_key("bcd"), "Token 'bcd' should be inaccessible");

// 'ab' can be formed by merging 'a' and 'b', so it should be accessible
assert!(!inaccessible.contains_key("ab"), "Token 'ab' should be accessible");
}

#[test]
fn test_bpe_with_byte_tokens() {
let vocabs = vec!["a", "b", "<0x41>", "<0x42>"];
let scores = vec![1.0, 1.0, 1.0, 1.0];
let is_byte = vec![false, false, true, true];
let bpe = Bpe::new(vocabs, scores, is_byte, 0);

let input = "aAB";
let encoded: Vec<_> = bpe.encode(input).into_iter().collect();
println!("Input: {:?}", input);
println!("Encoded tokens: {:?}", encoded);
println!("Vocabulary size: {}", bpe.vocab_size());

assert_eq!(encoded.len(), 3, "Expected 3 tokens for input 'aAB'");
}
}

0 comments on commit 5322c28

Please sign in to comment.