From 00e1f593bda0979a6b20c47c5715de7a01cd6e4c Mon Sep 17 00:00:00 2001 From: Marco Inacio Date: Tue, 21 Nov 2023 21:56:39 +0000 Subject: [PATCH] each session LLM on a thread --- migrations/20231021093421_initial.sql | 14 ++- migrations/20231120202003_queue.sql | 7 -- src/routes/llm.rs | 148 +++++++++++++++----------- src/routes/pages.rs | 28 +++-- 4 files changed, 121 insertions(+), 76 deletions(-) delete mode 100644 migrations/20231120202003_queue.sql diff --git a/migrations/20231021093421_initial.sql b/migrations/20231021093421_initial.sql index 940c4a4..d997bfe 100644 --- a/migrations/20231021093421_initial.sql +++ b/migrations/20231021093421_initial.sql @@ -1,8 +1,18 @@ CREATE TABLE "sessions" ( "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, - "channel" text NOT NULL UNIQUE, + "channel" text NOT NULL, + "thread_ts" text NOT NULL, "model_state" blob, "created_at" integer NOT NULL, - "updated_at" integer NOT NULL + "updated_at" integer NOT NULL, + UNIQUE(channel, thread_ts) +); +CREATE TABLE "queue" ( + "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, + "text" text NOT NULL, + "channel" text NOT NULL, + "thread_ts" text NOT NULL, + "created_at" integer NOT NULL, + "leased_at" integer NOT NULL ); PRAGMA journal_mode=WAL; diff --git a/migrations/20231120202003_queue.sql b/migrations/20231120202003_queue.sql deleted file mode 100644 index c955596..0000000 --- a/migrations/20231120202003_queue.sql +++ /dev/null @@ -1,7 +0,0 @@ -CREATE TABLE "queue" ( - "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, - "text" text NOT NULL, - "channel" text NOT NULL, - "created_at" integer NOT NULL, - "leased_at" integer NOT NULL -); diff --git a/src/routes/llm.rs b/src/routes/llm.rs index d50f4d4..3826edd 100644 --- a/src/routes/llm.rs +++ b/src/routes/llm.rs @@ -77,7 +77,7 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth repeat_last_n ); loop { - thread::scope(|s| { + let res = thread::scope(|s| { s.spawn(|| { thread_priority::set_current_thread_priority( thread_priority::ThreadPriority::Min, @@ -108,15 +108,16 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth loop { // async task to select a task from the queue - let (task_id, prompt_str, channel) = async_handle.block_on(async { - get_next_task(&db_pool) - .await - .map_err(|e| format!("Failed to get next task from queue: {e}")) - })?; + let (task_id, prompt_str, channel, thread_ts) = + async_handle.block_on(async { + get_next_task(&db_pool) + .await + .map_err(|e| format!("Failed to get next task from queue: {e}")) + })?; // async task to get the state if it exists let pre_prompt_tokens = async_handle.block_on(async { - get_session_state(&db_pool, &channel, &slack_oauth_token) + get_session_state(&db_pool, &channel, &thread_ts, &slack_oauth_token) .await .map_err(|e| format!("Failed to get session state: {e}")) })?; @@ -201,55 +202,74 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth let encoded: Vec = bincode::serialize(&next_pre_prompt_tokens) .map_err(|e| format!("Failed to encode model {e}"))?; - let _ = async_handle.block_on(async { - let now = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .map_err(|e| format!("Error: {:?}", e))? - .as_secs() as i64; - sqlx::query("DELETE FROM queue WHERE id = $1;") - .bind(task_id) - .execute(&db_pool) - .await?; - sqlx::query( - "INSERT INTO sessions - (channel, created_at, updated_at, model_state) - VALUES ($1, $2, $3, $4) - ON CONFLICT (channel) + async_handle + .block_on(async { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map_err(|e| format!("Error: {:?}", e))? + .as_secs() as i64; + sqlx::query("DELETE FROM queue WHERE id = $1;") + .bind(task_id) + .execute(&db_pool) + .await?; + sqlx::query( + "INSERT INTO sessions + (channel, thread_ts, created_at, updated_at, model_state) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (channel, thread_ts) DO UPDATE SET model_state = EXCLUDED.model_state, updated_at = EXCLUDED.updated_at;", - ) - .bind(&channel) - .bind(now) - .bind(now) - .bind(encoded) - .execute(&db_pool) - .await?; - - let reply_to_user = "Reply from the LLM:\n".to_owned() - + &generated_text[1..generated_text.len() - 4]; - - let form = multipart::Form::new() - .text("text", reply_to_user) - .text("channel", channel.to_owned()); - - let reqw_response = reqwest::Client::new() - .post("https://slack.com/api/chat.postMessage") - .header(AUTHORIZATION, format!("Bearer {}", slack_oauth_token.0)) - .multipart(form) - .send() + ) + .bind(&channel) + .bind(&thread_ts) + .bind(now) + .bind(now) + .bind(encoded) + .execute(&db_pool) .await?; - reqw_response.text().await.map_err(|e| { - format!("Failed to read reqwest response body: {e}") - })?; - Ok::<(), Box>(()) - }); + Ok::<(), Box>(()) + }) + .unwrap_or_else(|e| { + println!("Failed to save model state:\n{e}"); + }); + + async_handle + .block_on(async { + let reply_to_user = "Reply from the LLM:\n".to_owned() + + &generated_text[1..generated_text.len() - 4]; + + let form = multipart::Form::new() + .text("text", reply_to_user) + .text("channel", channel.to_owned()) + .text("thread_ts", thread_ts.clone()); + + let reqw_response = reqwest::Client::new() + .post("https://slack.com/api/chat.postMessage") + .header( + AUTHORIZATION, + format!("Bearer {}", slack_oauth_token.0), + ) + .multipart(form) + .send() + .await?; + reqw_response.text().await.map_err(|e| { + format!("Failed to read reqwest response body: {e}") + })?; + Ok::<(), Box>(()) + }) + .unwrap_or_else(|e| { + println!("Failed to send user message:\n{e}"); + }); } #[allow(unreachable_code)] Ok::<(), Box>(()) - }); + }) + .join() }); + println!("LLM worker thread exited with message: {res:?}, restarting in 5 seconds"); + thread::sleep(std::time::Duration::from_secs(5)); } }); } @@ -294,8 +314,8 @@ fn format_size(size_in_bytes: usize) -> String { async fn get_next_task( db_pool: &SqlitePool, -) -> Result<(i64, String, String), Box> { - let (task_id, prompt_str, channel) = loop { +) -> Result<(i64, String, String, String), Box> { + let (task_id, prompt_str, channel, thread_ts) = loop { let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .map_err(|e| format!("Error: {:?}", e))? @@ -303,7 +323,7 @@ async fn get_next_task( let mut tx = db_pool.begin().await?; match sqlx::query_as( " - SELECT id,text,channel FROM queue + SELECT id,text,channel,thread_ts FROM queue WHERE leased_at <= $1 ORDER BY created_at ASC LIMIT 0,1 @@ -314,7 +334,7 @@ async fn get_next_task( .await { Ok(res) => { - let (task_id, prompt_str, channel): (i64, String, String) = res; + let (task_id, prompt_str, channel, thread_ts) = res; if sqlx::query( " @@ -329,7 +349,7 @@ async fn get_next_task( .is_ok() && tx.commit().await.is_ok() { - break (task_id, prompt_str, channel); + break (task_id, prompt_str, channel, thread_ts); } } Err(_) => { @@ -338,24 +358,27 @@ async fn get_next_task( } } }; - Ok((task_id, prompt_str, channel)) + Ok((task_id, prompt_str, channel, thread_ts)) } async fn get_session_state( db_pool: &SqlitePool, channel: &str, + thread_ts: &str, slack_oauth_token: &SlackOAuthToken, ) -> Result, Box> { let mut initial_message = "Running LLM ".to_owned(); - let query: Result<(Vec,), _> = - sqlx::query_as(r#"SELECT model_state FROM sessions WHERE channel = $1;"#) - .bind(channel) - .fetch_one(db_pool) - .await; + let query: Result<(Vec,), _> = sqlx::query_as( + r#"SELECT model_state FROM sessions WHERE channel = $1 AND thread_ts = $2;"#, + ) + .bind(channel) + .bind(thread_ts) + .fetch_one(db_pool) + .await; let pre_prompt_tokens = if let Ok(query) = query { initial_message.push_str("reusing section. "); let (model_state,) = query; - let deserialized: Result, _> = bincode::deserialize(&model_state[..]); + let deserialized = bincode::deserialize(&model_state[..]); deserialized.unwrap_or_default() } else { initial_message.push_str("with new section. "); @@ -365,9 +388,11 @@ async fn get_session_state( .as_secs() as i64; sqlx::query( r#"INSERT OR IGNORE INTO - sessions (channel, created_at, updated_at) VALUES ($1, $2, $3);"#, + sessions (channel, thread_ts, created_at, updated_at) + VALUES ($1, $2, $3, $4);"#, ) .bind(channel) + .bind(thread_ts) .bind(timestamp) .bind(timestamp) .execute(db_pool) @@ -378,7 +403,8 @@ async fn get_session_state( let reqw_client = reqwest::Client::new(); let form = multipart::Form::new() .text("text", initial_message) - .text("channel", channel.to_owned()); + .text("channel", channel.to_owned()) + .text("thread_ts", thread_ts.to_owned()); tokio::spawn( reqw_client .post("https://slack.com/api/chat.postMessage") diff --git a/src/routes/pages.rs b/src/routes/pages.rs index 00ce9ce..259b3c2 100644 --- a/src/routes/pages.rs +++ b/src/routes/pages.rs @@ -129,6 +129,9 @@ async fn process_slack_events( }; print!("From user {user} at channel {channel} and type {type_}, received message: {text}. "); + let thread_ts = event.get("event_ts").ok_or("event_ts not found on query")?; + let thread_ts = thread_ts.as_str().ok_or("event_ts is not a string")?; + let text = match Regex::new(r" ?<@.*> ?") { Ok(pattern) if type_ == "app_mention" => { let text = pattern.replace_all(text, " "); @@ -142,28 +145,37 @@ async fn process_slack_events( let reqw_client = reqwest::Client::new(); let reply_to_user = if text == "delete" || text == "\"delete\"" { - let _ = sqlx::query("DELETE FROM sessions WHERE channel = $1") + let _ = sqlx::query("DELETE FROM sessions WHERE channel = $1 AND thread_ts = $2") .bind(channel) + .bind(thread_ts) .execute(&db_pool) .await; - let _ = sqlx::query("DELETE FROM queue WHERE channel = $1") + let _ = sqlx::query("DELETE FROM queue WHERE channel = $1 AND thread_ts = $2") .bind(channel) + .bind(thread_ts) .execute(&db_pool) .await; "Ok, the LLM section was deleted. A new message will start a fresh LLM section.".to_owned() } else if text == "plot" || text == "\"plot\"" { - return plot_random_stuff(channel.to_owned(), slack_oauth_token.clone()).await; + return plot_random_stuff( + channel.to_owned(), + thread_ts.to_owned(), + slack_oauth_token.clone(), + ) + .await; } else { let created_at = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .map_err(|e| format!("Error: {:?}", e))? .as_secs() as i64; sqlx::query( - "INSERT INTO queue (text, channel, created_at, leased_at) VALUES ($1, $2, $3, 0);", + "INSERT INTO queue (text, channel, thread_ts, created_at, leased_at) + VALUES ($1, $2, $3, $4, 0);", ) .bind(text) .bind(channel) + .bind(thread_ts) .bind(created_at) .execute(&db_pool) .await?; @@ -181,7 +193,9 @@ async fn process_slack_events( let form = multipart::Form::new() .text("text", reply_to_user) - .text("channel", channel.to_owned()); + .text("channel", channel.to_owned()) + .text("thread_ts", thread_ts.to_owned()); + let reqw_response = reqw_client .post("https://slack.com/api/chat.postMessage") .header(AUTHORIZATION, format!("Bearer {}", slack_oauth_token.0)) @@ -198,6 +212,7 @@ async fn process_slack_events( pub async fn plot_random_stuff( channel: String, + thread_ts: String, slack_oauth_token: SlackOAuthToken, ) -> Result<(), AppError> { let mut buffer_ = vec![0; 640 * 480 * 3]; @@ -268,7 +283,8 @@ pub async fn plot_random_stuff( let form = multipart::Form::new() .text("channels", channel) .text("title", "A plot for ya") - .part("file", part); + .part("file", part) + .text("thread_ts", thread_ts); let reqw_response = reqw_client .post("https://slack.com/api/files.upload") .header(AUTHORIZATION, format!("Bearer {}", slack_oauth_token.0))