Skip to content

Commit

Permalink
Merge pull request #1 from ZhangHanDong/patch-1
Browse files Browse the repository at this point in the history
Add tests for BPE
  • Loading branch information
YdrMaster authored Aug 10, 2024
2 parents 751f2f8 + faecc3b commit 96db746
Showing 1 changed file with 120 additions and 10 deletions.
130 changes: 120 additions & 10 deletions src/bpe/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! b-p-e for Byte Pair Encoding
//! b-p-e for Byte Pair Encoding
mod algorithm;

Expand Down Expand Up @@ -247,16 +247,126 @@ fn rank(scores: &[f32]) -> impl IntoIterator<Item = u32> + '_ {
scores.iter().map(move |&f| map[&FloatOrd(f)])
}

#[test]
fn test() {
if let Ok(buf) = std::fs::read("tokenizer.model") {
let bpe = Bpe::from_tokenizer_model(&buf);
#[cfg(test)]
mod bpe_tests {
use super::*;

#[test]
fn test() {
if let Ok(buf) = std::fs::read("tokenizer.model") {
let bpe = Bpe::from_tokenizer_model(&buf);
let inaccessible = bpe.inaccessible();
println!(
"bpe: detected {} tokens, compressed to {} bytes",
bpe.vocab_size(),
bpe._vocabs.len(),
);
println!("inaccessible: {inaccessible:#?}");
}
}

#[test]
fn test_bpe_new() {
let vocabs = vec!["a", "b", "c", "ab", "bc"];
let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0];
let is_byte = vec![false, false, false, false, false];
let bpe = Bpe::new(vocabs, scores, is_byte, 0);
assert_eq!(bpe.vocab_size(), 5);
}

#[test]
fn test_bpe_encode() {
let vocabs = vec!["a", "b", "c", "ab", "bc"];
let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0];
let is_byte = vec![false, false, false, false, false];
let bpe = Bpe::new(vocabs, scores, is_byte, 0);

let encoded: Vec<_> = bpe.encode("abc").into_iter().collect();
assert_eq!(encoded.len(), 2); // Should merge "ab" and leave "c"
assert_eq!(encoded[0], 3); // Assuming "ab" is assigned token ID 3
assert_eq!(encoded[1], 2); // Assuming "c" is assigned token ID 2
}

#[test]
fn test_bpe_decode() {
let vocabs = vec!["a", "b", "c", "ab", "bc"];
let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0];
let is_byte = vec![false, false, false, false, false];
let bpe = Bpe::new(vocabs, scores, is_byte, 0);

assert_eq!(bpe.decode(3), b"ab");
assert_eq!(bpe.decode(2), b"c");
}

#[test]
fn test_bpe_encode_decode() {
let vocabs = vec!["a", "b", "c", "ab", "bc"];
let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0];
let is_byte = vec![false, false, false, false, false];
let bpe = Bpe::new(vocabs, scores, is_byte, 0);

let text = "abcbc";
let encoded: Vec<_> = bpe.encode(text).into_iter().collect();
let decoded: Vec<u8> = encoded
.iter()
.flat_map(|&t| bpe.decode(t).iter().copied())
.collect();
assert_eq!(String::from_utf8(decoded).unwrap(), text);
}

#[test]
fn test_bpe_unk_token() {
let vocabs = vec!["a", "b", "c"];
let scores = vec![1.0, 1.0, 1.0];
let is_byte = vec![false, false, false];
let unk_token = 100;
let bpe = Bpe::new(vocabs, scores, is_byte, unk_token);

assert_eq!(bpe.unk_token(), unk_token);
}

#[test]
fn test_bpe_inaccessible() {
let vocabs = vec!["a", "b", "c", "ab", "bcd", "d"];
let scores = vec![1.0, 1.0, 1.0, 2.0, 1.5, 1.0];
let is_byte = vec![false, false, false, false, false, false];
let bpe = Bpe::new(vocabs, scores, is_byte, 0);

let inaccessible = bpe.inaccessible();
println!(
"bpe: detected {} tokens, compressed to {} bytes",
bpe.vocab_size(),
bpe._vocabs.len(),
println!("Inaccessible tokens: {:?}", inaccessible);

// 'd' is a single character, so it should be accessible
assert!(
!inaccessible.contains_key("d"),
"Token 'd' should be accessible"
);

// 'bcd' cannot be formed by merging other tokens, so it should be inaccessible
assert!(
inaccessible.contains_key("bcd"),
"Token 'bcd' should be inaccessible"
);

// 'ab' can be formed by merging 'a' and 'b', so it should be accessible
assert!(
!inaccessible.contains_key("ab"),
"Token 'ab' should be accessible"
);
println!("inaccessible: {inaccessible:#?}");
}

#[test]
fn test_bpe_with_byte_tokens() {
let vocabs = vec!["a", "b", "<0x41>", "<0x42>"];
let scores = vec![1.0, 1.0, 1.0, 1.0];
let is_byte = vec![false, false, true, true];
let bpe = Bpe::new(vocabs, scores, is_byte, 0);

let input = "aAB";
let encoded: Vec<_> = bpe.encode(input).into_iter().collect();
println!("Input: {:?}", input);
println!("Encoded tokens: {:?}", encoded);
println!("Vocabulary size: {}", bpe.vocab_size());

assert_eq!(encoded.len(), 3, "Expected 3 tokens for input 'aAB'");
}
}

0 comments on commit 96db746

Please sign in to comment.