Skip to content

Commit 751f2f8

Browse files
committed
style: 整理代码,添加注释
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 633b0fd commit 751f2f8

File tree

6 files changed

+174
-132
lines changed

6 files changed

+174
-132
lines changed

src/bpe/algorithm.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{
77
ops::Range,
88
};
99

10-
pub struct BpeTokenizer<'v, 't> {
10+
pub struct MergeState<'v, 't> {
1111
text: &'t [u8],
1212
bpe: &'v Bpe,
1313
marks: Vec<Mark>,
@@ -26,7 +26,7 @@ pub struct Iter<'a> {
2626
}
2727

2828
impl Bpe {
29-
pub fn build_tokenizer<'v, 't>(&'v self, text: &'t str) -> BpeTokenizer<'v, 't> {
29+
pub fn begin_merge<'v, 't>(&'v self, text: &'t str) -> MergeState<'v, 't> {
3030
let mut marks = vec![Mark::unk(self.unk); text.len()];
3131
let mut merges = BinaryHeap::new();
3232

@@ -55,7 +55,7 @@ impl Bpe {
5555
};
5656
}
5757

58-
BpeTokenizer {
58+
MergeState {
5959
text: text.as_bytes(),
6060
bpe: self,
6161
marks,
@@ -119,7 +119,7 @@ impl PartialOrd for Merge {
119119
}
120120
}
121121

122-
impl BpeTokenizer<'_, '_> {
122+
impl MergeState<'_, '_> {
123123
/// 尝试执行一次合并,返回是否成功执行了一次合并。
124124
pub fn merge(&mut self) -> bool {
125125
// 一次合并将涉及至多 4 个 token:
@@ -203,7 +203,7 @@ impl BpeTokenizer<'_, '_> {
203203
}
204204
}
205205

206-
impl<'v> IntoIterator for BpeTokenizer<'v, '_> {
206+
impl<'v> IntoIterator for MergeState<'v, '_> {
207207
type Item = utok;
208208
type IntoIter = IntoIter<'v>;
209209
#[inline]
@@ -244,7 +244,7 @@ impl Iterator for Iter<'_> {
244244
}
245245
}
246246

247-
impl fmt::Display for BpeTokenizer<'_, '_> {
247+
impl fmt::Display for MergeState<'_, '_> {
248248
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
249249
use std::str::{from_utf8, from_utf8_unchecked};
250250

src/bpe/mod.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
mod algorithm;
44

55
use crate::{
6-
functions::{collect_vocabs_with_hint, CompressedVocab},
7-
utok, Method,
6+
utok,
7+
vocab::{CollectedVocab, CompressedVocab},
8+
Method,
89
};
910
use std::{
1011
collections::{HashMap, HashSet},
@@ -37,6 +38,7 @@ struct TokenMeta {
3738
rank: u32,
3839
}
3940

41+
// SAFETY: TokenMeta 中的指针是指向 Bpe 内容的自引用指针,且仅用于不可变引用。
4042
unsafe impl Send for TokenMeta {}
4143
unsafe impl Sync for TokenMeta {}
4244

@@ -96,8 +98,15 @@ impl Bpe {
9698
is_byte: impl IntoIterator<Item = bool>,
9799
unk: utok,
98100
) -> Self {
99-
let (vocabs, bytes, total_len) =
100-
collect_vocabs_with_hint(vocabs.into_iter().map(|s| s.as_bytes()), is_byte, unk);
101+
let CollectedVocab {
102+
vocabs,
103+
total_len,
104+
bytes,
105+
} = CollectedVocab::collect_with_hint(
106+
vocabs.into_iter().map(|s| s.as_bytes()),
107+
is_byte,
108+
unk,
109+
);
101110
let CompressedVocab { vocabs, slices } = CompressedVocab::new(&vocabs, total_len);
102111
// 收集合词评分
103112
let scores = scores.into_iter().collect::<Vec<_>>();
@@ -189,7 +198,7 @@ impl Method for Bpe {
189198
}
190199
#[inline]
191200
fn encode(&self, text: &str) -> impl IntoIterator<Item = utok> + '_ {
192-
let mut tokenizer = self.build_tokenizer(text);
201+
let mut tokenizer = self.begin_merge(text);
193202
while tokenizer.merge() {}
194203
tokenizer.into_iter()
195204
}

src/functions.rs

Lines changed: 0 additions & 117 deletions
This file was deleted.

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#![deny(warnings)]
22

33
mod bpe;
4-
mod functions;
54
mod lpe;
65
mod tokeneer;
6+
mod vocab;
77

88
pub use bpe::Bpe;
99
pub use lpe::Lpe;

src/lpe/mod.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
//! l-p-e for Longest Prefix Encoding
22
33
use crate::{
4-
functions::{collect_vocabs, CompressedVocab},
5-
utok, Method,
4+
utok,
5+
vocab::{CollectedVocab, CompressedVocab},
6+
Method,
67
};
78
use patricia_tree::PatriciaMap;
89
use std::{collections::HashSet, pin::Pin};
@@ -37,7 +38,11 @@ impl Lpe {
3738
}
3839

3940
pub fn new<'a>(vocabs: impl IntoIterator<Item = &'a [u8]>, unk: utok) -> Self {
40-
let (vocabs, bytes, total_len) = collect_vocabs(vocabs, unk);
41+
let CollectedVocab {
42+
vocabs,
43+
total_len,
44+
bytes,
45+
} = CollectedVocab::collect(vocabs, unk);
4146
let CompressedVocab { vocabs, slices } = CompressedVocab::new(&vocabs, total_len);
4247
let tokens = slices
4348
.into_iter()

0 commit comments

Comments
 (0)