From ea8ffc7aeb2722bedb3cf4ace29e252e24e53067 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 7 Aug 2024 13:18:58 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20lpe=20tokenizer=20?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- Cargo.toml | 1 + src/bpe/mod.rs | 71 +++++++--------------------- src/functions.rs | 117 +++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 106 ++---------------------------------------- src/lpe/mod.rs | 96 ++++++++++++++++++++++++++++++++++++++ src/tokeneer.rs | 84 ++++++++++++++++++++++++++++++++++ 6 files changed, 319 insertions(+), 156 deletions(-) create mode 100644 src/functions.rs create mode 100644 src/lpe/mod.rs create mode 100644 src/tokeneer.rs diff --git a/Cargo.toml b/Cargo.toml index d027341..eb04949 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,3 +9,4 @@ authors = ["YdrMaster "] [dependencies] regex = "1.10" memchr = "2.7" +patricia_tree = "0.8" diff --git a/src/bpe/mod.rs b/src/bpe/mod.rs index 0ddb5c6..7b0e7f2 100644 --- a/src/bpe/mod.rs +++ b/src/bpe/mod.rs @@ -1,6 +1,11 @@ -mod algorithm; +//! b-p-e for Byte Pair Encoding -use crate::{utok, Method}; +mod algorithm; + +use crate::{ + functions::{collect_vocabs_with_hint, CompressedVocab}, + utok, Method, +}; use std::{ collections::{HashMap, HashSet}, iter::zip, @@ -11,7 +16,7 @@ use std::{ pub struct Bpe { /// 保存所有词的字符串内容,以 u8 为单位所以不需要对齐,占用空间少 - _vocab: Pin>, + _vocabs: Pin>, /// 按 token 顺序保存元信息 tokens: Box<[TokenMeta]>, /// 按字符串的字典序排序的 token 索引,用于从字符串二分查找 token。 @@ -46,7 +51,7 @@ impl Deref for TokenMeta { impl Bpe { /// 解析 tokenizer.model 文件并构造一个 bpe 分词器。 pub fn from_tokenizer_model(model: &[u8]) -> Self { - // 遍历文件,标记所有词汇的位置并记录最大长度 + // 遍历文件,标记所有词汇的位置 let offsets = (0..) .scan(0usize, |offset, _| match &model[*offset..] { [10, total_len, 10, content @ ..] => { @@ -91,51 +96,9 @@ impl Bpe { is_byte: impl IntoIterator, unk: utok, ) -> Self { - let mut bytes = Box::new([unk; 256]); - let mut total_len = 0; - // 收集词表字符内容和字节 token,同时计算内容总长度 - let vocabs = zip(vocabs, is_byte) - .enumerate() - .map(|(i, (piece, is_byte))| { - let piece = if is_byte { - const BYTES: [u8; 256] = { - let mut bytes = [0u8; 256]; - let mut i = 0usize; - while i < 256 { - bytes[i] = i as _; - i += 1; - } - bytes - }; - - let b = crate::as_byte_token(piece.as_bytes()).unwrap() as usize; - bytes[b] = i as utok; - std::slice::from_ref(&BYTES[b]) - } else { - piece.as_bytes() - }; - total_len += piece.len(); - piece - }) - .collect::>(); - // 创建字符内容缓存 - let mut slices = vec![(0usize, 0usize); vocabs.len()]; - let mut text_buf = Vec::::with_capacity(total_len); - let mut indices = (0..vocabs.len()).collect::>(); - // 对词按内容长度从长到短排序,因为短的内容有可能是长内容的子串,可以避免重复存储相同内容 - indices.sort_unstable_by_key(|&i| -(vocabs[i].len() as isize)); - for i in indices { - let v = vocabs[i]; - // 查找子串,若存在则复用,否则将新的内容追加到缓存 - let off = memchr::memmem::find(&text_buf, v).unwrap_or_else(|| { - let off = text_buf.len(); - text_buf.extend(v); - off - }); - slices[i] = (off, v.len()); - } - // 锁定字符串内容的位置,以实现安全的自引用 - let _vocab = unsafe { Pin::new_unchecked(text_buf.into_boxed_slice()) }; + let (vocabs, bytes, total_len) = + collect_vocabs_with_hint(vocabs.into_iter().map(|s| s.as_bytes()), is_byte, unk); + let CompressedVocab { vocabs, slices } = CompressedVocab::new(&vocabs, total_len); // 收集合词评分 let scores = scores.into_iter().collect::>(); assert_eq!( @@ -146,7 +109,7 @@ impl Bpe { // tokens 中直接引用字符串位置,绑定重新赋权并转换为整型的分词评分 let tokens = zip(slices, rank(&scores)) .map(|((off, len), rank)| TokenMeta { - ptr: unsafe { NonNull::new_unchecked(_vocab[off..].as_ptr().cast_mut()) }, + ptr: unsafe { NonNull::new_unchecked(vocabs[off..].as_ptr().cast_mut()) }, len: len as _, rank, }) @@ -160,13 +123,13 @@ impl Bpe { sorted_pieces.sort_unstable_by_key(|&i| &*tokens[i as usize]); // println!( - // "Building BPE vocab, detected {} tokens, compressed to {} bytes from {len} bytes", + // "Building BPE vocab, detected {} tokens, compressed to {} bytes from {total_len} bytes", // tokens.len(), - // _vocab.len(), + // vocabs.len(), // ); Self { - _vocab, + _vocabs: vocabs, tokens, sorted_pieces, bytes, @@ -283,7 +246,7 @@ fn test() { println!( "bpe: detected {} tokens, compressed to {} bytes", bpe.vocab_size(), - bpe._vocab.len(), + bpe._vocabs.len(), ); println!("inaccessible: {inaccessible:#?}"); } diff --git a/src/functions.rs b/src/functions.rs new file mode 100644 index 0000000..31483f9 --- /dev/null +++ b/src/functions.rs @@ -0,0 +1,117 @@ +use crate::utok; +use std::{iter::zip, pin::Pin, slice::from_ref}; + +// 收集词表字符内容和字节 token,同时计算内容总长度 +pub(crate) fn collect_vocabs<'s>( + vocabs: impl IntoIterator, + unk: utok, +) -> (Vec<&'s [u8]>, Box<[utok; 256]>, usize) { + let mut bytes = Box::new([unk; 256]); + let mut total_len = 0; + let vocabs = vocabs + .into_iter() + .enumerate() + .map(|(i, piece)| { + let piece = match as_byte_token(piece) { + Some(b) => { + let b = b as usize; + bytes[b] = i as _; + from_ref(&BYTES[b]) + } + None => piece, + }; + total_len += piece.len(); + piece + }) + .collect(); + (vocabs, bytes, total_len) +} + +// 收集词表字符内容和字节 token,同时计算内容总长度 +pub(crate) fn collect_vocabs_with_hint<'s>( + vocabs: impl IntoIterator, + is_byte: impl IntoIterator, + unk: utok, +) -> (Vec<&'s [u8]>, Box<[utok; 256]>, usize) { + let mut bytes = Box::new([unk; 256]); + let mut total_len = 0; + let vocabs = zip(vocabs, is_byte) + .enumerate() + .map(|(i, (piece, is_byte))| { + if is_byte { + let b = as_byte_token(piece) + .unwrap_or_else(|| panic!("{piece:?} is not a valid byte token")) + as usize; + bytes[b] = i as _; + from_ref(&BYTES[b]) + } else { + piece + }; + total_len += piece.len(); + piece + }) + .collect(); + (vocabs, bytes, total_len) +} + +pub(crate) struct CompressedVocab { + pub vocabs: Pin>, + pub slices: Vec<(usize, usize)>, +} + +impl CompressedVocab { + pub fn new(vocabs: &[&[u8]], total_len: usize) -> Self { + // 创建字符内容缓存 + let mut slices = vec![(0usize, 0usize); vocabs.len()]; + let mut text_buf = Vec::::with_capacity(total_len); + let mut indices = (0..vocabs.len()).collect::>(); + // 对词按内容长度从长到短排序,因为短的内容有可能是长内容的子串,可以避免重复存储相同内容 + indices.sort_unstable_by_key(|&i| -(vocabs[i].len() as isize)); + for i in indices { + let v = vocabs[i]; + // 查找子串,若存在则复用,否则将新的内容追加到缓存 + let off = memchr::memmem::find(&text_buf, v).unwrap_or_else(|| { + let off = text_buf.len(); + text_buf.extend(v); + off + }); + slices[i] = (off, v.len()); + } + Self { + // 锁定字符串内容的位置,以实现安全的自引用 + vocabs: unsafe { Pin::new_unchecked(text_buf.into_boxed_slice()) }, + slices, + } + } +} + +const BYTES: [u8; 256] = { + let mut bytes = [0u8; 256]; + let mut i = 0usize; + while i < 256 { + bytes[i] = i as _; + i += 1; + } + bytes +}; + +const fn as_byte_token(piece: &[u8]) -> Option { + // 按结构分解并转换 + match piece { + &[b'<', b'0', b'x', a, b, b'>'] if a.is_ascii_hexdigit() && b.is_ascii_hexdigit() => { + // ascii 转数字 + #[inline(always)] + const fn to_num(c: u8) -> u8 { + match c { + b'0'..=b'9' => c - b'0', + b'a'..=b'f' => c - b'a' + 10, + b'A'..=b'F' => c - b'A' + 10, + _ => unreachable!(), + } + } + + Some(to_num(a) * 16 + to_num(b)) + } + _ => None, + } +} diff --git a/src/lib.rs b/src/lib.rs index 491b0e7..0876518 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,12 @@ #![deny(warnings)] mod bpe; - -use regex::Regex; -use std::collections::HashMap; +mod functions; +mod lpe; +mod tokeneer; pub use bpe::Bpe; +pub use lpe::Lpe; /// `utok` for token id. #[allow(non_camel_case_types)] @@ -18,102 +19,3 @@ pub trait Method { fn encode(&self, text: &str) -> impl IntoIterator + '_; fn decode(&self, token: utok) -> &[u8]; } - -pub struct Tokeneer { - method: M, - special: HashMap>, - special_regex: regex::Regex, -} - -impl Tokeneer { - pub fn new(method: M) -> Self { - let special = method - .internal_special() - .into_iter() - .map(|(k, v)| (k.to_string(), vec![v])) - .collect::>(); - let special_regex = build_pattern(special.keys()); - Self { - method, - special, - special_regex, - } - } - - pub fn extend_special(&mut self, patterns: impl IntoIterator)>) { - use std::collections::hash_map::Entry::{Occupied, Vacant}; - let mut any = false; - for (k, v) in patterns { - match self.special.entry(k) { - Occupied(entry) => { - assert_eq!(entry.get(), &v); - } - Vacant(entry) => { - entry.insert(v); - any = true; - } - } - } - if any { - self.special_regex = build_pattern(self.special.keys()); - } - } - - pub fn encode(&self, text: &str) -> Vec { - let mut ans = Vec::new(); - let mut start = 0; - for m in self.special_regex.find_iter(text) { - ans.extend(self.method.encode(&text[start..m.start()])); - ans.extend_from_slice(&self.special[m.as_str()]); - start = m.end(); - } - ans.extend(self.method.encode(&text[start..])); - ans - } - - pub fn decode(&self, tokens: &[utok]) -> String { - let mut ans = Vec::new(); - for &t in tokens { - ans.extend_from_slice(self.method.decode(t)); - } - String::from_utf8(ans).unwrap() - } - #[inline] - pub fn internal(&self) -> &M { - &self.method - } -} - -fn build_pattern>(text: impl IntoIterator) -> Regex { - let mut pattern = String::new(); - let mut iter = text.into_iter(); - if let Some(p) = iter.next() { - pattern.push_str(p.as_ref()); - } - for p in iter { - pattern.push('|'); - pattern.push_str(p.as_ref()); - } - regex::Regex::new(&pattern).unwrap() -} - -const fn as_byte_token(piece: &[u8]) -> Option { - // 按结构分解并转换 - match piece { - &[b'<', b'0', b'x', a, b, b'>'] if a.is_ascii_hexdigit() && b.is_ascii_hexdigit() => { - // ascii 转数字 - #[inline(always)] - const fn to_num(c: u8) -> u8 { - match c { - b'0'..=b'9' => c - b'0', - b'a'..=b'f' => c - b'a' + 10, - b'A'..=b'F' => c - b'A' + 10, - _ => unreachable!(), - } - } - - Some(to_num(a) * 16 + to_num(b)) - } - _ => None, - } -} diff --git a/src/lpe/mod.rs b/src/lpe/mod.rs new file mode 100644 index 0000000..d4b782c --- /dev/null +++ b/src/lpe/mod.rs @@ -0,0 +1,96 @@ +//! l-p-e for Longest Prefix Encoding + +use crate::{ + functions::{collect_vocabs, CompressedVocab}, + utok, Method, +}; +use patricia_tree::PatriciaMap; +use std::{collections::HashSet, pin::Pin}; + +pub struct Lpe { + /// 保存所有词的字符串内容,以 u8 为单位所以不需要对齐,占用空间少 + vocabs: Pin>, + /// 按 token 顺序保存元信息 + tokens: Box<[(u32, u32)]>, + /// 词汇的前缀树 + trie: PatriciaMap, + /// 用于索引单字节 token,因此不需要其他元信息 + bytes: Box<[utok; 256]>, + /// token: + unk: utok, +} + +impl Lpe { + pub fn new<'a>(vocabs: impl IntoIterator, unk: utok) -> Self { + let (vocabs, bytes, total_len) = collect_vocabs(vocabs, unk); + let CompressedVocab { vocabs, slices } = CompressedVocab::new(&vocabs, total_len); + let tokens = slices + .into_iter() + .map(|(off, len)| (off as u32, len as u32)) + .collect::>(); + + let bytes_set = bytes.iter().chain(&[unk]).cloned().collect::>(); + let trie = tokens + .iter() + .enumerate() + .filter(|&(i, _)| !bytes_set.contains(&(i as utok))) + .map(|(i, &(off, len))| (&vocabs[off as usize..][..len as usize], i as utok)) + .collect(); + + // println!( + // "Building LPE vocab, detected {} tokens, compressed to {} bytes from {total_len} bytes", + // tokens.len(), + // vocabs.len(), + // ); + + Self { + vocabs, + tokens, + trie, + bytes, + unk, + } + } + + /// token id -> token meta + #[inline(always)] + fn token(&self, token: utok) -> &[u8] { + let (off, len) = self.tokens[token as usize]; + &self.vocabs[off as usize..][..len as usize] + } +} + +impl Method for Lpe { + #[inline] + fn unk_token(&self) -> utok { + self.unk + } + #[inline] + fn vocab_size(&self) -> usize { + self.tokens.len() + } + #[inline] + fn internal_special(&self) -> impl IntoIterator { + [] + } + #[inline] + fn encode(&self, text: &str) -> impl IntoIterator + '_ { + let mut text = text.as_bytes(); + let mut tokens = Vec::::new(); + + while !text.is_empty() { + let (tok, len) = match self.trie.get_longest_common_prefix(text) { + Some((pre, tok)) => (*tok, pre.len()), + None => (self.bytes[text[0] as usize], 1), + }; + tokens.push(tok); + text = &text[len..]; + } + + tokens + } + #[inline] + fn decode(&self, token: utok) -> &[u8] { + self.token(token) + } +} diff --git a/src/tokeneer.rs b/src/tokeneer.rs new file mode 100644 index 0000000..a523d4f --- /dev/null +++ b/src/tokeneer.rs @@ -0,0 +1,84 @@ +use crate::{utok, Method}; +use regex::Regex; +use std::collections::HashMap; + +pub struct Tokeneer { + method: M, + special: HashMap>, + special_regex: Regex, +} + +impl Tokeneer { + pub fn new(method: M) -> Self { + let special = method + .internal_special() + .into_iter() + .map(|(k, v)| (k.to_string(), vec![v])) + .collect::>(); + let special_regex = build_pattern(special.keys()); + Self { + method, + special, + special_regex, + } + } + + pub fn encode(&self, text: &str) -> Vec { + let mut ans = Vec::new(); + let mut start = 0; + for m in self.special_regex.find_iter(text) { + ans.extend(self.method.encode(&text[start..m.start()])); + ans.extend_from_slice(&self.special[m.as_str()]); + start = m.end(); + } + ans.extend(self.method.encode(&text[start..])); + ans + } + + pub fn decode(&self, tokens: &[utok]) -> String { + let mut ans = Vec::new(); + for &t in tokens { + ans.extend_from_slice(self.method.decode(t)); + } + String::from_utf8(ans).unwrap() + } +} + +impl Tokeneer { + pub fn extend_special(&mut self, patterns: impl IntoIterator)>) { + use std::collections::hash_map::Entry::{Occupied, Vacant}; + let mut any = false; + for (k, v) in patterns { + match self.special.entry(k) { + Occupied(entry) => { + assert_eq!(entry.get(), &v); + } + Vacant(entry) => { + entry.insert(v); + any = true; + } + } + } + if any { + self.special_regex = build_pattern(self.special.keys()); + } + } + + #[inline] + pub fn internal(&self) -> &M { + &self.method + } +} + +fn build_pattern>(text: impl IntoIterator) -> Regex { + let mut pattern = String::new(); + let mut iter = text.into_iter(); + if let Some(p) = iter.next() { + pattern.push_str(p.as_ref()); + } + for p in iter { + pattern.push('|'); + pattern.push_str(p.as_ref()); + } + Regex::new(&pattern).unwrap() +}