From df3dc138719a190845a07623d4c1c41af9006488 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 4 Sep 2023 19:06:29 +0200 Subject: [PATCH] feat(speculative-sampling): Add speculative sampling (#200) Signed-off-by: mudler --- binding.cpp | 317 ++++++++++++++++++++++++++-------------- binding.h | 8 +- llama.cpp | 2 +- llama.go | 65 +++++++- llama_test.go | 33 +++++ options.go | 14 ++ patches/1902-cuda.patch | 13 +- 7 files changed, 335 insertions(+), 117 deletions(-) diff --git a/binding.cpp b/binding.cpp index 58f6206..2cec6a5 100644 --- a/binding.cpp +++ b/binding.cpp @@ -257,8 +257,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { } } + struct llama_grammar * grammar = NULL; grammar_parser::parse_state parsed_grammar; - llama_grammar * grammar = NULL; if (!params.grammar.empty()) { parsed_grammar = grammar_parser::parse(params.grammar.c_str()); // will be empty (default) if there are parse errors @@ -284,8 +284,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { // TODO: replace with ring-buffer - std::vector last_n_tokens(n_ctx); - std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + std::vector last_tokens(n_ctx); + std::fill(last_tokens.begin(), last_tokens.end(), 0); bool is_antiprompt = false; bool input_echo = true; @@ -305,6 +305,9 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { std::vector embd; std::vector embd_guidance; + const int n_vocab = llama_n_vocab(ctx); + std::vector candidates; + candidates.reserve(n_vocab); std::string res = ""; @@ -340,19 +343,11 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { n_past = std::max(1, params.n_keep); n_past_guidance = std::max(1, params.n_keep + guidance_offset); - // insert n_left/2 tokens at the start of embd from last_n_tokens - embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); + // insert n_left/2 tokens at the start of embd from last_tokens + embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size()); // stop saving session if we run out of context path_session.clear(); - - //printf("\n---\n"); - //printf("resetting: '"); - //for (int i = 0; i < (int) embd.size(); i++) { - // printf("%s", llama_token_to_piece(ctx, embd[i])); - //} - //printf("'\n"); - //printf("\n---\n"); } @@ -445,96 +440,16 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { embd_guidance.clear(); if ((int) embd_inp.size() <= n_consumed) { - // out of user input, sample next token - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; - const float top_p = params.top_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; - const float repeat_penalty = params.repeat_penalty; - const float alpha_presence = params.presence_penalty; - const float alpha_frequency = params.frequency_penalty; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - const bool penalize_nl = params.penalize_nl; - // optionally save the session on first sample (for faster prompt loading next time) if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) { need_to_save_session = false; llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - llama_token id = 0; - - { - auto logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(ctx); + const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates); - // Apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { - logits[it->first] += it->second; - } - - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - - llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; - - // Apply penalties - float nl_logit = logits[llama_token_nl(ctx)]; - auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); - llama_sample_repetition_penalty(ctx, &cur_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx, &cur_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); - if (!penalize_nl) { - for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(ctx)) { - cur_p.data[idx].logit = nl_logit; - break; - } - } - } - if (grammar != NULL) { - llama_sample_grammar(ctx, &cur_p, grammar); - } - if (temp <= 0) { - // Greedy sampling - id = llama_sample_token_greedy(ctx, &cur_p); - } else { - if (mirostat == 1) { - static float mirostat_mu = 2.0f * mirostat_tau; - const int mirostat_m = 100; - llama_sample_temperature(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); - } else if (mirostat == 2) { - static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temperature(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); - } else { - // Temperature sampling - llama_sample_top_k(ctx, &cur_p, top_k, 1); - llama_sample_tail_free(ctx, &cur_p, tfs_z, 1); - llama_sample_typical(ctx, &cur_p, typical_p, 1); - llama_sample_top_p(ctx, &cur_p, top_p, 1); - llama_sample_temperature(ctx, &cur_p, temp); - id = llama_sample_token(ctx, &cur_p); - } - } - // printf("`%d`", candidates_p.size); - if (grammar != NULL) { - llama_grammar_accept_token(ctx, grammar, id); - } - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(id); - } + last_tokens.erase(last_tokens.begin()); + last_tokens.push_back(id); // add it to the context embd.push_back(id); @@ -553,8 +468,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { // some user input remains from prompt or interaction, forward it to processing while ((int) embd_inp.size() > n_consumed) { embd.push_back(embd_inp[n_consumed]); - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(embd_inp[n_consumed]); + last_tokens.erase(last_tokens.begin()); + last_tokens.push_back(embd_inp[n_consumed]); ++n_consumed; if ((int) embd.size() >= params.n_batch) { break; @@ -580,7 +495,7 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { // check for reverse prompt if (params.antiprompt.size()) { std::string last_output; - for (auto id : last_n_tokens) { + for (auto id : last_tokens) { last_output += llama_token_to_piece(ctx, id); } @@ -632,6 +547,196 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { return 0; } +// this is a bit of a hack now - ideally this should be in the predict function +// and be transparent to the caller, however this now maps 1:1 (mostly) the upstream implementation +// Note: both model have to be loaded with perplexity "true" to enable all logits +int speculative_sampling(void* params_ptr, void* target_model, void* draft_model, char* result, bool debug) { + + gpt_params* params_p = (gpt_params*) params_ptr; + llama_binding_state* target_model_state = (llama_binding_state*) target_model; + llama_binding_state* draft_model_state = (llama_binding_state*) draft_model; + + gpt_params params = *params_p; + llama_context * ctx_tgt = target_model_state->ctx; + llama_context * ctx_dft = draft_model_state->ctx; + + llama_model * model_tgt = target_model_state->model; + llama_model * model_dft = draft_model_state->model; + + std::string res = ""; + + // tokenize the prompt + std::vector inp; + inp = ::llama_tokenize(ctx_tgt, params.prompt, true); + + const int max_context_size = llama_n_ctx(ctx_tgt); + const int max_tokens_list_size = max_context_size - 4; + + if ((int) inp.size() > max_tokens_list_size) { + fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); + return 1; + } + + const int n_input = inp.size(); + + const auto t_enc_start = ggml_time_us(); + + // eval the prompt with both models + llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads); + llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads); + llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads); + + const auto t_enc_end = ggml_time_us(); + + // the 2 models should have the same vocab + const int n_ctx = llama_n_ctx(ctx_tgt); + const int n_vocab = llama_n_vocab(ctx_tgt); + //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft)); + + // how many tokens to draft each time + const int n_draft = params.n_draft; + + int n_predict = 0; + int n_drafted = 0; + int n_accept = 0; + + int n_past_tgt = inp.size(); + int n_past_dft = inp.size(); + + std::vector drafted; + + std::vector last_tokens(n_ctx); + std::fill(last_tokens.begin(), last_tokens.end(), 0); + + for (auto & id : inp) { + last_tokens.erase(last_tokens.begin()); + last_tokens.push_back(id); + } + + std::vector candidates; + candidates.reserve(n_vocab); + + // used to determine end of generation + bool has_eos = false; + + const auto t_dec_start = ggml_time_us(); + + while (true) { + // sample from the drafted tokens if any + int i_dft = 0; + while (true) { + const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft); + + last_tokens.erase(last_tokens.begin()); + last_tokens.push_back(id); + + //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens)); + + const std::string token_str = llama_token_to_piece(ctx_tgt, id); + if (!tokenCallback(draft_model, (char*)token_str.c_str())) { + break; + } + res += token_str.c_str(); + + if (id == llama_token_eos(ctx_tgt)) { + has_eos = true; + } + + ++n_predict; + + if (i_dft < (int) drafted.size() && id == drafted[i_dft]) { + LOG("drafted token %d accepted\n", id); + ++n_accept; + ++n_past_tgt; + ++n_past_dft; + ++i_dft; + + continue; + } + + // the drafted token was rejected or we are out of drafted tokens + llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); + ++n_past_dft; + + drafted.clear(); + drafted.push_back(id); + + break; + } + + if (n_predict > params.n_predict || has_eos) { + break; + } + + // sample n_draft tokens from the draft model picking the best token + int n_past_cur = n_past_dft; + for (int i = 0; i < n_draft; ++i) { + float * logits = llama_get_logits(ctx_dft); + + candidates.clear(); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; + + // computes softmax and sorts the candidates + llama_sample_softmax(ctx_dft, &cur_p); + + for (int i = 0; i < 3; ++i) { + LOG(" - draft candidate %d: %d (%.3f)\n", i, cur_p.data[i].id, cur_p.data[i].p); + } + + // too low probability, stop drafting + if (cur_p.data[0].p < 2*cur_p.data[1].p) { + break; + } + + drafted.push_back(cur_p.data[0].id); + ++n_drafted; + + if (i < n_draft - 1) { + // evaluate the drafted token on the draft model + llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); + ++n_past_cur; + } + } + + // evaluate the target model on the drafted tokens + llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads); + ++n_past_tgt; + + drafted.erase(drafted.begin()); + } + if (debug) { + auto t_dec_end = ggml_time_us(); + + LOG_TEE("\n\n"); + + LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + + // TODO: make sure these numbers are computed correctly + LOG_TEE("\n"); + LOG_TEE("n_draft = %d\n", n_draft); + LOG_TEE("n_predict = %d\n", n_predict); + LOG_TEE("n_drafted = %d\n", n_drafted); + LOG_TEE("n_accept = %d\n", n_accept); + LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + + LOG_TEE("\ndraft:\n"); + llama_print_timings(ctx_dft); + + LOG_TEE("\ntarget:\n"); + llama_print_timings(ctx_tgt); + + fprintf(stderr, "\n\n"); + } + + strcpy(result, res.c_str()); + return 0; +} + void llama_binding_free_model(void *state_ptr) { llama_binding_state* ctx = (llama_binding_state*) state_ptr; llama_free(ctx->ctx); @@ -711,8 +816,7 @@ void* llama_allocate_params(const char *prompt, int seed, int threads, int token float top_p, float temp, float repeat_penalty, int repeat_last_n, bool ignore_eos, bool memory_f16, int n_batch, int n_keep, const char** antiprompt, int antiprompt_count, float tfs_z, float typical_p, float frequency_penalty, float presence_penalty, int mirostat, float mirostat_eta, float mirostat_tau, bool penalize_nl, const char *logit_bias, const char *session_file, bool prompt_cache_all, bool mlock, bool mmap, const char *maingpu,const char *tensorsplit , bool prompt_cache_ro, const char *grammar, - float rope_freq_base, float rope_freq_scale, float negative_prompt_scale, const char* negative_prompt - ) { + float rope_freq_base, float rope_freq_scale, float negative_prompt_scale, const char* negative_prompt, int n_draft) { gpt_params* params = new gpt_params; params->seed = seed; params->n_threads = threads; @@ -733,7 +837,7 @@ void* llama_allocate_params(const char *prompt, int seed, int threads, int token params->rope_freq_scale = rope_freq_scale; params->cfg_scale = negative_prompt_scale; params->cfg_negative_prompt = std::string(negative_prompt); - + params->n_draft = n_draft; if (maingpu[0] != '\0') { params->main_gpu = std::stoi(maingpu); } @@ -784,8 +888,8 @@ void* llama_allocate_params(const char *prompt, int seed, int threads, int token return params; } -void* load_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base) { - return load_binding_model(fname, n_ctx, n_seed, memory_f16, mlock, embeddings, mmap, low_vram, n_gpu_layers, n_batch, maingpu, tensorsplit, numa, rope_freq_base, rope_freq_scale, mul_mat_q, lora, lora_base); +void* load_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity) { + return load_binding_model(fname, n_ctx, n_seed, memory_f16, mlock, embeddings, mmap, low_vram, n_gpu_layers, n_batch, maingpu, tensorsplit, numa, rope_freq_base, rope_freq_scale, mul_mat_q, lora, lora_base, perplexity); } /* @@ -803,7 +907,7 @@ struct llama_binding_state { llama_model * model; }; -void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base); +void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity); common.cpp: gpt_params* create_gpt_params(const std::string& fname,const std::string& lora,const std::string& lora_base) { @@ -821,7 +925,7 @@ gpt_params* create_gpt_params(const std::string& fname,const std::string& lora,c return lparams; } -void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base) { +void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity) { // load the model gpt_params * lparams = create_gpt_params(fname, lora, lora_base); llama_model * model; @@ -835,6 +939,7 @@ void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f lparams->use_mlock = mlock; lparams->n_gpu_layers = n_gpu_layers; lparams->use_mmap = mmap; + lparams->perplexity = perplexity; lparams->low_vram = low_vram; if (rope_freq_base != 0.0f) { diff --git a/binding.h b/binding.h index 0b2a32e..44664eb 100644 --- a/binding.h +++ b/binding.h @@ -29,7 +29,7 @@ void* load_model(const char *fname, bool numa, float rope_freq_base, float rope_freq_scale, - bool mul_mat_q, const char *lora, const char *lora_base + bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity ); int get_embeddings(void* params_ptr, void* state_pr, float * res_embeddings); @@ -41,8 +41,10 @@ void* llama_allocate_params(const char *prompt, int seed, int threads, int token int repeat_last_n, bool ignore_eos, bool memory_f16, int n_batch, int n_keep, const char** antiprompt, int antiprompt_count, float tfs_z, float typical_p, float frequency_penalty, float presence_penalty, int mirostat, float mirostat_eta, float mirostat_tau, bool penalize_nl, const char *logit_bias, const char *session_file, bool prompt_cache_all, bool mlock, bool mmap, const char *maingpu, const char *tensorsplit , - bool prompt_cache_ro, const char *grammar, float rope_freq_base, float rope_freq_scale, float negative_prompt_scale, const char* negative_prompt - ); + bool prompt_cache_ro, const char *grammar, float rope_freq_base, float rope_freq_scale, float negative_prompt_scale, const char* negative_prompt, + int n_draft); + +int speculative_sampling(void* params_ptr, void* target_model, void* draft_model, char* result, bool debug); void llama_free_params(void* params_ptr); diff --git a/llama.cpp b/llama.cpp index 69fdbb9..e4386f4 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 69fdbb9abc8907dd2a9ffdd840cba92d678a660a +Subproject commit e4386f417faf894f6706eec005e24d142b577fcb diff --git a/llama.go b/llama.go index 9aa0034..2fef3b7 100644 --- a/llama.go +++ b/llama.go @@ -41,7 +41,7 @@ func New(model string, opts ...ModelOption) (*LLama, error) { C.bool(mo.F16Memory), C.bool(mo.MLock), C.bool(mo.Embeddings), C.bool(mo.MMap), C.bool(mo.LowVRAM), C.int(mo.NGPULayers), C.int(mo.NBatch), C.CString(mo.MainGPU), C.CString(mo.TensorSplit), C.bool(mo.NUMA), C.float(mo.FreqRopeBase), C.float(mo.FreqRopeScale), - C.bool(MulMatQ), loraAdapter, loraBase, + C.bool(MulMatQ), loraAdapter, loraBase, C.bool(mo.Perplexity), ) if result == nil { @@ -123,6 +123,7 @@ func (l *LLama) TokenEmbeddings(tokens []int, opts ...PredictOption) ([]float32, C.bool(po.PromptCacheRO), C.CString(po.Grammar), C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt), + C.int(po.NDraft), ) ret := C.get_token_embeddings(params, l.state, myArray, C.int(len(tokens)), (*C.float)(&floats[0])) if ret != 0 { @@ -164,6 +165,7 @@ func (l *LLama) Embeddings(text string, opts ...PredictOption) ([]float32, error C.bool(po.PromptCacheRO), C.CString(po.Grammar), C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt), + C.int(po.NDraft), ) ret := C.get_embeddings(params, l.state, (*C.float)(&floats[0])) @@ -202,6 +204,7 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error { C.bool(po.PromptCacheRO), C.CString(po.Grammar), C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt), + C.int(po.NDraft), ) ret := C.eval(params, l.state, input) if ret != 0 { @@ -213,6 +216,64 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error { return nil } +func (l *LLama) SpeculativeSampling(ll *LLama, text string, opts ...PredictOption) (string, error) { + po := NewPredictOptions(opts...) + + if po.TokenCallback != nil { + setCallback(l.state, po.TokenCallback) + } + + input := C.CString(text) + if po.Tokens == 0 { + po.Tokens = 99999999 + } + out := make([]byte, po.Tokens) + + reverseCount := len(po.StopPrompts) + reversePrompt := make([]*C.char, reverseCount) + var pass **C.char + for i, s := range po.StopPrompts { + cs := C.CString(s) + reversePrompt[i] = cs + pass = &reversePrompt[0] + } + + params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), + C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), + C.bool(po.IgnoreEOS), C.bool(po.F16KV), + C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount), + C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty), + C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias), + C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap), + C.CString(po.MainGPU), C.CString(po.TensorSplit), + C.bool(po.PromptCacheRO), + C.CString(po.Grammar), + C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt), + C.int(po.NDraft), + ) + ret := C.speculative_sampling(params, l.state, ll.state, (*C.char)(unsafe.Pointer(&out[0])), C.bool(po.DebugMode)) + if ret != 0 { + return "", fmt.Errorf("inference failed") + } + res := C.GoString((*C.char)(unsafe.Pointer(&out[0]))) + + res = strings.TrimPrefix(res, " ") + res = strings.TrimPrefix(res, text) + res = strings.TrimPrefix(res, "\n") + + for _, s := range po.StopPrompts { + res = strings.TrimRight(res, s) + } + + C.llama_free_params(params) + + if po.TokenCallback != nil { + setCallback(l.state, nil) + } + + return res, nil +} + func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) { po := NewPredictOptions(opts...) @@ -246,6 +307,7 @@ func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) { C.bool(po.PromptCacheRO), C.CString(po.Grammar), C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt), + C.int(po.NDraft), ) ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.bool(po.DebugMode)) if ret != 0 { @@ -294,6 +356,7 @@ func (l *LLama) TokenizeString(text string, opts ...PredictOption) (int32, []int C.bool(po.PromptCacheRO), C.CString(po.Grammar), C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt), + C.int(po.NDraft), ) tokRet := C.llama_tokenize_string(params, l.state, (*C.int)(unsafe.Pointer(&out[0]))) //, C.int(po.Tokens), true) diff --git a/llama_test.go b/llama_test.go index a053cb2..3bc8d0c 100644 --- a/llama_test.go +++ b/llama_test.go @@ -3,6 +3,7 @@ package llama_test import ( "os" + "github.com/go-skynet/go-llama.cpp" . "github.com/go-skynet/go-llama.cpp" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -45,6 +46,38 @@ how much is 2+2? Expect(text).To(ContainSubstring("4"), text) }) + It("speculative sampling predicts", func() { + if testModelPath == "" { + Skip("test skipped - only makes sense if the TEST_MODEL environment variable is set.") + } + model, err := New( + testModelPath, + EnableF16Memory, + SetContext(128), + SetMMap(true), + SetNBatch(512), + SetPerplexity(true), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(model).ToNot(BeNil()) + model2, err := New( + testModelPath, + EnableF16Memory, + SetContext(128), + SetMMap(true), + SetNBatch(512), + SetPerplexity(true), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(model).ToNot(BeNil()) + text, err := model.SpeculativeSampling(model2, `[INST] Answer to the following question: +how much is 2+2? +[/INST]`, llama.SetNDraft(16), + ) + Expect(err).ToNot(HaveOccurred(), text) + Expect(text).To(ContainSubstring("4"), text) + }) + It("tokenizes strings successfully", func() { if testModelPath == "" { Skip("test skipped - only makes sense if the TEST_MODEL environment variable is set.") diff --git a/options.go b/options.go index 8da902d..b36c671 100644 --- a/options.go +++ b/options.go @@ -18,11 +18,13 @@ type ModelOptions struct { MulMatQ *bool LoraBase string LoraAdapter string + Perplexity bool } type PredictOptions struct { Seed, Threads, Tokens, TopK, Repeat, Batch, NKeep int TopP, Temperature, Penalty float32 + NDraft int F16KV bool DebugMode bool StopPrompts []string @@ -193,6 +195,18 @@ func SetRopeFreqScale(rfs float32) PredictOption { } } +func SetNDraft(nd int) PredictOption { + return func(p *PredictOptions) { + p.NDraft = nd + } +} + +func SetPerplexity(b bool) ModelOption { + return func(p *ModelOptions) { + p.Perplexity = b + } +} + func SetNegativePromptScale(nps float32) PredictOption { return func(p *PredictOptions) { p.NegativePromptScale = nps diff --git a/patches/1902-cuda.patch b/patches/1902-cuda.patch index 72ff31b..eb18e20 100644 --- a/patches/1902-cuda.patch +++ b/patches/1902-cuda.patch @@ -1,8 +1,8 @@ diff --git a/common/common.cpp b/common/common.cpp -index ed09fc2..ced02e8 100644 +index 3138213..af93a32 100644 --- a/common/common.cpp +++ b/common/common.cpp -@@ -1107,3 +1107,82 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l +@@ -1257,3 +1257,83 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "typical_p: %f # default: 1.0\n", params.typical_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); } @@ -22,7 +22,7 @@ index ed09fc2..ced02e8 100644 + return lparams; +} + -+void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base) { ++void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity) { + // load the model + gpt_params * lparams = create_gpt_params(fname, lora, lora_base); + llama_model * model; @@ -35,6 +35,7 @@ index ed09fc2..ced02e8 100644 + lparams->embedding = embeddings; + lparams->use_mlock = mlock; + lparams->n_gpu_layers = n_gpu_layers; ++ lparams->perplexity = perplexity; + lparams->use_mmap = mmap; + + lparams->low_vram = low_vram; @@ -87,10 +88,10 @@ index ed09fc2..ced02e8 100644 +} \ No newline at end of file diff --git a/common/common.h b/common/common.h -index 5a37968..8b09050 100644 +index 105fb09..8f60434 100644 --- a/common/common.h +++ b/common/common.h -@@ -165,3 +165,10 @@ std::string get_sortable_timestamp(); +@@ -201,3 +201,10 @@ std::string get_sortable_timestamp(); void dump_non_result_info_yaml( FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); @@ -100,4 +101,4 @@ index 5a37968..8b09050 100644 + llama_model * model; +}; + -+void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base); ++void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity);