Skip to content

Commit

Permalink
use log instead of println!
Browse files Browse the repository at this point in the history
  • Loading branch information
randommm committed Mar 3, 2024
1 parent 6ed8fb2 commit 24a1de8
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 34 deletions.
51 changes: 51 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ tokenizers = "0.15"
thread-priority = "0.16"
regex = "1.5"
chrono = "0.4.33"
log = "0.4"
pretty_env_logger = "0.5"

[profile.dev.package."*"]
opt-level = 3
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
mod llm;
mod routes;
use log::info;
use tokio::net::TcpListener;

pub async fn run() -> Result<(), Box<dyn std::error::Error>> {
let app = routes::create_routes().await?;
let bind_addr = &"0.0.0.0:51005";
println!("Listening on: http://localhost:51005");
info!("Listening on: http://localhost:51005");
let listener = TcpListener::bind(bind_addr)
.await
.map_err(|e| format!("Failed to parse address: {}", e))?;
Expand Down
21 changes: 8 additions & 13 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
mod utils;
use super::routes::SlackOAuthToken;
use log::error;
use reqwest::{header::AUTHORIZATION, multipart};
use sqlx::SqlitePool;
use std::thread;
use std::time::SystemTime;
use tokio::runtime::Handle;
use utils::Model;

use sqlx::SqlitePool;

use reqwest::{header::AUTHORIZATION, multipart};

use std::time::SystemTime;

pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuthToken) {
let async_handle = Handle::current();
thread::spawn(move || {
Expand Down Expand Up @@ -86,7 +84,7 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
})
.unwrap_or_else(|e| {
println!("Failed to save model state:\n{e}");
error!("Failed to save model state:\n{e}");
});

async_handle
Expand Down Expand Up @@ -114,7 +112,7 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
})
.unwrap_or_else(|e| {
println!("Failed to send user message:\n{e}");
error!("Failed to send user message:\n{e}");
});
}

Expand All @@ -123,7 +121,7 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth
})
.join()
});
println!("LLM worker thread exited with message: {res:?}, restarting in 5 seconds");
error!("LLM worker thread exited with message: {res:?}, restarting in 5 seconds");
thread::sleep(std::time::Duration::from_secs(5));
}
});
Expand Down Expand Up @@ -169,10 +167,7 @@ async fn get_next_task(
break (task_id, prompt_str, channel, thread_ts);
}
}
Err(_) => {
// println!("No work to do, sleeping {e} {now}");
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await
}
Err(_) => tokio::time::sleep(tokio::time::Duration::from_secs(1)).await,
}
};
Ok((task_id, prompt_str, channel, thread_ts))
Expand Down
19 changes: 10 additions & 9 deletions src/llm/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@ use candle::quantized::gguf_file;
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::quantized_llama as model;
use log::{info, trace};
use model::ModelWeights;
use tokenizers::Tokenizer;

pub fn print_stats(temperature: Option<f64>, repeat_penalty: f32, repeat_last_n: usize) {
println!(
info!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
info!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
{
#[allow(clippy::unnecessary_literal_unwrap)]
Expand Down Expand Up @@ -71,15 +72,15 @@ impl Model {
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
info!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file, &device)?
};
println!("model built");
info!("model built");

let api = hf_hub::api::sync::Api::new()?;
let repo = "mistralai/Mixtral-8x7B-v0.1";
Expand All @@ -95,13 +96,13 @@ impl Model {
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
println!("Using {} as LogitsProcessor RNG seed", seed);
info!("Using {} as LogitsProcessor RNG seed", seed);
seed
}),
temperature,
top_p,
);
println!("Starting LLM model");
info!("Starting LLM model");

