From 979a8eb84a7ab60bc887f00573ae3437338fff6d Mon Sep 17 00:00:00 2001 From: imShire Date: Fri, 19 Apr 2024 23:54:55 +0800 Subject: [PATCH] fix: fix needs: HashMap --- src/lib.rs | 64 +++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index cc37865..c2613b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -use serde_json::{Value, json}; +use serde_json::{json, Value}; use std::collections::HashMap; use std::error::Error; @@ -6,30 +6,60 @@ use std::error::Error; pub fn translate( text: &str, // 待翻译文本 from: &str, // 源语言 - to: &str, // 目标语言 + to: &str, // 目标语言 // (pot会根据info.json 中的 language 字段传入插件需要的语言代码,无需再次转换) - _detect: &str, // 检测到的语言 (若使用 detect, 需要手动转换) - _needs: HashMap<&str, String>,// 插件需要的其他参数,由info.json定义 + detect: &str, // 检测到的语言 (若使用 detect, 需要手动转换) + needs: HashMap, // 插件需要的其他参数,由info.json定义 ) -> Result> { let client = reqwest::blocking::ClientBuilder::new().build()?; - let default_url = "https://api.cohere.ai".to_string(); - let default_mode = "1".to_string(); - let default_model = "command-r-plus".to_string(); - let default_prompt = "".to_string(); - let api_url = _needs.get("apiUrl").unwrap_or(&default_url); - let apikey = _needs.get("apiKey"); - let model = _needs.get("model").unwrap_or(&default_model); - let mode = _needs.get("mode").unwrap_or(&default_mode); - let customize_prompt = _needs.get("customizePrompt").unwrap_or(&default_prompt); - // let api_url = _needs.get("apiUrl"); + let api_url = match needs.get("apiUrl") { + Some(raw_url) => { + if !raw_url.starts_with("http") { + format!("https://{raw_url}") + } else { + raw_url.to_string() + } + } + None => { + String::from("https://api.cohere.ai") + } + }; + + let model = match needs.get("model") { + Some(raw_model) => { + raw_model.to_string() + } + None => { + String::from("command-r-plus") + } + }; + + let mode = match needs.get("model") { + Some(raw_mode) => { + raw_mode.to_string() + } + None => { + String::from("1") + } + }; + let customize_prompt = match needs.get("customizePrompt") { + Some(raw_prompt) => { + raw_prompt.to_string() + } + None => { + String::from("") + } + }; + let apikey = needs.get("apiKey"); + // let api_url = needs.get("apiUrl"); let api_url_path = "/v1/chat"; - if apikey.unwrap_or(&"".to_string()).is_empty() { + if apikey.unwrap_or(&&"".to_string()).is_empty() { return Err("apiKey is required".into()); } println!("using default: \n{}\n{}\n{}\n{}\n", api_url,model,mode,apikey.unwrap()); let full_url = format!("{}{}", api_url, api_url_path); let auth_header = format!("bearer {}", apikey.unwrap()); - let body = build_request_body(model, mode, customize_prompt, text, from, to); + let body = build_request_body(&model, &mode, &customize_prompt, text, from, to); println!("body: \n{}\n{}\n{}\n", full_url,auth_header,body); let res = client .post(&full_url) @@ -93,7 +123,7 @@ mod tests { #[test] fn try_request() { let needs = HashMap::new(); - let result = translate("Hello", "auto", "zh", "en", needs); + let result = translate("Hello", "auto", "zh", "en", needs).unwrap(); println!("{result:?}"); } }