diff --git a/Cargo.lock b/Cargo.lock index 81417cec9d9..c6c3671ae2d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5107,7 +5107,11 @@ version = "0.18.0-dev.0" dependencies = [ "aim-downloader", "anyhow", + "futures", + "regex", + "reqwest", "tabby-common", + "tokio", "tokio-retry", "tracing", ] diff --git a/crates/aim-downloader/src/https.rs b/crates/aim-downloader/src/https.rs index a2fea8c7013..0a45202f333 100644 --- a/crates/aim-downloader/src/https.rs +++ b/crates/aim-downloader/src/https.rs @@ -2,20 +2,36 @@ use std::{cmp::min, io::Error}; use futures_util::StreamExt; use regex::Regex; -use reqwest::Client; +use reqwest::{header::HeaderMap, Client}; use tokio_util::io::ReaderStream; use crate::{ address::ParsedAddress, bar::WrappedBar, consts::*, - error::{DownloadError, ValidateError}, + error::{DownloadError, HTTPHeaderError, ValidateError}, hash::HashChecker, io, }; pub struct HTTPSHandler; impl HTTPSHandler { + pub async fn head(input: &str) -> Result { + let parsed_address = ParsedAddress::parse_address(input, true); + let res = Client::new() + .head(input) + .header( + reqwest::header::USER_AGENT, + reqwest::header::HeaderValue::from_static(CLIENT_ID), + ) + .basic_auth(parsed_address.username, Some(parsed_address.password)) + .send() + .await + .map_err(|_| format!("Failed to HEAD from {}", &input)) + .unwrap(); + Ok(res.headers().clone()) + } + pub async fn get( input: &str, output: &str, @@ -275,3 +291,10 @@ async fn get_links_works_when_typical() { assert_eq!(result[0], expected); } + +#[ignore] +#[tokio::test] +async fn head_works() { + let result = HTTPSHandler::head("https://github.com/XAMPPRocky/tokei/releases/download/v12.0.4/tokei-x86_64-unknown-linux-gnu.tar.gz").await; + assert!(result.is_ok()); +} diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index b6294235231..b7084932062 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -10,7 +10,7 @@ use serde::Deserialize; use supervisor::LlamaCppSupervisor; use tabby_common::{ config::{HttpModelConfigBuilder, LocalModelConfig, ModelConfig}, - registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}, + registry::{parse_model_id, ModelRegistry}, }; use tabby_inference::{ChatCompletionStream, CompletionOptions, CompletionStream, Embedding}; @@ -277,14 +277,9 @@ pub async fn create_embedding(config: &ModelConfig) -> Arc { } async fn resolve_model_path(model_id: &str) -> String { - let path = PathBuf::from(model_id); - let path = if path.exists() { - path.join(GGML_MODEL_RELATIVE_PATH.as_str()) - } else { - let (registry, name) = parse_model_id(model_id); - let registry = ModelRegistry::new(registry).await; - registry.get_model_path(name) - }; + let (registry, name) = parse_model_id(model_id); + let registry = ModelRegistry::new(registry).await; + let path = registry.get_model_entry_path(name); path.display().to_string() } diff --git a/crates/tabby-common/src/env.rs b/crates/tabby-common/src/env.rs new file mode 100644 index 00000000000..630b4a20ce0 --- /dev/null +++ b/crates/tabby-common/src/env.rs @@ -0,0 +1,12 @@ +pub fn get_download_host() -> String { + std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or_else(|_| "huggingface.co".to_string()) +} + +pub fn get_huggingface_mirror_host() -> Option { + std::env::var("TABBY_HUGGINGFACE_HOST_OVERRIDE").ok() +} + +// for debug only +pub fn use_local_model_json() -> bool { + std::env::var("TABBY_USE_LOCAL_MODEL_JSON").is_ok() +} diff --git a/crates/tabby-common/src/lib.rs b/crates/tabby-common/src/lib.rs index 8d81958238d..d461c992985 100644 --- a/crates/tabby-common/src/lib.rs +++ b/crates/tabby-common/src/lib.rs @@ -4,6 +4,7 @@ pub mod api; pub mod axum; pub mod config; pub mod constants; +pub mod env; pub mod index; pub mod languages; pub mod path; diff --git a/crates/tabby-common/src/registry.rs b/crates/tabby-common/src/registry.rs index 862e2388c02..a1929f25878 100644 --- a/crates/tabby-common/src/registry.rs +++ b/crates/tabby-common/src/registry.rs @@ -1,10 +1,14 @@ use std::{fs, path::PathBuf}; use anyhow::{Context, Result}; -use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; -use crate::path::models_dir; +use crate::{env::use_local_model_json, path::models_dir}; + +// default_entrypoint is legacy entrypoint for single model file +fn default_entrypoint() -> String { + "model.gguf".to_string() +} #[derive(Serialize, Deserialize)] pub struct ModelInfo { @@ -16,6 +20,10 @@ pub struct ModelInfo { #[serde(skip_serializing_if = "Option::is_none")] pub urls: Option>, pub sha256: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub urls_sha256: Option>, + #[serde(default = "default_entrypoint")] + pub entrypoint: String, } fn models_json_file(registry: &str) -> PathBuf { @@ -54,30 +62,71 @@ pub struct ModelRegistry { pub models: Vec, } +// model registry tree structure + +// root: ~/.tabby/models/TABBYML + +// fn get_model_root_dir(model_name) -> {root}/{model_name} + +// fn get_model_dir(model_name) -> {root}/{model_name}/ggml + +// fn get_model_path(model_name) +// for single model file +// -> {root}/{model_name}/ggml/model.gguf +// for multiple model files +// -> {root}/{model_name}/ggml/{entrypoint} + impl ModelRegistry { pub async fn new(registry: &str) -> Self { - Self { - name: registry.to_owned(), - models: load_remote_registry(registry).await.unwrap_or_else(|err| { - load_local_registry(registry).unwrap_or_else(|_| { - panic!( - "Failed to fetch model organization <{}>: {:?}", - registry, err - ) - }) - }), + if use_local_model_json() { + Self { + name: registry.to_owned(), + models: load_local_registry(registry).unwrap_or_else(|_| { + panic!("Failed to fetch model organization <{}>", registry) + }), + } + } else { + Self { + name: registry.to_owned(), + models: load_remote_registry(registry).await.unwrap_or_else(|err| { + load_local_registry(registry).unwrap_or_else(|_| { + panic!( + "Failed to fetch model organization <{}>: {:?}", + registry, err + ) + }) + }), + } } } - fn get_model_dir(&self, name: &str) -> PathBuf { + // get_model_store_dir returns {root}/{name}/ggml, e.g.. ~/.tabby/models/TABBYML/StarCoder-1B/ggml + pub fn get_model_store_dir(&self, name: &str) -> PathBuf { + models_dir().join(&self.name).join(name).join("ggml") + } + + // get_model_dir returns {root}/{name}, e.g. ~/.tabby/models/TABBYML/StarCoder-1B + pub fn get_model_dir(&self, name: &str) -> PathBuf { models_dir().join(&self.name).join(name) } + // get_legacy_model_path returns {root}/{name}/q8_0.v2.gguf, e.g. ~/.tabby/models/TABBYML/StarCoder-1B/q8_0.v2.gguf + fn get_legacy_model_path(&self, name: &str) -> PathBuf { + self.get_model_store_dir(name).join("q8_0.v2.gguf") + } + + // get_model_path returns the entrypoint of the model, + // for single model file, it returns {root}/{name}/ggml/model.gguf + // for multiple model files, it returns {root}/{name}/ggml/{entrypoint} + pub fn get_model_entry_path(&self, name: &str) -> PathBuf { + let model_info = self.get_model_info(name); + self.get_model_store_dir(name) + .join(model_info.entrypoint.clone()) + } + pub fn migrate_model_path(&self, name: &str) -> Result<(), std::io::Error> { - let model_path = self.get_model_path(name); - let old_model_path = self - .get_model_dir(name) - .join(LEGACY_GGML_MODEL_RELATIVE_PATH.as_str()); + let model_path = self.get_model_entry_path(name); + let old_model_path = self.get_legacy_model_path(name); if !model_path.exists() && old_model_path.exists() { std::fs::rename(&old_model_path, &model_path)?; @@ -89,11 +138,6 @@ impl ModelRegistry { Ok(()) } - pub fn get_model_path(&self, name: &str) -> PathBuf { - self.get_model_dir(name) - .join(GGML_MODEL_RELATIVE_PATH.as_str()) - } - pub fn save_model_info(&self, name: &str) { let model_info = self.get_model_info(name); let path = self.get_model_dir(name).join("tabby.json"); @@ -120,18 +164,11 @@ pub fn parse_model_id(model_id: &str) -> (&str, &str) { } } -lazy_static! { - pub static ref LEGACY_GGML_MODEL_RELATIVE_PATH: String = - format!("ggml{}q8_0.v2.gguf", std::path::MAIN_SEPARATOR_STR); - pub static ref GGML_MODEL_RELATIVE_PATH: String = - format!("ggml{}model.gguf", std::path::MAIN_SEPARATOR_STR); -} - #[cfg(test)] mod tests { use temp_testdir::TempDir; - use super::{ModelRegistry, *}; + use super::ModelRegistry; use crate::path::set_tabby_root; #[tokio::test] @@ -140,9 +177,9 @@ mod tests { set_tabby_root(root.to_path_buf()); let registry = ModelRegistry::new("TabbyML").await; - let dir = registry.get_model_dir("StarCoder-1B"); + let name = "StarCoder-1B"; - let old_model_path = dir.join(LEGACY_GGML_MODEL_RELATIVE_PATH.as_str()); + let old_model_path = registry.get_legacy_model_path(name); tokio::fs::create_dir_all(old_model_path.parent().unwrap()) .await .unwrap(); @@ -154,7 +191,7 @@ mod tests { .unwrap(); registry.migrate_model_path("StarCoder-1B").unwrap(); - assert!(registry.get_model_path("StarCoder-1B").exists()); + assert!(registry.get_model_entry_path("StarCoder-1B").exists()); assert!(old_model_path.exists()); } } diff --git a/crates/tabby-download/Cargo.toml b/crates/tabby-download/Cargo.toml index 57834e3681a..50b6e92c6e6 100644 --- a/crates/tabby-download/Cargo.toml +++ b/crates/tabby-download/Cargo.toml @@ -9,3 +9,7 @@ tabby-common = { path = "../tabby-common" } anyhow = { workspace = true } tracing = { workspace = true } tokio-retry = "0.3.0" +futures.workspace = true +reqwest = { workspace = true } +tokio = {workspace=true} +regex = {workspace=true} \ No newline at end of file diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index cb05301b1e5..2f20b8d5d83 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -4,21 +4,40 @@ use std::{ path::Path, }; -use aim_downloader::{bar::WrappedBar, error::DownloadError, hash::HashChecker, https}; +use aim_downloader::{ + bar::WrappedBar, + error::DownloadError, + hash::HashChecker, + https::{self, HTTPSHandler}, +}; use anyhow::{anyhow, bail, Result}; -use tabby_common::registry::{parse_model_id, ModelRegistry}; +use futures::future::join_all; +use regex::Regex; +use tabby_common::registry::{parse_model_id, ModelInfo, ModelRegistry}; use tokio_retry::{ strategy::{jitter, ExponentialBackoff}, Retry, }; use tracing::{info, warn}; -fn select_by_download_host(url: &String) -> bool { - if let Ok(host) = std::env::var("TABBY_DOWNLOAD_HOST") { - url.contains(&host) - } else { - true - } +fn filter_download_urls(model_info: &ModelInfo) -> Vec { + let download_host = tabby_common::env::get_download_host(); + model_info + .urls + .iter() + .flatten() + .filter_map(|f| { + if f.contains(&download_host) { + if let Some(mirror_host) = tabby_common::env::get_huggingface_mirror_host() { + Some(f.replace("huggingface.co", &mirror_host)) + } else { + Some(f.to_owned()) + } + } else { + None + } + }) + .collect() } async fn download_model_impl( @@ -30,7 +49,8 @@ async fn download_model_impl( registry.save_model_info(name); registry.migrate_model_path(name)?; - let model_path = registry.get_model_path(name); + + let model_path = registry.get_model_entry_path(name); if model_path.exists() { if !prefer_local_file { info!("Checking model integrity.."); @@ -41,49 +61,132 @@ async fn download_model_impl( "Checksum doesn't match for <{}/{}>, re-downloading...", registry.name, name ); - fs::remove_file(&model_path)?; + + fs::remove_dir_all(registry.get_model_store_dir(name))?; } else { return Ok(()); } } - let Some(model_url) = model_info - .urls - .iter() - .flatten() - .find(|x| select_by_download_host(x)) - else { - return Err(anyhow!("No valid url for model <{}>", model_info.name)); - }; - - // Replace the huggingface.co domain with the mirror host if it is set. - let model_url = if let Ok(host) = std::env::var("TABBY_HUGGINGFACE_HOST_OVERRIDE") { - model_url.replace("huggingface.co", &host) - } else { - model_url.to_owned() - }; - - let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2); - let download_job = Retry::spawn(strategy, || { - download_file(&model_url, model_path.as_path(), &model_info.sha256) - }); - download_job.await?; - Ok(()) -} + let urls = filter_download_urls(model_info); + + if urls.is_empty() { + bail!( + "No download URLs available for <{}/{}>", + registry.name, + model_info.name + ); + } + if urls.len() > 1 { + // if model_info.entrypoint.is_none(){ + // bail!("Multiple download URLs available for <{}/{}>, but no entrypoint specified", registry.name, model_info.name); + // } + if let Some(urls_sha256) = &model_info.urls_sha256 { + if urls_sha256.len() != urls.len() { + bail!( + "Number of urls_sha256 does not match number of URLs for <{}/{}>", + registry.name, + model_info.name + ); + } + } else { + bail!( + "No urls_sha256 available for <{}/{}>", + registry.name, + model_info.name + ); + } + } -async fn download_file(url: &str, path: &Path, expected_sha256: &str) -> Result<()> { - let dir = path + // prepare for download + let dir = model_path .parent() .ok_or_else(|| anyhow!("Must not be in root directory"))?; fs::create_dir_all(dir)?; - let filename = path - .to_str() - .ok_or_else(|| anyhow!("Could not convert filename to UTF-8"))?; - let intermediate_filename = filename.to_owned() + ".tmp"; + let mut urls_sha256 = vec![]; + if urls.len() > 1 { + urls_sha256.extend(model_info.urls_sha256.clone().unwrap()); + } else { + urls_sha256.push(model_info.sha256.clone()); + } - let mut bar = WrappedBar::new(0, url, false); + let mut download_tasks = vec![]; + for (url, sha256) in urls.iter().zip(urls_sha256.iter()) { + let dir = registry + .get_model_store_dir(name) + .to_string_lossy() + .into_owned(); + let filename = if urls.len() == 1 { + Some(model_info.entrypoint.clone()) + } else { + None + }; + let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2); + download_tasks.push(Retry::spawn(strategy, move || { + let dir = dir.clone(); + let filename = filename.clone(); + download_file(url, dir, filename, sha256) + })); + } + + let results = join_all(download_tasks).await; + + // handle errors + let errors: Vec = results + .into_iter() + .filter_map(|e| if let Err(e) = e { Some(e) } else { None }) + .collect(); + if errors.is_empty() { + Ok(()) + } else { + let combined_error = errors + .into_iter() + .fold(anyhow::anyhow!("Multiple errors occurred"), |acc, err| { + acc.context(err) + }); + Err(combined_error) + } +} +async fn tryget_download_filename(url: &str) -> Result { + // try to get filename from Content-Disposition header + let response = HTTPSHandler::head(url).await?; + if let Some(content_disposition) = response.get(reqwest::header::CONTENT_DISPOSITION) { + if let Ok(disposition_str) = content_disposition.to_str() { + let re = Regex::new(r#"filename="(.+?)""#).unwrap(); + let file_name = re + .captures(disposition_str) + .and_then(|cap| cap.get(1)) + .map(|m| m.as_str().to_owned()); + if let Some(file_name) = file_name { + return Ok(file_name); + } + } + } + // try to parse filename from URL + if let Some(parsed_name) = Path::new(url).file_name() { + let parsed_name = parsed_name.to_string_lossy().to_string(); + if parsed_name.is_empty() { + Err(anyhow!("Failed to get filename from URL {}", url)) + } else { + Ok(parsed_name) + } + } else { + Err(anyhow!("Failed to get filename from URL {}", url)) + } +} + +async fn download_file( + url: &str, + dir: String, + filename: Option, + expected_sha256: &str, +) -> Result<()> { + let filename = filename.unwrap_or(tryget_download_filename(url).await?); + let fullpath = format! {"{}{}{}", dir,std::path::MAIN_SEPARATOR ,filename}; + let intermediate_filename = fullpath.clone() + ".tmp"; + let mut bar = WrappedBar::new(0, url, false); if let Err(e) = https::HTTPSHandler::get(url, &intermediate_filename, &mut bar, expected_sha256).await { @@ -97,7 +200,7 @@ async fn download_file(url: &str, path: &Path, expected_sha256: &str) -> Result< } } - fs::rename(intermediate_filename, filename)?; + fs::rename(intermediate_filename, fullpath)?; Ok(()) } @@ -111,3 +214,56 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) { .await .unwrap_or_else(handler) } + +#[cfg(test)] +mod tests { + use tabby_common::registry::ModelInfo; + + use super::*; + #[test] + fn test_filter_download_urls() { + // multiple urls + let model_info = ModelInfo { + name: "test".to_string(), + urls: Some(vec![ + "https://huggingface.co/test".to_string(), + "https://huggingface.co/test2".to_string(), + "https://modelscope.co/test2".to_string(), + ]), + urls_sha256: Some(vec!["test_sha256".to_string(), "test2_sha256".to_string()]), + entrypoint: "test".to_string(), + sha256: "test_sha256".to_string(), + prompt_template: None, + chat_template: None, + }; + let urls = super::filter_download_urls(&model_info); + assert_eq!(urls.len(), 2); + assert_eq!(urls[0], "https://huggingface.co/test"); + assert_eq!(urls[1], "https://huggingface.co/test2"); + + // single url + let model_info = ModelInfo { + name: "test".to_string(), + urls: Some(vec![ + "https://huggingface.co/test".to_string(), + "https://modelscope.co/test2".to_string(), + ]), + urls_sha256: None, + entrypoint: "model.gguf".to_string(), + sha256: "test_sha256".to_string(), + prompt_template: None, + chat_template: None, + }; + let urls = super::filter_download_urls(&model_info); + assert_eq!(urls.len(), 1); + assert_eq!(urls[0], "https://huggingface.co/test"); + } + + #[tokio::test] + async fn test_tryget_download_filename() { + let url = "https://huggingface.co/TabbyML/models/resolve/main/.gitattributes"; + let result = tryget_download_filename(url).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), ".gitattributes"); + } +}