Skip to content

Commit

Permalink
perf(bpe): 压缩词表内容以降低空间占用并提升局部性
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Aug 7, 2024
1 parent 8a9c60d commit 25973a8
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 22 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
/Cargo.lock
/tokenizer.model
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]

[dependencies]
regex = "1.10"
memchr = "2.7"
78 changes: 56 additions & 22 deletions src/bpe/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod algorithm;

use crate::{as_byte_token, utok, Method};
use crate::{utok, Method};
use std::{
collections::{HashMap, HashSet},
iter::zip,
Expand Down Expand Up @@ -82,41 +82,62 @@ impl Bpe {
}
});
// 构造分词器
Self::new(vocabs, scores, is_byte, 0, offsets.len())
Self::new(vocabs, scores, is_byte, 0)
}

pub fn new<'a>(
vocabs: impl IntoIterator<Item = &'a str>,
scores: impl IntoIterator<Item = f32>,
is_byte: impl IntoIterator<Item = bool>,
unk: utok,
vocab_size_hint: usize,
) -> Self {
let mut text_buf = Vec::with_capacity(vocab_size_hint * 4);
let mut bytes = Box::new([unk; 256]);
// 重新编排词表
// 将字符串的内容和元信息分离
// 内容全部保存到 text_buf 以实现缓存友好性
// 字符串起始位置在 text_buf 中的偏移量和字符串长度保存到 meta 中
let meta = vocabs
.into_iter()
.map(str::as_bytes)
.zip(is_byte)
let mut total_len = 0;
// 收集词表字符内容和字节 token,同时计算内容总长度
let vocabs = zip(vocabs, is_byte)
.enumerate()
.map(|(t, (piece, is_byte))| {
let off = text_buf.len();
let len = if is_byte {
let b = as_byte_token(piece).unwrap();
text_buf.push(b);
bytes[b as usize] = t as utok;
1
.map(|(i, (piece, is_byte))| {
let piece = if is_byte {
const BYTES: [u8; 256] = {
let mut bytes = [0u8; 256];
let mut i = 0usize;
while i < 256 {
bytes[i] = i as _;
i += 1;
}
bytes
};

let b = crate::as_byte_token(piece.as_bytes()).unwrap() as usize;
bytes[b] = i as utok;
std::slice::from_ref(&BYTES[b])
} else {
text_buf.extend_from_slice(piece);
piece.len()
piece.as_bytes()
};
(off, len)
total_len += piece.len();
piece
})
.collect::<Vec<_>>();
// 创建字符内容缓存
let mut text_buf = Vec::<u8>::with_capacity(total_len);
let meta = {
let mut indices = (0..vocabs.len()).collect::<Vec<_>>();
// 对词按内容长度从长到短排序,因为短的内容有可能是长内容的子串,可以避免重复存储相同内容
indices.sort_unstable_by_key(|&i| -(vocabs[i].len() as isize));
indices
}
.into_iter()
.map(|i| vocabs[i])
.map(|v| {
// 查找子串,若存在则复用,否则将新的内容追加到缓存
let off = memchr::memmem::find(&text_buf, v).unwrap_or_else(|| {
let off = text_buf.len();
text_buf.extend(v);
off
});
(off, v.len())
})
.collect::<Vec<_>>();
// 锁定字符串内容的位置,以实现安全的自引用
let _vocab = unsafe { Pin::new_unchecked(text_buf.into_boxed_slice()) };
// 对分词评分重新赋权,转换为整型
Expand All @@ -142,6 +163,12 @@ impl Bpe {
.collect::<Box<[_]>>();
sorted_pieces.sort_unstable_by_key(|&i| &*tokens[i as usize]);

// println!(
// "Building BPE vocab, detected {} tokens, compressed to {} bytes from {len} bytes",
// tokens.len(),
// _vocab.len(),
// );

Self {
_vocab,
tokens,
Expand Down Expand Up @@ -251,3 +278,10 @@ fn rank(scores: &[f32]) -> Vec<u32> {

scores.iter().map(|f| map[&FloatOrd(*f)]).collect()
}

#[test]
fn test() {
if let Ok(buf) = std::fs::read("tokenizer.model") {
let _bpe = Bpe::from_tokenizer_model(&buf);
}
}

0 comments on commit 25973a8

Please sign in to comment.