Skip to content

Commit

Permalink
Fix: per-prediction seed (#198)
Browse files Browse the repository at this point in the history
  • Loading branch information
ljcleo authored Sep 3, 2023
1 parent d8c8547 commit 371ecd1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
14 changes: 10 additions & 4 deletions binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ int get_embeddings(void* params_ptr, void* state_pr, float * res_embeddings) {
params.seed = time(NULL);
}

std::mt19937 rng(params.seed);
// no need for a rng
// std::mt19937 rng(params.seed);

int n_past = 0;

Expand Down Expand Up @@ -127,7 +128,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
params.seed = time(NULL);
}

std::mt19937 rng(params.seed);
// no need for a rng
// std::mt19937 rng(params.seed);

if (params.rope_freq_base != 10000.0) {
fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
Expand Down Expand Up @@ -171,7 +173,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
return 1;
}
session_tokens.resize(n_token_count_out);
llama_set_rng_seed(ctx, params.seed);
// no need to set the seed here --- we'll always set it later
// llama_set_rng_seed(ctx, params.seed);
if (debug) {
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
}
Expand Down Expand Up @@ -311,6 +314,9 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
llama_reset_timings(ctx);
}

// set the seed before actually predicting
llama_set_rng_seed(ctx, params.seed);

while (n_remain != 0) {
// predict
if (embd.size() > 0) {
Expand Down Expand Up @@ -878,4 +884,4 @@ void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f
state->model= model;
return state;
}
*/
*/
4 changes: 3 additions & 1 deletion examples/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var (
threads = 4
tokens = 128
gpulayers = 0
seed = -1
)

func main() {
Expand All @@ -26,6 +27,7 @@ func main() {
flags.IntVar(&gpulayers, "ngl", 0, "Number of GPU layers to use")
flags.IntVar(&threads, "t", runtime.NumCPU(), "number of threads to use during computation")
flags.IntVar(&tokens, "n", 512, "number of tokens to predict")
flags.IntVar(&seed, "s", -1, "predict RNG seed, -1 for random seed")

err := flags.Parse(os.Args[1:])
if err != nil {
Expand All @@ -47,7 +49,7 @@ func main() {
_, err := l.Predict(text, llama.Debug, llama.SetTokenCallback(func(token string) bool {
fmt.Print(token)
return true
}), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
}), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"), llama.SetSeed(seed))
if err != nil {
panic(err)
}
Expand Down
2 changes: 2 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ var DefaultOptions PredictOptions = PredictOptions{
MirostatTAU: 5.0,
MirostatETA: 0.1,
MMap: true,
RopeFreqBase: 10000,
RopeFreqScale: 1.0,
}

func SetMulMatQ(b bool) ModelOption {
Expand Down

0 comments on commit 371ecd1

Please sign in to comment.