diff --git a/.gitignore b/.gitignore index f189325..bc0cb50 100644 --- a/.gitignore +++ b/.gitignore @@ -3,9 +3,13 @@ /build /ggml-metal.metal /libbinding.a +prepare # just in case .DS_Store .idea .vscode Thumbs.db + +# model files downloaded by the test script +ggllm-test-model.bin \ No newline at end of file diff --git a/binding.cpp b/binding.cpp index bf32f42..873e9a7 100644 --- a/binding.cpp +++ b/binding.cpp @@ -618,6 +618,16 @@ void llama_free_params(void* params_ptr) { delete params; } +int llama_tokenize_string(void* params_ptr, void* state_pr, int* result) { + gpt_params* params_p = (gpt_params*) params_ptr; + llama_state* state = (llama_state*) state_pr; + llama_context* ctx = state->ctx; + + // TODO: add_bos + + return llama_tokenize(ctx, params_p->prompt.c_str(), result, params_p->n_ctx, true); +} + std::vector create_vector(const char** strings, int count) { std::vector* vec = new std::vector; diff --git a/binding.h b/binding.h index d609022..77d116f 100644 --- a/binding.h +++ b/binding.h @@ -48,6 +48,8 @@ void llama_free_params(void* params_ptr); void llama_binding_free_model(void* state); +int llama_tokenize_string(void* params_ptr, void* state_pr, int* result); + int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug); #ifdef __cplusplus diff --git a/lama_test.go b/lama_test.go deleted file mode 100644 index 1a0e9ee..0000000 --- a/lama_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package llama_test - -import ( - "os" - - . "github.com/go-skynet/go-llama.cpp" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("LLama binding", func() { - Context("Declaration", func() { - It("fails with no model", func() { - model, err := New("not-existing") - Expect(err).To(HaveOccurred()) - Expect(model).To(BeNil()) - }) - }) - Context("Inferencing", func() { - It("Works with "+os.Getenv("TEST_MODEL"), func() { - model, err := New( - os.Getenv("TEST_MODEL"), - EnableF16Memory, - SetContext(128), - SetMMap(true), - SetNBatch(512), - ) - Expect(err).ToNot(HaveOccurred()) - Expect(model).ToNot(BeNil()) - text, err := model.Predict(`Below is an instruction that describes a task. Write a response that appropriately completes the request. - -### Instruction: How much is 2+2? - -### Response: `, SetRopeFreqBase(10000.0), SetRopeFreqScale(1)) - Expect(err).ToNot(HaveOccurred()) - Expect(text).To(ContainSubstring("4")) - }) - }) -}) diff --git a/llama.go b/llama.go index 5d5de4b..7bd6992 100644 --- a/llama.go +++ b/llama.go @@ -259,6 +259,48 @@ func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) { return res, nil } +// tokenize has an interesting return property: negative lengths (potentially) have meaning. Therefore, return the length seperate from the slice and error - all three can be used together +func (l *LLama) TokenizeString(text string, opts ...PredictOption) (int32, []int32, error) { + po := NewPredictOptions(opts...) + + input := C.CString(text) + if po.Tokens == 0 { + po.Tokens = 4096 // ??? + } + out := make([]C.int, po.Tokens) + + var fakeDblPtr **C.char + + // copy pasted and modified minimally. Should I simplify down / do we need an "allocate defaults" + params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), + C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), + C.bool(po.IgnoreEOS), C.bool(po.F16KV), + C.int(po.Batch), C.int(po.NKeep), fakeDblPtr, C.int(0), + C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty), + C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias), + C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap), + C.CString(po.MainGPU), C.CString(po.TensorSplit), + C.bool(po.PromptCacheRO), + C.CString(po.Grammar), + C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt), + ) + + tokRet := C.llama_tokenize_string(params, l.state, (*C.int)(unsafe.Pointer(&out[0]))) //, C.int(po.Tokens), true) + + if tokRet < 0 { + return int32(tokRet), []int32{}, fmt.Errorf("llama_tokenize_string returned negative count %d", tokRet) + } + + // TODO: Is this loop still required to unbox cgo to go? + gTokRet := int32(tokRet) + goSlice := make([]int32, gTokRet) + for i := int32(0); i < gTokRet; i++ { + goSlice[i] = int32(out[i]) + } + + return gTokRet, goSlice, nil +} + // CGo only allows us to use static calls from C to Go, we can't just dynamically pass in func's. // This is the next best thing, we register the callbacks in this map and call tokenCallback from // the C code. We also attach a finalizer to LLama, so it will unregister the callback when the diff --git a/llama_test.go b/llama_test.go new file mode 100644 index 0000000..cde8d83 --- /dev/null +++ b/llama_test.go @@ -0,0 +1,64 @@ +package llama_test + +import ( + "os" + + . "github.com/go-skynet/go-llama.cpp" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("LLama binding", func() { + testModelPath := os.Getenv("TEST_MODEL") + + Context("Declaration", func() { + It("fails with no model", func() { + model, err := New("not-existing") + Expect(err).To(HaveOccurred()) + Expect(model).To(BeNil()) + }) + }) + Context("Inferencing tests (using "+testModelPath+") ", func() { + getModel := func() (*LLama, error) { + model, err := New( + testModelPath, + EnableF16Memory, + SetContext(128), + SetMMap(true), + SetNBatch(512), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(model).ToNot(BeNil()) + return model, err + } + + It("predicts successfully", func() { + if testModelPath == "" { + Skip("test skipped - only makes sense if the TEST_MODEL environment variable is set.") + } + + model, err := getModel() + text, err := model.Predict(`Below is an instruction that describes a task. Write a response that appropriately completes the request. + +### Instruction: How much is 2+2? + +### Response: `, SetRopeFreqBase(10000.0), SetRopeFreqScale(1)) + Expect(err).ToNot(HaveOccurred()) + Expect(text).To(ContainSubstring("4")) + }) + + It("tokenizes strings successfully", func() { + if testModelPath == "" { + Skip("test skipped - only makes sense if the TEST_MODEL environment variable is set.") + } + + model, err := getModel() + l, tokens, err := model.TokenizeString("A STRANGE GAME.\nTHE ONLY WINNING MOVE IS NOT TO PLAY.\n\nHOW ABOUT A NICE GAME OF CHESS?", + SetRopeFreqBase(10000.0), SetRopeFreqScale(1)) + + Expect(err).ToNot(HaveOccurred()) + Expect(l).To(BeNumerically(">", 0)) + Expect(int(l)).To(Equal(len(tokens))) + }) + }) +})