From b17667b72c28095628781ab0bf73d0d39dc8c045 Mon Sep 17 00:00:00 2001 From: Cahya Wirawan Date: Mon, 3 Jun 2024 14:38:43 +0200 Subject: [PATCH] fixed the byte tokens --- src/lib.rs | 30 +++++++++++++++++++++++++----- src/trie.rs | 18 +++++++++--------- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 11a66b3..54586fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,17 +25,24 @@ impl Tokenizer { let file = File::open(filename)?; let reader = io::BufReader::new(file); - let re = Regex::new(r"(\d+)\s+b?(.+)\s+(\d+)").unwrap(); + let re = Regex::new(r"(\d+)\s+(b?)(.+)\s+(\d+)").unwrap(); tokenizer.tokens.push("".to_string()); for line in reader.lines() { let line = line?; if let Some(captures) = re.captures(&line) { let id = captures[1].parse::().unwrap(); - let mut string = captures[2].to_string(); - let _length = captures[3].parse::().unwrap(); + let is_byte = captures[2].to_string(); + let _length = captures[4].parse::().unwrap(); + let mut string: String = captures[3].to_string(); string = string[1..string.len()-1].parse().unwrap(); - let string = unescape(string.as_str()).unwrap(); - tokenizer.trie.insert(string.as_str(), id); + let sbytes: Vec; + if is_byte.len() == 0 { + string = unescape(string.as_str()).unwrap(); + sbytes = string.clone().into_bytes(); + } else { + sbytes = hex_to_bytes(string.as_str()).unwrap(); + } + tokenizer.trie.insert(&sbytes, id); tokenizer.tokens.push(string.to_string()); } else { @@ -58,6 +65,19 @@ impl Tokenizer { } } +fn hex_to_bytes(hex: &str) -> Option> { + let hex = hex.replace("\\x", ""); + if hex.len() % 2 == 0 { + (0..hex.len()) + .step_by(2) + .map(|i| hex.get(i..i + 2) + .and_then(|sub| u8::from_str_radix(sub, 16).ok())) + .collect() + } else { + None + } +} + #[pymodule] fn rwkv_tokenizer(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; diff --git a/src/trie.rs b/src/trie.rs index 568bb6d..f9ad11d 100644 --- a/src/trie.rs +++ b/src/trie.rs @@ -31,13 +31,13 @@ impl Trie { } } - pub(crate) fn insert(&mut self, word: &str, id: u16) { + pub(crate) fn insert(&mut self, word: &Vec, id: u16) { let mut node = &mut self.root; - for ch in word.bytes() { - if node.children[ch as usize].is_none() { - node.children[ch as usize] = Option::from(TrieNode::new()); + for ch in word { + if node.children[u8::from_be(*ch) as usize].is_none() { + node.children[u8::from_be(*ch) as usize] = Option::from(TrieNode::new()); } - match &mut node.children[ch as usize] { + match &mut node.children[u8::from_be(*ch) as usize] { Some(next_node) => node = next_node, None => unreachable!(), // We've just checked that it's not None } @@ -45,13 +45,13 @@ impl Trie { node.id = id } - fn search_the_longest(&self, word: &str) -> (u16, u16) { + fn search_the_longest(&self, word: &[u8]) -> (u16, u16) { let mut node = &self.root; let mut old_node: &TrieNode = &self.root; let mut index = 0; let mut old_index = 0; - for ch in word.bytes() { - if let Some(next_node) = &node.children[ch as usize] { + for ch in word { + if let Some(next_node) = &node.children[*ch as usize] { if node.id != 0 { old_node = node; old_index = index; @@ -80,7 +80,7 @@ impl Trie { let text_length = text.len(); let mut index: usize = 0; loop { - let result = self.search_the_longest(&text[index..]); + let result = self.search_the_longest(&text.as_bytes()[index..]); if result.0 != 0 { vec.push(result.1.into()); index += >::into(result.0);