Ok(Self {
model_weights,
Expand All @@ -120,7 +121,7 @@ impl Model {
pre_prompt_tokens: Vec<u32>,
) -> Result<(Vec<u32>, String), Box<dyn std::error::Error + Send + Sync>> {
let prompt_str = format!("[INST] {prompt_str} [/INST]");
// print!("{}", &prompt_str);
trace!("{}", &prompt_str);
let tokens = self
.tokenizer
.encode(prompt_str, true)
Expand Down Expand Up @@ -176,12 +177,12 @@ impl Model {
};
}
let dt = start_post_prompt.elapsed();
println!(
info!(
"{:4} prompt tokens processed: {:.2} token/s",
prompt_tokens.len(),
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
info!(
"{:4} tokens generated: {:.2} token/s",
to_sample,
to_sample as f64 / dt.as_secs_f64(),
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ use rust_slackbot_llm::run;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
pretty_env_logger::init();
run().await
}
3 changes: 2 additions & 1 deletion src/routes/error_handling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use axum::{
http::StatusCode,
response::{IntoResponse, Json},
};
use log::error;
use serde::{Deserialize, Serialize};
use uuid::Uuid;

Expand Down Expand Up @@ -51,7 +52,7 @@ impl IntoResponse for AppError {

let internal_message =
format!("Error id: {}. Message: {}", error_id, self.internal_message);
println!("{}. AppError: {}", error_id, internal_message);
error!("{}. AppError: {}", error_id, internal_message);

let user_message = self.user_message.unwrap_or("Server error".to_owned());

Expand Down
18 changes: 8 additions & 10 deletions src/routes/pages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use axum::{
response::IntoResponse,
Json,
};
use chrono::Local;
use hmac::{Hmac, Mac};
use log::{error, info, trace};
use regex::Regex;
use reqwest::{header::AUTHORIZATION, multipart};
use serde_json::Value;
Expand All @@ -24,8 +24,8 @@ pub async fn get_slack_events(
body: String,
) -> Result<impl IntoResponse, AppError> {
if PRINT_SLACK_EVENTS {
println!("Slack event body: {:?}", body);
println!("Slack event headers: {:?}", headers);
info!("Slack event body: {:?}", body);
info!("Slack event headers: {:?}", headers);
}

let provided_timestamp = headers
Expand All @@ -50,7 +50,7 @@ pub async fn get_slack_events(
.map(|i| u8::from_str_radix(i.iter().collect::<String>().as_str(), 16))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| format!("Invalid hex on proposed slack signature: {e}"))?;
// println!("provided_signature: {:?}", provided_signature);
trace!("provided_signature: {:?}", provided_signature);

// check timestamp
let now = SystemTime::now()
Expand Down Expand Up @@ -87,8 +87,7 @@ pub async fn try_process_slack_events(
let value = process_slack_events(slack_oauth_token, db_pool, &query).await;

if let Err(ref value) = value {
print!("{}: ", Local::now());
println!(
error!(
"failed to process Slack event.\nGot error:\n{:?}\nGot payload:{:?} ",
value, query
);
Expand Down Expand Up @@ -129,8 +128,7 @@ async fn process_slack_events(
.unwrap_or(x.to_owned()),
None => "unknown".to_owned(),
};
print!("{}: ", Local::now());
print!("from user {user} at channel {channel} and type {type_}, received message: {text}. ");
info!("from user {user} at channel {channel} and type {type_}, received message: {text}. ");

let thread_ts = event.get("thread_ts");
let thread_ts = match thread_ts {
Expand All @@ -149,7 +147,7 @@ async fn process_slack_events(
_ => text.trim().to_owned(),
};

println!("Processed message: {text}.");
info!("Processed message: {text}.");

let reqw_client = reqwest::Client::new();

Expand Down Expand Up @@ -307,7 +305,7 @@ pub async fn plot_random_stuff(
let reqw_response: Value = serde_json::from_str(&reqw_response)
.map_err(|e| format!("Could not parse response body: {e}"))?;
if PRINT_SLACK_EVENTS {
println!("Received send plot response {:?}", reqw_response);
info!("Received send plot response {:?}", reqw_response);
}
Ok(())
}
Expand Down

0 comments on commit 24a1de8

Please sign in to comment.