Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): support configuring stop words in model config #3209

Merged
merged 7 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions crates/tabby-common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ fn default_embedding_config() -> ModelConfig {
num_gpu_layers: 9999,
enable_fast_attention: None,
context_size: default_context_size(),
additional_stop_words: None,
})
}

Expand Down Expand Up @@ -221,6 +222,7 @@ impl ModelConfig {
num_gpu_layers,
enable_fast_attention: None,
context_size: default_context_size(),
additional_stop_words: None,
})
}
}
Expand Down Expand Up @@ -256,6 +258,9 @@ pub struct HttpModelConfig {
/// Used by Chat/Completion API allowing users to get supported models info.
#[builder(default)]
pub supported_models: Option<Vec<String>>,

#[builder(default)]
pub additional_stop_words: Option<Vec<String>>,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
Expand All @@ -273,6 +278,9 @@ pub struct LocalModelConfig {

#[serde(default = "default_context_size")]
pub context_size: usize,

#[serde(default)]
pub additional_stop_words: Option<Vec<String>>,
}

fn default_parallelism() -> u8 {
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby-index/src/code/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ mod tests {

#[tokio::test]
async fn test_code_splitter() {
// First file, chat/openai_chat.rs
// First file, tabby-inference/src/decoding.rs
let file_contents = include_str!("../../../tabby-inference/src/decoding.rs");

let rust_chunks = CodeIntelligence::chunks(file_contents, "rust")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ source: crates/tabby-index/src/code/index.rs
expression: "format!(\"{:#?}\", text_chunks)"
---
[
"use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap<String, Trie<u8>>,\n}\n\nfn reverse<T>(s: T) -> String\nwhere\n T: Into<String>,\n{\n s.into().chars().rev().collect()\n}\n\nimpl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n }\n }\n}\n\ntype CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie<u8>>;",
"impl StopConditionFactory {\n pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }",
"fn get_trie<'a>(&'a self, language: &'static Language) -> Option<CachedTrie<'a>> {\n let stop_words = language.get_stop_words();\n if stop_words.is_empty() {\n None\n } else {\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));",
"trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}\n\nfn create_stop_trie(stop_words: Vec<String>) -> Trie<u8> {\n let mut builder = TrieBuilder::new();\n for word in stop_words {\n builder.push(reverse(word))\n }\n builder.build()\n}\n\npub struct StopCondition<'a> {\n stop_trie: Option<CachedTrie<'a>>,\n reversed_text: String,\n num_decoded: usize,\n}",
"use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap<String, Trie<u8>>,\n stop_words_from_model_config: Vec<String>,\n}\n\nfn reverse<T>(s: T) -> String\nwhere\n T: Into<String>,\n{\n s.into().chars().rev().collect()\n}",
"impl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n stop_words_from_model_config: vec![],\n }\n }\n}\n\ntype CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie<u8>>;\n\nimpl StopConditionFactory {\n pub fn with_stop_words(stop_words: Vec<String>) -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n stop_words_from_model_config: stop_words,\n }\n }",
"pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }",
"fn get_trie<'a>(&'a self, language: &'static Language) -> Option<CachedTrie<'a>> {\n let mut stop_words = language.get_stop_words();\n // append model stop words\n stop_words.extend(self.stop_words_from_model_config.iter().cloned());",
"if stop_words.is_empty() {\n None\n } else {\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));\n trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}",
"fn create_stop_trie(stop_words: Vec<String>) -> Trie<u8> {\n let mut builder = TrieBuilder::new();\n for word in stop_words {\n builder.push(reverse(word))\n }\n builder.build()\n}\n\npub struct StopCondition<'a> {\n stop_trie: Option<CachedTrie<'a>>,\n reversed_text: String,\n num_decoded: usize,\n}",
"impl<'a> StopCondition<'a> {\n pub fn new(stop_trie: Option<CachedTrie<'a>>, text: &str) -> Self {\n Self {\n stop_trie,\n reversed_text: reverse(text),\n num_decoded: 0,\n }\n }\n\n pub fn should_stop(&mut self, new_text: &str) -> (bool, usize) {\n self.num_decoded += 1;\n if !new_text.is_empty() {\n self.reversed_text = reverse(new_text) + &self.reversed_text;",
"if let Some(re) = &self.stop_trie {\n let matches = re.common_prefix_search(&self.reversed_text);\n let matched_length = matches.into_iter().map(|x| x.len()).max();\n if let Some(matched_length) = matched_length {\n return (true, matched_length);\n }\n }\n }\n (false, 0)\n }\n}\n\n#[cfg(test)]\nmod tests {\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;",
"if let Some(re) = &self.stop_trie {\n let matches = re.common_prefix_search(&self.reversed_text);\n let matched_length = matches.into_iter().map(|x| x.len()).max();\n if let Some(matched_length) = matched_length {\n return (true, matched_length);\n }\n }\n }\n (false, 0)\n }\n}\n\n#[cfg(test)]\nmod tests {\n\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;",
"#[test]\n fn test_trie_works() {\n let text = reverse(\"void write_u32(std::uint32_t val) const {\\n write_raw(&val, sizeof(val));\\n }\\n\\n ~llama_file() {\\n if (fp) {\\n std::fclose(fp);\\n }\\n }\\n};\\n\\nvoid\");\n\n let trie = create_stop_trie(vec![\"\\n\\n\".to_owned(), \"\\n\\n \".to_owned()]);\n assert!(trie.common_prefix_search(&text).is_empty());",
"let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n }",
"let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n \"<|file_sep|>\".to_owned(), // qwen 2.5 coder style\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n\n let qwen25coder = reverse(\"qwen25 style stop words;<|file_sep|>\");\n assert!(!trie.common_prefix_search(&qwen25coder).is_empty());\n }",
"#[test]\n fn test_stop_condition_max_length() {\n let factory = StopConditionFactory::default();\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"2\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"3\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"4\");\n assert!(!should_stop)",
"}\n}",
"}\n\n #[test]\n fn test_stop_condition_additional_stop_words() {\n let factory = StopConditionFactory::with_stop_words(vec![\"<|endoftext|>\".to_owned()]);\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"<|endoftext|>\");\n assert!(should_stop);\n }\n}",
]
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@ source: crates/tabby-index/src/code/index.rs
expression: "format!(\"{:#?}\", rust_chunks)"
---
[
"use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap<String, Trie<u8>>,\n}\n\nfn reverse<T>(s: T) -> String\nwhere\n T: Into<String>,\n{\n s.into().chars().rev().collect()\n}\n\nimpl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n }\n }\n}\n\ntype CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie<u8>>;",
"use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap<String, Trie<u8>>,\n stop_words_from_model_config: Vec<String>,\n}\n\nfn reverse<T>(s: T) -> String\nwhere\n T: Into<String>,\n{\n s.into().chars().rev().collect()\n}",
"impl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n stop_words_from_model_config: vec![],\n }\n }\n}\n\ntype CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie<u8>>;",
"impl StopConditionFactory",
"{\n pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }",
"{\n pub fn with_stop_words(stop_words: Vec<String>) -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n stop_words_from_model_config: stop_words,\n }\n }\n\n pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }",
"fn get_trie<'a>(&'a self, language: &'static Language) -> Option<CachedTrie<'a>>",
"{\n let stop_words = language.get_stop_words();\n if stop_words.is_empty() {\n None\n } else {\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));\n trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}",
"{\n let mut stop_words = language.get_stop_words();\n // append model stop words\n stop_words.extend(self.stop_words_from_model_config.iter().cloned());",
"if stop_words.is_empty() {\n None\n } else {\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));\n trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}",
"fn create_stop_trie(stop_words: Vec<String>) -> Trie<u8> {\n let mut builder = TrieBuilder::new();\n for word in stop_words {\n builder.push(reverse(word))\n }\n builder.build()\n}\n\npub struct StopCondition<'a> {\n stop_trie: Option<CachedTrie<'a>>,\n reversed_text: String,\n num_decoded: usize,\n}",
"impl<'a> StopCondition<'a>",
"{\n pub fn new(stop_trie: Option<CachedTrie<'a>>, text: &str) -> Self {\n Self {\n stop_trie,\n reversed_text: reverse(text),\n num_decoded: 0,\n }\n }",
"pub fn should_stop(&mut self, new_text: &str) -> (bool, usize)",
"{\n self.num_decoded += 1;\n if !new_text.is_empty() {\n self.reversed_text = reverse(new_text) + &self.reversed_text;\n\n if let Some(re) = &self.stop_trie {\n let matches = re.common_prefix_search(&self.reversed_text);\n let matched_length = matches.into_iter().map(|x| x.len()).max();\n if let Some(matched_length) = matched_length {\n return (true, matched_length);\n }\n }\n }",
"(false, 0)\n }\n}\n\n#[cfg(test)]",
"mod tests",
"{\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;\n\n #[test]",
"{\n\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;\n\n #[test]",
"fn test_trie_works()",
"{\n let text = reverse(\"void write_u32(std::uint32_t val) const {\\n write_raw(&val, sizeof(val));\\n }\\n\\n ~llama_file() {\\n if (fp) {\\n std::fclose(fp);\\n }\\n }\\n};\\n\\nvoid\");\n\n let trie = create_stop_trie(vec![\"\\n\\n\".to_owned(), \"\\n\\n \".to_owned()]);\n assert!(trie.common_prefix_search(&text).is_empty());",
"let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n }\n\n #[test]",
"fn test_stop_condition_max_length() {\n let factory = StopConditionFactory::default();\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"2\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"3\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"4\");\n assert!(!should_stop)\n }\n}",
"let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n \"<|file_sep|>\".to_owned(), // qwen 2.5 coder style\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n\n let qwen25coder = reverse(\"qwen25 style stop words;<|file_sep|>\");\n assert!(!trie.common_prefix_search(&qwen25coder).is_empty());\n }\n\n #[test]",
"fn test_stop_condition_max_length() {\n let factory = StopConditionFactory::default();\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"2\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"3\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"4\");\n assert!(!should_stop)\n }",
"#[test]\n fn test_stop_condition_additional_stop_words() {\n let factory = StopConditionFactory::with_stop_words(vec![\"<|endoftext|>\".to_owned()]);\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"<|endoftext|>\");\n assert!(should_stop);\n }\n}",
]
13 changes: 10 additions & 3 deletions crates/tabby-inference/src/code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use async_stream::stream;
use derive_builder::Builder;
use futures::StreamExt;
use tabby_common::languages::Language;
use tabby_common::{config::ModelConfig, languages::Language};

use crate::{decoding::StopConditionFactory, CompletionOptionsBuilder, CompletionStream};

Expand Down Expand Up @@ -31,10 +31,17 @@ pub struct CodeGeneration {
}

impl CodeGeneration {
pub fn new(imp: Arc<dyn CompletionStream>) -> Self {
pub fn new(imp: Arc<dyn CompletionStream>, config: Option<ModelConfig>) -> Self {
let additional_stop_words = match config {
Some(ModelConfig::Local(config)) => config.additional_stop_words.unwrap_or_default(),
Some(ModelConfig::Http(config)) => config.additional_stop_words.unwrap_or_default(),
_ => vec![],
};
let stop_condition_factory = StopConditionFactory::with_stop_words(additional_stop_words);

Self {
imp,
stop_condition_factory: StopConditionFactory::default(),
stop_condition_factory,
}
}
}
Expand Down
29 changes: 28 additions & 1 deletion crates/tabby-inference/src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use trie_rs::{Trie, TrieBuilder};

pub struct StopConditionFactory {
stop_trie_cache: DashMap<String, Trie<u8>>,
stop_words_from_model_config: Vec<String>,
}

fn reverse<T>(s: T) -> String
Expand All @@ -17,13 +18,21 @@ impl Default for StopConditionFactory {
fn default() -> Self {
Self {
stop_trie_cache: DashMap::new(),
stop_words_from_model_config: vec![],
}
}
}

type CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie<u8>>;

impl StopConditionFactory {
pub fn with_stop_words(stop_words: Vec<String>) -> Self {
Self {
stop_trie_cache: DashMap::new(),
stop_words_from_model_config: stop_words,
}
}

pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {
if let Some(language) = language {
StopCondition::new(self.get_trie(language), text)
Expand All @@ -33,7 +42,10 @@ impl StopConditionFactory {
}

fn get_trie<'a>(&'a self, language: &'static Language) -> Option<CachedTrie<'a>> {
let stop_words = language.get_stop_words();
let mut stop_words = language.get_stop_words();
// append model stop words
stop_words.extend(self.stop_words_from_model_config.iter().cloned());

if stop_words.is_empty() {
None
} else {
Expand Down Expand Up @@ -92,6 +104,7 @@ impl<'a> StopCondition<'a> {

#[cfg(test)]
mod tests {

use tabby_common::languages::UNKNOWN_LANGUAGE;

use super::*;
Expand All @@ -107,8 +120,12 @@ mod tests {
"\n\n".to_owned(),
"\n\n ".to_owned(),
"\nvoid".to_owned(),
"<|file_sep|>".to_owned(), // qwen 2.5 coder style
]);
assert!(!trie.common_prefix_search(&text).is_empty());

let qwen25coder = reverse("qwen25 style stop words;<|file_sep|>");
assert!(!trie.common_prefix_search(&qwen25coder).is_empty());
}

#[test]
Expand All @@ -124,4 +141,14 @@ mod tests {
let (should_stop, _) = cond.should_stop("4");
assert!(!should_stop)
}

#[test]
fn test_stop_condition_additional_stop_words() {
let factory = StopConditionFactory::with_stop_words(vec!["<|endoftext|>".to_owned()]);
let mut cond = factory.create("", Some(&UNKNOWN_LANGUAGE));
let (should_stop, _) = cond.should_stop("1");
assert!(!should_stop);
let (should_stop, _) = cond.should_stop("<|endoftext|>");
assert!(should_stop);
}
}
Loading
Loading