From ea21a1a8eda4b99d061b6e97a0245e25fc6c48a4 Mon Sep 17 00:00:00 2001 From: meloalright Date: Mon, 7 Oct 2024 22:12:22 +0800 Subject: [PATCH] feat: make threading can work with join --- interpreter/src/evaluator/builtins.rs | 101 ++++++++++++++------------ interpreter/src/evaluator/object.rs | 2 + 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/interpreter/src/evaluator/builtins.rs b/interpreter/src/evaluator/builtins.rs index 2e6d740..b55951e 100644 --- a/interpreter/src/evaluator/builtins.rs +++ b/interpreter/src/evaluator/builtins.rs @@ -48,8 +48,8 @@ pub fn new_builtins() -> HashMap { ); #[cfg(feature="threading")] // threading builtins.insert( - String::from("饱和式救援"), - Object::Builtin(1, three_body_threading), + String::from("程心"), + Object::Builtin(0, three_body_threading), ); builtins } @@ -361,56 +361,63 @@ fn eval(input: &str) -> Option { #[cfg(feature="threading")] fn three_body_threading(args: Vec) -> Object { - match &args[0] { - Object::Int(o) => { - async fn local_task(id: i64) { - println!("Local task {} is running!", id); - tokio::time::sleep(Duration::from_secs(1)).await; - println!("Local task {} completed!", id); + let mut session_hash = HashMap::new(); + { + fn three_body_thread_new(args: Vec) -> Object { + match &args[0] { + Object::String(input) => { + let input = (*input).clone(); + + let mut handle = std::thread::spawn(move || { + let local_set = tokio::task::LocalSet::new(); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + // 在 LocalSet 中安排任务 + local_set.spawn_local(async move { eval(&input) }); + + // 运行 LocalSet 直到其中的任务完成 + rt.block_on(local_set); + }); + + let handle = Box::leak(Box::new(handle)); + let handle_ptr = &mut *handle as *mut std::thread::JoinHandle<()>; + Object::Native(Box::new(NativeObject::Thread(handle_ptr))) + }, + _ => panic!() } + } + session_hash.insert(Object::String("thread".to_owned()), Object::Builtin(1, three_body_thread_new)); + } - let o = (*o).clone(); - - let mut handle = std::thread::spawn(move || { - let local_set = tokio::task::LocalSet::new(); - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - // 在 LocalSet 中安排任务 - local_set.spawn_local(local_task(o)); - // 运行 LocalSet 直到其中的任务完成 - rt.block_on(local_set); - }); - Object::Null - }, - Object::String(input) => { - // async fn local_task(stmt: &BlockStmt) { - // println!("Local task {} is running!", id); - // tokio::time::sleep(Duration::from_secs(1)).await; - // // println!("Local task {} completed!", id); - // } - let input = (*input).clone(); - - let mut handle = std::thread::spawn(move || { - let local_set = tokio::task::LocalSet::new(); - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - - // 在 LocalSet 中安排任务 - local_set.spawn_local(async move { eval(&input) }); - - // 运行 LocalSet 直到其中的任务完成 - rt.block_on(local_set); - }); - Object::Null - }, - _ => Object::Null, + { + fn three_body_thread_join(args: Vec) -> Object { + match &args[0] { + Object::Native(ptr) => { + let handle_ptr = match **ptr { + NativeObject::Thread(handle_ptr) => { + handle_ptr.clone() + } + _ => panic!() + }; + // let model = unsafe { & *model_ptr }; + unsafe { Box::from_raw(handle_ptr) }.join(); + // std::mem::drop(model); + Object::Null + }, + _ => panic!() + } + } + session_hash.insert(Object::String("join".to_owned()), Object::Builtin(1, three_body_thread_join)); } + + + Object::Hash(session_hash) + } #[cfg(test)] diff --git a/interpreter/src/evaluator/object.rs b/interpreter/src/evaluator/object.rs index 026abda..3a0234d 100644 --- a/interpreter/src/evaluator/object.rs +++ b/interpreter/src/evaluator/object.rs @@ -16,6 +16,8 @@ pub type BuiltinFunc = fn(Vec) -> Object; pub enum NativeObject { #[cfg(feature="sophon")] LLMModel(*mut dyn llm::Model), + #[cfg(feature="threading")] + Thread(*mut std::thread::JoinHandle<()>), } #[derive(PartialEq, Clone, Debug)]