From 371ecd13c7fe00d281cb19c6588574134d7091b1 Mon Sep 17 00:00:00 2001
From: Leo Liang <ljcleo@qq.com>
Date: Sun, 3 Sep 2023 15:41:10 +0800
Subject: [PATCH] Fix: per-prediction seed (#198)

---
 binding.cpp      | 14 ++++++++++----
 examples/main.go |  4 +++-
 options.go       |  2 ++
 3 files changed, 15 insertions(+), 5 deletions(-)

diff --git a/binding.cpp b/binding.cpp
index ae4c59d..58f6206 100644
--- a/binding.cpp
+++ b/binding.cpp
@@ -44,7 +44,8 @@ int get_embeddings(void* params_ptr, void* state_pr, float * res_embeddings) {
         params.seed = time(NULL);
     }
     
-    std::mt19937 rng(params.seed);
+    // no need for a rng
+    // std::mt19937 rng(params.seed);
   
     int n_past = 0;
 
@@ -127,7 +128,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
         params.seed = time(NULL);
     }
 
-    std::mt19937 rng(params.seed);
+    // no need for a rng
+    // std::mt19937 rng(params.seed);
 
     if (params.rope_freq_base != 10000.0) {
         fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
@@ -171,7 +173,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
                 return 1;
             }
             session_tokens.resize(n_token_count_out);
-            llama_set_rng_seed(ctx, params.seed);
+            // no need to set the seed here --- we'll always set it later
+            // llama_set_rng_seed(ctx, params.seed);
             if (debug) {
                 fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
             }
@@ -311,6 +314,9 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
         llama_reset_timings(ctx);
     }
     
+    // set the seed before actually predicting
+    llama_set_rng_seed(ctx, params.seed);
+
     while (n_remain != 0) {
                // predict
         if (embd.size() > 0) {
@@ -878,4 +884,4 @@ void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f
     state->model= model;
     return state;
 }
-*/
\ No newline at end of file
+*/
diff --git a/examples/main.go b/examples/main.go
index 1b3fedc..c9a4070 100644
--- a/examples/main.go
+++ b/examples/main.go
@@ -16,6 +16,7 @@ var (
 	threads   = 4
 	tokens    = 128
 	gpulayers = 0
+	seed      = -1
 )
 
 func main() {
@@ -26,6 +27,7 @@ func main() {
 	flags.IntVar(&gpulayers, "ngl", 0, "Number of GPU layers to use")
 	flags.IntVar(&threads, "t", runtime.NumCPU(), "number of threads to use during computation")
 	flags.IntVar(&tokens, "n", 512, "number of tokens to predict")
+	flags.IntVar(&seed, "s", -1, "predict RNG seed, -1 for random seed")
 
 	err := flags.Parse(os.Args[1:])
 	if err != nil {
@@ -47,7 +49,7 @@ func main() {
 		_, err := l.Predict(text, llama.Debug, llama.SetTokenCallback(func(token string) bool {
 			fmt.Print(token)
 			return true
-		}), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
+		}), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"), llama.SetSeed(seed))
 		if err != nil {
 			panic(err)
 		}
diff --git a/options.go b/options.go
index aa4cd3d..8da902d 100644
--- a/options.go
+++ b/options.go
@@ -91,6 +91,8 @@ var DefaultOptions PredictOptions = PredictOptions{
 	MirostatTAU:       5.0,
 	MirostatETA:       0.1,
 	MMap:              true,
+	RopeFreqBase:      10000,
+	RopeFreqScale:     1.0,
 }
 
 func SetMulMatQ(b bool) ModelOption {