diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 5d5cdb1..c363f6d 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -12,9 +12,8 @@ name: "CodeQL" on: - push: - branches: [ "main" ] pull_request: + push: # The branches below must be a subset of the branches above branches: [ "main" ] schedule: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 87ec212..a45dd69 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -36,32 +36,6 @@ jobs: run: | make test - windows-latest: - runs-on: windows-latest - strategy: - matrix: - go-version: [ 'stable' ] - steps: - - name: Checkout repository - uses: actions/checkout@v3 - with: - submodules: recursive - - name: Setup Go ${{ matrix.go-version }} - uses: actions/setup-go@v4 - with: - go-version: ${{ matrix.go-version }} - - name: Display Go version - run: go version - - name: Display GCC version - run: gcc --version - - name: Display CMake version - run: cmake --version - - name: Install dependencies - run: go mod tidy - - name: Test - run: | - make test - macOS-latest: runs-on: macOS-latest strategy: @@ -88,29 +62,31 @@ jobs: run: | make test - macOS-metal-latest: - runs-on: macOS-latest - strategy: - matrix: - go-version: ['stable'] - steps: - - name: Checkout repository - uses: actions/checkout@v3 - with: - submodules: recursive - - name: Setup Go ${{ matrix.go-version }} - uses: actions/setup-go@v4 - with: - go-version: ${{ matrix.go-version }} - # You can test your matrix by printing the current Go version - - name: Display Go version - run: go version - - name: Display GCC version - run: gcc --version - - name: Display CMake version - run: cmake --version - - name: Install dependencies - run: go mod tidy - - name: Test - run: | - make BUILD_TYPE=metal test \ No newline at end of file +# arm not support https://github.com/actions/runner-images/issues/8610 +# Apple Silicon powered macOS runners are now available in public beta! https://github.com/actions/runner-images/issues/8439 +# macOS-arm64-metal-latest: +# runs-on: macos-13-arm64 +# strategy: +# matrix: +# go-version: ['stable'] +# steps: +# - name: Checkout repository +# uses: actions/checkout@v3 +# with: +# submodules: recursive +# - name: Setup Go ${{ matrix.go-version }} +# uses: actions/setup-go@v4 +# with: +# go-version: ${{ matrix.go-version }} +# # You can test your matrix by printing the current Go version +# - name: Display Go version +# run: go version +# - name: Display GCC version +# run: gcc --version +# - name: Display CMake version +# run: cmake --version +# - name: Install dependencies +# run: go mod tidy +# - name: Test +# run: | +# make BUILD_TYPE=metal test \ No newline at end of file diff --git a/Makefile b/Makefile index a5a308b..3957a3a 100644 --- a/Makefile +++ b/Makefile @@ -2,29 +2,16 @@ INCLUDE_PATH := $(abspath ./) LIBRARY_PATH := $(abspath ./) ifndef UNAME_S - ifeq ($(OS),Windows_NT) - UNAME_S := $(shell ver) - else - UNAME_S := $(shell uname -s) - endif + UNAME_S := $(shell uname -s) endif - ifndef UNAME_P - ifeq ($(OS),Windows_NT) - UNAME_P := $(shell wmic cpu get caption) - else - UNAME_P := $(shell uname -p) - endif + UNAME_P := $(shell uname -p) endif - ifndef UNAME_M - ifeq ($(OS),Windows_NT) - UNAME_M := $(PROCESSOR_ARCHITECTURE) - else - UNAME_M := $(shell uname -s) - endif + UNAME_M := $(shell uname -m) endif + CCV := $(shell $(CC) --version | head -n 1) CXXV := $(shell $(CXX) --version | head -n 1) @@ -34,8 +21,8 @@ ifeq ($(UNAME_S),Darwin) ifneq ($(UNAME_P),arm) SYSCTL_M := $(shell sysctl -n hw.optional.arm64 2>/dev/null) ifeq ($(SYSCTL_M),1) - # UNAME_P := arm - # UNAME_M := arm64 +# UNAME_P := arm +# UNAME_M := arm64 warn := $(warning Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-1282546789) endif endif @@ -47,30 +34,25 @@ endif BUILD_TYPE?= # keep standard at C17 and C++17 -CFLAGS = -I. -O3 -DNDEBUG -std=c17 -fPIC -pthread CXXFLAGS = -I. -O3 -DNDEBUG -std=c++17 -fPIC -pthread -LDFLAGS = -CMAKE_ARGS = -DCMAKE_C_COMPILER=$(shell which gcc) -DCMAKE_CXX_COMPILER=$(shell which g++) # warnings -CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function -CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function +CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -pedantic-errors # GPGPU specific GGML_CUDA_OBJ_PATH=third_party/ggml/src/CMakeFiles/ggml.dir/ggml-cuda.cu.o # Architecture specific -# TODO: probably these flags need to be tweaked on some architectures -# feel free to update the Makefile for your architecture and send a pull request or issue +# feel free to update the Makefile for your architecture and send a pull request or issue ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686)) # Use all CPU extensions that are available: - CFLAGS += -march=native -mtune=native + CMAKE_ARGS += -DCMAKE_C_FLAGS=-march=native -DGGML_F16C=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF + CXXFLAGS += -march=native -mtune=native endif ifneq ($(filter ppc64%,$(UNAME_M)),) POWER9_M := $(shell grep "POWER9" /proc/cpuinfo) ifneq (,$(findstring POWER9,$(POWER9_M))) - CFLAGS += -mcpu=power9 CXXFLAGS += -mcpu=power9 endif # Require c++23's std::byteswap for big-endian support. @@ -79,61 +61,49 @@ ifneq ($(filter ppc64%,$(UNAME_M)),) endif endif ifdef CHATGLM_GPROF - CFLAGS += -pg CXXFLAGS += -pg endif ifneq ($(filter aarch64%,$(UNAME_M)),) - CFLAGS += -mcpu=native CXXFLAGS += -mcpu=native endif ifneq ($(filter armv6%,$(UNAME_M)),) # Raspberry Pi 1, 2, 3 - CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access + CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access endif ifneq ($(filter armv7%,$(UNAME_M)),) # Raspberry Pi 4 - CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations + CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations endif ifneq ($(filter armv8%,$(UNAME_M)),) # Raspberry Pi 4 - CFLAGS += -mfp16-format=ieee -mno-unaligned-access + CXXFLAGS += -mfp16-format=ieee -mno-unaligned-access endif -# Build Acceleration ifeq ($(BUILD_TYPE),cublas) - EXTRA_LIBS= CMAKE_ARGS+=-DGGML_CUBLAS=ON - EXTRA_TARGETS+=ggml.dir/ggml-cuda.o endif ifeq ($(BUILD_TYPE),openblas) - EXTRA_LIBS= CMAKE_ARGS+=-DGGML_OPENBLAS=ON - CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas - LDFLAGS += -lopenblas + CXXFLAGS += -I/usr/local/include/openblas -lopenblas CGO_TAGS=-tags openblas endif ifeq ($(BUILD_TYPE),hipblas) ROCM_HOME ?= "/opt/rocm" CXX="$(ROCM_HOME)"/llvm/bin/clang++ CC="$(ROCM_HOME)"/llvm/bin/clang - EXTRA_LIBS= GPU_TARGETS ?= gfx900,gfx90a,gfx1030,gfx1031,gfx1100 AMDGPU_TARGETS ?= "$(GPU_TARGETS)" CMAKE_ARGS+=-DGGML_HIPBLAS=ON -DAMDGPU_TARGETS="$(AMDGPU_TARGETS)" -DGPU_TARGETS="$(GPU_TARGETS)" - EXTRA_TARGETS+=ggml.dir/ggml-cuda.o GGML_CUDA_OBJ_PATH=CMakeFiles/ggml-rocm.dir/ggml-cuda.cu.o endif ifeq ($(BUILD_TYPE),clblas) - EXTRA_LIBS= CMAKE_ARGS+=-DGGML_CLBLAST=ON - EXTRA_TARGETS+=ggml.dir/ggml-opencl.o CGO_TAGS=-tags cublas endif ifeq ($(BUILD_TYPE),metal) - EXTRA_LIBS= CMAKE_ARGS+=-DGGML_METAL=ON - EXTRA_TARGETS+=ggml.dir/ggml-metal.o CGO_TAGS=-tags metal + EXTRA_TARGETS+=ggml-metal endif ifdef CLBLAST_DIR @@ -147,14 +117,13 @@ $(info I chatglm.cpp build info: ) $(info I UNAME_S: $(UNAME_S)) $(info I UNAME_P: $(UNAME_P)) $(info I UNAME_M: $(UNAME_M)) -$(info I CFLAGS: $(CFLAGS)) $(info I CXXFLAGS: $(CXXFLAGS)) -$(info I LDFLAGS: $(LDFLAGS)) $(info I BUILD_TYPE: $(BUILD_TYPE)) $(info I CMAKE_ARGS: $(CMAKE_ARGS)) $(info I EXTRA_TARGETS: $(EXTRA_TARGETS)) $(info I CC: $(CCV)) $(info I CXX: $(CXXV)) +$(info I CGO_TAGS: $(CGO_TAGS)) $(info ) # Use this if you want to set the default behavior @@ -164,45 +133,37 @@ prepare: # build chatglm.cpp build/chatglm.cpp: prepare - cd build && CC="$(CC)" CXX="$(CXX)" cmake ../chatglm.cpp $(CMAKE_ARGS) && VERBOSE=1 cmake --build . -j --config Release + cd build && CC="$(CC)" CXX="$(CXX)" cmake $(CMAKE_ARGS) ../chatglm.cpp && VERBOSE=1 cmake --build . -j --config Release # chatglm.dir chatglm.dir: build/chatglm.cpp - cd out && mkdir -p chatglm.dir && cd ../build && \ - cp -rp CMakeFiles/chatglm.dir/chatglm.cpp.o ../out/chatglm.dir/chatglm.o + cd out && mkdir -p chatglm.dir + cp build/CMakeFiles/chatglm.dir/chatglm.cpp.o out/chatglm.dir/ # ggml.dir ggml.dir: build/chatglm.cpp - cd out && mkdir -p ggml.dir && cd ../build && \ - cp -rf third_party/ggml/src/CMakeFiles/ggml.dir/ggml.c.o ../out/ggml.dir/ggml.o && \ - cp -rf third_party/ggml/src/CMakeFiles/ggml.dir/ggml-alloc.c.o ../out/ggml.dir/ggml-alloc.o + cd out && mkdir -p ggml.dir + cp build/third_party/ggml/src/CMakeFiles/ggml.dir/*.o out/ggml.dir/ # sentencepiece.dir sentencepiece.dir: build/chatglm.cpp - cd out && mkdir -p sentencepiece.dir && cd ../build && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/sentencepiece_processor.cc.o ../out/sentencepiece.dir/sentencepiece_processor.o && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/error.cc.o ../out/sentencepiece.dir/error.o && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/model_factory.cc.o ../out/sentencepiece.dir/model_factory.o && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/model_interface.cc.o ../out/sentencepiece.dir/model_interface.o && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/bpe_model.cc.o ../out/sentencepiece.dir/bpe_model.o && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/char_model.cc.o ../out/sentencepiece.dir/char_model.o && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/word_model.cc.o ../out/sentencepiece.dir/word_model.o && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/unigram_model.cc.o ../out/sentencepiece.dir/unigram_model.o && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/util.cc.o ../out/sentencepiece.dir/util.o && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/normalizer.cc.o ../out/sentencepiece.dir/normalizer.o && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/filesystem.cc.o ../out/sentencepiece.dir/filesystem.o && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/builtin_pb/sentencepiece.pb.cc.o ../out/sentencepiece.dir/sentencepiece.pb.o && \ - cp -rf third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/builtin_pb/sentencepiece_model.pb.cc.o ../out/sentencepiece.dir/sentencepiece_model.pb.o + cd out && mkdir -p sentencepiece.dir + cp build/third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/*.cc.o out/sentencepiece.dir/ + cp build/third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/builtin_pb/*.cc.o out/sentencepiece.dir/ # protobuf-lite.dir protobuf-lite.dir: sentencepiece.dir - cd out && mkdir -p protobuf-lite.dir && cd ../build && \ - find third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/__/third_party/protobuf-lite -name '*.cc.o' -exec cp {} ../out/protobuf-lite.dir/ \; + cd out && mkdir -p protobuf-lite.dir + cp build/third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/__/third_party/protobuf-lite/*.cc.o out/protobuf-lite.dir/ # absl.dir absl.dir: sentencepiece.dir - cd out && mkdir -p absl.dir && cd ../build && \ - find third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/__/third_party/absl/flags/ -name '*.cc.o' -exec cp {} ../out/absl.dir/ \; + cd out && mkdir -p absl.dir + cp build/third_party/sentencepiece/src/CMakeFiles/sentencepiece-static.dir/__/third_party/absl/flags/flag.cc.o out/absl.dir/ + +# ggml-metal +ggml-metal: ggml.dir + cp build/bin/ggml-metal.metal . # binding binding.o: prepare build/chatglm.cpp chatglm.dir ggml.dir sentencepiece.dir protobuf-lite.dir absl.dir @@ -210,47 +171,25 @@ binding.o: prepare build/chatglm.cpp chatglm.dir ggml.dir sentencepiece.dir prot -I./chatglm.cpp \ -I./chatglm.cpp/third_party/ggml/include/ggml \ -I./chatglm.cpp/third_party/sentencepiece/src \ - binding.cpp -o binding.o -c $(LDFLAGS) - -# ggml-cuda -ggml.dir/ggml-cuda.o: ggml.dir - cd build && cp -rf "$(GGML_CUDA_OBJ_PATH)" ../out/ggml.dir/ggml-cuda.o - -# ggml-opencl -ggml.dir/ggml-opencl.o: ggml.dir - cd build && cp -rf third_party/ggml/src/CMakeFiles/ggml.dir/ggml-opencl.cpp.o ../out/ggml.dir/ggml-opencl.o - -# ggml-metal -ggml.dir/ggml-metal.o: ggml.dir ggml.dir/ggml-backend.o - cd build && cp -rf bin/ggml-metal.metal ../ggml-metal.metal && \ - cp -rf third_party/ggml/src/CMakeFiles/ggml.dir/ggml-metal.m.o ../out/ggml.dir/ggml-metal.o - -# ggml-backend -ggml.dir/ggml-backend.o: - cd build && cp -rf third_party/ggml/src/CMakeFiles/ggml.dir/ggml-backend.c.o ../out/ggml.dir/ggml-backend.o + binding.cpp -MD -MT binding.o -MF binding.d -o binding.o -c libbinding.a: prepare binding.o $(EXTRA_TARGETS) ar src libbinding.a \ - out/chatglm.dir/chatglm.o \ + out/chatglm.dir/*.o \ out/ggml.dir/*.o out/sentencepiece.dir/*.o \ out/protobuf-lite.dir/*.o out/absl.dir/*.o \ binding.o clean: rm -rf *.o + rm -rf *.d rm -rf *.a rm -rf out rm -rf build -DOWNLOAD_TARGETS=ggllm-test-model.bin -ifeq ($(OS),Windows_NT) - DOWNLOAD_TARGETS:=windows/ggllm-test-model.bin -endif - ggllm-test-model.bin: wget -q -N https://huggingface.co/Xorbits/chatglm3-6B-GGML/resolve/main/chatglm3-ggml-q4_0.bin -O ggllm-test-model.bin -windows/ggllm-test-model.bin: - powershell -Command "Invoke-WebRequest -Uri 'https://huggingface.co/Xorbits/chatglm3-6B-GGML/resolve/main/chatglm3-ggml-q4_0.bin' -OutFile 'ggllm-test-model.bin'" -test: $(DOWNLOAD_TARGETS) libbinding.a - TEST_MODEL=ggllm-test-model.bin go test ${CGO_TAGS} . +test: ggllm-test-model.bin libbinding.a + go test ${CGO_TAGS} -timeout 1800s -o go-chatglm.cpp.test -c -cover + TEST_MODEL=ggllm-test-model.bin ./go-chatglm.cpp.test diff --git a/Makefile.win b/Makefile.win new file mode 100644 index 0000000..e1e3711 --- /dev/null +++ b/Makefile.win @@ -0,0 +1,144 @@ +INCLUDE_PATH := $(abspath ./) +LIBRARY_PATH := $(abspath ./) + +ifndef UNAME_S + UNAME_S := $(shell ver) +endif +ifndef UNAME_P + UNAME_P := $(shell wmic cpu get caption) +endif +ifndef UNAME_M + UNAME_M := $(PROCESSOR_ARCHITECTURE) +endif + + +CCV := $(shell $(CC) --version | head -n 1) +CXXV := $(shell $(CXX) --version | head -n 1) + + +# +# Compile flags +# +BUILD_TYPE?= +# keep standard at C17 and C++17 +CXXFLAGS = -I. -O3 -DNDEBUG -std=c++17 -fPIC -pthread + +# GPGPU specific +GGML_CUDA_OBJ_PATH=third_party/ggml/src/CMakeFiles/ggml.dir/ggml-cuda.cu.o + + +# Architecture specific +# feel free to update the Makefile for your architecture and send a pull request or issue +ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686)) + # Use all CPU extensions that are available: + CXXFLAGS += -march=native -mtune=native +endif +ifneq ($(filter ppc64%,$(UNAME_M)),) + POWER9_M := $(shell grep "POWER9" /proc/cpuinfo) + ifneq (,$(findstring POWER9,$(POWER9_M))) + CXXFLAGS += -mcpu=power9 + endif + # Require c++23's std::byteswap for big-endian support. + ifeq ($(UNAME_M),ppc64) + CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN + endif +endif +ifdef CHATGLM_GPROF + CXXFLAGS += -pg +endif +ifneq ($(filter aarch64%,$(UNAME_M)),) + CXXFLAGS += -mcpu=native +endif +ifneq ($(filter armv6%,$(UNAME_M)),) + # Raspberry Pi 1, 2, 3 + CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access +endif +ifneq ($(filter armv7%,$(UNAME_M)),) + # Raspberry Pi 4 + CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations +endif +ifneq ($(filter armv8%,$(UNAME_M)),) + # Raspberry Pi 4 + CXXFLAGS += -mfp16-format=ieee -mno-unaligned-access +endif + +ifeq ($(BUILD_TYPE),cublas) + CMAKE_ARGS+=-DGGML_CUBLAS=ON +endif +ifeq ($(BUILD_TYPE),openblas) + CMAKE_ARGS+=-DGGML_OPENBLAS=ON + CXXFLAGS += -I/usr/local/include/openblas -lopenblas + CGO_TAGS=-tags openblas +endif +ifeq ($(BUILD_TYPE),hipblas) + ROCM_HOME ?= "/opt/rocm" + CXX="$(ROCM_HOME)"/llvm/bin/clang++ + CC="$(ROCM_HOME)"/llvm/bin/clang + GPU_TARGETS ?= gfx900,gfx90a,gfx1030,gfx1031,gfx1100 + AMDGPU_TARGETS ?= "$(GPU_TARGETS)" + CMAKE_ARGS+=-DGGML_HIPBLAS=ON -DAMDGPU_TARGETS="$(AMDGPU_TARGETS)" -DGPU_TARGETS="$(GPU_TARGETS)" + GGML_CUDA_OBJ_PATH=CMakeFiles/ggml-rocm.dir/ggml-cuda.cu.o +endif +ifeq ($(BUILD_TYPE),clblas) + CMAKE_ARGS+=-DGGML_CLBLAST=ON + CGO_TAGS=-tags cublas +endif + +ifdef CLBLAST_DIR + CMAKE_ARGS+=-DCLBlast_dir=$(CLBLAST_DIR) +endif + +# +# Print build information +# +$(info I chatglm.cpp build info: ) +$(info I UNAME_S: $(UNAME_S)) +$(info I UNAME_P: $(UNAME_P)) +$(info I UNAME_M: $(UNAME_M)) +$(info I CXXFLAGS: $(CXXFLAGS)) +$(info I BUILD_TYPE: $(BUILD_TYPE)) +$(info I CMAKE_ARGS: $(CMAKE_ARGS)) +$(info I EXTRA_TARGETS: $(EXTRA_TARGETS)) +$(info I CC: $(CCV)) +$(info I CXX: $(CXXV)) +$(info I CGO_TAGS: $(CGO_TAGS)) +$(info ) + +# Use this if you want to set the default behavior + +prepare: + mkdir -p build && mkdir -p out + +# build chatglm.cpp +build/chatglm.cpp: prepare + cd build && CC="$(CC)" CXX="$(CXX)" cmake $(CMAKE_ARGS) ../chatglm.cpp && VERBOSE=1 cmake --build . -j --config Release + +chatglm.dir: build/chatglm.cpp + xcopy build\\lib\\Release\\*.lib out\\ + +# binding +binding.o: prepare build/chatglm.cpp chatglm.dir ggml.dir sentencepiece.dir protobuf-lite.dir absl.dir + $(CXX) $(CXXFLAGS) \ + -I./chatglm.cpp \ + -I./chatglm.cpp/third_party/ggml/include/ggml \ + -I./chatglm.cpp/third_party/sentencepiece/src \ + binding.cpp -MD -MT binding.lib -MF binding.d -o binding.lib -c + +libbinding.a: prepare binding.o $(EXTRA_TARGETS) + lib.exe /OUT:libbinding.lib out/*.lib binding.lib + +clean: + rm -rf *.o + rm -rf *.d + rm -rf *.a + rm -rf out + rm -rf build + +ggllm-test-model.bin: + powershell -Command "Invoke-WebRequest -Uri 'https://huggingface.co/Xorbits/chatglm3-6B-GGML/resolve/main/chatglm3-ggml-q4_0.bin' -OutFile 'ggllm-test-model.bin'" + +test: ggllm-test-model.bin libbinding.a + go test ${CGO_TAGS} -timeout 1800s -o go-chatglm.cpp.test -c -cover + TEST_MODEL=ggllm-test-model.bin ./go-chatglm.cpp.test + +# build\lib\Release\chatglm.lib \ No newline at end of file diff --git a/README.md b/README.md index 26ec2f2..33d56e3 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # go-chatglm.cpp + [![GoDoc](https://godoc.org/github.com/Weaxs/go-chatglm.cpp?status.svg)](https://godoc.org/github.com/Weaxs/go-chatglm.cpp) [![Go Report Card](https://goreportcard.com/badge/github.com/Weaxs/go-chatglm.cpp)](https://goreportcard.com/report/github.com/Weaxs/go-chatglm.cpp) [![License](https://img.shields.io/github/license/Weaxs/go-chatglm.cpp)](https://github.com/Weaxs/go-chatglm.cpp/blob/main/LICENSE) @@ -7,42 +8,90 @@ The go-chatglm.cpp bindings are high level, as such most of the work is kept into the C/C++ code to avoid any extra computational cost, be more performant and lastly ease out maintenance, while keeping the usage as simple as possible. -# Attention +# Attention! ### Environment You need to make sure there are `make`, `cmake`, `gcc` command in your machine, otherwise should support C++17. -If you want to run on **Windows OS**, you can use [cygwin](https://www.cygwin.com/). +If you want to run on **Windows OS**, you can use [cygwin](https://www.cygwin.com/) or [MinGW](https://www.mingw-w64.org/). > **`cmake` > 3.8** and **`gcc` > 5.1.0** (support C++17) ### Not Support LoRA model + go-chatglm.cpp is not anymore compatible with `LoRA model`, but it woks ONLY with the model which merged by LoRA model and base model. -You can use [convert.py](https://github.com/li-plus/chatglm.cpp/blob/main/chatglm_cpp/convert.py) in [chatglm.cpp](https://github.com/li-plus/chatglm.cpp) -to merge LoRA model into base model. +You can use [convert.py](https://github.com/li-plus/chatglm.cpp/blob/main/chatglm_cpp/convert.py) in [chatglm.cpp](https://github.com/li-plus/chatglm.cpp) to merge LoRA model into base model. # Usage Note: This repository uses git submodules to keep track of [chatglm.cpp](https://github.com/li-plus/chatglm.cpp) . Clone the repository locally: + ```shell git clone --recurse-submodules https://github.com/Weaxs/go-chatglm.cpp ``` To build the bindings locally, run: + ```shell cd go-chatglm.cpp make libbinding.a ``` Now you can run the example with: + ```shell -LIBRARY_PATH=$PWD C_INCLUDE_PATH=$PWD go run ./examples -m "/model/path/here" -t 14 +go run ./examples -m "/model/path/here" + ____ _ _ ____ _ __ __ + __ _ ___ / ___| |__ __ _| |_ / ___| | | \/ | ___ _ __ _ __ + / _` |/ _ \ _____| | | '_ \ / _` | __| | _| | | |\/| | / __| '_ \| '_ \ +| (_| | (_) |_____| |___| | | | (_| | |_| |_| | |___| | | || (__| |_) | |_) | + \__, |\___/ \____|_| |_|\__,_|\__|\____|_____|_| |_(_)___| .__/| .__/ + |___/ |_| |_| + +>>> 你好 + +Sending 你好 + + +你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。 ``` +# Acceleration + +## Metal (Apple Silicon) + +MPS (Metal Performance Shaders) allows computation to run on powerful Apple Silicon GPU. + +``` +BUILD_TYPE=metal make libbinding.a +go build -tags metal ./examples/main.go +./main -m "/model/path/here" +``` + +## OpenBLAS + +OpenBLAS provides acceleration on CPU. + +``` +BUILD_TYPE=openblas make libbinding.a +go build -tags openblas ./examples/main.go +./main -m "/model/path/here" +``` + +## cuBLAS + +cuBLAS uses NVIDIA GPU to accelerate BLAS. + +``` +BUILD_TYPE=cublas make libbinding.a +go build -tags cublas ./examples/main.go +./main -m "/model/path/here" +``` # Acknowledgements - * This project is greatly inspired by [@mudler](https://github.com/mudler)'s [go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) + +* This project is greatly inspired by [@mudler](https://github.com/mudler)'s [go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) diff --git a/binding.cpp b/binding.cpp index 93418c9..f111cbf 100644 --- a/binding.cpp +++ b/binding.cpp @@ -1,5 +1,4 @@ #include "chatglm.h" - #include "binding.h" #include #include @@ -12,15 +11,17 @@ #include #include #include +#include #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) -#include #include -#elif defined (_WIN32) +#endif +#if defined (_WIN32) #define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX #define NOMINMAX +#endif #include -#include #endif #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) @@ -47,40 +48,88 @@ class TextBindStreamer : public chatglm::BaseStreamer { int print_len_; }; -std::vector create_vector(const char** strings, int count) { - auto vec = new std::vector; +std::vector create_chat_message_vector(void** history, int count) { + std::vector* vec = new std::vector; for (int i = 0; i < count; i++) { - vec->push_back(std::string(strings[i])); + chatglm::ChatMessage* msg = (chatglm::ChatMessage*) history[i]; + vec->push_back(*msg); } + return *vec; } +std::vector create_tool_call_vector(void** tool_calls, int count) { + std::vector* vec = new std::vector; + for (int i = 0; i < count; i++) { + chatglm::ToolCallMessage* msg = (chatglm::ToolCallMessage*) tool_calls[i]; + vec->push_back(*msg); + } + + return *vec; +} + +std::string decode_with_special_tokens(chatglm::ChatGLM3Tokenizer* tokenizer, const std::vector &ids) { + std::vector pieces; + for (int id : ids) { + auto pos = tokenizer->index_special_tokens.find(id); + if (pos != tokenizer->index_special_tokens.end()) { + // special tokens + pieces.emplace_back(pos->second); + } else { + // normal tokens + pieces.emplace_back(tokenizer->sp.IdToPiece(id)); + } + } + + std::string text = tokenizer->sp.DecodePieces(pieces); + return text; +} + void* load_model(const char *name) { return new chatglm::Pipeline(name); } -int chat(void* pipe_pr, const char** history, int history_count, void* params_ptr, char* result) { - std::vector vectors = create_vector(history, history_count); +int chat(void* pipe_pr, void** history, int history_count, void* params_ptr, char* result) { + std::vector vectors = create_chat_message_vector(history, history_count); chatglm::Pipeline* pipe_p = (chatglm::Pipeline*) pipe_pr; chatglm::GenerationConfig* params = (chatglm::GenerationConfig*) params_ptr; - std::string res = pipe_p->chat(vectors, *params); - strcpy(result, res.c_str()); + chatglm::ChatMessage res = pipe_p->chat(vectors, *params); - vectors.clear(); + std::string out = res.content; + // ChatGLM3Tokenizer::decode_message change origin output, convert it to ChatMessage + // So we need to convert it back + if (pipe_p->model->config.model_type == chatglm::ModelType::CHATGLM3) { + std::vector* resultVec = new std::vector{res}; + chatglm::ChatGLM3Tokenizer* tokenizer = dynamic_cast(pipe_p->tokenizer.get()); + std::vector input_ids = tokenizer->encode_messages(*resultVec, params->max_context_length); + out = decode_with_special_tokens(tokenizer, input_ids); + } + strcpy(result, out.c_str()); + vectors.clear(); return 0; } -int stream_chat(void* pipe_pr, const char** history, int history_count,void* params_ptr, char* result) { - std::vector vectors = create_vector(history, history_count); +int stream_chat(void* pipe_pr, void** history, int history_count,void* params_ptr, char* result) { + std::vector vectors = create_chat_message_vector(history, history_count); chatglm::Pipeline* pipe_p = (chatglm::Pipeline*) pipe_pr; chatglm::GenerationConfig* params = (chatglm::GenerationConfig*) params_ptr; TextBindStreamer* text_stream = new TextBindStreamer(pipe_p->tokenizer.get(), pipe_pr); - std::string res = pipe_p->chat(vectors, *params, text_stream); - strcpy(result, res.c_str()); + chatglm::ChatMessage res = pipe_p->chat(vectors, *params, text_stream); + + std::string out = res.content; + // ChatGLM3Tokenizer::decode_message change origin output, convert it to ChatMessage + // So we need to convert it back + if (pipe_p->model->config.model_type == chatglm::ModelType::CHATGLM3) { + std::vector* resultVec = new std::vector{res}; + chatglm::ChatGLM3Tokenizer* tokenizer = dynamic_cast(pipe_p->tokenizer.get()); + std::vector input_ids = tokenizer->encode_messages(*resultVec, params->max_context_length); + out = decode_with_special_tokens(tokenizer, input_ids); + } + strcpy(result, out.c_str()); vectors.clear(); return 0; @@ -108,11 +157,10 @@ int stream_generate(void* pipe_pr, const char *prompt, void* params_ptr, char* r return 0; } -int get_embedding(void* pipe_pr, void* params_ptr, const char *prompt, int * result) { +int get_embedding(void* pipe_pr, const char *prompt, int max_length, int * result) { chatglm::Pipeline* pipe_p = (chatglm::Pipeline*) pipe_pr; - chatglm::GenerationConfig* params = (chatglm::GenerationConfig*) params_ptr; - std::vector embeddings = pipe_p->tokenizer->encode(prompt, params->max_length); + std::vector embeddings = pipe_p->tokenizer->encode(prompt, max_length); for (size_t i = 0; i < embeddings.size(); i++) { result[i]=embeddings[i]; @@ -122,7 +170,7 @@ int get_embedding(void* pipe_pr, void* params_ptr, const char *prompt, int * res } void* allocate_params(int max_length, int max_context_length, bool do_sample, int top_k, - float top_p, float temperature, float repetition_penalty, int num_threads) { + float top_p, float temperature, float repetition_penalty, int num_threads) { chatglm::GenerationConfig* gen_config = new chatglm::GenerationConfig; gen_config->max_length = max_length; gen_config->max_context_length = max_context_length; @@ -145,6 +193,38 @@ void free_model(void* pipe_pr) { delete pipe_p; } +void* create_chat_message(const char* role, const char *content, void** tool_calls, int tool_calls_count) { + std::vector vector = create_tool_call_vector(tool_calls, tool_calls_count); + return new chatglm::ChatMessage(role, content, vector); +} + +void* create_tool_call(const char* type, void* codeOrFunc) { + if (type == chatglm::ToolCallMessage::TYPE_FUNCTION) { + chatglm::FunctionMessage* function_p = (chatglm::FunctionMessage*) codeOrFunc; + return new chatglm::ToolCallMessage(*function_p); + } else if (type == chatglm::ToolCallMessage::TYPE_CODE) { + chatglm::CodeMessage* code_p = (chatglm::CodeMessage*) codeOrFunc; + return new chatglm::ToolCallMessage(*code_p); + } + return nullptr; +} + +void* create_function(const char* name, const char *arguments) { + return new chatglm::FunctionMessage(name, arguments); +} + + +void* create_code(const char* input) { + return new chatglm::CodeMessage(input); +} + +char* get_model_type(void* pipe_pr) { + chatglm::Pipeline* pipe_p = (chatglm::Pipeline*) pipe_pr; + chatglm::ModelLoader loader(pipe_p->mapped_file->data, pipe_p->mapped_file->size); + loader.read_string(4); + return strdup(chatglm::to_string((chatglm::ModelType)loader.read_basic()).data()); +} + // copy from chatglm::TextStreamer void TextBindStreamer::put(const std::vector &output_ids) { if (is_prompt_) { @@ -178,7 +258,7 @@ void TextBindStreamer::put(const std::vector &output_ids) { } // callback go function - if (!streamCallback(draft_pipe, (char*)printable_text.c_str())) { + if (!streamCallback(draft_pipe, printable_text.data())) { return; } } @@ -187,7 +267,7 @@ void TextBindStreamer::put(const std::vector &output_ids) { void TextBindStreamer::end() { std::string text = tokenizer_->decode(token_cache_); // callback go function - if (!streamCallback(draft_pipe, (char*)text.substr(print_len_).c_str())) { + if (!streamCallback(draft_pipe, text.substr(print_len_).data())) { return; } is_prompt_ = true; diff --git a/binding.h b/binding.h index 3b651f1..c1a518d 100644 --- a/binding.h +++ b/binding.h @@ -9,15 +9,15 @@ extern bool streamCallback(void *, char *); void* load_model(const char *name); -int chat(void* pipe_pr, const char** history, int history_count, void* params_ptr, char* result); +int chat(void* pipe_pr, void** history, int history_count, void* params_ptr, char* result); -int stream_chat(void* pipe_pr, const char** history, int history_count, void* params_ptr, char* result); +int stream_chat(void* pipe_pr, void** history, int history_count, void* params_ptr, char* result); int generate(void* pipe_pr, const char *prompt, void* params_ptr, char* result); int stream_generate(void* pipe_pr, const char *prompt, void* params_ptr, char* result); -int get_embedding(void* pipe_pr, void* params_ptr, const char *prompt, int * result); +int get_embedding(void* pipe_pr, const char *prompt, int max_length, int * result); void* allocate_params(int max_length, int max_context_length, bool do_sample, int top_k, float top_p, float temperature, float repetition_penalty, int num_threads); @@ -26,6 +26,16 @@ void free_params(void* params_ptr); void free_model(void* pipe_pr); +void* create_chat_message(const char* role, const char *content, void** tool_calls, int tool_calls_count); + +void* create_tool_call(const char* type, void* codeOrFunc); + +void* create_function(const char* name, const char *arguments); + +void* create_code(const char* code); + +char* get_model_type(void* pipe_pr); + #ifdef __cplusplus } diff --git a/chatglm.cpp b/chatglm.cpp index 95d3b8c..3286db5 160000 --- a/chatglm.cpp +++ b/chatglm.cpp @@ -1 +1 @@ -Subproject commit 95d3b8c1730d646c1916701eaf4ce03fe98baa8c +Subproject commit 3286db5306c5d3245ea147082e69313010617a92 diff --git a/chatglm.go b/chatglm.go index 87fd560..a27a6e7 100644 --- a/chatglm.go +++ b/chatglm.go @@ -4,7 +4,7 @@ package chatglm // #cgo CXXFLAGS: -I${SRCDIR}/chatglm.cpp // #cgo CXXFLAGS: -I${SRCDIR}/chatglm.cpp/third_party/ggml/include/ggml -I${SRCDIR}/chatglm.cpp/third_party/ggml/src // #cgo CXXFLAGS: -I${SRCDIR}/chatglm.cpp/third_party/sentencepiece/src -// #cgo LDFLAGS: -L${SRCDIR}/ -lbinding -lm -lstdc++ +// #cgo LDFLAGS: -L${SRCDIR}/ -lbinding -lm -v // #cgo darwin LDFLAGS: -framework Accelerate // #include "binding.h" // #include @@ -36,21 +36,56 @@ func New(model string) (*Chatglm, error) { return llm, nil } -// Chat sync chat -func (llm *Chatglm) Chat(history []string, opts ...GenerationOption) (string, error) { +func NewAssistantMsg(input string, modelType string) *ChatMessage { + result := &ChatMessage{Role: RoleAssistant, Content: input} + if modelType != "ChatGLM3" { + return result + } + + if !strings.Contains(input, DELIMITER) { + return result + } + + ciPos := strings.Index(input, DELIMITER) + if ciPos != 0 { + content := input[:ciPos] + code := input[ciPos+len(DELIMITER):] + toolCalls := []*ToolCallMessage{{Type: TypeCode, Code: &CodeMessage{code}}} + result.Content = content + result.ToolCalls = toolCalls + } + return result +} + +func NewUserMsg(content string) *ChatMessage { + return &ChatMessage{Role: RoleUser, Content: content} +} + +func NewSystemMsg(content string) *ChatMessage { + return &ChatMessage{Role: RoleSystem, Content: content} +} + +func NewObservationMsg(content string) *ChatMessage { + return &ChatMessage{Role: RoleObservation, Content: content} +} + +// Chat by history [synchronous] +func (llm *Chatglm) Chat(messages []*ChatMessage, opts ...GenerationOption) (string, error) { + err := checkChatMessages(messages) + if err != nil { + return "", err + } + reverseMsgs, err := allocateChatMessages(messages) + if err != nil { + return "", err + } + reverseCount := len(reverseMsgs) + pass := &reverseMsgs[0] + opt := NewGenerationOptions(opts...) params := allocateParams(opt) defer freeParams(params) - reverseCount := len(history) - reversePrompt := make([]*C.char, reverseCount) - var pass **C.char - for i, s := range history { - cs := C.CString(s) - reversePrompt[i] = cs - pass = &reversePrompt[0] - } - if opt.MaxContextLength == 0 { opt.MaxContextLength = 99999999 } @@ -61,25 +96,27 @@ func (llm *Chatglm) Chat(history []string, opts ...GenerationOption) (string, er return "", fmt.Errorf("model chat failed") } res := C.GoString((*C.char)(unsafe.Pointer(&out[0]))) - res = strings.TrimPrefix(res, " ") - res = strings.TrimPrefix(res, "\n") + res = removeSpecialTokens(res) return res, nil } -func (llm *Chatglm) StreamChat(history []string, opts ...GenerationOption) (string, error) { +// StreamChat chat with stream output by StreamCallback +func (llm *Chatglm) StreamChat(messages []*ChatMessage, opts ...GenerationOption) (string, error) { + err := checkChatMessages(messages) + if err != nil { + return "", err + } + reverseMsgs, err := allocateChatMessages(messages) + if err != nil { + return "", err + } + reverseCount := len(reverseMsgs) + pass := &reverseMsgs[0] + opt := NewGenerationOptions(opts...) params := allocateParams(opt) defer freeParams(params) - reverseCount := len(history) - reversePrompt := make([]*C.char, reverseCount) - var pass **C.char - for i, s := range history { - cs := C.CString(s) - reversePrompt[i] = cs - pass = &reversePrompt[0] - } - if opt.StreamCallback != nil { setStreamCallback(llm.pipeline, opt.StreamCallback) } else { @@ -96,11 +133,11 @@ func (llm *Chatglm) StreamChat(history []string, opts ...GenerationOption) (stri return "", fmt.Errorf("model chat failed") } res := C.GoString((*C.char)(unsafe.Pointer(&out[0]))) - res = strings.TrimPrefix(res, " ") - res = strings.TrimPrefix(res, "\n") + res = removeSpecialTokens(res) return res, nil } +// Generate by prompt [synchronous] func (llm *Chatglm) Generate(prompt string, opts ...GenerationOption) (string, error) { opt := NewGenerationOptions(opts...) params := allocateParams(opt) @@ -121,6 +158,7 @@ func (llm *Chatglm) Generate(prompt string, opts ...GenerationOption) (string, e return res, nil } +// StreamGenerate with stream output by StreamCallback func (llm *Chatglm) StreamGenerate(prompt string, opts ...GenerationOption) (string, error) { opt := NewGenerationOptions(opts...) params := allocateParams(opt) @@ -148,6 +186,7 @@ func (llm *Chatglm) StreamGenerate(prompt string, opts ...GenerationOption) (str return res, nil } +// Embeddings get text input_ids, func (llm *Chatglm) Embeddings(text string, opts ...GenerationOption) ([]int, error) { opt := NewGenerationOptions(opts...) input := C.CString(text) @@ -156,8 +195,7 @@ func (llm *Chatglm) Embeddings(text string, opts ...GenerationOption) ([]int, er } ints := make([]int, opt.MaxLength) - params := allocateParams(opt) - ret := C.get_embedding(llm.pipeline, params, input, (*C.int)(unsafe.Pointer(&ints[0]))) + ret := C.get_embedding(llm.pipeline, input, C.int(opt.MaxLength), (*C.int)(unsafe.Pointer(&ints[0]))) if ret != 0 { return ints, fmt.Errorf("embedding failed") } @@ -169,16 +207,99 @@ func (llm *Chatglm) Free() { C.free_model(llm.pipeline) } +func (llm *Chatglm) ModelType() string { + return C.GoString(C.get_model_type(llm.pipeline)) +} + +// allocateParams create GenerationOptions from c func allocateParams(opt *GenerationOptions) unsafe.Pointer { return C.allocate_params(C.int(opt.MaxLength), C.int(opt.MaxContextLength), C.bool(opt.DoSample), C.int(opt.TopK), C.float(opt.TopP), C.float(opt.Temperature), C.float(opt.RepetitionPenalty), C.int(opt.NumThreads)) } +// freeParams func freeParams(params unsafe.Pointer) { C.free_params(params) } +// checkChatMessages check messages format +func checkChatMessages(messages []*ChatMessage) error { + n := len(messages) + if n < 1 { + return fmt.Errorf("invalid chat messages size: %d", n) + } + isSys := messages[0].Role == RoleSystem + + if !isSys && n%2 == 0 { + return fmt.Errorf("invalid chat messages size: %d", n) + } + if isSys && n%2 == 1 { + return fmt.Errorf("invalid chat messages size: %d", n) + } + + for i, message := range messages { + if message.ToolCalls == nil { + continue + } + + for j, toolCall := range message.ToolCalls { + if toolCall.Type == TypeCode && toolCall.Code == nil { + return fmt.Errorf("expect messages[%d].ToolCalls[%d].Code is not nil", i, j) + } + if toolCall.Type == TypeFunction && toolCall.Function == nil { + return fmt.Errorf("expect messages[%d].ToolCalls[%d].Function is not nil", i, j) + } + } + } + return nil +} + +// allocateChatMessages covert []*ChatMessage in go to []C.ChatMessage in c++ +func allocateChatMessages(messages []*ChatMessage) ([]unsafe.Pointer, error) { + reverseMessages := make([]unsafe.Pointer, len(messages)) + for i, message := range messages { + var reverseToolCalls []unsafe.Pointer + if message.ToolCalls != nil { + for _, toolCall := range message.ToolCalls { + var codeOrFunc unsafe.Pointer + if toolCall.Type == TypeCode { + codeOrFunc = C.create_code(C.CString(toolCall.Code.Input)) + } else if toolCall.Type == TypeFunction { + codeOrFunc = C.create_function( + C.CString(toolCall.Function.Name), C.CString(toolCall.Function.Arguments)) + } + toolCallPoint := C.create_tool_call(C.CString(toolCall.Type), codeOrFunc) + if toolCallPoint != nil { + reverseToolCalls = append(reverseToolCalls, toolCallPoint) + } + } + } + var pass *unsafe.Pointer + if len(reverseToolCalls) > 0 { + pass = &reverseToolCalls[0] + } + reverseMessages[i] = C.create_chat_message( + C.CString(message.Role), C.CString(message.Content), pass, C.int(len(reverseToolCalls))) + } + return reverseMessages, nil +} + +func removeSpecialTokens(data string) string { + output := strings.ReplaceAll(data, "[MASK]", "") + output = strings.ReplaceAll(output, "[gMASK]", "") + output = strings.ReplaceAll(output, "[sMASK]", "") + output = strings.ReplaceAll(output, "sop", "") + output = strings.ReplaceAll(output, "eop", "") + output = strings.Replace(output, "<|assistant|>", "", 1) + output = strings.TrimSuffix(output, "<|assistant|>") + output = strings.ReplaceAll(output, "<|assistant|>", DELIMITER) + output = strings.TrimLeftFunc(output, func(r rune) bool { + return r == '\n' || r == ' ' + }) + return output +} + var ( m sync.RWMutex callbacks = map[unsafe.Pointer]func(string) bool{} @@ -196,6 +317,7 @@ func streamCallback(pipeline unsafe.Pointer, printableText *C.char) C.bool { return C.bool(true) } +// setStreamCallback add callback into global map callbacks func setStreamCallback(pipeline unsafe.Pointer, callback func(string) bool) { m.Lock() defer m.Unlock() diff --git a/chatglm_test.go b/chatglm_test.go index 61f08a9..86e79f0 100644 --- a/chatglm_test.go +++ b/chatglm_test.go @@ -7,12 +7,15 @@ import ( "testing" ) -var chatglm *Chatglm +var ( + chatglm *Chatglm + modelType string +) func setup() { testModelPath, exist := os.LookupEnv("TEST_MODEL") if !exist { - testModelPath = "./chatglm3-ggml-q4_0.bin" + testModelPath = "chatglm3-ggml-q4_0.bin" } var err error @@ -20,6 +23,7 @@ func setup() { if err != nil { panic("load model failed.") } + modelType = chatglm.ModelType() } func TestMain(m *testing.M) { @@ -51,29 +55,31 @@ func TestStreamGenerate(t *testing.T) { } func TestChat(t *testing.T) { - history := []string{"2+2等于多少"} - ret, err := chatglm.Chat(history) + var messages []*ChatMessage + messages = append(messages, NewUserMsg("2+2等于多少")) + ret, err := chatglm.Chat(messages) if err != nil { assert.Fail(t, "first chat failed") } assert.Contains(t, ret, "4") - history = append(history, ret) - history = append(history, "再加4等于多少") - ret, err = chatglm.Chat(history) + messages = append(messages, NewAssistantMsg(ret, modelType)) + messages = append(messages, NewUserMsg("再加4等于多少")) + ret, err = chatglm.Chat(messages) if err != nil { assert.Fail(t, "second chat failed") } assert.Contains(t, ret, "8") - history = append(history, ret) - assert.Len(t, history, 4) + messages = append(messages, NewAssistantMsg(ret, modelType)) + assert.Len(t, messages, 4) } func TestChatStream(t *testing.T) { - history := []string{"2+2等于多少"} + var messages []*ChatMessage + messages = append(messages, NewUserMsg("2+2等于多少")) out1 := strings.Builder{} - ret, err := chatglm.StreamChat(history, SetStreamCallback(func(s string) bool { + ret, err := chatglm.StreamChat(messages, SetStreamCallback(func(s string) bool { out1.WriteString(s) return true })) @@ -86,9 +92,9 @@ func TestChatStream(t *testing.T) { assert.Contains(t, ret, "4") assert.Contains(t, outStr1, "4") - history = append(history, ret) - history = append(history, "再加4等于多少") - ret, err = chatglm.StreamChat(history) + messages = append(messages, NewAssistantMsg(ret, modelType)) + messages = append(messages, NewUserMsg("再加4等于多少")) + ret, err = chatglm.StreamChat(messages) if err != nil { assert.Fail(t, "second chat failed") } @@ -98,15 +104,64 @@ func TestChatStream(t *testing.T) { assert.Contains(t, ret, "8") assert.Contains(t, out2, "8") - history = append(history, ret) - assert.Len(t, history, 4) + messages = append(messages, NewAssistantMsg(ret, modelType)) + assert.Len(t, messages, 4) } func TestEmbedding(t *testing.T) { maxLength := 1024 - embeddings, err := chatglm.Embeddings("你好", SetMaxLength(1024)) + embeddings, err := chatglm.Embeddings("你好", SetMaxLength(maxLength)) if err != nil { assert.Fail(t, "embedding failed.") } assert.Len(t, embeddings, maxLength) } + +func TestSystemToolCall(t *testing.T) { + file, err := os.ReadFile("examples/system/function_call.txt") + if err != nil { + return + } + var messages []*ChatMessage + messages = append(messages, NewSystemMsg(string(file))) + messages = append(messages, NewUserMsg("生成一个随机数")) + + ret, err := chatglm.Chat(messages, SetDoSample(false)) + if err != nil { + assert.Fail(t, "call system tool failed.") + } + assert.Contains(t, ret, "```python\ntool_call(seed=42, range=(0, 100))\n```") + messages = append(messages, NewAssistantMsg(ret, modelType)) + messages = append(messages, NewObservationMsg("22")) + + ret, err = chatglm.Chat(messages, SetDoSample(false)) + if err != nil { + assert.Fail(t, "call system tool failed.") + } + assert.Contains(t, ret, "22") +} + +func TestCodeInterpreter(t *testing.T) { + file, err := os.ReadFile("examples/system/code_interpreter.txt") + if err != nil { + return + } + var messages []*ChatMessage + messages = append(messages, NewSystemMsg(string(file))) + messages = append(messages, NewUserMsg("列出100以内的所有质数")) + ret, err := chatglm.Chat(messages, SetDoSample(false)) + if err != nil { + assert.Fail(t, "call code interpreter failed.") + } + msg := NewAssistantMsg(ret, modelType) + msg.ToolCalls = append(msg.ToolCalls, &ToolCallMessage{Type: TypeCode, Code: &CodeMessage{Input: "```python\ndef is_prime(n):\n \"\"\"Check if a number is prime.\"\"\"\n if n <= 1:\n return False\n if n <= 3:\n return True\n if n % 2 == 0 or n % 3 == 0:\n return False\n i = 5\n while i * i <= n:\n if n % i == 0 or n % (i + 2) == 0:\n return False\n i += 6\n return True\n\n# Get all prime numbers up to 100\nprimes_upto_100 = [i for i in range(2, 101) if is_prime(i)]\nprimes_upto_100\n```"}}) + messages = append(messages, msg) + assert.Contains(t, ret, "好的,我会为您列出100以内的所有质数。\n\n质数是指只能被1和它本身整除的大于1的整数。例如,2、3、5、7等都是质数。\n\n让我们开始吧!") + messages = append(messages, NewObservationMsg("[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]")) + + ret, err = chatglm.Chat(messages, SetDoSample(false)) + if err != nil { + assert.Fail(t, "call code interpreter failed.") + } + assert.Contains(t, ret, "2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97") +} diff --git a/examples/ChatGLM3-6B/main.go b/examples/ChatGLM3-6B/main.go deleted file mode 100644 index 94af103..0000000 --- a/examples/ChatGLM3-6B/main.go +++ /dev/null @@ -1,28 +0,0 @@ -package main - -import ( - "fmt" - "github.com/Weaxs/go-chatglm.cpp" -) - -func main() { - llm, err := chatglm.New("./chatglm3-ggml-q4_0.bin") - if err != nil { - return - } - - var history []string - history = append(history, "你好,我叫 Weaxs") - res, err := llm.Generate(history[0]) - if err != nil { - return - } - fmt.Println(res) - history = append(history, res) - history = append(history, "我的名字是什么") - res, err = llm.Chat(history) - if err != nil { - return - } - fmt.Println(res) -} diff --git a/examples/main.go b/examples/main.go new file mode 100644 index 0000000..c7a3d74 --- /dev/null +++ b/examples/main.go @@ -0,0 +1,107 @@ +package main + +import ( + "bufio" + "flag" + "fmt" + c "github.com/Weaxs/go-chatglm.cpp" + "io" + "os" + "strings" +) + +func main() { + var model string + var system string + var temp float64 + var topK int + var topP float64 + var maxLength int + var maxContentLength int + var threads int + var repeatPenalty float64 + + flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError) + flags.StringVar(&model, "m", "./chatglm3-ggml-q4_0.bin", "path to model file to load") + flags.StringVar(&system, "s", "", "system message to set the behavior of the assistant") + flags.Float64Var(&temp, "temp", 0.95, "temperature (default: 0.95)") + flags.IntVar(&maxLength, "max_length ", 2048, "max total length including prompt and output (default: 2048)") + flags.IntVar(&maxContentLength, "max_context_length ", 512, " max context length (default: 512)") + flags.IntVar(&topK, "top_k", 0, "top-k sampling (default: 0)") + flags.Float64Var(&topP, "top_p", 0.7, "top-p sampling (default: 0.7)") + flags.Float64Var(&repeatPenalty, "repeat_penalty", 1.0, "penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled)") + flags.IntVar(&threads, "threads", 0, " number of threads for inference") + + err := flags.Parse(os.Args[1:]) + + chatglm, err := c.New(model) + modelType := chatglm.ModelType() + + if err != nil { + fmt.Printf("Parsing program arguments failed: %s", err) + os.Exit(1) + } + + fmt.Printf(" ____ _ _ ____ _ __ __ \n" + + " __ _ ___ / ___| |__ __ _| |_ / ___| | | \\/ | ___ _ __ _ __ \n" + + " / _` |/ _ \\ _____| | | '_ \\ / _` | __| | _| | | |\\/| | / __| '_ \\| '_ \\ \n" + + "| (_| | (_) |_____| |___| | | | (_| | |_| |_| | |___| | | || (__| |_) | |_) |\n" + + " \\__, |\\___/ \\____|_| |_|\\__,_|\\__|\\____|_____|_| |_(_)___| .__/| .__/ \n" + + " |___/ |_| |_| \n\n") + + reader := bufio.NewReader(os.Stdin) + var message []*c.ChatMessage + + for { + text := readMultiLineInput(reader) + message = append(message, c.NewUserMsg(text)) + r, err := chatglm.StreamChat(message, + c.SetTemperature(float32(temp)), c.SetTopP(float32(topP)), c.SetTopK(topK), + c.SetMaxLength(maxLength), c.SetMaxContextLength(maxContentLength), + c.SetRepetitionPenalty(float32(repeatPenalty)), c.SetNumThreads(threads), + c.SetStreamCallback(callback)) + if err != nil { + panic(err) + } + message = append(message, c.NewAssistantMsg(r, modelType)) + + _, err = chatglm.Embeddings(text) + if err != nil { + fmt.Printf("Embeddings: error %s \n", err.Error()) + } + fmt.Printf("\n\n") + } + +} + +func callback(s string) bool { + fmt.Print(s) + return true +} + +// readMultiLineInput reads input until an empty line is entered. +func readMultiLineInput(reader *bufio.Reader) string { + var lines []string + fmt.Print(">>> ") + + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + os.Exit(0) + } + fmt.Printf("Reading the prompt failed: %s", err) + os.Exit(1) + } + + if len(strings.TrimSpace(line)) == 0 { + break + } + + lines = append(lines, line) + } + + text := strings.Join(lines, "") + fmt.Println("Sending", text) + return text +} diff --git a/examples/system/code_interpreter.txt b/examples/system/code_interpreter.txt new file mode 100644 index 0000000..7561ed2 --- /dev/null +++ b/examples/system/code_interpreter.txt @@ -0,0 +1 @@ +你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。 \ No newline at end of file diff --git a/examples/system/default.txt b/examples/system/default.txt new file mode 100644 index 0000000..2345cf5 --- /dev/null +++ b/examples/system/default.txt @@ -0,0 +1 @@ +You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown. \ No newline at end of file diff --git a/examples/system/function_call.txt b/examples/system/function_call.txt new file mode 100644 index 0000000..d25a10c --- /dev/null +++ b/examples/system/function_call.txt @@ -0,0 +1,33 @@ +Answer the following questions as best as you can. You have access to the following tools: +{ + "random_number_generator": { + "name": "random_number_generator", + "description": "Generates a random number x, s.t. range[0] <= x < range[1]", + "params": [ + { + "name": "seed", + "description": "The random seed used by the generator", + "type": "int", + "required": true + }, + { + "name": "range", + "description": "The range of the generated numbers", + "type": "tuple[int, int]", + "required": true + } + ] + }, + "get_weather": { + "name": "get_weather", + "description": "Get the current weather for `city_name`", + "params": [ + { + "name": "city_name", + "description": "The name of the city to be queried", + "type": "str", + "required": true + } + ] + } +} \ No newline at end of file diff --git a/options.go b/options.go index 43eb5ea..c9651d4 100644 --- a/options.go +++ b/options.go @@ -1,5 +1,17 @@ package chatglm +const ( + RoleUser = "user" + RoleAssistant = "assistant" + RoleSystem = "system" + RoleObservation = "observation" + + TypeFunction = "function" + TypeCode = "code" + + DELIMITER = "<|delimiter|>" +) + type GenerationOptions struct { MaxLength int MaxContextLength int @@ -12,6 +24,24 @@ type GenerationOptions struct { StreamCallback func(string) bool } +type ChatMessage struct { + Role string + Content string + ToolCalls []*ToolCallMessage +} +type ToolCallMessage struct { + Type string + Function *FunctionMessage + Code *CodeMessage +} +type FunctionMessage struct { + Name string + Arguments string +} +type CodeMessage struct { + Input string +} + type GenerationOption func(g *GenerationOptions) var DefaultGenerationOptions GenerationOptions = GenerationOptions{