Skip to content

Commit ea21a1a

Browse files
committed
feat: make threading can work with join
1 parent 3f28886 commit ea21a1a

File tree

2 files changed

+56
-47
lines changed

2 files changed

+56
-47
lines changed

interpreter/src/evaluator/builtins.rs

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ pub fn new_builtins() -> HashMap<String, Object> {
4848
);
4949
#[cfg(feature="threading")] // threading
5050
builtins.insert(
51-
String::from("饱和式救援"),
52-
Object::Builtin(1, three_body_threading),
51+
String::from("程心"),
52+
Object::Builtin(0, three_body_threading),
5353
);
5454
builtins
5555
}
@@ -361,56 +361,63 @@ fn eval(input: &str) -> Option<Object> {
361361

362362
#[cfg(feature="threading")]
363363
fn three_body_threading(args: Vec<Object>) -> Object {
364-
match &args[0] {
365-
Object::Int(o) => {
366-
async fn local_task(id: i64) {
367-
println!("Local task {} is running!", id);
368-
tokio::time::sleep(Duration::from_secs(1)).await;
369-
println!("Local task {} completed!", id);
364+
let mut session_hash = HashMap::new();
365+
{
366+
fn three_body_thread_new(args: Vec<Object>) -> Object {
367+
match &args[0] {
368+
Object::String(input) => {
369+
let input = (*input).clone();
370+
371+
let mut handle = std::thread::spawn(move || {
372+
let local_set = tokio::task::LocalSet::new();
373+
let rt = tokio::runtime::Builder::new_current_thread()
374+
.enable_all()
375+
.build()
376+
.unwrap();
377+
378+
// 在 LocalSet 中安排任务
379+
local_set.spawn_local(async move { eval(&input) });
380+
381+
// 运行 LocalSet 直到其中的任务完成
382+
rt.block_on(local_set);
383+
});
384+
385+
let handle = Box::leak(Box::new(handle));
386+
let handle_ptr = &mut *handle as *mut std::thread::JoinHandle<()>;
387+
Object::Native(Box::new(NativeObject::Thread(handle_ptr)))
388+
},
389+
_ => panic!()
370390
}
391+
}
392+
session_hash.insert(Object::String("thread".to_owned()), Object::Builtin(1, three_body_thread_new));
393+
}
371394

372-
let o = (*o).clone();
373-
374-
let mut handle = std::thread::spawn(move || {
375-
let local_set = tokio::task::LocalSet::new();
376-
let rt = tokio::runtime::Builder::new_current_thread()
377-
.enable_all()
378-
.build()
379-
.unwrap();
380395

381-
// 在 LocalSet 中安排任务
382-
local_set.spawn_local(local_task(o));
383396

384-
// 运行 LocalSet 直到其中的任务完成
385-
rt.block_on(local_set);
386-
});
387-
Object::Null
388-
},
389-
Object::String(input) => {
390-
// async fn local_task(stmt: &BlockStmt) {
391-
// println!("Local task {} is running!", id);
392-
// tokio::time::sleep(Duration::from_secs(1)).await;
393-
// // println!("Local task {} completed!", id);
394-
// }
395-
let input = (*input).clone();
396-
397-
let mut handle = std::thread::spawn(move || {
398-
let local_set = tokio::task::LocalSet::new();
399-
let rt = tokio::runtime::Builder::new_current_thread()
400-
.enable_all()
401-
.build()
402-
.unwrap();
403-
404-
// 在 LocalSet 中安排任务
405-
local_set.spawn_local(async move { eval(&input) });
406-
407-
// 运行 LocalSet 直到其中的任务完成
408-
rt.block_on(local_set);
409-
});
410-
Object::Null
411-
},
412-
_ => Object::Null,
397+
{
398+
fn three_body_thread_join(args: Vec<Object>) -> Object {
399+
match &args[0] {
400+
Object::Native(ptr) => {
401+
let handle_ptr = match **ptr {
402+
NativeObject::Thread(handle_ptr) => {
403+
handle_ptr.clone()
404+
}
405+
_ => panic!()
406+
};
407+
// let model = unsafe { & *model_ptr };
408+
unsafe { Box::from_raw(handle_ptr) }.join();
409+
// std::mem::drop(model);
410+
Object::Null
411+
},
412+
_ => panic!()
413+
}
414+
}
415+
session_hash.insert(Object::String("join".to_owned()), Object::Builtin(1, three_body_thread_join));
413416
}
417+
418+
419+
Object::Hash(session_hash)
420+
414421
}
415422

416423
#[cfg(test)]

interpreter/src/evaluator/object.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ pub type BuiltinFunc = fn(Vec<Object>) -> Object;
1616
pub enum NativeObject {
1717
#[cfg(feature="sophon")]
1818
LLMModel(*mut dyn llm::Model),
19+
#[cfg(feature="threading")]
20+
Thread(*mut std::thread::JoinHandle<()>),
1921
}
2022

2123
#[derive(PartialEq, Clone, Debug)]

0 commit comments

Comments
 (0)