Skip to content

Commit

Permalink
ci: add GPU tests (#245)
Browse files Browse the repository at this point in the history
Signed-off-by: mudler <mudler@localai.io>
  • Loading branch information
mudler authored Sep 29, 2023
1 parent b8a1245 commit 8173a5b
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 3 deletions.
63 changes: 63 additions & 0 deletions .github/workflows/test-gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
---
name: 'GPU tests'

on:
pull_request:
push:
branches:
- master
tags:
- '*'

concurrency:
group: ci-gpu-tests-${{ github.head_ref || github.ref }}-${{ github.repository }}
cancel-in-progress: true

jobs:
ubuntu-latest:
runs-on: self-hosted
strategy:
matrix:
go-version: ['1.21.x']
steps:
- name: Clone
uses: actions/checkout@v3
with:
submodules: true
- name: Setup Go ${{ matrix.go-version }}
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go-version }}
# You can test your matrix by printing the current Go version
- name: Display Go version
run: go version
- name: Dependencies
run: |
sudo apt-get update
sudo DEBIAN_FRONTEND=noninteractive apt-get install -y make wget
- name: Dependencies
run: |
# This fixes libc6-dev installations errors on containers...
sudo rm -rfv /run/systemd/system
sudo apt-get update
sudo DEBIAN_FRONTEND=noninteractive apt-get install -y build-essential ffmpeg nvidia-cuda-toolkit cmake
sudo DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates cmake curl patch
sudo DEBIAN_FRONTEND=noninteractive apt-get install -y pip wget
- name: Build and test
run: |
GPU_TESTS=true BUILD_TYPE=cublas CMAKE_ARGS="-DLLAMA_METAL=OFF -DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF" \
make test 2>&1 | tee test_log.log
if grep -q "using CUDA for GPU acceleration" test_log.log; then
echo "All good";
else
echo "No CUDA found";
exit 1;
fi
- name: Release space from worker ♻
if: always()
run: |
sudo rm -rf build || true
sudo rm -rf bin || true
sudo rm -rf dist || true
sudo rm -rf *.log || true
make clean || true
16 changes: 13 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ ifdef CLBLAST_DIR
CMAKE_ARGS+=-DCLBlast_dir=$(CLBLAST_DIR)
endif

# TODO: support Windows
ifeq ($(GPU_TESTS),true)
CGO_LDFLAGS="-lcublas -lcudart -L/usr/local/cuda/lib64/"
TEST_LABEL=gpu
else
TEST_LABEL=!gpu
endif

#
# Print build information
#
Expand Down Expand Up @@ -236,6 +244,8 @@ clean:
$(MAKE) -C llama.cpp clean
rm -rf build

test: libbinding.a
test -f ggllm-test-model.bin || wget -q https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF/resolve/main/codellama-7b-instruct.Q2_K.gguf -O ggllm-test-model.bin
C_INCLUDE_PATH=${INCLUDE_PATH} CGO_LDFLAGS=${CGO_LDFLAGS} LIBRARY_PATH=${LIBRARY_PATH} TEST_MODEL=ggllm-test-model.bin go test -v ./...
ggllm-test-model.bin:
wget -q https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF/resolve/main/codellama-7b-instruct.Q2_K.gguf -O ggllm-test-model.bin

test: ggllm-test-model.bin libbinding.a
C_INCLUDE_PATH=${INCLUDE_PATH} CGO_LDFLAGS=${CGO_LDFLAGS} LIBRARY_PATH=${LIBRARY_PATH} TEST_MODEL=ggllm-test-model.bin go run github.com/onsi/ginkgo/v2/ginkgo --label-filter="$(TEST_LABEL)" --flake-attempts 5 -v -r ./...
25 changes: 25 additions & 0 deletions llama_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,29 @@ how much is 2+2?
Expect(int(l)).To(Equal(len(tokens)))
})
})

Context("Inferencing tests with GPU (using "+testModelPath+") ", Label("gpu"), func() {
getModel := func() (*LLama, error) {
model, err := New(
testModelPath,
llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(10),
)
Expect(err).ToNot(HaveOccurred())
Expect(model).ToNot(BeNil())
return model, err
}

It("predicts successfully", func() {
if testModelPath == "" {
Skip("test skipped - only makes sense if the TEST_MODEL environment variable is set.")
}

model, err := getModel()
text, err := model.Predict(`[INST] Answer to the following question:
how much is 2+2?
[/INST]`)
Expect(err).ToNot(HaveOccurred(), text)
Expect(text).To(ContainSubstring("4"), text)
})
})
})

0 comments on commit 8173a5b

Please sign in to comment.