diff --git a/src/bpe/mod.rs b/src/bpe/mod.rs index a49ff26..c482474 100644 --- a/src/bpe/mod.rs +++ b/src/bpe/mod.rs @@ -247,16 +247,114 @@ fn rank(scores: &[f32]) -> impl IntoIterator + '_ { scores.iter().map(move |&f| map[&FloatOrd(f)]) } -#[test] -fn test() { - if let Ok(buf) = std::fs::read("tokenizer.model") { - let bpe = Bpe::from_tokenizer_model(&buf); +#[cfg(test)] +mod bpe_tests { + use super::*; + + #[test] + fn test() { + if let Ok(buf) = std::fs::read("tokenizer.model") { + let bpe = Bpe::from_tokenizer_model(&buf); + let inaccessible = bpe.inaccessible(); + println!( + "bpe: detected {} tokens, compressed to {} bytes", + bpe.vocab_size(), + bpe._vocabs.len(), + ); + println!("inaccessible: {inaccessible:#?}"); + } + } + + #[test] + fn test_bpe_new() { + let vocabs = vec!["a", "b", "c", "ab", "bc"]; + let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0]; + let is_byte = vec![false, false, false, false, false]; + let bpe = Bpe::new(vocabs, scores, is_byte, 0); + assert_eq!(bpe.vocab_size(), 5); + } + + #[test] + fn test_bpe_encode() { + let vocabs = vec!["a", "b", "c", "ab", "bc"]; + let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0]; + let is_byte = vec![false, false, false, false, false]; + let bpe = Bpe::new(vocabs, scores, is_byte, 0); + + let encoded: Vec<_> = bpe.encode("abc").into_iter().collect(); + assert_eq!(encoded.len(), 2); // Should merge "ab" and leave "c" + assert_eq!(encoded[0], 3); // Assuming "ab" is assigned token ID 3 + assert_eq!(encoded[1], 2); // Assuming "c" is assigned token ID 2 + } + + #[test] + fn test_bpe_decode() { + let vocabs = vec!["a", "b", "c", "ab", "bc"]; + let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0]; + let is_byte = vec![false, false, false, false, false]; + let bpe = Bpe::new(vocabs, scores, is_byte, 0); + + assert_eq!(bpe.decode(3), b"ab"); + assert_eq!(bpe.decode(2), b"c"); + } + + #[test] + fn test_bpe_encode_decode() { + let vocabs = vec!["a", "b", "c", "ab", "bc"]; + let scores = vec![1.0, 1.0, 1.0, 2.0, 2.0]; + let is_byte = vec![false, false, false, false, false]; + let bpe = Bpe::new(vocabs, scores, is_byte, 0); + + let text = "abcbc"; + let encoded: Vec<_> = bpe.encode(text).into_iter().collect(); + let decoded: Vec = encoded.iter().flat_map(|&t| bpe.decode(t).iter().copied()).collect(); + assert_eq!(String::from_utf8(decoded).unwrap(), text); + } + + #[test] + fn test_bpe_unk_token() { + let vocabs = vec!["a", "b", "c"]; + let scores = vec![1.0, 1.0, 1.0]; + let is_byte = vec![false, false, false]; + let unk_token = 100; + let bpe = Bpe::new(vocabs, scores, is_byte, unk_token); + + assert_eq!(bpe.unk_token(), unk_token); + } + + #[test] + fn test_bpe_inaccessible() { + let vocabs = vec!["a", "b", "c", "ab", "bcd", "d"]; + let scores = vec![1.0, 1.0, 1.0, 2.0, 1.5, 1.0]; + let is_byte = vec![false, false, false, false, false, false]; + let bpe = Bpe::new(vocabs, scores, is_byte, 0); + let inaccessible = bpe.inaccessible(); - println!( - "bpe: detected {} tokens, compressed to {} bytes", - bpe.vocab_size(), - bpe._vocabs.len(), - ); - println!("inaccessible: {inaccessible:#?}"); + println!("Inaccessible tokens: {:?}", inaccessible); + + // 'd' is a single character, so it should be accessible + assert!(!inaccessible.contains_key("d"), "Token 'd' should be accessible"); + + // 'bcd' cannot be formed by merging other tokens, so it should be inaccessible + assert!(inaccessible.contains_key("bcd"), "Token 'bcd' should be inaccessible"); + + // 'ab' can be formed by merging 'a' and 'b', so it should be accessible + assert!(!inaccessible.contains_key("ab"), "Token 'ab' should be accessible"); + } + + #[test] + fn test_bpe_with_byte_tokens() { + let vocabs = vec!["a", "b", "<0x41>", "<0x42>"]; + let scores = vec![1.0, 1.0, 1.0, 1.0]; + let is_byte = vec![false, false, true, true]; + let bpe = Bpe::new(vocabs, scores, is_byte, 0); + + let input = "aAB"; + let encoded: Vec<_> = bpe.encode(input).into_iter().collect(); + println!("Input: {:?}", input); + println!("Encoded tokens: {:?}", encoded); + println!("Vocabulary size: {}", bpe.vocab_size()); + + assert_eq!(encoded.len(), 3, "Expected 3 tokens for input 'aAB'"); } }