Skip to content

Commit

Permalink
perf(bpe): 压缩词表内容以降低空间占用并提升局部性
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Aug 7, 2024
1 parent 8a9c60d commit f78f387
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 32 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ jobs:
- name: Check format
run: cargo fmt --check

- name: Download tokenizer.model
run: wget https://huggingface.co/TinyLlama/TinyLlama_v1.1/resolve/main/tokenizer.model
# run on windows: wget -Uri https://huggingface.co/TinyLlama/TinyLlama_v1.1/resolve/main/tokenizer.model -OutFile tokenizer.model

- name: Run test
run: cargo test

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
/Cargo.lock
/tokenizer.model
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]

[dependencies]
regex = "1.10"
memchr = "2.7"
101 changes: 69 additions & 32 deletions src/bpe/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod algorithm;

use crate::{as_byte_token, utok, Method};
use crate::{utok, Method};
use std::{
collections::{HashMap, HashSet},
iter::zip,
Expand Down Expand Up @@ -82,66 +82,89 @@ impl Bpe {
}
});
// 构造分词器
Self::new(vocabs, scores, is_byte, 0, offsets.len())
Self::new(vocabs, scores, is_byte, 0)
}

pub fn new<'a>(
vocabs: impl IntoIterator<Item = &'a str>,
scores: impl IntoIterator<Item = f32>,
is_byte: impl IntoIterator<Item = bool>,
unk: utok,
vocab_size_hint: usize,
) -> Self {
let mut text_buf = Vec::with_capacity(vocab_size_hint * 4);
let mut bytes = Box::new([unk; 256]);
// 重新编排词表
// 将字符串的内容和元信息分离
// 内容全部保存到 text_buf 以实现缓存友好性
// 字符串起始位置在 text_buf 中的偏移量和字符串长度保存到 meta 中
let meta = vocabs
.into_iter()
.map(str::as_bytes)
.zip(is_byte)
let mut total_len = 0;
// 收集词表字符内容和字节 token,同时计算内容总长度
let vocabs = zip(vocabs, is_byte)
.enumerate()
.map(|(t, (piece, is_byte))| {
let off = text_buf.len();
let len = if is_byte {
let b = as_byte_token(piece).unwrap();
text_buf.push(b);
bytes[b as usize] = t as utok;
1
.map(|(i, (piece, is_byte))| {
let piece = if is_byte {
const BYTES: [u8; 256] = {
let mut bytes = [0u8; 256];
let mut i = 0usize;
while i < 256 {
bytes[i] = i as _;
i += 1;
}
bytes
};

let b = crate::as_byte_token(piece.as_bytes()).unwrap() as usize;
bytes[b] = i as utok;
std::slice::from_ref(&BYTES[b])
} else {
text_buf.extend_from_slice(piece);
piece.len()
piece.as_bytes()
};
(off, len)
total_len += piece.len();
piece
})
.collect::<Vec<_>>();
// 创建字符内容缓存
let mut slices = vec![(0usize, 0usize); vocabs.len()];
let mut text_buf = Vec::<u8>::with_capacity(total_len);
let mut indices = (0..vocabs.len()).collect::<Vec<_>>();
// 对词按内容长度从长到短排序,因为短的内容有可能是长内容的子串,可以避免重复存储相同内容
indices.sort_unstable_by_key(|&i| -(vocabs[i].len() as isize));
for i in indices {
let v = vocabs[i];
// 查找子串,若存在则复用,否则将新的内容追加到缓存
let off = memchr::memmem::find(&text_buf, v).unwrap_or_else(|| {
let off = text_buf.len();
text_buf.extend(v);
off
});
slices[i] = (off, v.len());
}
// 锁定字符串内容的位置,以实现安全的自引用
let _vocab = unsafe { Pin::new_unchecked(text_buf.into_boxed_slice()) };
// 对分词评分重新赋权,转换为整型
let rank = rank(&scores.into_iter().collect::<Vec<_>>());
// 收集合词评分
let scores = scores.into_iter().collect::<Vec<_>>();
assert_eq!(
meta.len(),
rank.len(),
slices.len(),
scores.len(),
"scores size mismatch with vocab size"
);
// tokens 中直接引用字符串位置,绑定评分
let tokens = zip(meta, rank)
// tokens 中直接引用字符串位置,绑定重新赋权并转换为整型的分词评分
let tokens = zip(slices, rank(&scores))
.map(|((off, len), rank)| TokenMeta {
ptr: unsafe { NonNull::new_unchecked(_vocab[off..].as_ptr().cast_mut()) },
len: len as _,
rank,
})
.collect::<Box<[_]>>();
.collect::<Box<_>>();
// 对 token 按字符串的字典序排序,用于从字符串二分查找 token
// <unk> 和 <0xyz> 不应该通过 piece 搜索到,使用 set 排除
let bytes_set = bytes.iter().chain(&[unk]).cloned().collect::<HashSet<_>>();
let mut sorted_pieces = (0..tokens.len() as utok)
.filter(|i| !bytes_set.contains(i))
.collect::<Box<[_]>>();
.collect::<Box<_>>();
sorted_pieces.sort_unstable_by_key(|&i| &*tokens[i as usize]);

// println!(
// "Building BPE vocab, detected {} tokens, compressed to {} bytes from {len} bytes",
// tokens.len(),
// _vocab.len(),
// );

Self {
_vocab,
tokens,
Expand Down Expand Up @@ -214,7 +237,7 @@ impl Method for Bpe {
}

/// 对一组评分排序、去重并重新赋权,转换为保持相同顺序的整型序列
fn rank(scores: &[f32]) -> Vec<u32> {
fn rank(scores: &[f32]) -> impl IntoIterator<Item = u32> + '_ {
use std::{
cmp::Ordering,
collections::{BTreeMap, BTreeSet},
Expand Down Expand Up @@ -249,5 +272,19 @@ fn rank(scores: &[f32]) -> Vec<u32> {
.map(|(i, f)| (f, i as u32))
.collect::<BTreeMap<_, _>>();

scores.iter().map(|f| map[&FloatOrd(*f)]).collect()
scores.iter().map(move |&f| map[&FloatOrd(f)])
}

#[test]
fn test() {
if let Ok(buf) = std::fs::read("tokenizer.model") {
let bpe = Bpe::from_tokenizer_model(&buf);
let inaccessible = bpe.inaccessible();
println!(
"bpe: detected {} tokens, compressed to {} bytes",
bpe.vocab_size(),
bpe._vocab.len(),
);
println!("inaccessible: {inaccessible:#?}");
}
}

0 comments on commit f78f387

Please sign in to comment.