Skip to content

Commit 4e303e6

Browse files
committed
feat: 添加 lpe tokenizer 实现
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent f78f387 commit 4e303e6

File tree

6 files changed

+335
-156
lines changed

6 files changed

+335
-156
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
99
[dependencies]
1010
regex = "1.10"
1111
memchr = "2.7"
12+
patricia_tree = "0.8"

src/bpe/mod.rs

Lines changed: 17 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
mod algorithm;
1+
//! b-p-e for Byte Pair Encoding
22
3-
use crate::{utok, Method};
3+
mod algorithm;
4+
5+
use crate::{
6+
functions::{collect_vocabs_with_hint, CompressedVocab},
7+
utok, Method,
8+
};
49
use std::{
510
collections::{HashMap, HashSet},
611
iter::zip,
@@ -11,7 +16,7 @@ use std::{
1116

1217
pub struct Bpe {
1318
/// 保存所有词的字符串内容,以 u8 为单位所以不需要对齐,占用空间少
14-
_vocab: Pin<Box<[u8]>>,
19+
_vocabs: Pin<Box<[u8]>>,
1520
/// 按 token 顺序保存元信息
1621
tokens: Box<[TokenMeta]>,
1722
/// 按字符串的字典序排序的 token 索引,用于从字符串二分查找 token。
@@ -46,7 +51,7 @@ impl Deref for TokenMeta {
4651
impl Bpe {
4752
/// 解析 tokenizer.model 文件并构造一个 bpe 分词器。
4853
pub fn from_tokenizer_model(model: &[u8]) -> Self {
49-
// 遍历文件,标记所有词汇的位置并记录最大长度
54+
// 遍历文件,标记所有词汇的位置
5055
let offsets = (0..)
5156
.scan(0usize, |offset, _| match &model[*offset..] {
5257
[10, total_len, 10, content @ ..] => {
@@ -91,51 +96,9 @@ impl Bpe {
9196
is_byte: impl IntoIterator<Item = bool>,
9297
unk: utok,
9398
) -> Self {
94-
let mut bytes = Box::new([unk; 256]);
95-
let mut total_len = 0;
96-
// 收集词表字符内容和字节 token,同时计算内容总长度
97-
let vocabs = zip(vocabs, is_byte)
98-
.enumerate()
99-
.map(|(i, (piece, is_byte))| {
100-
let piece = if is_byte {
101-
const BYTES: [u8; 256] = {
102-
let mut bytes = [0u8; 256];
103-
let mut i = 0usize;
104-
while i < 256 {
105-
bytes[i] = i as _;
106-
i += 1;
107-
}
108-
bytes
109-
};
110-
111-
let b = crate::as_byte_token(piece.as_bytes()).unwrap() as usize;
112-
bytes[b] = i as utok;
113-
std::slice::from_ref(&BYTES[b])
114-
} else {
115-
piece.as_bytes()
116-
};
117-
total_len += piece.len();
118-
piece
119-
})
120-
.collect::<Vec<_>>();
121-
// 创建字符内容缓存
122-
let mut slices = vec![(0usize, 0usize); vocabs.len()];
123-
let mut text_buf = Vec::<u8>::with_capacity(total_len);
124-
let mut indices = (0..vocabs.len()).collect::<Vec<_>>();
125-
// 对词按内容长度从长到短排序,因为短的内容有可能是长内容的子串,可以避免重复存储相同内容
126-
indices.sort_unstable_by_key(|&i| -(vocabs[i].len() as isize));
127-
for i in indices {
128-
let v = vocabs[i];
129-
// 查找子串,若存在则复用,否则将新的内容追加到缓存
130-
let off = memchr::memmem::find(&text_buf, v).unwrap_or_else(|| {
131-
let off = text_buf.len();
132-
text_buf.extend(v);
133-
off
134-
});
135-
slices[i] = (off, v.len());
136-
}
137-
// 锁定字符串内容的位置,以实现安全的自引用
138-
let _vocab = unsafe { Pin::new_unchecked(text_buf.into_boxed_slice()) };
99+
let (vocabs, bytes, total_len) =
100+
collect_vocabs_with_hint(vocabs.into_iter().map(|s| s.as_bytes()), is_byte, unk);
101+
let CompressedVocab { vocabs, slices } = CompressedVocab::new(&vocabs, total_len);
139102
// 收集合词评分
140103
let scores = scores.into_iter().collect::<Vec<_>>();
141104
assert_eq!(
@@ -146,7 +109,7 @@ impl Bpe {
146109
// tokens 中直接引用字符串位置,绑定重新赋权并转换为整型的分词评分
147110
let tokens = zip(slices, rank(&scores))
148111
.map(|((off, len), rank)| TokenMeta {
149-
ptr: unsafe { NonNull::new_unchecked(_vocab[off..].as_ptr().cast_mut()) },
112+
ptr: unsafe { NonNull::new_unchecked(vocabs[off..].as_ptr().cast_mut()) },
150113
len: len as _,
151114
rank,
152115
})
@@ -160,13 +123,13 @@ impl Bpe {
160123
sorted_pieces.sort_unstable_by_key(|&i| &*tokens[i as usize]);
161124

162125
// println!(
163-
// "Building BPE vocab, detected {} tokens, compressed to {} bytes from {len} bytes",
126+
// "Building BPE vocab, detected {} tokens, compressed to {} bytes from {total_len} bytes",
164127
// tokens.len(),
165-
// _vocab.len(),
128+
// vocabs.len(),
166129
// );
167130

168131
Self {
169-
_vocab,
132+
_vocabs: vocabs,
170133
tokens,
171134
sorted_pieces,
172135
bytes,
@@ -283,7 +246,7 @@ fn test() {
283246
println!(
284247
"bpe: detected {} tokens, compressed to {} bytes",
285248
bpe.vocab_size(),
286-
bpe._vocab.len(),
249+
bpe._vocabs.len(),
287250
);
288251
println!("inaccessible: {inaccessible:#?}");
289252
}

src/functions.rs

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
use crate::utok;
2+
use std::{iter::zip, pin::Pin, slice::from_ref};
3+
4+
// 收集词表字符内容和字节 token,同时计算内容总长度
5+
pub(crate) fn collect_vocabs<'s>(
6+
vocabs: impl IntoIterator<Item = &'s [u8]>,
7+
unk: utok,
8+
) -> (Vec<&'s [u8]>, Box<[utok; 256]>, usize) {
9+
let mut bytes = Box::new([unk; 256]);
10+
let mut total_len = 0;
11+
let vocabs = vocabs
12+
.into_iter()
13+
.enumerate()
14+
.map(|(i, piece)| {
15+
let piece = match as_byte_token(piece) {
16+
Some(b) => {
17+
let b = b as usize;
18+
bytes[b] = i as _;
19+
from_ref(&BYTES[b])
20+
}
21+
None => piece,
22+
};
23+
total_len += piece.len();
24+
piece
25+
})
26+
.collect();
27+
(vocabs, bytes, total_len)
28+
}
29+
30+
// 收集词表字符内容和字节 token,同时计算内容总长度
31+
pub(crate) fn collect_vocabs_with_hint<'s>(
32+
vocabs: impl IntoIterator<Item = &'s [u8]>,
33+
is_byte: impl IntoIterator<Item = bool>,
34+
unk: utok,
35+
) -> (Vec<&'s [u8]>, Box<[utok; 256]>, usize) {
36+
let mut bytes = Box::new([unk; 256]);
37+
let mut total_len = 0;
38+
let vocabs = zip(vocabs, is_byte)
39+
.enumerate()
40+
.map(|(i, (piece, is_byte))| {
41+
if is_byte {
42+
let b = as_byte_token(piece)
43+
.unwrap_or_else(|| panic!("{piece:?} is not a valid byte token"))
44+
as usize;
45+
bytes[b] = i as _;
46+
from_ref(&BYTES[b])
47+
} else {
48+
piece
49+
};
50+
total_len += piece.len();
51+
piece
52+
})
53+
.collect();
54+
(vocabs, bytes, total_len)
55+
}
56+
57+
pub(crate) struct CompressedVocab {
58+
pub vocabs: Pin<Box<[u8]>>,
59+
pub slices: Vec<(usize, usize)>,
60+
}
61+
62+
impl CompressedVocab {
63+
pub fn new(vocabs: &[&[u8]], total_len: usize) -> Self {
64+
// 创建字符内容缓存
65+
let mut slices = vec![(0usize, 0usize); vocabs.len()];
66+
let mut text_buf = Vec::<u8>::with_capacity(total_len);
67+
let mut indices = (0..vocabs.len()).collect::<Vec<_>>();
68+
// 对词按内容长度从长到短排序,因为短的内容有可能是长内容的子串,可以避免重复存储相同内容
69+
indices.sort_unstable_by_key(|&i| -(vocabs[i].len() as isize));
70+
for i in indices {
71+
let v = vocabs[i];
72+
// 查找子串,若存在则复用,否则将新的内容追加到缓存
73+
let off = memchr::memmem::find(&text_buf, v).unwrap_or_else(|| {
74+
let off = text_buf.len();
75+
text_buf.extend(v);
76+
off
77+
});
78+
slices[i] = (off, v.len());
79+
}
80+
Self {
81+
// 锁定字符串内容的位置,以实现安全的自引用
82+
vocabs: unsafe { Pin::new_unchecked(text_buf.into_boxed_slice()) },
83+
slices,
84+
}
85+
}
86+
}
87+
88+
const BYTES: [u8; 256] = {
89+
let mut bytes = [0u8; 256];
90+
let mut i = 0usize;
91+
while i < 256 {
92+
bytes[i] = i as _;
93+
i += 1;
94+
}
95+
bytes
96+
};
97+
98+
const fn as_byte_token(piece: &[u8]) -> Option<u8> {
99+
// 按结构分解并转换
100+
match piece {
101+
&[b'<', b'0', b'x', a, b, b'>'] if a.is_ascii_hexdigit() && b.is_ascii_hexdigit() => {
102+
// ascii 转数字
103+
#[inline(always)]
104+
const fn to_num(c: u8) -> u8 {
105+
match c {
106+
b'0'..=b'9' => c - b'0',
107+
b'a'..=b'f' => c - b'a' + 10,
108+
b'A'..=b'F' => c - b'A' + 10,
109+
_ => unreachable!(),
110+
}
111+
}
112+
113+
Some(to_num(a) * 16 + to_num(b))
114+
}
115+
_ => None,
116+
}
117+
}

src/lib.rs

Lines changed: 5 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#![deny(warnings)]
22

33
mod bpe;
4-
5-
use regex::Regex;
6-
use std::collections::HashMap;
4+
mod functions;
5+
mod lpe;
6+
mod tokeneer;
77

88
pub use bpe::Bpe;
9+
pub use lpe::Lpe;
10+
pub use tokeneer::Tokeneer;
911

1012
/// `utok` for token id.
1113
#[allow(non_camel_case_types)]
@@ -18,102 +20,3 @@ pub trait Method {
1820
fn encode(&self, text: &str) -> impl IntoIterator<Item = utok> + '_;
1921
fn decode(&self, token: utok) -> &[u8];
2022
}
21-
22-
pub struct Tokeneer<M> {
23-
method: M,
24-
special: HashMap<String, Vec<utok>>,
25-
special_regex: regex::Regex,
26-
}
27-
28-
impl<M: Method> Tokeneer<M> {
29-
pub fn new(method: M) -> Self {
30-
let special = method
31-
.internal_special()
32-
.into_iter()
33-
.map(|(k, v)| (k.to_string(), vec![v]))
34-
.collect::<HashMap<_, _>>();
35-
let special_regex = build_pattern(special.keys());
36-
Self {
37-
method,
38-
special,
39-
special_regex,
40-
}
41-
}
42-
43-
pub fn extend_special(&mut self, patterns: impl IntoIterator<Item = (String, Vec<utok>)>) {
44-
use std::collections::hash_map::Entry::{Occupied, Vacant};
45-
let mut any = false;
46-
for (k, v) in patterns {
47-
match self.special.entry(k) {
48-
Occupied(entry) => {
49-
assert_eq!(entry.get(), &v);
50-
}
51-
Vacant(entry) => {
52-
entry.insert(v);
53-
any = true;
54-
}
55-
}
56-
}
57-
if any {
58-
self.special_regex = build_pattern(self.special.keys());
59-
}
60-
}
61-
62-
pub fn encode(&self, text: &str) -> Vec<utok> {
63-
let mut ans = Vec::new();
64-
let mut start = 0;
65-
for m in self.special_regex.find_iter(text) {
66-
ans.extend(self.method.encode(&text[start..m.start()]));
67-
ans.extend_from_slice(&self.special[m.as_str()]);
68-
start = m.end();
69-
}
70-
ans.extend(self.method.encode(&text[start..]));
71-
ans
72-
}
73-
74-
pub fn decode(&self, tokens: &[utok]) -> String {
75-
let mut ans = Vec::new();
76-
for &t in tokens {
77-
ans.extend_from_slice(self.method.decode(t));
78-
}
79-
String::from_utf8(ans).unwrap()
80-
}
81-
#[inline]
82-
pub fn internal(&self) -> &M {
83-
&self.method
84-
}
85-
}
86-
87-
fn build_pattern<T: AsRef<str>>(text: impl IntoIterator<Item = T>) -> Regex {
88-
let mut pattern = String::new();
89-
let mut iter = text.into_iter();
90-
if let Some(p) = iter.next() {
91-
pattern.push_str(p.as_ref());
92-
}
93-
for p in iter {
94-
pattern.push('|');
95-
pattern.push_str(p.as_ref());
96-
}
97-
regex::Regex::new(&pattern).unwrap()
98-
}
99-
100-
const fn as_byte_token(piece: &[u8]) -> Option<u8> {
101-
// 按结构分解并转换
102-
match piece {
103-
&[b'<', b'0', b'x', a, b, b'>'] if a.is_ascii_hexdigit() && b.is_ascii_hexdigit() => {
104-
// ascii 转数字
105-
#[inline(always)]
106-
const fn to_num(c: u8) -> u8 {
107-
match c {
108-
b'0'..=b'9' => c - b'0',
109-
b'a'..=b'f' => c - b'a' + 10,
110-
b'A'..=b'F' => c - b'A' + 10,
111-
_ => unreachable!(),
112-
}
113-
}
114-
115-
Some(to_num(a) * 16 + to_num(b))
116-
}
117-
_ => None,
118-
}
119-
}

0 commit comments

Comments
 (0)