5
5
Napi::FunctionReference LlamaCPPBinding::constructor;
6
6
7
7
Napi::Object LlamaCPPBinding::Init (Napi::Env env, Napi::Object exports) {
8
- Napi::HandleScope scope (env);
9
-
10
8
Napi::Function func = DefineClass (env, " LlamaCPP" , {
11
9
InstanceMethod (" initialize" , &LlamaCPPBinding::Initialize),
12
- InstanceMethod (" runQuery" , &LlamaCPPBinding::RunQuery),
13
- InstanceMethod (" runQueryStream" , &LlamaCPPBinding::RunQueryStream)
10
+ InstanceMethod (" setSystemPrompt" , &LlamaCPPBinding::SetSystemPrompt),
11
+ InstanceMethod (" prompt" , &LlamaCPPBinding::PromptStream),
12
+ InstanceMethod (" resetConversation" , &LlamaCPPBinding::ResetConversation),
14
13
});
15
14
16
15
constructor = Napi::Persistent (func);
@@ -20,68 +19,93 @@ Napi::Object LlamaCPPBinding::Init(Napi::Env env, Napi::Object exports) {
20
19
return exports;
21
20
}
22
21
23
- LlamaCPPBinding::LlamaCPPBinding (const Napi::CallbackInfo& info) : Napi::ObjectWrap<LlamaCPPBinding>(info) {
24
- Napi::Env env = info.Env ();
25
- Napi::HandleScope scope (env);
26
-
27
- llama_ = std::make_unique<LlamaWrapper>();
22
+ LlamaCPPBinding::LlamaCPPBinding (const Napi::CallbackInfo& info)
23
+ : Napi::ObjectWrap<LlamaCPPBinding>(info) {
24
+ llama_ = std::make_unique<LlamaChat>();
28
25
}
29
26
30
27
Napi::Value LlamaCPPBinding::Initialize (const Napi::CallbackInfo& info) {
31
28
Napi::Env env = info.Env ();
29
+
32
30
if (info.Length () < 1 || !info[0 ].IsString ()) {
33
- Napi::TypeError::New (env, " String expected " ).ThrowAsJavaScriptException ();
31
+ Napi::TypeError::New (env, " Model path must be a string " ).ThrowAsJavaScriptException ();
34
32
return env.Null ();
35
33
}
36
34
37
- std::string model_path = info[0 ].As <Napi::String>().Utf8Value ();
38
- size_t context_size = 80000 ;
39
- if (info.Length () > 1 && info[1 ].IsNumber ()) {
40
- context_size = info[1 ].As <Napi::Number>().Uint32Value ();
35
+ std::string modelPath = info[0 ].As <Napi::String>().Utf8Value ();
36
+
37
+ ModelParams modelParams;
38
+ ContextParams contextParams;
39
+
40
+ if (info.Length () > 1 && info[1 ].IsObject ()) {
41
+ Napi::Object modelParamsObj = info[1 ].As <Napi::Object>();
42
+ if (modelParamsObj.Has (" nGpuLayers" )) {
43
+ modelParams.nGpuLayers = modelParamsObj.Get (" nGpuLayers" ).As <Napi::Number>().Int32Value ();
44
+ }
45
+ // Add parsing for other ModelParams if needed
46
+ }
47
+
48
+ if (info.Length () > 2 && info[2 ].IsObject ()) {
49
+ Napi::Object contextParamsObj = info[2 ].As <Napi::Object>();
50
+ if (contextParamsObj.Has (" nContext" )) {
51
+ contextParams.nContext = contextParamsObj.Get (" nContext" ).As <Napi::Number>().Uint32Value ();
52
+ }
53
+ // Add parsing for other ContextParams if needed
41
54
}
42
55
43
- bool success = llama_->Initialize (model_path, context_size);
44
- return Napi::Boolean::New (env, success);
56
+ if (!llama_->InitializeModel (modelPath, modelParams)) {
57
+ Napi::Error::New (env, " Failed to initialize the model" ).ThrowAsJavaScriptException ();
58
+ return Napi::Boolean::New (env, false );
59
+ }
60
+
61
+ if (!llama_->InitializeContext (contextParams)) {
62
+ Napi::Error::New (env, " Failed to initialize the context" ).ThrowAsJavaScriptException ();
63
+ return Napi::Boolean::New (env, false );
64
+ }
65
+
66
+ return Napi::Boolean::New (env, true );
45
67
}
46
68
47
- Napi::Value LlamaCPPBinding::RunQuery (const Napi::CallbackInfo& info) {
69
+ Napi::Value LlamaCPPBinding::SetSystemPrompt (const Napi::CallbackInfo& info) {
48
70
Napi::Env env = info.Env ();
49
71
if (info.Length () < 1 || !info[0 ].IsString ()) {
50
- Napi::TypeError::New (env, " String expected " ).ThrowAsJavaScriptException ();
72
+ Napi::TypeError::New (env, " System prompt must be a string " ).ThrowAsJavaScriptException ();
51
73
return env.Null ();
52
74
}
53
75
54
- std::string prompt = info[0 ].As <Napi::String>().Utf8Value ();
55
- size_t max_tokens = 1000 ;
56
- if (info.Length () > 1 && info[1 ].IsNumber ()) {
57
- max_tokens = info[1 ].As <Napi::Number>().Uint32Value ();
58
- }
76
+ std::string systemPrompt = info[0 ].As <Napi::String>().Utf8Value ();
77
+ llama_->SetSystemPrompt (systemPrompt);
78
+ return env.Null ();
79
+ }
59
80
60
- std::string response = llama_->RunQuery (prompt, max_tokens);
61
- return Napi::String::New (env, response);
81
+ Napi::Value LlamaCPPBinding::ResetConversation (const Napi::CallbackInfo& info) {
82
+ llama_->ResetConversation ();
83
+ return info.Env ().Undefined ();
62
84
}
63
85
64
- Napi::Value LlamaCPPBinding::RunQueryStream (const Napi::CallbackInfo& info) {
86
+ Napi::Value LlamaCPPBinding::PromptStream (const Napi::CallbackInfo& info) {
65
87
Napi::Env env = info.Env ();
66
88
if (info.Length () < 1 || !info[0 ].IsString ()) {
67
- Napi::TypeError::New (env, " String expected " ).ThrowAsJavaScriptException ();
89
+ Napi::TypeError::New (env, " User message must be a string " ).ThrowAsJavaScriptException ();
68
90
return env.Null ();
69
91
}
70
92
71
- std::string prompt = info[0 ].As <Napi::String>().Utf8Value ();
72
- size_t max_tokens = 1000 ;
73
- if (info.Length () > 1 && info[1 ].IsNumber ()) {
74
- max_tokens = info[1 ].As <Napi::Number>().Uint32Value ();
75
- }
93
+ std::string userMessage = info[0 ].As <Napi::String>().Utf8Value ();
76
94
77
95
Napi::Object streamObj = TokenStream::NewInstance (env, env.Null ());
78
96
TokenStream* stream = Napi::ObjectWrap<TokenStream>::Unwrap (streamObj);
79
97
80
- std::thread ([this , prompt, max_tokens, stream]() {
81
- llama_->RunQueryStream (prompt, max_tokens, [stream](const std::string& token) {
82
- stream->Push (token);
83
- });
84
- stream->End ();
98
+ LlamaChat* llama_ptr = llama_.get ();
99
+
100
+ std::thread ([llama_ptr, userMessage, stream]() {
101
+ try {
102
+ llama_ptr->Prompt (userMessage, [stream](const std::string& piece) {
103
+ stream->Push (piece);
104
+ });
105
+ stream->End ();
106
+ } catch (const std::exception& e) {
107
+ stream->End ();
108
+ }
85
109
}).detach ();
86
110
87
111
return streamObj;
0 commit comments