Skip to content

Commit

Permalink
using structs and impl for llm model
Browse files Browse the repository at this point in the history
  • Loading branch information
randommm committed Nov 24, 2023
1 parent c140ff6 commit a19f86f
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 149 deletions.
29 changes: 11 additions & 18 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod utils;
use super::routes::SlackOAuthToken;
use std::thread;
use tokio::runtime::Handle;
use utils::{print_stats, run_model_iteraction, start_model};
use utils::Model;

use sqlx::SqlitePool;

Expand All @@ -19,20 +19,21 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth
let seed = None;
let repeat_penalty = 1.1;
let repeat_last_n = 64;
print_stats(temperature, repeat_penalty, repeat_last_n);
loop {
let res = thread::scope(|s| {
s.spawn(|| {
thread_priority::set_current_thread_priority(
thread_priority::ThreadPriority::Min,
)
.unwrap_or_default();
let (mut model_weights, tokenizer, mut logits_processor) =
start_model(temperature, top_p, seed)
.map_err(|e| {
println!("Failed to start model:\n{e}");
})
.unwrap();
let mut llm_model = Model::start_model(
temperature,
top_p,
seed,
sample_len,
repeat_penalty,
repeat_last_n,
)?;

loop {
// async task to select a task from the queue
Expand All @@ -50,16 +51,8 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth
.map_err(|e| format!("Failed to get session state: {e}"))
})?;

let (next_pre_prompt_tokens, generated_text) = run_model_iteraction(
prompt_str,
&mut model_weights,
&tokenizer,
&mut logits_processor,
pre_prompt_tokens,
sample_len,
repeat_penalty,
repeat_last_n,
)?;
let (next_pre_prompt_tokens, generated_text) =
llm_model.run_model_iteraction(prompt_str, pre_prompt_tokens)?;

let encoded: Vec<u8> = bincode::serialize(&next_pre_prompt_tokens)
.map_err(|e| format!("Failed to encode model {e}"))?;
Expand Down
282 changes: 151 additions & 131 deletions src/llm/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,145 +29,165 @@ pub fn print_stats(temperature: Option<f64>, repeat_penalty: f32, repeat_last_n:
);
}

pub fn start_model(
temperature: Option<f64>,
top_p: Option<f64>,
seed: Option<u64>,
) -> Result<(ModelWeights, Tokenizer, LogitsProcessor), Box<dyn std::error::Error + Send + Sync>> {
let repo = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF";
// let filename = "mistral-7b-instruct-v0.1.Q5_K_M.gguf";
let filename = "mistral-7b-instruct-v0.1.Q2_K.gguf";

let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
let model_path = api.get(filename)?;

let mut file = std::fs::File::open(model_path)?;
let start = std::time::Instant::now();

let model_weights = {
let model = gguf_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file)?
};
println!("model built");

let api = hf_hub::api::sync::Api::new()?;
let repo = "mistralai/Mistral-7B-v0.1";
let api = api.model(repo.to_string());
let tokenizer_path = api.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_path)
.map_err(|e| format!("Error loading tokenizer: {e}"))?;

let logits_processor = LogitsProcessor::new(
#[allow(clippy::unnecessary_literal_unwrap)]
seed.unwrap_or_else(|| {
let seed = std::time::SystemTime::now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
println!("Using {} as LogitsProcessor RNG seed", seed);
seed
}),
temperature,
top_p,
);
println!("Starting LLM model");

Ok((model_weights, tokenizer, logits_processor))
}

#[allow(clippy::too_many_arguments)]
pub fn run_model_iteraction(
prompt_str: String,
model_weights: &mut ModelWeights,
tokenizer: &Tokenizer,
logits_processor: &mut LogitsProcessor,
pre_prompt_tokens: Vec<u32>,
pub struct Model {
model_weights: ModelWeights,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
sample_len: usize,
repeat_penalty: f32,
repeat_last_n: usize,
) -> Result<(Vec<u32>, String), Box<dyn std::error::Error + Send + Sync>> {
let prompt_str = format!("[INST] {prompt_str} [/INST]");
// print!("{}", &prompt_str);
let tokens = tokenizer
.encode(prompt_str, true)
.map_err(|e| format!("Error encoding tokenizer: {e}"))?;

let prompt_tokens = [pre_prompt_tokens, tokens.get_ids().to_owned()].concat();
let mut to_sample = sample_len.saturating_sub(1);
let prompt_tokens = if prompt_tokens.len() + to_sample > model::MAX_SEQ_LEN - 10 {
let to_remove = prompt_tokens.len() + to_sample + 10 - model::MAX_SEQ_LEN;
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
} else {
prompt_tokens
};
let mut all_tokens = vec![];

let device = Device::Cpu;
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model_weights.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
let mut generated_text = String::new();
extract_token(next_token, tokenizer, &mut generated_text);

let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();

let start_post_prompt = std::time::Instant::now();
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model_weights.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if repeat_penalty == 1. {
logits
}

impl Model {
pub fn start_model(
temperature: Option<f64>,
top_p: Option<f64>,
seed: Option<u64>,
sample_len: usize,
repeat_penalty: f32,
repeat_last_n: usize,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
print_stats(temperature, repeat_penalty, repeat_last_n);

let repo = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF";
// let filename = "mistral-7b-instruct-v0.1.Q5_K_M.gguf";
let filename = "mistral-7b-instruct-v0.1.Q2_K.gguf";

let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
let model_path = api.get(filename)?;

let mut file = std::fs::File::open(model_path)?;
let start = std::time::Instant::now();

let model_weights = {
let model = gguf_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file)?
};
println!("model built");

let api = hf_hub::api::sync::Api::new()?;
let repo = "mistralai/Mistral-7B-v0.1";
let api = api.model(repo.to_string());
let tokenizer_path = api.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_path)
.map_err(|e| format!("Error loading tokenizer: {e}"))?;

let logits_processor = LogitsProcessor::new(
#[allow(clippy::unnecessary_literal_unwrap)]
seed.unwrap_or_else(|| {
let seed = std::time::SystemTime::now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
println!("Using {} as LogitsProcessor RNG seed", seed);
seed
}),
temperature,
top_p,
);
println!("Starting LLM model");

Ok(Self {
model_weights,
tokenizer,
logits_processor,
sample_len,
repeat_penalty,
repeat_last_n,
})
}

pub fn run_model_iteraction(
&mut self,
prompt_str: String,
pre_prompt_tokens: Vec<u32>,
) -> Result<(Vec<u32>, String), Box<dyn std::error::Error + Send + Sync>> {
let prompt_str = format!("[INST] {prompt_str} [/INST]");
// print!("{}", &prompt_str);
let tokens = self
.tokenizer
.encode(prompt_str, true)
.map_err(|e| format!("Error encoding tokenizer: {e}"))?;

let prompt_tokens = [pre_prompt_tokens, tokens.get_ids().to_owned()].concat();
let mut to_sample = self.sample_len.saturating_sub(1);
let prompt_tokens = if prompt_tokens.len() + to_sample > model::MAX_SEQ_LEN - 10 {
let to_remove = prompt_tokens.len() + to_sample + 10 - model::MAX_SEQ_LEN;
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
} else {
let start_at = all_tokens.len().saturating_sub(repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
repeat_penalty,
&all_tokens[start_at..],
)?
prompt_tokens
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
extract_token(next_token, tokenizer, &mut generated_text);
if next_token == eos_token {
to_sample = index + 1;
break;
let mut all_tokens = vec![];

let device = Device::Cpu;
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = self.model_weights.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
self.logits_processor.sample(&logits)?
};
}
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
prompt_tokens.len(),
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{:4} tokens generated: {:.2} token/s",
to_sample,
to_sample as f64 / dt.as_secs_f64(),
);
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
let mut generated_text = String::new();
extract_token(next_token, &self.tokenizer, &mut generated_text);

let eos_token = *self.tokenizer.get_vocab(true).get("</s>").unwrap();

let start_post_prompt = std::time::Instant::now();
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = self
.model_weights
.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = self.logits_processor.sample(&logits)?;
all_tokens.push(next_token);
extract_token(next_token, &self.tokenizer, &mut generated_text);
if next_token == eos_token {
to_sample = index + 1;
break;
};
}
let dt = start_post_prompt.elapsed();
println!(
"{:4} prompt tokens processed: {:.2} token/s",
prompt_tokens.len(),
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{:4} tokens generated: {:.2} token/s",
to_sample,
to_sample as f64 / dt.as_secs_f64(),
);

let next_pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat();
let next_pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat();

Ok((next_pre_prompt_tokens, generated_text))
Ok((next_pre_prompt_tokens, generated_text))
}
}

fn extract_token(next_token: u32, tokenizer: &Tokenizer, output: &mut String) {
Expand Down

0 comments on commit a19f86f

Please sign in to comment.