From 9eacd4f54761b56d78b7df2c64239849c65e391a Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Tue, 22 Oct 2024 14:13:51 +0800 Subject: [PATCH 1/7] feat: replease crlf to lf in completions --- crates/tabby/src/routes/completions.rs | 33 +++++++++++++++++++++++-- crates/tabby/src/services/completion.rs | 10 ++++---- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/crates/tabby/src/routes/completions.rs b/crates/tabby/src/routes/completions.rs index e1d8eea3ae2..282ae3f34c2 100644 --- a/crates/tabby/src/routes/completions.rs +++ b/crates/tabby/src/routes/completions.rs @@ -1,9 +1,11 @@ use std::sync::Arc; -use axum::{extract::State, Extension, Json}; +use async_openai::types::Choice; +use axum::{extract::State, http::request, Extension, Json}; use axum_extra::{headers, TypedHeader}; use hyper::StatusCode; use tabby_common::axum::{AllowedCodeRepository, MaybeUser}; +use tantivy::Segment; use tracing::{instrument, warn}; use crate::services::completion::{CompletionRequest, CompletionResponse, CompletionService}; @@ -36,11 +38,38 @@ pub async fn completions( let user_agent = user_agent.map(|x| x.0.to_string()); + let mut use_crlf = false; + if let Some(segments) = request.segments { + let mut new_segments = segments.clone(); + if segments.prefix.contains("\r\n") { + use_crlf = true; + new_segments.prefix = segments.prefix.replace("\r\n", "\n"); + } + if let Some(suffix) = segments.suffix { + if suffix.contains("\r\n") { + use_crlf = true; + new_segments.suffix = Some(suffix.replace("\r\n", "\n")); + } + } + request.segments = Some(new_segments); + } + match state .generate(&request, &allowed_code_repository, user_agent.as_deref()) .await { - Ok(resp) => Ok(Json(resp)), + Ok(resp) => { + if use_crlf { + let mut response_crlf = resp.clone(); + for (index, choice) in resp.choices.iter().enumerate() { + response_crlf.choices[index].text = choice.text.replace("\n", "\r\n"); + } + + return Ok(Json(response_crlf)); + } + + Ok(Json(resp)) + } Err(err) => { warn!("{}", err); Err(StatusCode::BAD_REQUEST) diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index b7a076fd766..4fa0dc167f5 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -42,7 +42,7 @@ pub struct CompletionRequest { language: Option, /// When segments are set, the `prompt` is ignored during the inference. - segments: Option, + pub segments: Option, /// A unique identifier representing your end-user, which can help Tabby to monitor & generating /// reports. @@ -105,10 +105,10 @@ fn default_false() -> bool { #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct Segments { /// Content that appears before the cursor in the editor window. - prefix: String, + pub prefix: String, /// Content that appears after the cursor in the editor window. - suffix: Option, + pub suffix: Option, /// The relative path of the file that is being edited. /// - When [Segments::git_url] is set, this is the path of the file in the git repository. @@ -191,7 +191,7 @@ impl From for api::event::Declaration { #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct Choice { index: u32, - text: String, + pub text: String, } impl Choice { @@ -214,7 +214,7 @@ pub struct Snippet { }))] pub struct CompletionResponse { id: String, - choices: Vec, + pub choices: Vec, #[serde(skip_serializing_if = "Option::is_none")] debug_data: Option, From 53fc9f947ae0290cf1263830176249f4eff6021c Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Tue, 22 Oct 2024 10:41:40 +0000 Subject: [PATCH 2/7] [autofix.ci] apply automated fixes --- crates/tabby/src/routes/completions.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/crates/tabby/src/routes/completions.rs b/crates/tabby/src/routes/completions.rs index 282ae3f34c2..99da65f68ba 100644 --- a/crates/tabby/src/routes/completions.rs +++ b/crates/tabby/src/routes/completions.rs @@ -1,11 +1,9 @@ use std::sync::Arc; -use async_openai::types::Choice; -use axum::{extract::State, http::request, Extension, Json}; +use axum::{extract::State, Extension, Json}; use axum_extra::{headers, TypedHeader}; use hyper::StatusCode; use tabby_common::axum::{AllowedCodeRepository, MaybeUser}; -use tantivy::Segment; use tracing::{instrument, warn}; use crate::services::completion::{CompletionRequest, CompletionResponse, CompletionService}; From e489cc9647ba2960a74d91381c54193abd4b8ee7 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Wed, 23 Oct 2024 15:26:06 +0800 Subject: [PATCH 3/7] chore: crlf logic should under service --- Cargo.lock | 1 + crates/tabby/Cargo.toml | 1 + crates/tabby/src/routes/completions.rs | 29 +------------------------ crates/tabby/src/services/completion.rs | 27 ++++++++++++++++++----- 4 files changed, 25 insertions(+), 33 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eebeaf59b63..e9d4ee0db72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5001,6 +5001,7 @@ dependencies = [ "llama-cpp-server", "nvml-wrapper", "openssl", + "regex", "reqwest", "reqwest-eventsource", "serde", diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 02d5242b284..2da06ef8808 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -56,6 +56,7 @@ color-eyre = { version = "0.6.3" } reqwest.workspace = true async-openai.workspace = true spinners = "4.1.1" +regex.workspace = true [dependencies.openssl] optional = true diff --git a/crates/tabby/src/routes/completions.rs b/crates/tabby/src/routes/completions.rs index 99da65f68ba..e1d8eea3ae2 100644 --- a/crates/tabby/src/routes/completions.rs +++ b/crates/tabby/src/routes/completions.rs @@ -36,38 +36,11 @@ pub async fn completions( let user_agent = user_agent.map(|x| x.0.to_string()); - let mut use_crlf = false; - if let Some(segments) = request.segments { - let mut new_segments = segments.clone(); - if segments.prefix.contains("\r\n") { - use_crlf = true; - new_segments.prefix = segments.prefix.replace("\r\n", "\n"); - } - if let Some(suffix) = segments.suffix { - if suffix.contains("\r\n") { - use_crlf = true; - new_segments.suffix = Some(suffix.replace("\r\n", "\n")); - } - } - request.segments = Some(new_segments); - } - match state .generate(&request, &allowed_code_repository, user_agent.as_deref()) .await { - Ok(resp) => { - if use_crlf { - let mut response_crlf = resp.clone(); - for (index, choice) in resp.choices.iter().enumerate() { - response_crlf.choices[index].text = choice.text.replace("\n", "\r\n"); - } - - return Ok(Json(response_crlf)); - } - - Ok(Json(resp)) - } + Ok(resp) => Ok(Json(resp)), Err(err) => { warn!("{}", err); Err(StatusCode::BAD_REQUEST) diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index 4fa0dc167f5..291791364a9 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -1,5 +1,6 @@ mod completion_prompt; +use regex::Regex; use std::sync::Arc; use serde::{Deserialize, Serialize}; @@ -321,27 +322,43 @@ impl CompletionService { self.config.max_decoding_tokens, ); + let mut use_crlf = false; let (prompt, segments, snippets) = if let Some(prompt) = request.raw_prompt() { (prompt, None, vec![]) } else if let Some(segments) = request.segments.as_ref() { + let mut new_segments = segments.clone(); + if segments.prefix.contains("\r\n") { + use_crlf = true; + new_segments.prefix = segments.prefix.replace("\r\n", "\n"); + } + if let Some(suffix) = &segments.suffix { + if suffix.contains("\r\n") { + use_crlf = true; + new_segments.suffix = Some(suffix.replace("\r\n", "\n")); + } + } + let snippets = self .build_snippets( &language, - segments, + &new_segments, allowed_code_repository, request.disable_retrieval_augmented_code_completion(), ) .await; let prompt = self .prompt_builder - .build(&language, segments.clone(), &snippets); + .build(&language, new_segments.clone(), &snippets); (prompt, Some(segments), snippets) } else { return Err(CompletionError::EmptyPrompt); }; - let text = self.engine.generate(&prompt, options).await; - let segments = segments.cloned().map(|s| s.into()); + let mut text = self.engine.generate(&prompt, options).await; + if use_crlf { + let re = Regex::new(r"([^\r])\n").unwrap(); // Match \n that is preceded by anything except \r + text = re.replace_all(&text, "$1\r\n").to_string() // Replace with captured character and \r\n + } self.logger.log( request.user.clone(), @@ -349,7 +366,7 @@ impl CompletionService { completion_id: completion_id.clone(), language, prompt: prompt.clone(), - segments, + segments: segments.cloned().map(|x| x.into()), choices: vec![api::event::Choice { index: 0, text: text.clone(), From 14c9571d16b7d4c5f86c4fda78cbc45c79c190d9 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 07:34:54 +0000 Subject: [PATCH 4/7] [autofix.ci] apply automated fixes --- crates/tabby/src/services/completion.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index 291791364a9..3730a3f55db 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -1,8 +1,8 @@ mod completion_prompt; -use regex::Regex; use std::sync::Arc; +use regex::Regex; use serde::{Deserialize, Serialize}; use tabby_common::{ api::{ From 5ed58c4eb4f824385c290855419ef8c6272a9c92 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Wed, 23 Oct 2024 15:53:01 +0800 Subject: [PATCH 5/7] chore: replace CRLF for prompt instead of segments --- crates/tabby/src/services/completion.rs | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index 3730a3f55db..978875d3256 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -43,7 +43,7 @@ pub struct CompletionRequest { language: Option, /// When segments are set, the `prompt` is ignored during the inference. - pub segments: Option, + segments: Option, /// A unique identifier representing your end-user, which can help Tabby to monitor & generating /// reports. @@ -106,10 +106,10 @@ fn default_false() -> bool { #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct Segments { /// Content that appears before the cursor in the editor window. - pub prefix: String, + prefix: String, /// Content that appears after the cursor in the editor window. - pub suffix: Option, + suffix: Option, /// The relative path of the file that is being edited. /// - When [Segments::git_url] is set, this is the path of the file in the git repository. @@ -192,7 +192,7 @@ impl From for api::event::Declaration { #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct Choice { index: u32, - pub text: String, + text: String, } impl Choice { @@ -215,7 +215,7 @@ pub struct Snippet { }))] pub struct CompletionResponse { id: String, - pub choices: Vec, + choices: Vec, #[serde(skip_serializing_if = "Option::is_none")] debug_data: Option, @@ -326,30 +326,31 @@ impl CompletionService { let (prompt, segments, snippets) = if let Some(prompt) = request.raw_prompt() { (prompt, None, vec![]) } else if let Some(segments) = request.segments.as_ref() { - let mut new_segments = segments.clone(); if segments.prefix.contains("\r\n") { use_crlf = true; - new_segments.prefix = segments.prefix.replace("\r\n", "\n"); } if let Some(suffix) = &segments.suffix { if suffix.contains("\r\n") { use_crlf = true; - new_segments.suffix = Some(suffix.replace("\r\n", "\n")); } } let snippets = self .build_snippets( &language, - &new_segments, + &segments, allowed_code_repository, request.disable_retrieval_augmented_code_completion(), ) .await; let prompt = self .prompt_builder - .build(&language, new_segments.clone(), &snippets); - (prompt, Some(segments), snippets) + .build(&language, segments.clone(), &snippets); + if use_crlf { + (prompt.replace("\r\n", "\n"), Some(segments), snippets) + } else { + (prompt, Some(segments), snippets) + } } else { return Err(CompletionError::EmptyPrompt); }; From 3197cb326083059bc27104ecb8a53f4f570162ca Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 07:59:40 +0000 Subject: [PATCH 6/7] [autofix.ci] apply automated fixes --- crates/tabby/src/services/completion.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index 978875d3256..8ac1c4adbcb 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -338,7 +338,7 @@ impl CompletionService { let snippets = self .build_snippets( &language, - &segments, + segments, allowed_code_repository, request.disable_retrieval_augmented_code_completion(), ) From 120602e616a0d32991bc59c18638e60feffe5ddd Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Wed, 23 Oct 2024 17:40:16 +0800 Subject: [PATCH 7/7] chore: use functions and ut for replacing crlf --- crates/tabby/src/services/completion.rs | 151 +++++++++++++++++++++--- 1 file changed, 133 insertions(+), 18 deletions(-) diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index 8ac1c4adbcb..867fbad87ca 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -326,14 +326,9 @@ impl CompletionService { let (prompt, segments, snippets) = if let Some(prompt) = request.raw_prompt() { (prompt, None, vec![]) } else if let Some(segments) = request.segments.as_ref() { - if segments.prefix.contains("\r\n") { + if contains_crlf(segments) { use_crlf = true; } - if let Some(suffix) = &segments.suffix { - if suffix.contains("\r\n") { - use_crlf = true; - } - } let snippets = self .build_snippets( @@ -346,20 +341,13 @@ impl CompletionService { let prompt = self .prompt_builder .build(&language, segments.clone(), &snippets); - if use_crlf { - (prompt.replace("\r\n", "\n"), Some(segments), snippets) - } else { - (prompt, Some(segments), snippets) - } + + (override_prompt(prompt, use_crlf), Some(segments), snippets) } else { return Err(CompletionError::EmptyPrompt); }; - let mut text = self.engine.generate(&prompt, options).await; - if use_crlf { - let re = Regex::new(r"([^\r])\n").unwrap(); // Match \n that is preceded by anything except \r - text = re.replace_all(&text, "$1\r\n").to_string() // Replace with captured character and \r\n - } + let generated = override_generated(self.engine.generate(&prompt, options).await, use_crlf); self.logger.log( request.user.clone(), @@ -370,7 +358,7 @@ impl CompletionService { segments: segments.cloned().map(|x| x.into()), choices: vec![api::event::Choice { index: 0, - text: text.clone(), + text: generated.clone(), }], user_agent: user_agent.map(|x| x.to_owned()), }, @@ -386,12 +374,47 @@ impl CompletionService { Ok(CompletionResponse::new( completion_id, - vec![Choice::new(text)], + vec![Choice::new(generated)], debug_data, )) } } +fn contains_crlf(segments: &Segments) -> bool { + if segments.prefix.contains("\r\n") { + return true; + } + if let Some(suffix) = &segments.suffix { + if suffix.contains("\r\n") { + return true; + } + } + + false +} + +fn override_prompt(prompt: String, use_crlf: bool) -> String { + if use_crlf { + prompt.replace("\r\n", "\n") + } else { + prompt + } +} + +/// override_generated replaces \n with \r\n in the generated text if use_crlf is true. +/// This is used to ensure that the generated text has the same line endings as the prompt. +/// +/// Because there might be \r\n in the text, which also has a `\n` and should not be replaced, +/// we can not simply replace \n with \r\n. +fn override_generated(generated: String, use_crlf: bool) -> String { + if use_crlf { + let re = Regex::new(r"([^\r])\n").unwrap(); // Match \n that is preceded by anything except \r + re.replace_all(&generated, "$1\r\n").to_string() // Replace with captured character and \r\n + } else { + generated + } +} + pub async fn create_completion_service_and_chat( config: &CompletionConfig, code: Arc, @@ -508,4 +531,96 @@ mod tests { .build("rust", segment.clone(), &[]); assert_eq!(prompt, "
fn hello_world() -> &'static str {}");
     }
+
+    #[test]
+    fn test_contains_crlf() {
+        let contained_crlf = vec![
+            Segments {
+                prefix: "fn hello_world() -> &'static str {\r\n".into(),
+                suffix: Some("}".into()),
+                filepath: None,
+                git_url: None,
+                declarations: None,
+                relevant_snippets_from_changed_files: None,
+                relevant_snippets_from_recently_opened_files: None,
+                clipboard: None,
+            },
+            Segments {
+                prefix: "fn hello_world() -> &'static str {".into(),
+                suffix: Some("}\r\n".into()),
+                filepath: None,
+                git_url: None,
+                declarations: None,
+                relevant_snippets_from_changed_files: None,
+                relevant_snippets_from_recently_opened_files: None,
+                clipboard: None,
+            },
+            Segments {
+                prefix: "fn hello_world() -> &'static str {\r\n".into(),
+                suffix: Some("}\r\n".into()),
+                filepath: None,
+                git_url: None,
+                declarations: None,
+                relevant_snippets_from_changed_files: None,
+                relevant_snippets_from_recently_opened_files: None,
+                clipboard: None,
+            },
+        ];
+        for segments in contained_crlf {
+            assert!(contains_crlf(&segments));
+        }
+
+        let not_contained_crlf = vec![Segments {
+            prefix: "fn hello_world() -> &'static str {\r".into(),
+            suffix: Some("}\n".into()),
+            filepath: None,
+            git_url: None,
+            declarations: None,
+            relevant_snippets_from_changed_files: None,
+            relevant_snippets_from_recently_opened_files: None,
+            clipboard: None,
+        }];
+        for segments in not_contained_crlf {
+            assert!(!contains_crlf(&segments));
+        }
+    }
+
+    #[test]
+    fn test_override_prompt() {
+        let prompt = "fn hello_world() -> &'static str {\r\n".to_string();
+        let use_crlf = true;
+        assert_eq!(
+            override_prompt(prompt.clone(), use_crlf),
+            "fn hello_world() -> &'static str {\n"
+        );
+
+        let use_crlf = false;
+        assert_eq!(override_prompt(prompt.clone(), use_crlf), prompt);
+    }
+
+    #[test]
+    fn test_override_generated() {
+        let cases = vec![
+            (
+                "fn hello_world() -> &'static str {\r\n".to_string(),
+                "fn hello_world() -> &'static str {\r\n".to_string(),
+            ),
+            (
+                "fn hello_world() -> &'static str {\n".to_string(),
+                "fn hello_world() -> &'static str {\r\n".to_string(),
+            ),
+            (
+                "fn hello_world() -> &'static str {\r".to_string(),
+                "fn hello_world() -> &'static str {\r".to_string(),
+            ),
+            (
+                "fn hello_world() -> &'static str {".to_string(),
+                "fn hello_world() -> &'static str {".to_string(),
+            ),
+        ];
+
+        for (generated, expected) in cases {
+            assert_eq!(override_generated(generated, true), expected);
+        }
+    }
 }