diff --git a/Cargo.toml b/Cargo.toml index 4efb156f..52881a05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ crate-type = ["cdylib"] pyo3 = { version = "0.20.0", features = ["extension-module"] } # tiktoken dependencies -fancy-regex = "0.11.0" -regex = "1.8.3" +fancy-regex = "0.13.0" +regex = "1.10.3" rustc-hash = "1.1.0" bstr = "1.5.0" diff --git a/src/lib.rs b/src/lib.rs index b466edd1..46712ecd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ use std::num::NonZeroU64; use std::thread; use fancy_regex::Regex; +use fancy_regex::RegexBuilder; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::pyclass; @@ -417,7 +418,7 @@ impl CoreBPE { special_tokens_encoder: HashMap, pattern: &str, ) -> PyResult { - let regex = Regex::new(pattern) + let regex = RegexBuilder::new(pattern).backtrack_limit(10_000).build() .map_err(|e| PyErr::new::(e.to_string()))?; let special_regex = { @@ -572,6 +573,7 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { #[cfg(test)] mod tests { + use fancy_regex::RegexBuilder; use rustc_hash::FxHashMap as HashMap; use crate::{byte_pair_split, Rank}; @@ -596,4 +598,16 @@ mod tests { let res = byte_pair_split(b"abab", &ranks); assert_eq!(res, vec![b"ab", b"ab"]); } + + #[test] + fn test_effect_of_backtrack_limit() { + let regex = RegexBuilder::new(r"(a|b|ab)*(?=c)") + .backtrack_limit(10) + .build() + .expect("Failed to build regex") + .clone(); + + let input = "ab".repeat(100) + "c"; + assert!(regex.is_match(&input).is_err(), "Should throw"); + } } diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 27b21925..0e02b47a 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -11,6 +11,22 @@ from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +def test_extremely_big_encoding(make_enc: Callable[[], tiktoken.Encoding]): + enc = make_enc() + for c in ["^", "0", "a", "'s", " ", "\n"]: + print(f"Validating `{c}`") + + big_value = c * 10_000 + assert big_value == enc.decode(enc.encode(big_value)) + + big_value = " " + big_value + assert big_value == enc.decode(enc.encode(big_value)) + + big_value = big_value + "\n" + assert big_value == enc.decode(enc.encode(big_value)) + + def test_simple(): enc = tiktoken.get_encoding("gpt2") assert enc.encode("hello world") == [31373, 995] diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py index 6b29a711..449ec068 100644 --- a/tiktoken_ext/openai_public.py +++ b/tiktoken_ext/openai_public.py @@ -6,6 +6,11 @@ FIM_SUFFIX = "<|fim_suffix|>" ENDOFPROMPT = "<|endofprompt|>" +# The pattern in the original GPT-2 release is: +# r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" +# This is equivalent, but executes faster: +_legacy_splitter_regex = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s""" + def gpt2(): mergeable_ranks = data_gym_to_mergeable_bpe_ranks( @@ -17,10 +22,7 @@ def gpt2(): return { "name": "gpt2", "explicit_n_vocab": 50257, - # The pattern in the original GPT-2 release is: - # r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" - # This is equivalent, but executes faster: - "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "pat_str": _legacy_splitter_regex, "mergeable_ranks": mergeable_ranks, "special_tokens": {ENDOFTEXT: 50256}, } @@ -34,7 +36,7 @@ def r50k_base(): return { "name": "r50k_base", "explicit_n_vocab": 50257, - "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "pat_str": _legacy_splitter_regex, "mergeable_ranks": mergeable_ranks, "special_tokens": {ENDOFTEXT: 50256}, } @@ -48,7 +50,7 @@ def p50k_base(): return { "name": "p50k_base", "explicit_n_vocab": 50281, - "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "pat_str": _legacy_splitter_regex, "mergeable_ranks": mergeable_ranks, "special_tokens": {ENDOFTEXT: 50256}, } @@ -62,7 +64,7 @@ def p50k_edit(): special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} return { "name": "p50k_edit", - "pat_str": r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "pat_str": _legacy_splitter_regex, "mergeable_ranks": mergeable_ranks, "special_tokens": special_tokens, } @@ -82,7 +84,7 @@ def cl100k_base(): } return { "name": "cl100k_base", - "pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""", + "pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s""", "mergeable_ranks": mergeable_ranks, "special_tokens": special_tokens, }