Skip to content

Commit

Permalink
fixed the byte tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
cahya-wirawan committed Jun 3, 2024
1 parent f5a4c8c commit b17667b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
30 changes: 25 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,24 @@ impl Tokenizer {
let file = File::open(filename)?;
let reader = io::BufReader::new(file);

let re = Regex::new(r"(\d+)\s+b?(.+)\s+(\d+)").unwrap();
let re = Regex::new(r"(\d+)\s+(b?)(.+)\s+(\d+)").unwrap();
tokenizer.tokens.push("".to_string());
for line in reader.lines() {
let line = line?;
if let Some(captures) = re.captures(&line) {
let id = captures[1].parse::<u16>().unwrap();
let mut string = captures[2].to_string();
let _length = captures[3].parse::<usize>().unwrap();
let is_byte = captures[2].to_string();
let _length = captures[4].parse::<usize>().unwrap();
let mut string: String = captures[3].to_string();
string = string[1..string.len()-1].parse().unwrap();
let string = unescape(string.as_str()).unwrap();
tokenizer.trie.insert(string.as_str(), id);
let sbytes: Vec<u8>;
if is_byte.len() == 0 {
string = unescape(string.as_str()).unwrap();
sbytes = string.clone().into_bytes();
} else {
sbytes = hex_to_bytes(string.as_str()).unwrap();
}
tokenizer.trie.insert(&sbytes, id);
tokenizer.tokens.push(string.to_string());
}
else {
Expand All @@ -58,6 +65,19 @@ impl Tokenizer {
}
}

fn hex_to_bytes(hex: &str) -> Option<Vec<u8>> {
let hex = hex.replace("\\x", "");
if hex.len() % 2 == 0 {
(0..hex.len())
.step_by(2)
.map(|i| hex.get(i..i + 2)
.and_then(|sub| u8::from_str_radix(sub, 16).ok()))
.collect()
} else {
None
}
}

#[pymodule]
fn rwkv_tokenizer(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Tokenizer>()?;
Expand Down
18 changes: 9 additions & 9 deletions src/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,27 @@ impl Trie {
}
}

pub(crate) fn insert(&mut self, word: &str, id: u16) {
pub(crate) fn insert(&mut self, word: &Vec<u8>, id: u16) {
let mut node = &mut self.root;
for ch in word.bytes() {
if node.children[ch as usize].is_none() {
node.children[ch as usize] = Option::from(TrieNode::new());
for ch in word {
if node.children[u8::from_be(*ch) as usize].is_none() {
node.children[u8::from_be(*ch) as usize] = Option::from(TrieNode::new());
}
match &mut node.children[ch as usize] {
match &mut node.children[u8::from_be(*ch) as usize] {
Some(next_node) => node = next_node,
None => unreachable!(), // We've just checked that it's not None
}
}
node.id = id
}

fn search_the_longest(&self, word: &str) -> (u16, u16) {
fn search_the_longest(&self, word: &[u8]) -> (u16, u16) {
let mut node = &self.root;
let mut old_node: &TrieNode = &self.root;
let mut index = 0;
let mut old_index = 0;
for ch in word.bytes() {
if let Some(next_node) = &node.children[ch as usize] {
for ch in word {
if let Some(next_node) = &node.children[*ch as usize] {
if node.id != 0 {
old_node = node;
old_index = index;
Expand Down Expand Up @@ -80,7 +80,7 @@ impl Trie {
let text_length = text.len();
let mut index: usize = 0;
loop {
let result = self.search_the_longest(&text[index..]);
let result = self.search_the_longest(&text.as_bytes()[index..]);
if result.0 != 0 {
vec.push(result.1.into());
index += <u16 as Into<usize>>::into(result.0);
Expand Down

0 comments on commit b17667b

Please sign in to comment.