Skip to content

Commit

Permalink
added a unit test + some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
randommm committed Mar 9, 2024
1 parent 24a1de8 commit 023c21a
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 66 deletions.
65 changes: 10 additions & 55 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod utils;
use super::routes::pages::send_user_message;
use super::routes::SlackOAuthToken;
use log::error;
use reqwest::{header::AUTHORIZATION, multipart};
Expand Down Expand Up @@ -87,32 +88,17 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth
error!("Failed to save model state:\n{e}");
});

let reply_to_user = "Reply from the LLM:\n".to_owned()
+ &generated_text[1..generated_text.len() - 4];
async_handle
.block_on(async {
let reply_to_user = "Reply from the LLM:\n".to_owned()
+ &generated_text[1..generated_text.len() - 4];

let form = multipart::Form::new()
.text("text", reply_to_user)
.text("channel", channel.to_owned())
.text("thread_ts", thread_ts.clone());

let reqw_response = reqwest::Client::new()
.post("https://slack.com/api/chat.postMessage")
.header(
AUTHORIZATION,
format!("Bearer {}", slack_oauth_token.0),
)
.multipart(form)
.send()
.await?;
reqw_response.text().await.map_err(|e| {
format!("Failed to read reqwest response body: {e}")
})?;
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
})
.block_on(send_user_message(
&slack_oauth_token,
channel,
thread_ts,
reply_to_user,
))
.unwrap_or_else(|e| {
error!("Failed to send user message:\n{e}");
error!("{:?}", e);
});
}

Expand Down Expand Up @@ -227,34 +213,3 @@ async fn get_session_state(

Ok(pre_prompt_tokens)
}

// #[cfg(test)]
// mod tests {
// use super::ModelBuilder;
// use tokio::sync::oneshot;

// #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
// async fn sequential_dialog() {
// let model = ModelBuilder {
// sample_len: 30,
// ..Default::default()
// }
// .build()
// .unwrap();
// let tx = model.run().await;

// let prompt = "Create a Rust program in 20 words".to_string();
// let pre_prompt_tokens = vec![];

// let (oneshot_tx, oneshot_rx) = oneshot::channel();
// tx.send((prompt, pre_prompt_tokens, oneshot_tx)).unwrap();
// let (output, pre_prompt_tokens) = oneshot_rx.await.unwrap();
// println!("{output}");

// let prompt = "Give me the Cargo.toml in 20 words".to_string();
// let (oneshot_tx, oneshot_rx) = oneshot::channel();
// tx.send((prompt, pre_prompt_tokens, oneshot_tx)).unwrap();
// let (output, _) = oneshot_rx.await.unwrap();
// println!("{output}");
// }
// }
44 changes: 44 additions & 0 deletions src/llm/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,47 @@ fn format_size(size_in_bytes: usize) -> String {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn sequential_dialog() {
let sample_len = 10_usize;
let temperature = Some(0.8);
let top_p = None;
let seed = Some(1);
let repeat_penalty = 1.1;
let repeat_last_n = 64;

let mut llm_model = Model::start_model(
temperature,
top_p,
seed,
sample_len,
repeat_penalty,
repeat_last_n,
)
.unwrap();

let prompt = "Create a Rust program, just code, no explanation".to_string();
let pre_prompt_tokens = vec![];

let (next_prompt_tokens, generated_text) = llm_model
.run_model_iteraction(prompt, pre_prompt_tokens)
.unwrap();

assert_eq!(generated_text, " ```rust\nuse std::fs;\n\n");

let prompt = "Explain previous code.".to_string();

let (_, generated_text) = llm_model
.run_model_iteraction(prompt, next_prompt_tokens)
.unwrap();
assert_eq!(
generated_text,
" The previous code is written in the Rust programming"
);
}
}
5 changes: 4 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ use rust_slackbot_llm::run;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
pretty_env_logger::init();
if std::env::var("RUST_LOG").is_err() {
std::env::set_var("RUST_LOG", "rust_slackbot_llm=info");
}
pretty_env_logger::init_timed();
run().await
}
4 changes: 2 additions & 2 deletions src/routes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod error_handling;
mod pages;
pub(crate) mod pages;
use super::llm::start_llm_worker;
use axum::{
extract::FromRef,
Expand Down Expand Up @@ -50,7 +50,7 @@ pub async fn create_routes() -> Result<Router, Box<dyn std::error::Error>> {
};

let api = Router::new()
.route("/slack_events", post(pages::get_slack_events))
.route("/slack_events", post(pages::receive_slack_events))
.route("/", get(pages::index))
.with_state(app_state.clone());

Expand Down
30 changes: 22 additions & 8 deletions src/routes/pages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::time::SystemTime;

const PRINT_SLACK_EVENTS: bool = false;

pub async fn get_slack_events(
pub async fn receive_slack_events(
State(db_pool): State<SqlitePool>,
State(slack_signing_secret): State<SlackSigningSecret>,
State(slack_oauth_token): State<SlackOAuthToken>,
Expand Down Expand Up @@ -149,8 +149,6 @@ async fn process_slack_events(

info!("Processed message: {text}.");

let reqw_client = reqwest::Client::new();

let reply_to_user = if text == "delete" || text == "\"delete\"" {
let _ = sqlx::query("DELETE FROM sessions WHERE channel = $1 AND thread_ts = $2")
.bind(channel)
Expand Down Expand Up @@ -198,12 +196,29 @@ async fn process_slack_events(
initial_message
};

send_user_message(
&slack_oauth_token,
channel.to_owned(),
thread_ts.to_owned(),
reply_to_user,
)
.await?;

Ok(())
}

pub async fn send_user_message(
slack_oauth_token: &SlackOAuthToken,
channel: String,
thread_ts: String,
text: String,
) -> Result<(), AppError> {
let form = multipart::Form::new()
.text("text", reply_to_user)
.text("channel", channel.to_owned())
.text("thread_ts", thread_ts.to_owned());
.text("text", text)
.text("channel", channel)
.text("thread_ts", thread_ts);

let reqw_response = reqw_client
let reqw_response = reqwest::Client::new()
.post("https://slack.com/api/chat.postMessage")
.header(AUTHORIZATION, format!("Bearer {}", slack_oauth_token.0))
.multipart(form)
Expand All @@ -213,7 +228,6 @@ async fn process_slack_events(
.text()
.await
.map_err(|e| format!("Failed to read reqwest response body: {e}"))?;

Ok(())
}

Expand Down

0 comments on commit 023c21a

Please sign in to comment.