From 371ecd13c7fe00d281cb19c6588574134d7091b1 Mon Sep 17 00:00:00 2001 From: Leo Liang Date: Sun, 3 Sep 2023 15:41:10 +0800 Subject: [PATCH] Fix: per-prediction seed (#198) --- binding.cpp | 14 ++++++++++---- examples/main.go | 4 +++- options.go | 2 ++ 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/binding.cpp b/binding.cpp index ae4c59d..58f6206 100644 --- a/binding.cpp +++ b/binding.cpp @@ -44,7 +44,8 @@ int get_embeddings(void* params_ptr, void* state_pr, float * res_embeddings) { params.seed = time(NULL); } - std::mt19937 rng(params.seed); + // no need for a rng + // std::mt19937 rng(params.seed); int n_past = 0; @@ -127,7 +128,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { params.seed = time(NULL); } - std::mt19937 rng(params.seed); + // no need for a rng + // std::mt19937 rng(params.seed); if (params.rope_freq_base != 10000.0) { fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base); @@ -171,7 +173,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { return 1; } session_tokens.resize(n_token_count_out); - llama_set_rng_seed(ctx, params.seed); + // no need to set the seed here --- we'll always set it later + // llama_set_rng_seed(ctx, params.seed); if (debug) { fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size()); } @@ -311,6 +314,9 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { llama_reset_timings(ctx); } + // set the seed before actually predicting + llama_set_rng_seed(ctx, params.seed); + while (n_remain != 0) { // predict if (embd.size() > 0) { @@ -878,4 +884,4 @@ void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f state->model= model; return state; } -*/ \ No newline at end of file +*/ diff --git a/examples/main.go b/examples/main.go index 1b3fedc..c9a4070 100644 --- a/examples/main.go +++ b/examples/main.go @@ -16,6 +16,7 @@ var ( threads = 4 tokens = 128 gpulayers = 0 + seed = -1 ) func main() { @@ -26,6 +27,7 @@ func main() { flags.IntVar(&gpulayers, "ngl", 0, "Number of GPU layers to use") flags.IntVar(&threads, "t", runtime.NumCPU(), "number of threads to use during computation") flags.IntVar(&tokens, "n", 512, "number of tokens to predict") + flags.IntVar(&seed, "s", -1, "predict RNG seed, -1 for random seed") err := flags.Parse(os.Args[1:]) if err != nil { @@ -47,7 +49,7 @@ func main() { _, err := l.Predict(text, llama.Debug, llama.SetTokenCallback(func(token string) bool { fmt.Print(token) return true - }), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama")) + }), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"), llama.SetSeed(seed)) if err != nil { panic(err) } diff --git a/options.go b/options.go index aa4cd3d..8da902d 100644 --- a/options.go +++ b/options.go @@ -91,6 +91,8 @@ var DefaultOptions PredictOptions = PredictOptions{ MirostatTAU: 5.0, MirostatETA: 0.1, MMap: true, + RopeFreqBase: 10000, + RopeFreqScale: 1.0, } func SetMulMatQ(b bool) ModelOption {