From 34f519587ed82d12ba3e6b2f67cc44e6b38d6d34 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Sun, 29 Sep 2024 06:05:25 +0000 Subject: [PATCH] [autofix.ci] apply automated fixes --- crates/aim-downloader/src/https.rs | 8 +- crates/tabby-common/src/env.rs | 2 +- crates/tabby-common/src/lib.rs | 2 +- crates/tabby-common/src/registry.rs | 23 ++--- crates/tabby-download/src/lib.rs | 128 +++++++++++++++++----------- 5 files changed, 88 insertions(+), 75 deletions(-) diff --git a/crates/aim-downloader/src/https.rs b/crates/aim-downloader/src/https.rs index 6d54f3145b3..0a45202f333 100644 --- a/crates/aim-downloader/src/https.rs +++ b/crates/aim-downloader/src/https.rs @@ -9,14 +9,13 @@ use crate::{ address::ParsedAddress, bar::WrappedBar, consts::*, - error::{DownloadError, ValidateError,HTTPHeaderError}, + error::{DownloadError, HTTPHeaderError, ValidateError}, hash::HashChecker, io, }; pub struct HTTPSHandler; impl HTTPSHandler { - pub async fn head(input: &str) -> Result { let parsed_address = ParsedAddress::parse_address(input, true); let res = Client::new() @@ -157,8 +156,6 @@ impl HTTPSHandler { name: input.into(), code: e.to_string(), })?; - - let total_size = downloaded + res.content_length().unwrap_or(0); @@ -295,10 +292,9 @@ async fn get_links_works_when_typical() { assert_eq!(result[0], expected); } - #[ignore] #[tokio::test] async fn head_works() { let result = HTTPSHandler::head("https://github.com/XAMPPRocky/tokei/releases/download/v12.0.4/tokei-x86_64-unknown-linux-gnu.tar.gz").await; assert!(result.is_ok()); -} \ No newline at end of file +} diff --git a/crates/tabby-common/src/env.rs b/crates/tabby-common/src/env.rs index 3054b7539f2..630b4a20ce0 100644 --- a/crates/tabby-common/src/env.rs +++ b/crates/tabby-common/src/env.rs @@ -9,4 +9,4 @@ pub fn get_huggingface_mirror_host() -> Option { // for debug only pub fn use_local_model_json() -> bool { std::env::var("TABBY_USE_LOCAL_MODEL_JSON").is_ok() -} \ No newline at end of file +} diff --git a/crates/tabby-common/src/lib.rs b/crates/tabby-common/src/lib.rs index f090475c388..d461c992985 100644 --- a/crates/tabby-common/src/lib.rs +++ b/crates/tabby-common/src/lib.rs @@ -4,10 +4,10 @@ pub mod api; pub mod axum; pub mod config; pub mod constants; +pub mod env; pub mod index; pub mod languages; pub mod path; pub mod registry; pub mod terminal; pub mod usage; -pub mod env; diff --git a/crates/tabby-common/src/registry.rs b/crates/tabby-common/src/registry.rs index 938aea91acd..a1929f25878 100644 --- a/crates/tabby-common/src/registry.rs +++ b/crates/tabby-common/src/registry.rs @@ -1,7 +1,6 @@ use std::{fs, path::PathBuf}; use anyhow::{Context, Result}; -use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use crate::{env::use_local_model_json, path::models_dir}; @@ -11,7 +10,6 @@ fn default_entrypoint() -> String { "model.gguf".to_string() } - #[derive(Serialize, Deserialize)] pub struct ModelInfo { pub name: String, @@ -64,7 +62,6 @@ pub struct ModelRegistry { pub models: Vec, } - // model registry tree structure // root: ~/.tabby/models/TABBYML @@ -75,20 +72,19 @@ pub struct ModelRegistry { // fn get_model_path(model_name) // for single model file -// -> {root}/{model_name}/ggml/model.gguf +// -> {root}/{model_name}/ggml/model.gguf // for multiple model files // -> {root}/{model_name}/ggml/{entrypoint} impl ModelRegistry { pub async fn new(registry: &str) -> Self { - if use_local_model_json() { - return Self { + Self { name: registry.to_owned(), models: load_local_registry(registry).unwrap_or_else(|_| { panic!("Failed to fetch model organization <{}>", registry) }), - }; + } } else { Self { name: registry.to_owned(), @@ -102,8 +98,6 @@ impl ModelRegistry { }), } } - - } // get_model_store_dir returns {root}/{name}/ggml, e.g.. ~/.tabby/models/TABBYML/StarCoder-1B/ggml @@ -117,7 +111,7 @@ impl ModelRegistry { } // get_legacy_model_path returns {root}/{name}/q8_0.v2.gguf, e.g. ~/.tabby/models/TABBYML/StarCoder-1B/q8_0.v2.gguf - fn get_legacy_model_path(&self, name:&str) ->PathBuf { + fn get_legacy_model_path(&self, name: &str) -> PathBuf { self.get_model_store_dir(name).join("q8_0.v2.gguf") } @@ -126,7 +120,8 @@ impl ModelRegistry { // for multiple model files, it returns {root}/{name}/ggml/{entrypoint} pub fn get_model_entry_path(&self, name: &str) -> PathBuf { let model_info = self.get_model_info(name); - self.get_model_store_dir(name).join(model_info.entrypoint.clone()) + self.get_model_store_dir(name) + .join(model_info.entrypoint.clone()) } pub fn migrate_model_path(&self, name: &str) -> Result<(), std::io::Error> { @@ -143,8 +138,6 @@ impl ModelRegistry { Ok(()) } - - pub fn save_model_info(&self, name: &str) { let model_info = self.get_model_info(name); let path = self.get_model_dir(name).join("tabby.json"); @@ -171,13 +164,11 @@ pub fn parse_model_id(model_id: &str) -> (&str, &str) { } } - - #[cfg(test)] mod tests { use temp_testdir::TempDir; - use super::{ModelRegistry, *}; + use super::ModelRegistry; use crate::path::set_tabby_root; #[tokio::test] diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index 7ffd7c2e099..2f20b8d5d83 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -3,37 +3,41 @@ use std::{ fs::{self}, path::Path, }; -use reqwest; -use aim_downloader::{bar::WrappedBar, error::DownloadError, hash::HashChecker, https::{self, HTTPSHandler}}; + +use aim_downloader::{ + bar::WrappedBar, + error::DownloadError, + hash::HashChecker, + https::{self, HTTPSHandler}, +}; use anyhow::{anyhow, bail, Result}; +use futures::future::join_all; +use regex::Regex; use tabby_common::registry::{parse_model_id, ModelInfo, ModelRegistry}; use tokio_retry::{ strategy::{jitter, ExponentialBackoff}, Retry, }; -use regex::Regex; use tracing::{info, warn}; -use futures::future::join_all; -fn filter_download_urls( model_info: &ModelInfo) -> Vec { +fn filter_download_urls(model_info: &ModelInfo) -> Vec { let download_host = tabby_common::env::get_download_host(); model_info .urls .iter() .flatten() - .filter_map(|f| + .filter_map(|f| { if f.contains(&download_host) { if let Some(mirror_host) = tabby_common::env::get_huggingface_mirror_host() { Some(f.replace("huggingface.co", &mirror_host)) } else { - Some(f.to_owned()) + Some(f.to_owned()) } } else { None } - ) + }) .collect() - } async fn download_model_impl( @@ -64,10 +68,14 @@ async fn download_model_impl( } } - let urls = filter_download_urls(&model_info); + let urls = filter_download_urls(model_info); - if urls.len() == 0 { - bail!("No download URLs available for <{}/{}>", registry.name, model_info.name); + if urls.is_empty() { + bail!( + "No download URLs available for <{}/{}>", + registry.name, + model_info.name + ); } if urls.len() > 1 { // if model_info.entrypoint.is_none(){ @@ -75,15 +83,25 @@ async fn download_model_impl( // } if let Some(urls_sha256) = &model_info.urls_sha256 { if urls_sha256.len() != urls.len() { - bail!("Number of urls_sha256 does not match number of URLs for <{}/{}>", registry.name, model_info.name); + bail!( + "Number of urls_sha256 does not match number of URLs for <{}/{}>", + registry.name, + model_info.name + ); } } else { - bail!("No urls_sha256 available for <{}/{}>", registry.name, model_info.name); + bail!( + "No urls_sha256 available for <{}/{}>", + registry.name, + model_info.name + ); } } - + // prepare for download - let dir = model_path.parent().ok_or_else(|| anyhow!("Must not be in root directory"))?; + let dir = model_path + .parent() + .ok_or_else(|| anyhow!("Must not be in root directory"))?; fs::create_dir_all(dir)?; let mut urls_sha256 = vec![]; @@ -95,37 +113,40 @@ async fn download_model_impl( let mut download_tasks = vec![]; for (url, sha256) in urls.iter().zip(urls_sha256.iter()) { - let dir = registry.get_model_store_dir(name).to_string_lossy().into_owned(); + let dir = registry + .get_model_store_dir(name) + .to_string_lossy() + .into_owned(); let filename = if urls.len() == 1 { Some(model_info.entrypoint.clone()) } else { None }; let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2); - download_tasks.push(Retry::spawn(strategy, move || { - let dir = dir.clone(); - let filename = filename.clone(); - download_file(&url, dir, filename, &sha256) - })); + download_tasks.push(Retry::spawn(strategy, move || { + let dir = dir.clone(); + let filename = filename.clone(); + download_file(url, dir, filename, sha256) + })); } let results = join_all(download_tasks).await; - + // handle errors - let errors: Vec = results.into_iter().filter_map(|e| { - if let Err(e) = e { - Some(e) - } else { - None - } - }).collect(); + let errors: Vec = results + .into_iter() + .filter_map(|e| if let Err(e) = e { Some(e) } else { None }) + .collect(); if errors.is_empty() { Ok(()) } else { - let combined_error = errors.into_iter().fold(anyhow::anyhow!("Multiple errors occurred"), |acc, err| acc.context(err)); + let combined_error = errors + .into_iter() + .fold(anyhow::anyhow!("Multiple errors occurred"), |acc, err| { + acc.context(err) + }); Err(combined_error) } - } async fn tryget_download_filename(url: &str) -> Result { @@ -134,10 +155,10 @@ async fn tryget_download_filename(url: &str) -> Result { if let Some(content_disposition) = response.get(reqwest::header::CONTENT_DISPOSITION) { if let Ok(disposition_str) = content_disposition.to_str() { let re = Regex::new(r#"filename="(.+?)""#).unwrap(); - let file_name = re.captures(disposition_str). - and_then(|cap| cap.get(1)). - map(|m| m.as_str(). - to_owned()); + let file_name = re + .captures(disposition_str) + .and_then(|cap| cap.get(1)) + .map(|m| m.as_str().to_owned()); if let Some(file_name) = file_name { return Ok(file_name); } @@ -154,13 +175,16 @@ async fn tryget_download_filename(url: &str) -> Result { } else { Err(anyhow!("Failed to get filename from URL {}", url)) } - } -async fn download_file(url: &str, dir: String, filename:Option, expected_sha256: &str) -> Result<()> { - +async fn download_file( + url: &str, + dir: String, + filename: Option, + expected_sha256: &str, +) -> Result<()> { let filename = filename.unwrap_or(tryget_download_filename(url).await?); - let fullpath = format!{"{}{}{}", dir,std::path::MAIN_SEPARATOR ,filename}; + let fullpath = format! {"{}{}{}", dir,std::path::MAIN_SEPARATOR ,filename}; let intermediate_filename = fullpath.clone() + ".tmp"; let mut bar = WrappedBar::new(0, url, false); if let Err(e) = @@ -194,50 +218,52 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) { #[cfg(test)] mod tests { use tabby_common::registry::ModelInfo; + use super::*; #[test] fn test_filter_download_urls() { // multiple urls let model_info = ModelInfo { name: "test".to_string(), - urls: Some(vec!["https://huggingface.co/test".to_string(), "https://huggingface.co/test2".to_string(), "https://modelscope.co/test2".to_string()]), + urls: Some(vec![ + "https://huggingface.co/test".to_string(), + "https://huggingface.co/test2".to_string(), + "https://modelscope.co/test2".to_string(), + ]), urls_sha256: Some(vec!["test_sha256".to_string(), "test2_sha256".to_string()]), entrypoint: "test".to_string(), sha256: "test_sha256".to_string(), prompt_template: None, - chat_template: None, + chat_template: None, }; let urls = super::filter_download_urls(&model_info); assert_eq!(urls.len(), 2); assert_eq!(urls[0], "https://huggingface.co/test"); assert_eq!(urls[1], "https://huggingface.co/test2"); - // single url let model_info = ModelInfo { name: "test".to_string(), - urls: Some(vec!["https://huggingface.co/test".to_string(), "https://modelscope.co/test2".to_string()]), + urls: Some(vec![ + "https://huggingface.co/test".to_string(), + "https://modelscope.co/test2".to_string(), + ]), urls_sha256: None, entrypoint: "model.gguf".to_string(), sha256: "test_sha256".to_string(), prompt_template: None, - chat_template: None, + chat_template: None, }; let urls = super::filter_download_urls(&model_info); assert_eq!(urls.len(), 1); assert_eq!(urls[0], "https://huggingface.co/test"); - - } #[tokio::test] async fn test_tryget_download_filename() { - let url = "https://huggingface.co/TabbyML/models/resolve/main/.gitattributes"; let result = tryget_download_filename(url).await; assert!(result.is_ok()); assert_eq!(result.unwrap(), ".gitattributes"); } - - -} \ No newline at end of file +}