From a19f86f392bb1025d68b39030023f9790ee58921 Mon Sep 17 00:00:00 2001 From: Marco Inacio Date: Fri, 24 Nov 2023 19:12:06 +0000 Subject: [PATCH] using structs and impl for llm model --- src/llm/mod.rs | 29 ++--- src/llm/utils.rs | 282 +++++++++++++++++++++++++---------------------- 2 files changed, 162 insertions(+), 149 deletions(-) diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 6ef9db2..7c68208 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -2,7 +2,7 @@ mod utils; use super::routes::SlackOAuthToken; use std::thread; use tokio::runtime::Handle; -use utils::{print_stats, run_model_iteraction, start_model}; +use utils::Model; use sqlx::SqlitePool; @@ -19,7 +19,6 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth let seed = None; let repeat_penalty = 1.1; let repeat_last_n = 64; - print_stats(temperature, repeat_penalty, repeat_last_n); loop { let res = thread::scope(|s| { s.spawn(|| { @@ -27,12 +26,14 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth thread_priority::ThreadPriority::Min, ) .unwrap_or_default(); - let (mut model_weights, tokenizer, mut logits_processor) = - start_model(temperature, top_p, seed) - .map_err(|e| { - println!("Failed to start model:\n{e}"); - }) - .unwrap(); + let mut llm_model = Model::start_model( + temperature, + top_p, + seed, + sample_len, + repeat_penalty, + repeat_last_n, + )?; loop { // async task to select a task from the queue @@ -50,16 +51,8 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth .map_err(|e| format!("Failed to get session state: {e}")) })?; - let (next_pre_prompt_tokens, generated_text) = run_model_iteraction( - prompt_str, - &mut model_weights, - &tokenizer, - &mut logits_processor, - pre_prompt_tokens, - sample_len, - repeat_penalty, - repeat_last_n, - )?; + let (next_pre_prompt_tokens, generated_text) = + llm_model.run_model_iteraction(prompt_str, pre_prompt_tokens)?; let encoded: Vec = bincode::serialize(&next_pre_prompt_tokens) .map_err(|e| format!("Failed to encode model {e}"))?; diff --git a/src/llm/utils.rs b/src/llm/utils.rs index 9c5cc55..a641858 100644 --- a/src/llm/utils.rs +++ b/src/llm/utils.rs @@ -29,145 +29,165 @@ pub fn print_stats(temperature: Option, repeat_penalty: f32, repeat_last_n: ); } -pub fn start_model( - temperature: Option, - top_p: Option, - seed: Option, -) -> Result<(ModelWeights, Tokenizer, LogitsProcessor), Box> { - let repo = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF"; - // let filename = "mistral-7b-instruct-v0.1.Q5_K_M.gguf"; - let filename = "mistral-7b-instruct-v0.1.Q2_K.gguf"; - - let api = hf_hub::api::sync::Api::new()?; - let api = api.model(repo.to_string()); - let model_path = api.get(filename)?; - - let mut file = std::fs::File::open(model_path)?; - let start = std::time::Instant::now(); - - let model_weights = { - let model = gguf_file::Content::read(&mut file)?; - let mut total_size_in_bytes = 0; - for (_, tensor) in model.tensor_infos.iter() { - let elem_count = tensor.shape.elem_count(); - total_size_in_bytes += - elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size(); - } - println!( - "loaded {:?} tensors ({}) in {:.2}s", - model.tensor_infos.len(), - &format_size(total_size_in_bytes), - start.elapsed().as_secs_f32(), - ); - ModelWeights::from_gguf(model, &mut file)? - }; - println!("model built"); - - let api = hf_hub::api::sync::Api::new()?; - let repo = "mistralai/Mistral-7B-v0.1"; - let api = api.model(repo.to_string()); - let tokenizer_path = api.get("tokenizer.json")?; - let tokenizer = Tokenizer::from_file(tokenizer_path) - .map_err(|e| format!("Error loading tokenizer: {e}"))?; - - let logits_processor = LogitsProcessor::new( - #[allow(clippy::unnecessary_literal_unwrap)] - seed.unwrap_or_else(|| { - let seed = std::time::SystemTime::now() - .duration_since(std::time::SystemTime::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - println!("Using {} as LogitsProcessor RNG seed", seed); - seed - }), - temperature, - top_p, - ); - println!("Starting LLM model"); - - Ok((model_weights, tokenizer, logits_processor)) -} - -#[allow(clippy::too_many_arguments)] -pub fn run_model_iteraction( - prompt_str: String, - model_weights: &mut ModelWeights, - tokenizer: &Tokenizer, - logits_processor: &mut LogitsProcessor, - pre_prompt_tokens: Vec, +pub struct Model { + model_weights: ModelWeights, + tokenizer: Tokenizer, + logits_processor: LogitsProcessor, sample_len: usize, repeat_penalty: f32, repeat_last_n: usize, -) -> Result<(Vec, String), Box> { - let prompt_str = format!("[INST] {prompt_str} [/INST]"); - // print!("{}", &prompt_str); - let tokens = tokenizer - .encode(prompt_str, true) - .map_err(|e| format!("Error encoding tokenizer: {e}"))?; - - let prompt_tokens = [pre_prompt_tokens, tokens.get_ids().to_owned()].concat(); - let mut to_sample = sample_len.saturating_sub(1); - let prompt_tokens = if prompt_tokens.len() + to_sample > model::MAX_SEQ_LEN - 10 { - let to_remove = prompt_tokens.len() + to_sample + 10 - model::MAX_SEQ_LEN; - prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec() - } else { - prompt_tokens - }; - let mut all_tokens = vec![]; - - let device = Device::Cpu; - let start_prompt_processing = std::time::Instant::now(); - let mut next_token = { - let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; - let logits = model_weights.forward(&input, 0)?; - let logits = logits.squeeze(0)?; - logits_processor.sample(&logits)? - }; - let prompt_dt = start_prompt_processing.elapsed(); - all_tokens.push(next_token); - let mut generated_text = String::new(); - extract_token(next_token, tokenizer, &mut generated_text); - - let eos_token = *tokenizer.get_vocab(true).get("").unwrap(); - - let start_post_prompt = std::time::Instant::now(); - for index in 0..to_sample { - let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; - let logits = model_weights.forward(&input, prompt_tokens.len() + index)?; - let logits = logits.squeeze(0)?; - let logits = if repeat_penalty == 1. { - logits +} + +impl Model { + pub fn start_model( + temperature: Option, + top_p: Option, + seed: Option, + sample_len: usize, + repeat_penalty: f32, + repeat_last_n: usize, + ) -> Result> { + print_stats(temperature, repeat_penalty, repeat_last_n); + + let repo = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF"; + // let filename = "mistral-7b-instruct-v0.1.Q5_K_M.gguf"; + let filename = "mistral-7b-instruct-v0.1.Q2_K.gguf"; + + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(repo.to_string()); + let model_path = api.get(filename)?; + + let mut file = std::fs::File::open(model_path)?; + let start = std::time::Instant::now(); + + let model_weights = { + let model = gguf_file::Content::read(&mut file)?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensor_infos.iter() { + let elem_count = tensor.shape.elem_count(); + total_size_in_bytes += + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensor_infos.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + ModelWeights::from_gguf(model, &mut file)? + }; + println!("model built"); + + let api = hf_hub::api::sync::Api::new()?; + let repo = "mistralai/Mistral-7B-v0.1"; + let api = api.model(repo.to_string()); + let tokenizer_path = api.get("tokenizer.json")?; + let tokenizer = Tokenizer::from_file(tokenizer_path) + .map_err(|e| format!("Error loading tokenizer: {e}"))?; + + let logits_processor = LogitsProcessor::new( + #[allow(clippy::unnecessary_literal_unwrap)] + seed.unwrap_or_else(|| { + let seed = std::time::SystemTime::now() + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + println!("Using {} as LogitsProcessor RNG seed", seed); + seed + }), + temperature, + top_p, + ); + println!("Starting LLM model"); + + Ok(Self { + model_weights, + tokenizer, + logits_processor, + sample_len, + repeat_penalty, + repeat_last_n, + }) + } + + pub fn run_model_iteraction( + &mut self, + prompt_str: String, + pre_prompt_tokens: Vec, + ) -> Result<(Vec, String), Box> { + let prompt_str = format!("[INST] {prompt_str} [/INST]"); + // print!("{}", &prompt_str); + let tokens = self + .tokenizer + .encode(prompt_str, true) + .map_err(|e| format!("Error encoding tokenizer: {e}"))?; + + let prompt_tokens = [pre_prompt_tokens, tokens.get_ids().to_owned()].concat(); + let mut to_sample = self.sample_len.saturating_sub(1); + let prompt_tokens = if prompt_tokens.len() + to_sample > model::MAX_SEQ_LEN - 10 { + let to_remove = prompt_tokens.len() + to_sample + 10 - model::MAX_SEQ_LEN; + prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec() } else { - let start_at = all_tokens.len().saturating_sub(repeat_last_n); - candle_transformers::utils::apply_repeat_penalty( - &logits, - repeat_penalty, - &all_tokens[start_at..], - )? + prompt_tokens }; - next_token = logits_processor.sample(&logits)?; - all_tokens.push(next_token); - extract_token(next_token, tokenizer, &mut generated_text); - if next_token == eos_token { - to_sample = index + 1; - break; + let mut all_tokens = vec![]; + + let device = Device::Cpu; + let start_prompt_processing = std::time::Instant::now(); + let mut next_token = { + let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; + let logits = self.model_weights.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + self.logits_processor.sample(&logits)? }; - } - let dt = start_post_prompt.elapsed(); - println!( - "\n\n{:4} prompt tokens processed: {:.2} token/s", - prompt_tokens.len(), - prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(), - ); - println!( - "{:4} tokens generated: {:.2} token/s", - to_sample, - to_sample as f64 / dt.as_secs_f64(), - ); + let prompt_dt = start_prompt_processing.elapsed(); + all_tokens.push(next_token); + let mut generated_text = String::new(); + extract_token(next_token, &self.tokenizer, &mut generated_text); + + let eos_token = *self.tokenizer.get_vocab(true).get("").unwrap(); + + let start_post_prompt = std::time::Instant::now(); + for index in 0..to_sample { + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; + let logits = self + .model_weights + .forward(&input, prompt_tokens.len() + index)?; + let logits = logits.squeeze(0)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = all_tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &all_tokens[start_at..], + )? + }; + next_token = self.logits_processor.sample(&logits)?; + all_tokens.push(next_token); + extract_token(next_token, &self.tokenizer, &mut generated_text); + if next_token == eos_token { + to_sample = index + 1; + break; + }; + } + let dt = start_post_prompt.elapsed(); + println!( + "{:4} prompt tokens processed: {:.2} token/s", + prompt_tokens.len(), + prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(), + ); + println!( + "{:4} tokens generated: {:.2} token/s", + to_sample, + to_sample as f64 / dt.as_secs_f64(), + ); - let next_pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat(); + let next_pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat(); - Ok((next_pre_prompt_tokens, generated_text)) + Ok((next_pre_prompt_tokens, generated_text)) + } } fn extract_token(next_token: u32, tokenizer: &Tokenizer, output: &mut String) {