Skip to content

Commit

Permalink
feat: 添加 lpe tokenizer 实现
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Aug 7, 2024
1 parent f78f387 commit ea8ffc7
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 156 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
[dependencies]
regex = "1.10"
memchr = "2.7"
patricia_tree = "0.8"
71 changes: 17 additions & 54 deletions src/bpe/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
mod algorithm;
//! b-p-e for Byte Pair Encoding
use crate::{utok, Method};
mod algorithm;

use crate::{
functions::{collect_vocabs_with_hint, CompressedVocab},
utok, Method,
};
use std::{
collections::{HashMap, HashSet},
iter::zip,
Expand All @@ -11,7 +16,7 @@ use std::{

pub struct Bpe {
/// 保存所有词的字符串内容,以 u8 为单位所以不需要对齐,占用空间少
_vocab: Pin<Box<[u8]>>,
_vocabs: Pin<Box<[u8]>>,
/// 按 token 顺序保存元信息
tokens: Box<[TokenMeta]>,
/// 按字符串的字典序排序的 token 索引,用于从字符串二分查找 token。
Expand Down Expand Up @@ -46,7 +51,7 @@ impl Deref for TokenMeta {
impl Bpe {
/// 解析 tokenizer.model 文件并构造一个 bpe 分词器。
pub fn from_tokenizer_model(model: &[u8]) -> Self {
// 遍历文件,标记所有词汇的位置并记录最大长度
// 遍历文件,标记所有词汇的位置
let offsets = (0..)
.scan(0usize, |offset, _| match &model[*offset..] {
[10, total_len, 10, content @ ..] => {
Expand Down Expand Up @@ -91,51 +96,9 @@ impl Bpe {
is_byte: impl IntoIterator<Item = bool>,
unk: utok,
) -> Self {
let mut bytes = Box::new([unk; 256]);
let mut total_len = 0;
// 收集词表字符内容和字节 token,同时计算内容总长度
let vocabs = zip(vocabs, is_byte)
.enumerate()
.map(|(i, (piece, is_byte))| {
let piece = if is_byte {
const BYTES: [u8; 256] = {
let mut bytes = [0u8; 256];
let mut i = 0usize;
while i < 256 {
bytes[i] = i as _;
i += 1;
}
bytes
};

let b = crate::as_byte_token(piece.as_bytes()).unwrap() as usize;
bytes[b] = i as utok;
std::slice::from_ref(&BYTES[b])
} else {
piece.as_bytes()
};
total_len += piece.len();
piece
})
.collect::<Vec<_>>();
// 创建字符内容缓存
let mut slices = vec![(0usize, 0usize); vocabs.len()];
let mut text_buf = Vec::<u8>::with_capacity(total_len);
let mut indices = (0..vocabs.len()).collect::<Vec<_>>();
// 对词按内容长度从长到短排序,因为短的内容有可能是长内容的子串,可以避免重复存储相同内容
indices.sort_unstable_by_key(|&i| -(vocabs[i].len() as isize));
for i in indices {
let v = vocabs[i];
// 查找子串,若存在则复用,否则将新的内容追加到缓存
let off = memchr::memmem::find(&text_buf, v).unwrap_or_else(|| {
let off = text_buf.len();
text_buf.extend(v);
off
});
slices[i] = (off, v.len());
}
// 锁定字符串内容的位置,以实现安全的自引用
let _vocab = unsafe { Pin::new_unchecked(text_buf.into_boxed_slice()) };
let (vocabs, bytes, total_len) =
collect_vocabs_with_hint(vocabs.into_iter().map(|s| s.as_bytes()), is_byte, unk);
let CompressedVocab { vocabs, slices } = CompressedVocab::new(&vocabs, total_len);
// 收集合词评分
let scores = scores.into_iter().collect::<Vec<_>>();
assert_eq!(
Expand All @@ -146,7 +109,7 @@ impl Bpe {
// tokens 中直接引用字符串位置,绑定重新赋权并转换为整型的分词评分
let tokens = zip(slices, rank(&scores))
.map(|((off, len), rank)| TokenMeta {
ptr: unsafe { NonNull::new_unchecked(_vocab[off..].as_ptr().cast_mut()) },
ptr: unsafe { NonNull::new_unchecked(vocabs[off..].as_ptr().cast_mut()) },
len: len as _,
rank,
})
Expand All @@ -160,13 +123,13 @@ impl Bpe {
sorted_pieces.sort_unstable_by_key(|&i| &*tokens[i as usize]);

// println!(
// "Building BPE vocab, detected {} tokens, compressed to {} bytes from {len} bytes",
// "Building BPE vocab, detected {} tokens, compressed to {} bytes from {total_len} bytes",
// tokens.len(),
// _vocab.len(),
// vocabs.len(),
// );

Self {
_vocab,
_vocabs: vocabs,
tokens,
sorted_pieces,
bytes,
Expand Down Expand Up @@ -283,7 +246,7 @@ fn test() {
println!(
"bpe: detected {} tokens, compressed to {} bytes",
bpe.vocab_size(),
bpe._vocab.len(),
bpe._vocabs.len(),
);
println!("inaccessible: {inaccessible:#?}");
}
Expand Down
117 changes: 117 additions & 0 deletions src/functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use crate::utok;
use std::{iter::zip, pin::Pin, slice::from_ref};

// 收集词表字符内容和字节 token,同时计算内容总长度
pub(crate) fn collect_vocabs<'s>(
vocabs: impl IntoIterator<Item = &'s [u8]>,
unk: utok,
) -> (Vec<&'s [u8]>, Box<[utok; 256]>, usize) {
let mut bytes = Box::new([unk; 256]);
let mut total_len = 0;
let vocabs = vocabs
.into_iter()
.enumerate()
.map(|(i, piece)| {
let piece = match as_byte_token(piece) {
Some(b) => {
let b = b as usize;
bytes[b] = i as _;
from_ref(&BYTES[b])
}
None => piece,
};
total_len += piece.len();
piece
})
.collect();
(vocabs, bytes, total_len)
}

// 收集词表字符内容和字节 token,同时计算内容总长度
pub(crate) fn collect_vocabs_with_hint<'s>(
vocabs: impl IntoIterator<Item = &'s [u8]>,
is_byte: impl IntoIterator<Item = bool>,
unk: utok,
) -> (Vec<&'s [u8]>, Box<[utok; 256]>, usize) {
let mut bytes = Box::new([unk; 256]);
let mut total_len = 0;
let vocabs = zip(vocabs, is_byte)
.enumerate()
.map(|(i, (piece, is_byte))| {
if is_byte {
let b = as_byte_token(piece)
.unwrap_or_else(|| panic!("{piece:?} is not a valid byte token"))
as usize;
bytes[b] = i as _;
from_ref(&BYTES[b])
} else {
piece
};
total_len += piece.len();
piece
})
.collect();
(vocabs, bytes, total_len)
}

pub(crate) struct CompressedVocab {
pub vocabs: Pin<Box<[u8]>>,
pub slices: Vec<(usize, usize)>,
}

impl CompressedVocab {
pub fn new(vocabs: &[&[u8]], total_len: usize) -> Self {
// 创建字符内容缓存
let mut slices = vec![(0usize, 0usize); vocabs.len()];
let mut text_buf = Vec::<u8>::with_capacity(total_len);
let mut indices = (0..vocabs.len()).collect::<Vec<_>>();
// 对词按内容长度从长到短排序,因为短的内容有可能是长内容的子串,可以避免重复存储相同内容
indices.sort_unstable_by_key(|&i| -(vocabs[i].len() as isize));
for i in indices {
let v = vocabs[i];
// 查找子串,若存在则复用,否则将新的内容追加到缓存
let off = memchr::memmem::find(&text_buf, v).unwrap_or_else(|| {
let off = text_buf.len();
text_buf.extend(v);
off
});
slices[i] = (off, v.len());
}
Self {
// 锁定字符串内容的位置,以实现安全的自引用
vocabs: unsafe { Pin::new_unchecked(text_buf.into_boxed_slice()) },
slices,
}
}
}

const BYTES: [u8; 256] = {
let mut bytes = [0u8; 256];
let mut i = 0usize;
while i < 256 {
bytes[i] = i as _;
i += 1;
}
bytes
};

const fn as_byte_token(piece: &[u8]) -> Option<u8> {
// 按结构分解并转换
match piece {
&[b'<', b'0', b'x', a, b, b'>'] if a.is_ascii_hexdigit() && b.is_ascii_hexdigit() => {
// ascii 转数字
#[inline(always)]
const fn to_num(c: u8) -> u8 {
match c {
b'0'..=b'9' => c - b'0',
b'a'..=b'f' => c - b'a' + 10,
b'A'..=b'F' => c - b'A' + 10,
_ => unreachable!(),
}
}

Some(to_num(a) * 16 + to_num(b))
}
_ => None,
}
}
106 changes: 4 additions & 102 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#![deny(warnings)]

mod bpe;

use regex::Regex;
use std::collections::HashMap;
mod functions;
mod lpe;
mod tokeneer;

pub use bpe::Bpe;
pub use lpe::Lpe;

/// `utok` for token id.
#[allow(non_camel_case_types)]
Expand All @@ -18,102 +19,3 @@ pub trait Method {
fn encode(&self, text: &str) -> impl IntoIterator<Item = utok> + '_;
fn decode(&self, token: utok) -> &[u8];
}

pub struct Tokeneer<M> {
method: M,
special: HashMap<String, Vec<utok>>,
special_regex: regex::Regex,
}

impl<M: Method> Tokeneer<M> {
pub fn new(method: M) -> Self {
let special = method
.internal_special()
.into_iter()
.map(|(k, v)| (k.to_string(), vec![v]))
.collect::<HashMap<_, _>>();
let special_regex = build_pattern(special.keys());
Self {
method,
special,
special_regex,
}
}

pub fn extend_special(&mut self, patterns: impl IntoIterator<Item = (String, Vec<utok>)>) {
use std::collections::hash_map::Entry::{Occupied, Vacant};
let mut any = false;
for (k, v) in patterns {
match self.special.entry(k) {
Occupied(entry) => {
assert_eq!(entry.get(), &v);
}
Vacant(entry) => {
entry.insert(v);
any = true;
}
}
}
if any {
self.special_regex = build_pattern(self.special.keys());
}
}

pub fn encode(&self, text: &str) -> Vec<utok> {
let mut ans = Vec::new();
let mut start = 0;
for m in self.special_regex.find_iter(text) {
ans.extend(self.method.encode(&text[start..m.start()]));
ans.extend_from_slice(&self.special[m.as_str()]);
start = m.end();
}
ans.extend(self.method.encode(&text[start..]));
ans
}

pub fn decode(&self, tokens: &[utok]) -> String {
let mut ans = Vec::new();
for &t in tokens {
ans.extend_from_slice(self.method.decode(t));
}
String::from_utf8(ans).unwrap()
}
#[inline]
pub fn internal(&self) -> &M {
&self.method
}
}

fn build_pattern<T: AsRef<str>>(text: impl IntoIterator<Item = T>) -> Regex {
let mut pattern = String::new();
let mut iter = text.into_iter();
if let Some(p) = iter.next() {
pattern.push_str(p.as_ref());
}
for p in iter {
pattern.push('|');
pattern.push_str(p.as_ref());
}
regex::Regex::new(&pattern).unwrap()
}

const fn as_byte_token(piece: &[u8]) -> Option<u8> {
// 按结构分解并转换
match piece {
&[b'<', b'0', b'x', a, b, b'>'] if a.is_ascii_hexdigit() && b.is_ascii_hexdigit() => {
// ascii 转数字
#[inline(always)]
const fn to_num(c: u8) -> u8 {
match c {
b'0'..=b'9' => c - b'0',
b'a'..=b'f' => c - b'a' + 10,
b'A'..=b'F' => c - b'A' + 10,
_ => unreachable!(),
}
}

Some(to_num(a) * 16 + to_num(b))
}
_ => None,
}
}
Loading

0 comments on commit ea8ffc7

Please sign in to comment.