diff --git a/.devops/cuda.Dockerfile b/.devops/cuda.Dockerfile index 974dd78a8b08a..0bc4e7ee13f66 100644 --- a/.devops/cuda.Dockerfile +++ b/.devops/cuda.Dockerfile @@ -1,6 +1,6 @@ ARG UBUNTU_VERSION=22.04 # This needs to generally match the container host's environment. -ARG CUDA_VERSION=12.6.0 +ARG CUDA_VERSION=12.4.0 # Target the CUDA build image ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} diff --git a/.devops/llama-cpp-cuda.srpm.spec b/.devops/llama-cpp-cuda.srpm.spec index 7425d3a9d7a40..3bbf4a4def2a5 100644 --- a/.devops/llama-cpp-cuda.srpm.spec +++ b/.devops/llama-cpp-cuda.srpm.spec @@ -17,10 +17,10 @@ Version: %( date "+%%Y%%m%%d" ) Release: 1%{?dist} Summary: CPU Inference of LLaMA model in pure C/C++ (no CUDA/OpenCL) License: MIT -Source0: https://github.com/ggerganov/llama.cpp/archive/refs/heads/master.tar.gz +Source0: https://github.com/ggml-org/llama.cpp/archive/refs/heads/master.tar.gz BuildRequires: coreutils make gcc-c++ git cuda-toolkit Requires: cuda-toolkit -URL: https://github.com/ggerganov/llama.cpp +URL: https://github.com/ggml-org/llama.cpp %define debug_package %{nil} %define source_date_epoch_from_changelog 0 diff --git a/.devops/llama-cpp.srpm.spec b/.devops/llama-cpp.srpm.spec index 4d5560089816c..45902dcf896e0 100644 --- a/.devops/llama-cpp.srpm.spec +++ b/.devops/llama-cpp.srpm.spec @@ -18,10 +18,10 @@ Version: %( date "+%%Y%%m%%d" ) Release: 1%{?dist} Summary: CPU Inference of LLaMA model in pure C/C++ (no CUDA/OpenCL) License: MIT -Source0: https://github.com/ggerganov/llama.cpp/archive/refs/heads/master.tar.gz +Source0: https://github.com/ggml-org/llama.cpp/archive/refs/heads/master.tar.gz BuildRequires: coreutils make gcc-c++ git libstdc++-devel Requires: libstdc++ -URL: https://github.com/ggerganov/llama.cpp +URL: https://github.com/ggml-org/llama.cpp %define debug_package %{nil} %define source_date_epoch_from_changelog 0 diff --git a/.devops/musa.Dockerfile b/.devops/musa.Dockerfile index bfd7fc1c1740f..1e87737abfb71 100644 --- a/.devops/musa.Dockerfile +++ b/.devops/musa.Dockerfile @@ -1,6 +1,6 @@ ARG UBUNTU_VERSION=22.04 # This needs to generally match the container host's environment. -ARG MUSA_VERSION=rc3.1.0 +ARG MUSA_VERSION=rc3.1.1 # Target the MUSA build image ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION} diff --git a/.devops/nix/package.nix b/.devops/nix/package.nix index 043c4364b956a..6e8050a499635 100644 --- a/.devops/nix/package.nix +++ b/.devops/nix/package.nix @@ -133,12 +133,12 @@ effectiveStdenv.mkDerivation (finalAttrs: { --replace '[bundle pathForResource:@"default" ofType:@"metallib"];' "@\"$out/bin/default.metallib\";" ''; - # With PR#6015 https://github.com/ggerganov/llama.cpp/pull/6015, + # With PR#6015 https://github.com/ggml-org/llama.cpp/pull/6015, # `default.metallib` may be compiled with Metal compiler from XCode # and we need to escape sandbox on MacOS to access Metal compiler. # `xcrun` is used find the path of the Metal compiler, which is varible # and not on $PATH - # see https://github.com/ggerganov/llama.cpp/pull/6118 for discussion + # see https://github.com/ggml-org/llama.cpp/pull/6118 for discussion __noChroot = effectiveStdenv.isDarwin && useMetalKit && precompileMetalShaders; nativeBuildInputs = @@ -220,7 +220,7 @@ effectiveStdenv.mkDerivation (finalAttrs: { broken = (useMetalKit && !effectiveStdenv.isDarwin); description = "Inference of LLaMA model in pure C/C++${descriptionSuffix}"; - homepage = "https://github.com/ggerganov/llama.cpp/"; + homepage = "https://github.com/ggml-org/llama.cpp/"; license = lib.licenses.mit; # Accommodates `nix run` and `lib.getExe` diff --git a/.devops/rocm.Dockerfile b/.devops/rocm.Dockerfile index a8088ea00da5b..48e7e6aaa5b77 100644 --- a/.devops/rocm.Dockerfile +++ b/.devops/rocm.Dockerfile @@ -11,7 +11,7 @@ ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-co FROM ${BASE_ROCM_DEV_CONTAINER} AS build # Unless otherwise specified, we make a fat build. -# List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 +# List from https://github.com/ggml-org/llama.cpp/pull/1087#issuecomment-1682807878 # This is mostly tied to rocBLAS supported archs. # gfx803, gfx900, gfx1032, gfx1101, gfx1102,not officialy supported # gfx906 is deprecated diff --git a/.github/ISSUE_TEMPLATE/020-enhancement.yml b/.github/ISSUE_TEMPLATE/020-enhancement.yml index 02dd4f575a686..cee1446f5a097 100644 --- a/.github/ISSUE_TEMPLATE/020-enhancement.yml +++ b/.github/ISSUE_TEMPLATE/020-enhancement.yml @@ -6,7 +6,7 @@ body: - type: markdown attributes: value: | - [Please post your idea first in Discussion if there is not yet a consensus for this enhancement request. This will help to keep this issue tracker focused on enhancements that the community has agreed needs to be implemented.](https://github.com/ggerganov/llama.cpp/discussions/categories/ideas) + [Please post your idea first in Discussion if there is not yet a consensus for this enhancement request. This will help to keep this issue tracker focused on enhancements that the community has agreed needs to be implemented.](https://github.com/ggml-org/llama.cpp/discussions/categories/ideas) - type: checkboxes id: prerequisites @@ -16,11 +16,11 @@ body: options: - label: I am running the latest code. Mention the version if possible as well. required: true - - label: I carefully followed the [README.md](https://github.com/ggerganov/llama.cpp/blob/master/README.md). + - label: I carefully followed the [README.md](https://github.com/ggml-org/llama.cpp/blob/master/README.md). required: true - label: I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed). required: true - - label: I reviewed the [Discussions](https://github.com/ggerganov/llama.cpp/discussions), and have a new and useful enhancement to share. + - label: I reviewed the [Discussions](https://github.com/ggml-org/llama.cpp/discussions), and have a new and useful enhancement to share. required: true - type: textarea diff --git a/.github/ISSUE_TEMPLATE/030-research.yml b/.github/ISSUE_TEMPLATE/030-research.yml index 18975dbbfd0fe..e774550d5908c 100644 --- a/.github/ISSUE_TEMPLATE/030-research.yml +++ b/.github/ISSUE_TEMPLATE/030-research.yml @@ -6,7 +6,7 @@ body: - type: markdown attributes: value: | - Don't forget to check for any [duplicate research issue tickets](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3A%22research+%F0%9F%94%AC%22) + Don't forget to check for any [duplicate research issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3A%22research+%F0%9F%94%AC%22) - type: checkboxes id: research-stage diff --git a/.github/ISSUE_TEMPLATE/040-refactor.yml b/.github/ISSUE_TEMPLATE/040-refactor.yml index b6e6ab36defd6..2fe94e26c6988 100644 --- a/.github/ISSUE_TEMPLATE/040-refactor.yml +++ b/.github/ISSUE_TEMPLATE/040-refactor.yml @@ -6,8 +6,8 @@ body: - type: markdown attributes: value: | - Don't forget to [check for existing refactor issue tickets](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3Arefactoring) in case it's already covered. - Also you may want to check [Pull request refactor label as well](https://github.com/ggerganov/llama.cpp/pulls?q=is%3Aopen+is%3Apr+label%3Arefactoring) for duplicates too. + Don't forget to [check for existing refactor issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3Arefactoring) in case it's already covered. + Also you may want to check [Pull request refactor label as well](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Aopen+is%3Apr+label%3Arefactoring) for duplicates too. - type: textarea id: background-description diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index eb8c4b472df4c..0d246533c9515 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,11 +1,11 @@ blank_issues_enabled: true contact_links: - name: Got an idea? - url: https://github.com/ggerganov/llama.cpp/discussions/categories/ideas + url: https://github.com/ggml-org/llama.cpp/discussions/categories/ideas about: Pop it there. It may then become an enhancement ticket. - name: Got a question? - url: https://github.com/ggerganov/llama.cpp/discussions/categories/q-a + url: https://github.com/ggml-org/llama.cpp/discussions/categories/q-a about: Ask a question there! - name: Want to contribute? - url: https://github.com/ggerganov/llama.cpp/wiki/contribute + url: https://github.com/ggml-org/llama.cpp/wiki/contribute about: Head to the contribution guide page of the wiki for areas you can help with diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index d9f5bdc235a00..d0bdd73c4439c 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1 +1 @@ -*Make sure to read the [contributing guidelines](https://github.com/ggerganov/llama.cpp/blob/master/CONTRIBUTING.md) before submitting a PR* +*Make sure to read the [contributing guidelines](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md) before submitting a PR* diff --git a/.github/workflows/bench.yml.disabled b/.github/workflows/bench.yml.disabled index 1c8787ef78f7e..0370c8943fa0e 100644 --- a/.github/workflows/bench.yml.disabled +++ b/.github/workflows/bench.yml.disabled @@ -1,5 +1,5 @@ # TODO: there have been some issues with the workflow, so disabling for now -# https://github.com/ggerganov/llama.cpp/issues/7893 +# https://github.com/ggml-org/llama.cpp/issues/7893 # # Benchmark name: Benchmark @@ -57,17 +57,7 @@ jobs: if: | inputs.gpu-series == 'Standard_NC4as_T4_v3' - || ( - github.event_name == 'schedule' - && github.ref_name == 'master' - && github.repository_owner == 'ggerganov' - ) || github.event_name == 'pull_request_target' - || ( - github.event_name == 'push' - && github.event.ref == 'refs/heads/master' - && github.repository_owner == 'ggerganov' - ) steps: - name: Clone id: checkout diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6841ba5897809..e6555eb479c28 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -129,7 +129,7 @@ jobs: run: | sysctl -a # Metal is disabled due to intermittent failures with Github runners not having a GPU: - # https://github.com/ggerganov/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313 + # https://github.com/ggml-org/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313 cmake -B build \ -DCMAKE_BUILD_RPATH="@loader_path" \ -DLLAMA_FATAL_WARNINGS=ON \ @@ -374,6 +374,8 @@ jobs: - name: Clone id: checkout uses: actions/checkout@v4 + with: + fetch-depth: 0 - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 @@ -401,7 +403,35 @@ jobs: run: | cd build # This is using llvmpipe and runs slower than other backends - ctest -L main --verbose --timeout 1800 + ctest -L main --verbose --timeout 2700 + + - name: Determine tag name + id: tag + shell: bash + run: | + BUILD_NUMBER="$(git rev-list --count HEAD)" + SHORT_HASH="$(git rev-parse --short=7 HEAD)" + if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then + echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT + else + SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') + echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT + fi + + - name: Pack artifacts + id: pack_artifacts + if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} + run: | + cp LICENSE ./build/bin/ + cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp + zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/* + + - name: Upload artifacts + if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip + name: llama-bin-ubuntu-vulkan-x64.zip ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 @@ -443,7 +473,7 @@ jobs: ubuntu-22-cmake-musa: runs-on: ubuntu-22.04 - container: mthreads/musa:rc3.1.0-devel-ubuntu22.04 + container: mthreads/musa:rc3.1.1-devel-ubuntu22.04 steps: - name: Clone @@ -1345,8 +1375,10 @@ jobs: needs: - ubuntu-cpu-cmake + - ubuntu-22-cmake-vulkan - windows-latest-cmake - windows-2019-cmake-cuda + - windows-latest-cmake-sycl - windows-latest-cmake-hip-release - macOS-latest-cmake-arm64 - macOS-latest-cmake-x64 diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 6955a7dc8234d..c81d21fcd7e67 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -51,6 +51,8 @@ jobs: - name: Set up QEMU uses: docker/setup-qemu-action@v3 + with: + image: tonistiigi/binfmt:qemu-v7.0.0-28 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 368dbdbe5dccc..0b0f300aa402a 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -11,7 +11,7 @@ jobs: steps: - uses: actions/checkout@v4 with: - repository: "ggerganov/llama.cpp" + repository: "ggml-org/llama.cpp" - uses: actions/labeler@v5 with: configuration-path: '.github/labeler.yml' diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8d411982b4379..9d4e5a56fe48d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,7 +12,7 @@ - Squash-merge PRs - Use the following format for the squashed commit title: ` : (#)`. For example: `utils : fix typo in utils.py (#1234)` -- Optionally pick a `` from here: https://github.com/ggerganov/llama.cpp/wiki/Modules +- Optionally pick a `` from here: https://github.com/ggml-org/llama.cpp/wiki/Modules - Consider adding yourself to [CODEOWNERS](CODEOWNERS) # Coding guidelines @@ -40,14 +40,14 @@ - Try to follow the existing patterns in the code (indentation, spaces, etc.). In case of doubt use `clang-format` to format the added code - For anything not covered in the current guidelines, refer to the [C++ Core Guidelines](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines) - Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices -- Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggerganov/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$ +- Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggml-org/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$ ![matmul](media/matmul.png) # Naming guidelines - Use `snake_case` for function, variable and type names -- Naming usually optimizes for longest common prefix (see https://github.com/ggerganov/ggml/pull/302#discussion_r1243240963) +- Naming usually optimizes for longest common prefix (see https://github.com/ggml-org/ggml/pull/302#discussion_r1243240963) ```cpp // not OK @@ -122,4 +122,4 @@ The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects: -https://github.com/ggerganov/llama.cpp/projects +https://github.com/ggml-org/llama.cpp/projects diff --git a/Makefile b/Makefile index dc3de3cb14e44..662194086eaaf 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ ifndef LLAMA_MAKEFILE -$(error The Makefile build is deprecated. Use the CMake build instead. For more details, see https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) +$(error The Makefile build is deprecated. Use the CMake build instead. For more details, see https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md) endif # Define the default target now so that it is always the first target @@ -463,7 +463,7 @@ endif ifneq '' '$(findstring mingw,$(shell $(CC) -dumpmachine))' # The stack is only 16-byte aligned on Windows, so don't let gcc emit aligned moves. # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54412 - # https://github.com/ggerganov/llama.cpp/issues/2922 + # https://github.com/ggml-org/llama.cpp/issues/2922 MK_CFLAGS += -Xassembler -muse-unaligned-vector-move MK_CXXFLAGS += -Xassembler -muse-unaligned-vector-move @@ -1078,8 +1078,8 @@ endif ifdef REMOVE_WARNING $(info !!! REMOVAL WARNING !!!) $(info The following LLAMA_ options have been removed and are no longer supported) -$(info - LLAMA_DISABLE_LOGS (https://github.com/ggerganov/llama.cpp/pull/9418)) -$(info - LLAMA_SERVER_VERBOSE (https://github.com/ggerganov/llama.cpp/pull/9418)) +$(info - LLAMA_DISABLE_LOGS (https://github.com/ggml-org/llama.cpp/pull/9418)) +$(info - LLAMA_SERVER_VERBOSE (https://github.com/ggml-org/llama.cpp/pull/9418)) $(info ) endif diff --git a/README.md b/README.md index 11f3d0286aa2b..f70c8ae1ed851 100644 --- a/README.md +++ b/README.md @@ -3,26 +3,33 @@ ![llama](https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) -[![Server](https://github.com/ggerganov/llama.cpp/actions/workflows/server.yml/badge.svg)](https://github.com/ggerganov/llama.cpp/actions/workflows/server.yml) +[![Server](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml/badge.svg)](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml) -[Roadmap](https://github.com/users/ggerganov/projects/7) / [Project status](https://github.com/ggerganov/llama.cpp/discussions/3471) / [Manifesto](https://github.com/ggerganov/llama.cpp/discussions/205) / [ggml](https://github.com/ggerganov/ggml) +[Roadmap](https://github.com/users/ggml-org/projects/7) / [Project status](https://github.com/ggml-org/llama.cpp/discussions/3471) / [Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml) Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++ +> [!IMPORTANT] +> New `llama.cpp` package location: [ggml-org/llama.cpp](https://github.com/ggml-org/llama.cpp/pkgs/container/llama.cpp) +> +> Update your container URLs to: `ghcr.io/ggml-org/llama.cpp` +> +> More info: https://github.com/ggml-org/llama.cpp/discussions/11801 + ## Recent API changes -- [Changelog for `libllama` API](https://github.com/ggerganov/llama.cpp/issues/9289) -- [Changelog for `llama-server` REST API](https://github.com/ggerganov/llama.cpp/issues/9291) +- [Changelog for `libllama` API](https://github.com/ggml-org/llama.cpp/issues/9289) +- [Changelog for `llama-server` REST API](https://github.com/ggml-org/llama.cpp/issues/9291) ## Hot topics -- **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggerganov/llama.cpp/pull/11427 +- **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggml-org/llama.cpp/pull/11427 - **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode -- Universal tool call support in `llama-server`: https://github.com/ggerganov/llama.cpp/pull/9639 +- Universal tool call support in `llama-server`: https://github.com/ggml-org/llama.cpp/pull/9639 - Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim -- Introducing GGUF-my-LoRA https://github.com/ggerganov/llama.cpp/discussions/10123 -- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggerganov/llama.cpp/discussions/9669 -- Hugging Face GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor) +- Introducing GGUF-my-LoRA https://github.com/ggml-org/llama.cpp/discussions/10123 +- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggml-org/llama.cpp/discussions/9669 +- Hugging Face GGUF editor: [discussion](https://github.com/ggml-org/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor) ---- @@ -39,7 +46,7 @@ range of hardware - locally and in the cloud. - Vulkan and SYCL backend support - CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity -The `llama.cpp` project is the main playground for developing new features for the [ggml](https://github.com/ggerganov/ggml) library. +The `llama.cpp` project is the main playground for developing new features for the [ggml](https://github.com/ggml-org/ggml) library.
Models @@ -59,23 +66,23 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon) - [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2) - [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne) -- [X] [BERT](https://github.com/ggerganov/llama.cpp/pull/5423) +- [X] [BERT](https://github.com/ggml-org/llama.cpp/pull/5423) - [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/) - [X] [Baichuan 1 & 2](https://huggingface.co/models?search=baichuan-inc/Baichuan) + [derivations](https://huggingface.co/hiyouga/baichuan-7b-sft) - [X] [Aquila 1 & 2](https://huggingface.co/models?search=BAAI/Aquila) -- [X] [Starcoder models](https://github.com/ggerganov/llama.cpp/pull/3187) +- [X] [Starcoder models](https://github.com/ggml-org/llama.cpp/pull/3187) - [X] [Refact](https://huggingface.co/smallcloudai/Refact-1_6B-fim) -- [X] [MPT](https://github.com/ggerganov/llama.cpp/pull/3417) -- [X] [Bloom](https://github.com/ggerganov/llama.cpp/pull/3553) +- [X] [MPT](https://github.com/ggml-org/llama.cpp/pull/3417) +- [X] [Bloom](https://github.com/ggml-org/llama.cpp/pull/3553) - [x] [Yi models](https://huggingface.co/models?search=01-ai/Yi) - [X] [StableLM models](https://huggingface.co/stabilityai) - [x] [Deepseek models](https://huggingface.co/models?search=deepseek-ai/deepseek) - [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen) -- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557) +- [x] [PLaMo-13B](https://github.com/ggml-org/llama.cpp/pull/3557) - [x] [Phi models](https://huggingface.co/models?search=microsoft/phi) -- [x] [PhiMoE](https://github.com/ggerganov/llama.cpp/pull/11003) +- [x] [PhiMoE](https://github.com/ggml-org/llama.cpp/pull/11003) - [x] [GPT-2](https://huggingface.co/gpt2) -- [x] [Orion 14B](https://github.com/ggerganov/llama.cpp/pull/5118) +- [x] [Orion 14B](https://github.com/ggml-org/llama.cpp/pull/5118) - [x] [InternLM2](https://huggingface.co/models?search=internlm2) - [x] [CodeShell](https://github.com/WisdomShell/codeshell) - [x] [Gemma](https://ai.google.dev/gemma) @@ -146,7 +153,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - Zig: [deins/llama.cpp.zig](https://github.com/Deins/llama.cpp.zig) - Flutter/Dart: [netdur/llama_cpp_dart](https://github.com/netdur/llama_cpp_dart) - Flutter: [xuegao-tzx/Fllama](https://github.com/xuegao-tzx/Fllama) -- PHP (API bindings and features built on top of llama.cpp): [distantmagic/resonance](https://github.com/distantmagic/resonance) [(more info)](https://github.com/ggerganov/llama.cpp/pull/6326) +- PHP (API bindings and features built on top of llama.cpp): [distantmagic/resonance](https://github.com/distantmagic/resonance) [(more info)](https://github.com/ggml-org/llama.cpp/pull/6326) - Guile Scheme: [guile_llama_cpp](https://savannah.nongnu.org/projects/guile-llama-cpp) - Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift) - Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama) @@ -235,6 +242,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo | [HIP](docs/build.md#hip) | AMD GPU | | [Vulkan](docs/build.md#vulkan) | GPU | | [CANN](docs/build.md#cann) | Ascend NPU | +| [OpenCL](docs/backend/OPENCL.md) | Adreno GPU | ## Building the project @@ -244,7 +252,7 @@ The project also includes many example programs and tools using the `llama` libr - Clone this repository and build locally, see [how to build](docs/build.md) - On MacOS or Linux, install `llama.cpp` via [brew, flox or nix](docs/install.md) - Use a Docker image, see [documentation for Docker](docs/docker.md) -- Download pre-built binaries from [releases](https://github.com/ggerganov/llama.cpp/releases) +- Download pre-built binaries from [releases](https://github.com/ggml-org/llama.cpp/releases) ## Obtaining and quantizing models @@ -257,14 +265,14 @@ You can either manually download the GGUF file or directly use any `llama.cpp`-c After downloading a model, use the CLI tools to run it locally - see below. -`llama.cpp` requires the model to be stored in the [GGUF](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) file format. Models in other data formats can be converted to GGUF using the `convert_*.py` Python scripts in this repo. +`llama.cpp` requires the model to be stored in the [GGUF](https://github.com/ggml-org/ggml/blob/master/docs/gguf.md) file format. Models in other data formats can be converted to GGUF using the `convert_*.py` Python scripts in this repo. The Hugging Face platform provides a variety of online tools for converting, quantizing and hosting models with `llama.cpp`: - Use the [GGUF-my-repo space](https://huggingface.co/spaces/ggml-org/gguf-my-repo) to convert to GGUF format and quantize model weights to smaller sizes -- Use the [GGUF-my-LoRA space](https://huggingface.co/spaces/ggml-org/gguf-my-lora) to convert LoRA adapters to GGUF format (more info: https://github.com/ggerganov/llama.cpp/discussions/10123) -- Use the [GGUF-editor space](https://huggingface.co/spaces/CISCai/gguf-editor) to edit GGUF meta data in the browser (more info: https://github.com/ggerganov/llama.cpp/discussions/9268) -- Use the [Inference Endpoints](https://ui.endpoints.huggingface.co/) to directly host `llama.cpp` in the cloud (more info: https://github.com/ggerganov/llama.cpp/discussions/9669) +- Use the [GGUF-my-LoRA space](https://huggingface.co/spaces/ggml-org/gguf-my-lora) to convert LoRA adapters to GGUF format (more info: https://github.com/ggml-org/llama.cpp/discussions/10123) +- Use the [GGUF-editor space](https://huggingface.co/spaces/CISCai/gguf-editor) to edit GGUF meta data in the browser (more info: https://github.com/ggml-org/llama.cpp/discussions/9268) +- Use the [Inference Endpoints](https://ui.endpoints.huggingface.co/) to directly host `llama.cpp` in the cloud (more info: https://github.com/ggml-org/llama.cpp/discussions/9669) To learn more about model quantization, [read this documentation](examples/quantize/README.md) @@ -487,9 +495,9 @@ To learn more about model quantization, [read this documentation](examples/quant - Collaborators can push to branches in the `llama.cpp` repo and merge PRs into the `master` branch - Collaborators will be invited based on contributions - Any help with managing issues, PRs and projects is very appreciated! -- See [good first issues](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions +- See [good first issues](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions - Read the [CONTRIBUTING.md](CONTRIBUTING.md) for more information -- Make sure to read this: [Inference at the edge](https://github.com/ggerganov/llama.cpp/discussions/205) +- Make sure to read this: [Inference at the edge](https://github.com/ggml-org/llama.cpp/discussions/205) - A bit of backstory for those who are interested: [Changelog podcast](https://changelog.com/podcast/532) ## Other documentation @@ -504,7 +512,7 @@ To learn more about model quantization, [read this documentation](examples/quant - [Running on Docker](docs/docker.md) - [Build on Android](docs/android.md) - [Performance troubleshooting](docs/development/token_generation_performance_tips.md) -- [GGML tips & tricks](https://github.com/ggerganov/llama.cpp/wiki/GGML-Tips-&-Tricks) +- [GGML tips & tricks](https://github.com/ggml-org/llama.cpp/wiki/GGML-Tips-&-Tricks) #### Seminal papers and background on the models @@ -518,5 +526,18 @@ If your issue is with model generation quality, then please at least scan the fo - [Aligning language models to follow instructions](https://openai.com/research/instruction-following) - [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155) -#### References - +## Completions +Command-line completion is available for some environments. + +#### Bash Completion +```bash +$ build/bin/llama-cli --completion-bash > ~/.llama-completion.bash +$ source ~/.llama-completion.bash +``` +Optionally this can be added to your `.bashrc` or `.bash_profile` to load it +automatically. For example: +```console +$ echo "source ~/.llama-completion.bash" >> ~/.bashrc +``` + +## References diff --git a/SECURITY.md b/SECURITY.md index f4322c6ee4d18..6a1bb6c32cd8e 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -62,6 +62,6 @@ Beware that none of the topics under [Using llama.cpp securely](#using-llamacpp- However, If you have discovered a security vulnerability in this project, please report it privately. **Do not disclose it as a public issue.** This gives us time to work with you to fix the issue before public exposure, reducing the chance that the exploit will be used before a patch is released. -Please disclose it as a private [security advisory](https://github.com/ggerganov/llama.cpp/security/advisories/new). +Please disclose it as a private [security advisory](https://github.com/ggml-org/llama.cpp/security/advisories/new). A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure. diff --git a/ci/README.md b/ci/README.md index 4064705190697..8245c9df65db8 100644 --- a/ci/README.md +++ b/ci/README.md @@ -1,11 +1,11 @@ # CI -In addition to [Github Actions](https://github.com/ggerganov/llama.cpp/actions) `llama.cpp` uses a custom CI framework: +In addition to [Github Actions](https://github.com/ggml-org/llama.cpp/actions) `llama.cpp` uses a custom CI framework: https://github.com/ggml-org/ci It monitors the `master` branch for new commits and runs the -[ci/run.sh](https://github.com/ggerganov/llama.cpp/blob/master/ci/run.sh) script on dedicated cloud instances. This allows us +[ci/run.sh](https://github.com/ggml-org/llama.cpp/blob/master/ci/run.sh) script on dedicated cloud instances. This allows us to execute heavier workloads compared to just using Github Actions. Also with time, the cloud instances will be scaled to cover various hardware architectures, including GPU and Apple Silicon instances. diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index e61015d2ad7f9..c2b4aa7d09f1c 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -96,6 +96,22 @@ if (LLAMA_LLGUIDANCE) include(ExternalProject) set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source) set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release) + + # Set the correct library file extension based on platform + if (WIN32) + set(LLGUIDANCE_LIB_NAME "llguidance.lib") + # Add Windows-specific libraries + set(LLGUIDANCE_PLATFORM_LIBS + ws2_32 # Windows Sockets API + userenv # For GetUserProfileDirectoryW + ntdll # For NT functions + bcrypt # For BCryptGenRandom + ) + else() + set(LLGUIDANCE_LIB_NAME "libllguidance.a") + set(LLGUIDANCE_PLATFORM_LIBS "") + endif() + ExternalProject_Add(llguidance_ext GIT_REPOSITORY https://github.com/guidance-ai/llguidance # v0.6.12: @@ -106,17 +122,18 @@ if (LLAMA_LLGUIDANCE) CONFIGURE_COMMAND "" BUILD_COMMAND cargo build --release INSTALL_COMMAND "" - BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/libllguidance.a ${LLGUIDANCE_PATH}/llguidance.h + BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME} ${LLGUIDANCE_PATH}/llguidance.h UPDATE_COMMAND "" ) target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE) add_library(llguidance STATIC IMPORTED) - set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/libllguidance.a) + set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME}) add_dependencies(llguidance llguidance_ext) target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH}) - set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance) + # Add platform libraries to the main target + set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS}) endif () target_include_directories(${TARGET} PUBLIC .) diff --git a/common/arg.cpp b/common/arg.cpp index 152f671ab738e..f06aa1076cca7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -365,6 +365,112 @@ static void common_params_print_usage(common_params_context & ctx_arg) { print_options(specific_options); } +static void common_params_print_completion(common_params_context & ctx_arg) { + std::vector common_options; + std::vector sparam_options; + std::vector specific_options; + + for (auto & opt : ctx_arg.options) { + if (opt.is_sparam) { + sparam_options.push_back(&opt); + } else if (opt.in_example(ctx_arg.ex)) { + specific_options.push_back(&opt); + } else { + common_options.push_back(&opt); + } + } + + printf("_llama_completions() {\n"); + printf(" local cur prev opts\n"); + printf(" COMPREPLY=()\n"); + printf(" cur=\"${COMP_WORDS[COMP_CWORD]}\"\n"); + printf(" prev=\"${COMP_WORDS[COMP_CWORD-1]}\"\n\n"); + + printf(" opts=\""); + auto print_options = [](const std::vector & options) { + for (const common_arg * opt : options) { + for (const char * arg : opt->args) { + printf("%s ", arg); + } + } + }; + + print_options(common_options); + print_options(sparam_options); + print_options(specific_options); + printf("\"\n\n"); + + printf(" case \"$prev\" in\n"); + printf(" --model)\n"); + printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n"); + printf(" return 0\n"); + printf(" ;;\n"); + printf(" --grammar-file)\n"); + printf(" COMPREPLY=( $(compgen -f -X '!*.gbnf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n"); + printf(" return 0\n"); + printf(" ;;\n"); + printf(" --chat-template-file)\n"); + printf(" COMPREPLY=( $(compgen -f -X '!*.jinja' -- \"$cur\") $(compgen -d -- \"$cur\") )\n"); + printf(" return 0\n"); + printf(" ;;\n"); + printf(" *)\n"); + printf(" COMPREPLY=( $(compgen -W \"${opts}\" -- \"$cur\") )\n"); + printf(" return 0\n"); + printf(" ;;\n"); + printf(" esac\n"); + printf("}\n\n"); + + std::set executables = { + "llama-batched", + "llama-batched-bench", + "llama-bench", + "llama-cli", + "llama-convert-llama2c-to-ggml", + "llama-cvector-generator", + "llama-embedding", + "llama-eval-callback", + "llama-export-lora", + "llama-gbnf-validator", + "llama-gen-docs", + "llama-gguf", + "llama-gguf-hash", + "llama-gguf-split", + "llama-gritlm", + "llama-imatrix", + "llama-infill", + "llama-llava-cli", + "llama-llava-clip-quantize-cli", + "llama-lookahead", + "llama-lookup", + "llama-lookup-create", + "llama-lookup-merge", + "llama-lookup-stats", + "llama-minicpmv-cli", + "llama-parallel", + "llama-passkey", + "llama-perplexity", + "llama-q8dot", + "llama-quantize", + "llama-quantize-stats", + "llama-qwen2vl-cli", + "llama-retrieval", + "llama-run", + "llama-save-load-state", + "llama-server", + "llama-simple", + "llama-simple-chat", + "llama-speculative", + "llama-speculative-simple", + "llama-tokenize", + "llama-tts", + "llama-vdot" + }; + + for (const auto& exe : executables) { + printf("complete -F _llama_completions %s\n", exe.c_str()); + } +} + static std::vector parse_device_list(const std::string & value) { std::vector devices; auto dev_names = string_split(value, ','); @@ -426,6 +532,10 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e } exit(0); } + if (ctx_arg.params.completion) { + common_params_print_completion(ctx_arg); + exit(0); + } } catch (const std::invalid_argument & ex) { fprintf(stderr, "%s\n", ex.what()); ctx_arg.params = params_org; @@ -494,6 +604,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex exit(0); } )); + add_opt(common_arg( + {"--completion-bash"}, + "print source-able bash completion script for llama.cpp", + [](common_params & params) { + params.completion = true; + } + )); add_opt(common_arg( {"--verbose-prompt"}, string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"), @@ -674,7 +791,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex )); add_opt(common_arg( {"--no-context-shift"}, - string_format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), + string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), [](common_params & params) { params.ctx_shift = false; } @@ -946,6 +1063,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.min_p = std::stof(value); } ).set_sparam()); + add_opt(common_arg( + {"--top-nsigma"}, "N", + string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma), + [](common_params & params, const std::string & value) { + params.sampling.top_n_sigma = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_MAIN}).set_sparam()); add_opt(common_arg( {"--xtc-probability"}, "N", string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), @@ -1445,7 +1569,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "- isolate: only spawn threads on CPUs on the node that execution started on\n" "- numactl: use the CPU map provided by numactl\n" "if run without this previously, it is recommended to drop the system page cache before using this\n" - "see https://github.com/ggerganov/llama.cpp/issues/1437", + "see https://github.com/ggml-org/llama.cpp/issues/1437", [](common_params & params, const std::string & value) { /**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } @@ -1975,6 +2099,17 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.use_jinja = true; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA")); + add_opt(common_arg( + {"--reasoning-format"}, "FORMAT", + "reasoning format (default: deepseek; allowed values: deepseek, none)\n" + "controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).\n" + "only supported for non-streamed responses", + [](common_params & params, const std::string & value) { + /**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; } + else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; } + else { std::invalid_argument("invalid value"); } + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( @@ -2112,7 +2247,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_env("LLAMA_LOG_VERBOSITY")); add_opt(common_arg( {"--log-prefix"}, - "Enable prefx in log messages", + "Enable prefix in log messages", [](common_params &) { common_log_set_prefix(common_log_main(), true); } diff --git a/common/chat.cpp b/common/chat.cpp index ef1c6fb3d3978..f21a9d2a63a4b 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -12,11 +12,13 @@ std::string common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; + case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)"; case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; + case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)"; default: throw std::runtime_error("Unknown chat format"); } @@ -105,7 +107,6 @@ static common_chat_msg parse_json_tool_calls( std::sregex_iterator rend; std::sregex_iterator rit(it, end, function_regex); if (rit == rend) { - fprintf(stderr, "No more tool calls found\n"); result.content += std::string(it, end); break; } @@ -115,14 +116,21 @@ static common_chat_msg parse_json_tool_calls( json arguments; if (!parse_json(it, end, arguments)) { - throw std::runtime_error("Failed to parse json tool call arguments"); + throw std::runtime_error("Failed to parse json tool call arguments: " + input); } if (!std::regex_search(it, end, match, close_regex)) { - throw std::runtime_error("Malformed input, missing closing pattern"); + throw std::runtime_error("Malformed input, missing closing pattern: " + input); } it = match.suffix().first; result.tool_calls.push_back({name, arguments.is_string() ? arguments.get() : arguments.dump(), /* id= */ ""}); } + + if (!result.tool_calls.empty()) { + if (!string_strip(result.content).empty()) { + LOG_WRN("Content found with tool calls: %s\n", result.content.c_str()); + } + result.content = ""; + } return result; } @@ -134,11 +142,11 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in result.role = "assistant"; const auto process_tool_calls = [&](const json & tool_calls) { for (const auto & tool_call : tool_calls) { - const auto & arguments = tool_call["arguments"]; + const auto & arguments = tool_call.at("arguments"); result.tool_calls.push_back({ - tool_call["name"], + tool_call.at("name"), arguments.is_string() ? arguments.get() : arguments.dump(), - tool_call.contains("id") ? tool_call["id"] : "", + tool_call.contains("id") ? tool_call.at("id") : "", }); } }; @@ -155,7 +163,7 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in static void foreach_function(const json & tools, const std::function & fn) { for (const auto & tool : tools) { - if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) { + if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str()); continue; } @@ -190,27 +198,27 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp auto tool_call_schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; + const auto & function = tool.at("function"); auto tool_schema = json { {"type", "object"}, {"properties", { {"name", { {"type", "string"}, - {"const", function["name"]}, + {"const", function.at("name")}, }}, - {"arguments", function["parameters"]}, + {"arguments", function.at("parameters")}, }}, {"required", json::array({"name", "arguments"})}, }; if (function.contains("description")) { - tool_schema["description"] = function["description"]; + tool_schema["description"] = function.at("description"); } if (inputs.parallel_tool_calls) { - tool_schema["properties"]["id"] = { + tool_schema.at("properties")["id"] = { {"type", "string"}, {"minLength", 4}, }; - tool_schema["required"].push_back("id"); + tool_schema.at("required").push_back("id"); } tool_call_schemas.emplace_back(tool_schema); }); @@ -275,21 +283,21 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) { common_chat_msg result; result.role = "assistant"; if (data.contains("tool_calls")) { - for (const auto & tool_call : data["tool_calls"]) { + for (const auto & tool_call : data.at("tool_calls")) { result.tool_calls.push_back({ - tool_call["name"], - tool_call["arguments"].dump(), - tool_call.contains("id") ? tool_call["id"] : "", + tool_call.at("name"), + tool_call.at("arguments").dump(), + tool_call.contains("id") ? tool_call.at("id") : "", }); } } else if (data.contains("tool_call")) { result.tool_calls.push_back({ - data["tool_call"]["name"], - data["tool_call"]["arguments"].dump(), + data.at("tool_call").at("name"), + data.at("tool_call").at("arguments").dump(), /* id= */ "", }); } else if (data.contains("response")) { - const auto & response = data["response"]; + const auto & response = data.at("response"); result.content = response.is_string() ? response.get() : response.dump(2); } return result; @@ -301,7 +309,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; + const auto & function = tool.at("function"); schemas.push_back({ {"type", "object"}, {"properties", { @@ -309,9 +317,9 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. {"name", { {"type", "string"}, - {"const", function["name"]}, + {"const", function.at("name")}, }}, - {"arguments", function["parameters"]}, + {"arguments", function.at("parameters")}, {"id", { {"type", "string"}, // Nemo's template expects a 9-character alphanumeric ID. @@ -346,7 +354,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; + const auto & function = tool.at("function"); schemas.push_back({ {"type", "object"}, {"properties", { @@ -357,9 +365,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ }}, {"tool_name", { {"type", "string"}, - {"const", function["name"]}, + {"const", function.at("name")}, }}, - {"parameters", function["parameters"]}, + {"parameters", function.at("parameters")}, }}, {"required", json::array({"tool_call_id", "tool_name", "parameters"})}, }); @@ -382,39 +390,65 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ "<|END_THINKING|>", "<|END_ACTION|>", }; - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); - data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; + auto adjusted_messages = json::array(); + for (const auto & msg : inputs.messages) { + auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); + auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); + if (has_reasoning_content && has_tool_calls) { + auto adjusted_message = msg; + adjusted_message["tool_plan"] = msg.at("reasoning_content"); + adjusted_message.erase("reasoning_content"); + adjusted_messages.push_back(adjusted_message); + } else { + adjusted_messages.push_back(msg); + } + } + data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); + data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B; return data; } -static common_chat_msg common_chat_parse_command_r7b(const std::string & input) { - static std::regex response_regex("<\\|START_RESPONSE\\|>([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>"); - static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>"); +static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) { + static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)"); + static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>"); + static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>"); + std::smatch match; common_chat_msg result; result.role = "assistant"; - if (std::regex_match(input, match, response_regex)) { - result.content = match[1].str(); - } else if (std::regex_match(input, match, thought_action_regex)) { - result.tool_plan = match[1].str(); - auto actions_str = match[2].str(); + + std::string rest = input; + + if (std::regex_match(rest, match, thought_regex)) { + if (extract_reasoning) { + result.reasoning_content = match[2].str(); + } else if (!match[2].str().empty()) { + // Let the unparsed thinking tags through in content only if their insides aren't empty. + result.content = match[1].str(); + } + rest = match[3].str(); + } + if (std::regex_match(rest, match, action_regex)) { + auto actions_str = match[1].str(); auto actions = json::parse(actions_str); for (const auto & action : actions) { result.tool_calls.push_back({ - /* .name = */ action["tool_name"], - /* .arguments = */ action["parameters"].dump(), - /* .id = */ action["tool_call_id"], + /* .name = */ action.at("tool_name"), + /* .arguments = */ action.at("parameters").dump(), + /* .id = */ action.at("tool_call_id"), }); } + } else if (std::regex_match(rest, match, response_regex)) { + auto response = match[1].str(); + result.content += response; } else { - LOG_ERR("Failed to parse command_r output"); - result.content = input; + result.content += rest; } return result; } static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { - if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) { + if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) { throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); } const auto & parameters_properties = parameters.at("properties"); @@ -468,9 +502,9 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com }; foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); builder.resolve_refs(parameters); // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime @@ -546,34 +580,90 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { common_chat_params data; - data.grammar_lazy = inputs.tool_choice != "required"; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - auto args_rule = builder.add_schema(name + "-args", parameters); - tool_rules.push_back(builder.add_rule(name + "-call", - "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); - }); - data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); - data.preserved_tokens = { - "<|tool▁sep|>", - "<|tool▁call▁end|>", - }; - builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space"); - }, grammar_options); + if (inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null(); + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + auto args_rule = builder.add_schema(name + "-args", parameters); + tool_rules.push_back(builder.add_rule(name + "-call", + "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n" + "```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); + }); + // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, + // so we accept common variants (then it's all constrained) + builder.add_rule("root", + "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) " + "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " + "\"<|tool▁calls▁end|>\"" + " space"); + data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false}); + data.preserved_tokens = { + "", + "", + "<|tool▁sep|>", + "<|tool▁calls▁end|", + "<|tool▁call▁end|>", + }; + }, grammar_options); + } auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + + // Hacks to fix the official (broken) prompt. + // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead, + // until the official template is fixed. + if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) { + // Don't leave the chat dangling after tool results + if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) { + prompt += "<|end▁of▁sentence|>"; + if (inputs.add_generation_prompt) { + prompt += "<|Assistant|>"; + } + } + // Fix up tool call delta example added by Minja + prompt = std::regex_replace( + prompt, + std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"), + "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2"); + } data.prompt = prompt; - data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; + data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1; return data; } -static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) { - static std::regex trigger_regex("<|tool▁calls▁begin|>"); +static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) { static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); - static std::regex close_regex("```<|tool▁call▁end|>"); - return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex); + static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); + static std::regex reasoning_content_regex("((?:)?([\\s\\S\\r\\n]*?))?([\\s\\S\\r\\n]*)"); + static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>"); + common_chat_msg msg; + msg.role = "assistant"; + std::smatch match; + if (std::regex_match(input, match, reasoning_content_regex)) { + std::string rest; + if (extract_reasoning) { + msg.reasoning_content = string_strip(match[2].str()); + } else { + msg.content = match[1].str(); + } + rest = match[3].str(); + + if (std::regex_search(rest, match, tool_calls_regex)) { + auto tool_calls = match[1].str(); + auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex); + msg.tool_calls = std::move(msg2.tool_calls); + } else { + msg.content += std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end()); + } + } else { + msg.content = input; + } + return msg; } static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { @@ -583,20 +673,20 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c {"datetime", "Jan 29 2025 13:00:00 GMT"}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, }); - if (!inputs.tools.is_null() && !inputs.tools.empty()) { + if (inputs.tools.is_array() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; + const auto & function = tool.at("function"); schemas.push_back({ {"type", "object"}, {"properties", { {"name", { {"type", "string"}, - {"const", function["name"]}, + {"const", function.at("name")}, }}, - {"arguments", function["parameters"]}, + {"arguments", function.at("parameters")}, }}, {"required", json::array({"name", "arguments", "id"})}, }); @@ -628,15 +718,15 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ common_chat_params data; data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; - if (!inputs.tools.is_null() && !inputs.tools.empty()) { + if (inputs.tools.is_array() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != "required"; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector first_tool_rules; std::vector subsequent_tool_rules; foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); auto args_rule = builder.add_schema(name + "-args", parameters); first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); @@ -716,9 +806,9 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; - const auto & parameters = function["parameters"]; - std::string name = function["name"]; + const auto & function = tool.at("function"); + const auto & parameters = function.at("parameters"); + std::string name = function.at("name"); if (name == "python" || name == "ipython") { if (!parameters.contains("type")) { throw std::runtime_error("Missing type in python tool"); @@ -789,9 +879,9 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); builder.resolve_refs(parameters); tool_rules.push_back(builder.add_schema(name + "-call", { {"type", "object"}, @@ -839,9 +929,9 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) if (!parse_json(it, end, call)) { throw std::runtime_error("Failed to parse json tool call"); } - const auto & arguments = call["arguments"]; + const auto & arguments = call.at("arguments"); result.tool_calls.push_back({ - call["name"], + call.at("name"), arguments.dump(), // arguments.is_string() ? arguments.get() : arguments.dump(), /* id= */ "", @@ -878,53 +968,78 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha } data.grammar = json_schema_to_grammar(inputs.json_schema); } else { - data.grammar = inputs.grammar.empty(); + data.grammar = inputs.grammar; } return data; } common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { - auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none"; - LOG_DBG("[%s] has_tools=%s\n", __func__, has_tools ? "true" : "false"); + const auto & src = tmpl.source(); + const auto & caps = tmpl.original_caps(); - if (has_tools && !inputs.grammar.empty()) { - throw std::runtime_error("Cannot specify grammar with tools"); + if (inputs.tools.is_array()) { + if (inputs.tool_choice != "none" && !inputs.grammar.empty()) { + throw std::runtime_error("Cannot specify grammar with tools"); + } + if (caps.supports_tool_calls && !caps.supports_tools) { + LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); + } } - const auto & src = tmpl.source(); + // DeepSeek R1: use handler in all cases except json schema (thinking / tools). + if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) { + return common_chat_params_init_deepseek_r1(tmpl, inputs); + } + + // Command R7B: : use handler in all cases except json schema (thinking / tools). + if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) { + return common_chat_params_init_command_r7b(tmpl, inputs); + } + + // Use generic handler when mixing tools + JSON schema. + // TODO: support that mix in handlers below. + if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) { + return common_chat_params_init_generic(tmpl, inputs); + } + + // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. if (src.find(">>>all") != std::string::npos) { - // Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when return common_chat_params_init_functionary_v3_2(tmpl, inputs); } + + // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. if (src.find(" functools[") != std::string::npos) { - // Firefunction v2 requires datetime and functions in the context, even w/o tools. return common_chat_params_init_firefunction_v2(tmpl, inputs); } - if (!has_tools) { + // Plain handler (no tools) + if (inputs.tools.is_null() || inputs.tool_choice == "none") { return common_chat_params_init_without_tools(tmpl, inputs); } + // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) if (src.find("") != std::string::npos) { return common_chat_params_init_hermes_2_pro(tmpl, inputs); } + + // Functionary v3.1 (w/ tools) if (src.find("<|start_header_id|>") != std::string::npos && src.find("ipython<|end_header_id|>") != std::string::npos) { auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools); } - if (src.find("<|tool▁calls▁begin|>") != std::string::npos) { - return common_chat_params_init_deepseek_r1(tmpl, inputs); - } + + // Mistral Nemo (w/ tools) if (src.find("[TOOL_CALLS]") != std::string::npos) { return common_chat_params_init_mistral_nemo(tmpl, inputs); } - if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) { - return common_chat_params_init_command_r7b(tmpl, inputs); - } + + // Generic fallback return common_chat_params_init_generic(tmpl, inputs); } @@ -949,7 +1064,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true); case COMMON_CHAT_FORMAT_DEEPSEEK_R1: - return common_chat_parse_deepseek_r1(input); + return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false); + case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: + return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true); case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return common_chat_parse_functionary_v3_2(input); case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: @@ -959,7 +1076,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return common_chat_parse_firefunction_v2(input); case COMMON_CHAT_FORMAT_COMMAND_R7B: - return common_chat_parse_command_r7b(input); + return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false); + case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: + return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true); default: throw std::runtime_error("Unsupported format: " + common_chat_format_name(format)); } diff --git a/common/chat.hpp b/common/chat.hpp index 33e64a430d51e..ba1632f669cf7 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -19,6 +19,7 @@ struct common_chat_inputs { bool stream; std::string grammar; bool add_generation_prompt = true; + bool extract_reasoning = true; }; enum common_chat_format { @@ -28,11 +29,13 @@ enum common_chat_format { COMMON_CHAT_FORMAT_LLAMA_3_X, COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, COMMON_CHAT_FORMAT_DEEPSEEK_R1, + COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, COMMON_CHAT_FORMAT_FIREFUNCTION_V2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_HERMES_2_PRO, COMMON_CHAT_FORMAT_COMMAND_R7B, + COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; diff --git a/common/common.h b/common/common.h index b208d0c7ece59..98b9a4464787a 100644 --- a/common/common.h +++ b/common/common.h @@ -140,6 +140,7 @@ struct common_params_sampling { int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float top_n_sigma = -1.00f;// -1.0 = disabled float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; @@ -202,6 +203,11 @@ struct common_params_vocoder { bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT }; +enum common_reasoning_format { + COMMON_REASONING_FORMAT_NONE, + COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content` +}; + struct common_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size @@ -292,6 +298,7 @@ struct common_params { bool kl_divergence = false; // compute KL divergence bool usage = false; // print usage + bool completion = false; // print source-able completion script bool use_color = false; // use color to distinguish generations and inputs bool special = false; // enable special token output bool interactive = false; // interactive mode @@ -346,6 +353,7 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; std::vector api_keys; @@ -424,13 +432,13 @@ bool set_process_priority(enum ggml_sched_priority prio); // #ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) -#else -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif +# if defined(__MINGW32__) && !defined(__clang__) +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# else +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +# endif #else -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) #endif LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) @@ -623,7 +631,7 @@ struct common_chat_msg { std::string role; std::string content; std::vector tool_calls; - std::string tool_plan = ""; + std::string reasoning_content = ""; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid diff --git a/common/log.cpp b/common/log.cpp index 4bfbecf15e314..52b31470c46bd 100644 --- a/common/log.cpp +++ b/common/log.cpp @@ -1,5 +1,6 @@ #include "log.h" +#include #include #include #include diff --git a/common/log.h b/common/log.h index 4ebc6314b25d6..c56bb50d95db0 100644 --- a/common/log.h +++ b/common/log.h @@ -15,7 +15,7 @@ #ifndef __GNUC__ # define LOG_ATTRIBUTE_FORMAT(...) -#elif defined(__MINGW32__) +#elif defined(__MINGW32__) && !defined(__clang__) # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) #else # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) diff --git a/common/sampling.cpp b/common/sampling.cpp index e4b21ca1011dd..37a0d9c85ae30 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -134,11 +134,11 @@ std::string common_params_sampling::print() const { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" - "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" + "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, - top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, + top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, mirostat, mirostat_eta, mirostat_tau); return std::string(result); @@ -151,12 +151,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co lparams.no_perf = params.no_perf; - std::vector trigger_words; - trigger_words.reserve(params.grammar_trigger_words.size()); - for (const auto & str : params.grammar_trigger_words) { - trigger_words.push_back(str.word.c_str()); - } - struct llama_sampler * grmr; if (params.grammar.compare(0, 11, "%llguidance") == 0) { #ifdef LLAMA_USE_LLGUIDANCE @@ -165,6 +159,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE } else { + std::vector trigger_words; + trigger_words.reserve(params.grammar_trigger_words.size()); + for (const auto & str : params.grammar_trigger_words) { + trigger_words.push_back(str.word.c_str()); + } + grmr = params.grammar_lazy ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", trigger_words.data(), trigger_words.size(), @@ -188,45 +188,51 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co params.logit_bias.data())); if (params.mirostat == 0) { - for (const auto & cnstr : params.samplers) { - switch (cnstr) { - case COMMON_SAMPLER_TYPE_DRY: - { - std::vector c_breakers; - c_breakers.reserve(params.dry_sequence_breakers.size()); - for (const auto & str : params.dry_sequence_breakers) { - c_breakers.push_back(str.c_str()); + if (params.top_n_sigma >= 0) { + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); + } else { + for (const auto & cnstr : params.samplers) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: + { + std::vector c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto & str : params.dry_sequence_breakers) { + c_breakers.push_back(str.c_str()); + } + + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } - - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); - } - break; - case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); - break; - case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); - break; - case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); - break; - case COMMON_SAMPLER_TYPE_PENALTIES: - llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); - break; - default: - GGML_ASSERT(false && "unknown sampler type"); + break; + case COMMON_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + break; + case COMMON_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_XTC: + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + case COMMON_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); + break; + case COMMON_SAMPLER_TYPE_PENALTIES: + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + default: + GGML_ASSERT(false && "unknown sampler type"); + } } } llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); diff --git a/common/speculative.h b/common/speculative.h index 50ec0344618aa..2baf99fc78ce1 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -9,7 +9,7 @@ struct common_speculative_params { int n_draft = 16; // max drafted tokens int n_reuse = 256; - float p_min = 0.9f; // min probabiliy required to accept a token in the draft + float p_min = 0.9f; // min probability required to accept a token in the draft }; struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 018a2a588ae9d..8b7c75d85a6f5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -558,7 +558,7 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: # NOTE: this function is generated by convert_hf_to_gguf_update.py # do not modify it manually! - # ref: https://github.com/ggerganov/llama.cpp/pull/6920 + # ref: https://github.com/ggml-org/llama.cpp/pull/6920 # Marker: Start get_vocab_base_pre def get_vocab_base_pre(self, tokenizer) -> str: # encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that @@ -708,7 +708,7 @@ def get_vocab_base_pre(self, tokenizer) -> str: logger.warning("** - the model has not been added to convert_hf_to_gguf_update.py yet") logger.warning("** - the pre-tokenization config has changed upstream") logger.warning("** Check your model files and convert_hf_to_gguf_update.py and update them accordingly.") - logger.warning("** ref: https://github.com/ggerganov/llama.cpp/pull/6920") + logger.warning("** ref: https://github.com/ggml-org/llama.cpp/pull/6920") logger.warning("**") logger.warning(f"** chkhsh: {chkhsh}") logger.warning("**************************************************************************************") @@ -2835,7 +2835,7 @@ def set_vocab(self): if chat_eos_token_id is not None: # For the chat model, we replace the eos with '<|im_end|>'. # TODO: this is a hack, should be fixed - # https://github.com/ggerganov/llama.cpp/pull/6745#issuecomment-2067687048 + # https://github.com/ggml-org/llama.cpp/pull/6745#issuecomment-2067687048 special_vocab.special_token_ids["eos"] = chat_eos_token_id logger.warning(f"Replace eos:{old_eos} with a special token:{chat_eos_token_id}" " in chat mode so that the conversation can end normally.") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index cea34413f441f..fa4989a80c544 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -8,7 +8,7 @@ # provide the necessary information to llama.cpp via the GGUF header in order to implement # the same pre-tokenizer. # -# ref: https://github.com/ggerganov/llama.cpp/pull/6920 +# ref: https://github.com/ggml-org/llama.cpp/pull/6920 # # Instructions: # @@ -246,7 +246,7 @@ def get_vocab_base_pre(self, tokenizer) -> str: logger.warning("** - the model has not been added to convert_hf_to_gguf_update.py yet") logger.warning("** - the pre-tokenization config has changed upstream") logger.warning("** Check your model files and convert_hf_to_gguf_update.py and update them accordingly.") - logger.warning("** ref: https://github.com/ggerganov/llama.cpp/pull/6920") + logger.warning("** ref: https://github.com/ggml-org/llama.cpp/pull/6920") logger.warning("**") logger.warning(f"** chkhsh: {{chkhsh}}") logger.warning("**************************************************************************************") diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index 6dea14a2329e8..bdc991533b4e0 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -395,7 +395,7 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") if ".embed_tokens.weight" in name or ".lm_head.weight" in name: logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning") - logger.error("Please refer to https://github.com/ggerganov/llama.cpp/pull/9948") + logger.error("Please refer to https://github.com/ggml-org/llama.cpp/pull/9948") sys.exit(1) if base_name in tensor_map: @@ -419,7 +419,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # some archs may have the same tensor for lm_head and output (tie word embeddings) # in this case, adapters targeting lm_head will fail when using llama-export-lora # therefore, we ignore them for now - # see: https://github.com/ggerganov/llama.cpp/issues/9065 + # see: https://github.com/ggml-org/llama.cpp/issues/9065 if name == "lm_head.weight" and len(dest) == 0: raise ValueError("lm_head is present in adapter, but is ignored in base model") for dest_name, dest_data in dest: diff --git a/docs/android.md b/docs/android.md index 47530c6c1d478..d2a835653fe5d 100644 --- a/docs/android.md +++ b/docs/android.md @@ -12,7 +12,7 @@ $ apt update && apt upgrade -y $ apt install git cmake ``` -Then, follow the [build instructions](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md), specifically for CMake. +Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake. Once the binaries are built, download your model of choice (e.g., from Hugging Face). It's recommended to place it in the `~/` directory for best performance: diff --git a/docs/backend/OPENCL.md b/docs/backend/OPENCL.md new file mode 100644 index 0000000000000..2a946dc8df0ff --- /dev/null +++ b/docs/backend/OPENCL.md @@ -0,0 +1,205 @@ +# llama.cpp for OpenCL + +- [Background](#background) +- [OS](#os) +- [Hardware](#hardware) +- [DataType Supports](#datatype-supports) +- [Model Preparation](#model-preparation) +- [CMake Options](#cmake-options) +- [Android](#android) +- [Windows 11 Arm64](#windows-11-arm64) +- [Known Issue](#known-issues) +- [TODO](#todo) + +## Background + +OpenCL (Open Computing Language) is an open, royalty-free standard for cross-platform, parallel programming of diverse accelerators found in supercomputers, cloud servers, personal computers, mobile devices and embedded platforms. OpenCL specifies a programming language (based on C99) for programming these devices and application programming interfaces (APIs) to control the platform and execute programs on the compute devices. Similar to CUDA, OpenCL has been widely used to program GPUs and is supported by most GPU vendors. + +### Llama.cpp + OpenCL + +The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adreno GPU** firstly via OpenCL. Thanks to the portabilty of OpenCL, the OpenCL backend can also run on certain Intel GPUs although the performance is not optimal. + +## OS + +| OS | Status | Verified | +|---------|---------|------------------------------------------------| +| Android | Support | Snapdragon 8 Gen 3, Snapdragon 8 Elite | +| Windows | Support | Windows 11 Arm64 with Snapdragon X Elite | +| Linux | Support | Ubuntu 22.04 WSL2 with Intel 12700H | + +## Hardware + +### Adreno GPU + +**Verified devices** + +| Adreno GPU | Status | +|:------------------------------------:|:-------:| +| Adreno 750 (Snapdragon 8 Gen 3) | Support | +| Adreno 830 (Snapdragon 8 Elite) | Support | +| Adreno X85 (Snapdragon X Elite) | Support | + +## DataType Supports + +| DataType | Status | +|:----------------------:|:--------------------------:| +| Q4_0 | Support | +| Q6_K | Support, but not optimized | + +## Model Preparation + +You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model prepration. + +Currently we support `Q4_0` quantization and have optimize for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize`. For example, + +```sh +./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0 +``` + +Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization. + +## CMake Options + +The OpenCL backend has the following CMake options that control the behavior of the backend. + +| CMake options | Default value | Description | +|:---------------------------------:|:--------------:|:------------------------------------------| +| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. | +| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. | + +## Android + +Ubuntu 22.04 is used for targeting Android. Make sure the following tools are accessible from command line, + +* Git +* CMake 3.29 +* Ninja +* Python3 + +### I. Setup Environment + +1. **Install NDK** + +```sh +cd ~ +wget https://dl.google.com/android/repository/commandlinetools-linux-8512546_latest.zip && \ +unzip commandlinetools-linux-8512546_latest.zip && \ +mkdir -p ~/android-sdk/cmdline-tools && \ +mv cmdline-tools latest && \ +mv latest ~/android-sdk/cmdline-tools/ && \ +rm -rf commandlinetools-linux-8512546_latest.zip + +yes | ~/android-sdk/cmdline-tools/latest/bin/sdkmanager "ndk;26.3.11579264" +``` + +2. **Install OpenCL Headers and Library** + +```sh +mkdir -p ~/dev/llm +cd ~/dev/llm + +git clone https://github.com/KhronosGroup/OpenCL-Headers && \ +cd OpenCL-Headers && \ +cp -r CL ~/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include + +cd ~/dev/llm + +git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && \ +cd OpenCL-ICD-Loader && \ +mkdir build_ndk26 && cd build_ndk26 && \ +cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$HOME/android-sdk/ndk/26.3.11579264/build/cmake/android.toolchain.cmake \ + -DOPENCL_ICD_LOADER_HEADERS_DIR=$HOME/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=24 \ + -DANDROID_STL=c++_shared && \ +ninja && \ +cp libOpenCL.so ~/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android +``` + +### II. Build llama.cpp + +```sh +cd ~/dev/llm + +git clone https://github.com/ggml-org/llama.cpp && \ +cd llama.cpp && \ +mkdir build-android && cd build-android + +cmake .. -G Ninja \ + -DCMAKE_TOOLCHAIN_FILE=$HOME/android-sdk/ndk/26.3.11579264/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DBUILD_SHARED_LIBS=OFF \ + -DGGML_OPENCL=ON + +ninja +``` + +## Windows 11 Arm64 + +A Snapdragon X Elite device with Windows 11 Arm64 is used. Make sure the following tools are accessible from command line, + +* Git +* CMake 3.29 +* Clang 19 +* Ninja +* Visual Studio 2022 + +Powershell is used for the following instructions. + +### I. Setup Environment + +1. **Install OpenCL Headers and Library** + +```powershell +mkdir -p ~/dev/llm + +cd ~/dev/llm +git clone https://github.com/KhronosGroup/OpenCL-Headers && cd OpenCL-Headers +mkdir build && cd build +cmake .. -G Ninja ` + -DBUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF ` + -DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl" +cmake --build . --target install + +cd ~/dev/llm +git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && cd OpenCL-ICD-Loader +mkdir build && cd build +cmake .. -G Ninja ` + -DCMAKE_BUILD_TYPE=Release ` + -DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" ` + -DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl" +cmake --build . --target install +``` + +### II. Build llama.cpp + +```powershell + +mkdir -p ~/dev/llm +cd ~/dev/llm + +git clone https://github.com/ggml-org/llama.cpp && cd llama.cpp +mkdir build && cd build + +cmake .. -G Ninja ` + -DCMAKE_TOOLCHAIN_FILE="$HOME/dev/llm/llama.cpp/cmake/arm64-windows-llvm.cmake" ` + -DCMAKE_BUILD_TYPE=Release ` + -DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" ` + -DBUILD_SHARED_LIBS=OFF ` + -DGGML_OPENCL=ON +ninja +``` + +## Known Issues + +- Qwen2.5 0.5B model produces gibberish output with Adreno kernels. + +## TODO + +- Fix Qwen2.5 0.5B +- Optimization for Q6_K +- Support and optimization for Q4_K diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 89ddbd669afa0..0cb39e7927cd6 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -36,8 +36,8 @@ The following release is verified with good quality: |Commit ID|Tag|Release|Verified Platform| Update date| |-|-|-|-|-| -|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggerganov/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19| -|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggerganov/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1|| +|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19| +|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1|| ## News @@ -58,7 +58,7 @@ The following release is verified with good quality: - 2024.3 - Release binary files of Windows. - A blog is published: **Run LLM on all Intel GPUs Using llama.cpp**: [intel.com](https://www.intel.com/content/www/us/en/developer/articles/technical/run-llm-on-all-gpus-using-llama-cpp-artical.html) or [medium.com](https://medium.com/@jianyu_neo/run-llm-on-all-intel-gpus-using-llama-cpp-fd2e2dcbd9bd). - - New base line is ready: [tag b2437](https://github.com/ggerganov/llama.cpp/tree/b2437). + - New base line is ready: [tag b2437](https://github.com/ggml-org/llama.cpp/tree/b2437). - Support multiple cards: **--split-mode**: [none|layer]; not support [row], it's on developing. - Support to assign main GPU by **--main-gpu**, replace $GGML_SYCL_DEVICE. - Support detecting all GPUs with level-zero and same top **Max compute units**. diff --git a/docs/build.md b/docs/build.md index afb7a040218db..69480aa0849a5 100644 --- a/docs/build.md +++ b/docs/build.md @@ -3,7 +3,7 @@ **To get the Code:** ```bash -git clone https://github.com/ggerganov/llama.cpp +git clone https://github.com/ggml-org/llama.cpp cd llama.cpp ``` @@ -46,7 +46,7 @@ cmake --build build --config Release ``` - Building for Windows (x86, x64 and arm64) with MSVC or clang as compilers: - - Install Visual Studio 2022, e.g. via the [Community Edition](https://visualstudio.microsoft.com/de/vs/community/). In the installer, select at least the following options (this also automatically installs the required additional tools like CMake,...): + - Install Visual Studio 2022, e.g. via the [Community Edition](https://visualstudio.microsoft.com/vs/community/). In the installer, select at least the following options (this also automatically installs the required additional tools like CMake,...): - Tab Workload: Desktop-development with C++ - Tab Components (select quickly via search): C++-_CMake_ Tools for Windows, _Git_ for Windows, C++-_Clang_ Compiler for Windows, MS-Build Support for LLVM-Toolset (clang) - Please remember to always use a Developer Command Prompt / PowerShell for VS2022 for git, build, test diff --git a/docs/cuda-fedora.md b/docs/cuda-fedora.md index 9c88b769415cf..75cd2b499d086 100644 --- a/docs/cuda-fedora.md +++ b/docs/cuda-fedora.md @@ -248,7 +248,7 @@ You have successfully set up CUDA on Fedora within a toolbox environment using t - **Building `llama.cpp`:** - - With CUDA installed, you can follow these [build instructions for `llama.cpp`](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) to compile it with CUDA support. + - With CUDA installed, you can follow these [build instructions for `llama.cpp`](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md) to compile it with CUDA support. - Ensure that any CUDA-specific build flags or paths are correctly set in your build configuration. - **Using the Toolbox Environment:** diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md index 8fcd7081130f2..78c6f76077a2b 100644 --- a/docs/development/HOWTO-add-model.md +++ b/docs/development/HOWTO-add-model.md @@ -104,16 +104,16 @@ Note: to debug the inference graph: you can use [llama-eval-callback](/examples/ ## GGUF specification -https://github.com/ggerganov/ggml/blob/master/docs/gguf.md +https://github.com/ggml-org/ggml/blob/master/docs/gguf.md ## Resources -- YaRN RoPE scaling https://github.com/ggerganov/llama.cpp/pull/2268 -- support Baichuan serial models https://github.com/ggerganov/llama.cpp/pull/3009 -- support attention bias https://github.com/ggerganov/llama.cpp/pull/4283 -- Mixtral support https://github.com/ggerganov/llama.cpp/pull/4406 -- BERT embeddings https://github.com/ggerganov/llama.cpp/pull/5423 -- Grok-1 support https://github.com/ggerganov/llama.cpp/pull/6204 -- Command R Plus support https://github.com/ggerganov/llama.cpp/pull/6491 -- support arch DBRX https://github.com/ggerganov/llama.cpp/pull/6515 -- How to convert HuggingFace model to GGUF format https://github.com/ggerganov/llama.cpp/discussions/2948 +- YaRN RoPE scaling https://github.com/ggml-org/llama.cpp/pull/2268 +- support Baichuan serial models https://github.com/ggml-org/llama.cpp/pull/3009 +- support attention bias https://github.com/ggml-org/llama.cpp/pull/4283 +- Mixtral support https://github.com/ggml-org/llama.cpp/pull/4406 +- BERT embeddings https://github.com/ggml-org/llama.cpp/pull/5423 +- Grok-1 support https://github.com/ggml-org/llama.cpp/pull/6204 +- Command R Plus support https://github.com/ggml-org/llama.cpp/pull/6491 +- support arch DBRX https://github.com/ggml-org/llama.cpp/pull/6515 +- How to convert HuggingFace model to GGUF format https://github.com/ggml-org/llama.cpp/discussions/2948 diff --git a/docs/docker.md b/docs/docker.md index dac9a9ec164ff..343146dbd214f 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -7,21 +7,21 @@ ## Images We have three Docker images available for this project: -1. `ghcr.io/ggerganov/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`) -2. `ghcr.io/ggerganov/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`) -3. `ghcr.io/ggerganov/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`) +1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`) +2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`) +3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`) Additionally, there the following images, similar to the above: -- `ghcr.io/ggerganov/llama.cpp:full-cuda`: Same as `full` but compiled with CUDA support. (platforms: `linux/amd64`) -- `ghcr.io/ggerganov/llama.cpp:light-cuda`: Same as `light` but compiled with CUDA support. (platforms: `linux/amd64`) -- `ghcr.io/ggerganov/llama.cpp:server-cuda`: Same as `server` but compiled with CUDA support. (platforms: `linux/amd64`) -- `ghcr.io/ggerganov/llama.cpp:full-rocm`: Same as `full` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) -- `ghcr.io/ggerganov/llama.cpp:light-rocm`: Same as `light` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) -- `ghcr.io/ggerganov/llama.cpp:server-rocm`: Same as `server` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) -- `ghcr.io/ggerganov/llama.cpp:full-musa`: Same as `full` but compiled with MUSA support. (platforms: `linux/amd64`) -- `ghcr.io/ggerganov/llama.cpp:light-musa`: Same as `light` but compiled with MUSA support. (platforms: `linux/amd64`) -- `ghcr.io/ggerganov/llama.cpp:server-musa`: Same as `server` but compiled with MUSA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:full-cuda`: Same as `full` but compiled with CUDA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:light-cuda`: Same as `light` but compiled with CUDA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:server-cuda`: Same as `server` but compiled with CUDA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:full-rocm`: Same as `full` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) +- `ghcr.io/ggml-org/llama.cpp:light-rocm`: Same as `light` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) +- `ghcr.io/ggml-org/llama.cpp:server-rocm`: Same as `server` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) +- `ghcr.io/ggml-org/llama.cpp:full-musa`: Same as `full` but compiled with MUSA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:light-musa`: Same as `light` but compiled with MUSA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:server-musa`: Same as `server` but compiled with MUSA support. (platforms: `linux/amd64`) The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now). @@ -32,25 +32,25 @@ The easiest way to download the models, convert them to ggml and optimize them i Replace `/path/to/models` below with the actual path where you downloaded the models. ```bash -docker run -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:full --all-in-one "/models/" 7B +docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --all-in-one "/models/" 7B ``` On completion, you are ready to play! ```bash -docker run -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:full --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 +docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 ``` or with a light image: ```bash -docker run -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:light -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 +docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:light -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 ``` or with a server image: ```bash -docker run -v /path/to/models:/models -p 8000:8000 ghcr.io/ggerganov/llama.cpp:server -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 +docker run -v /path/to/models:/models -p 8000:8000 ghcr.io/ggml-org/llama.cpp:server -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 ``` ## Docker With CUDA @@ -69,7 +69,7 @@ You may want to pass in some different `ARGS`, depending on the CUDA environment The defaults are: -- `CUDA_VERSION` set to `12.6.0` +- `CUDA_VERSION` set to `12.4.0` - `CUDA_DOCKER_ARCH` set to the cmake build default, which includes all the supported architectures The resulting images, are essentially the same as the non-CUDA images: @@ -104,7 +104,7 @@ You may want to pass in some different `ARGS`, depending on the MUSA environment The defaults are: -- `MUSA_VERSION` set to `rc3.1.0` +- `MUSA_VERSION` set to `rc3.1.1` The resulting images, are essentially the same as the non-MUSA images: diff --git a/docs/install.md b/docs/install.md index 10a568506835b..0e23a2c9e7ae1 100644 --- a/docs/install.md +++ b/docs/install.md @@ -7,7 +7,7 @@ On Mac and Linux, the homebrew package manager can be used via ```sh brew install llama.cpp ``` -The formula is automatically updated with new `llama.cpp` releases. More info: https://github.com/ggerganov/llama.cpp/discussions/7668 +The formula is automatically updated with new `llama.cpp` releases. More info: https://github.com/ggml-org/llama.cpp/discussions/7668 ## Nix diff --git a/docs/llguidance.md b/docs/llguidance.md index 792d20704036d..cda787b14de04 100644 --- a/docs/llguidance.md +++ b/docs/llguidance.md @@ -13,13 +13,15 @@ cmake -B build -DLLAMA_LLGUIDANCE=ON make -C build -j ``` +For Windows use `cmake --build build --config Release` instead of `make`. + This requires the Rust compiler and the `cargo` tool to be [installed](https://www.rust-lang.org/tools/install). ## Interface There are no new command-line arguments or modifications to `common_params`. When enabled, grammars starting with `%llguidance` are passed to LLGuidance instead of the [current](../grammars/README.md) llama.cpp grammars. Additionally, JSON Schema requests (e.g., using the `-j` argument in `llama-cli`) are also passed to LLGuidance. -For your existing GBNF grammars, you can use [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/scripts/gbnf_to_lark.py) to convert them to LLGuidance Lark-like format. +For your existing GBNF grammars, you can use [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/python/llguidance/gbnf_to_lark.py) to convert them to LLGuidance Lark-like format. ## Performance diff --git a/examples/cvector-generator/README.md b/examples/cvector-generator/README.md index be4dd5250f15f..6d5fd74ad8ca0 100644 --- a/examples/cvector-generator/README.md +++ b/examples/cvector-generator/README.md @@ -3,9 +3,9 @@ This example demonstrates how to generate a control vector using gguf models. Related PRs: -- [Add support for control vectors](https://github.com/ggerganov/llama.cpp/pull/5970) -- (Issue) [Generate control vector using llama.cpp](https://github.com/ggerganov/llama.cpp/issues/6880) -- [Add cvector-generator example](https://github.com/ggerganov/llama.cpp/pull/7514) +- [Add support for control vectors](https://github.com/ggml-org/llama.cpp/pull/5970) +- (Issue) [Generate control vector using llama.cpp](https://github.com/ggml-org/llama.cpp/issues/6880) +- [Add cvector-generator example](https://github.com/ggml-org/llama.cpp/pull/7514) ## Examples diff --git a/examples/imatrix/README.md b/examples/imatrix/README.md index 9c056986b3ed6..9aa2b20347927 100644 --- a/examples/imatrix/README.md +++ b/examples/imatrix/README.md @@ -1,7 +1,7 @@ # llama.cpp/examples/imatrix -Compute an importance matrix for a model and given text dataset. Can be used during quantization to enchance the quality of the quantized models. -More information is available here: https://github.com/ggerganov/llama.cpp/pull/4861 +Compute an importance matrix for a model and given text dataset. Can be used during quantization to enhance the quality of the quantized models. +More information is available here: https://github.com/ggml-org/llama.cpp/pull/4861 ## Usage diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index e335ecc74b8fe..b584b992faa7f 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -3,6 +3,7 @@ #include "log.h" #include "llama.h" +#include #include #include #include @@ -99,7 +100,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * const float * data = is_host ? (const float *) src1->data : m_src1_data.data(); // this has been adapted to the new format of storing merged experts in a single 3d tensor - // ref: https://github.com/ggerganov/llama.cpp/pull/6387 + // ref: https://github.com/ggml-org/llama.cpp/pull/6387 if (t->op == GGML_OP_MUL_MAT_ID) { // ids -> [n_experts_used, n_tokens] // src1 -> [cols, n_expert_used, n_tokens] diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index fc58135fe5fa8..cbcbfcee861ee 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -876,8 +876,8 @@ static std::vector get_cmd_params_instances(const cmd_param struct test { static const std::string build_commit; static const int build_number; - static const std::string cpu_info; - static const std::string gpu_info; + const std::string cpu_info; + const std::string gpu_info; std::string model_filename; std::string model_type; uint64_t model_size; @@ -903,7 +903,10 @@ struct test { std::string test_time; std::vector samples_ns; - test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) { + test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) : + cpu_info(get_cpu_info()), + gpu_info(get_gpu_info()) { + model_filename = inst.model; char buf[128]; llama_model_desc(lmodel, buf, sizeof(buf)); @@ -1058,8 +1061,6 @@ struct test { const std::string test::build_commit = LLAMA_COMMIT; const int test::build_number = LLAMA_BUILD_NUMBER; -const std::string test::cpu_info = get_cpu_info(); -const std::string test::gpu_info = get_gpu_info(); struct printer { virtual ~printer() {} diff --git a/examples/llama.android/llama/src/main/cpp/CMakeLists.txt b/examples/llama.android/llama/src/main/cpp/CMakeLists.txt index 2de496574f54a..6119fe09b0cb6 100644 --- a/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +++ b/examples/llama.android/llama/src/main/cpp/CMakeLists.txt @@ -14,7 +14,7 @@ project("llama-android") #include(FetchContent) #FetchContent_Declare( # llama -# GIT_REPOSITORY https://github.com/ggerganov/llama.cpp +# GIT_REPOSITORY https://github.com/ggml-org/llama.cpp # GIT_TAG master #) diff --git a/examples/llama.swiftui/README.md b/examples/llama.swiftui/README.md index 96cf743d48202..f717886d661ce 100644 --- a/examples/llama.swiftui/README.md +++ b/examples/llama.swiftui/README.md @@ -3,9 +3,9 @@ Local inference of llama.cpp on an iPhone. This is a sample app that can be used as a starting point for more advanced projects. -For usage instructions and performance stats, check the following discussion: https://github.com/ggerganov/llama.cpp/discussions/4508 +For usage instructions and performance stats, check the following discussion: https://github.com/ggml-org/llama.cpp/discussions/4508 -![image](https://github.com/ggerganov/llama.cpp/assets/1991296/2b40284f-8421-47a2-b634-74eece09a299) +![image](https://github.com/ggml-org/llama.cpp/assets/1991296/2b40284f-8421-47a2-b634-74eece09a299) Video demonstration: diff --git a/examples/llama.vim b/examples/llama.vim index 57eb2a9772d51..af3fd3935d765 100644 --- a/examples/llama.vim +++ b/examples/llama.vim @@ -39,7 +39,7 @@ " " :call llama#init() " -" more info: https://github.com/ggerganov/llama.cpp/pull/9787 +" more info: https://github.com/ggml-org/llama.cpp/pull/9787 " " colors (adjust to your liking) diff --git a/examples/llava/README-minicpmo2.6.md b/examples/llava/README-minicpmo2.6.md index 8713a43d64fd8..8f591506dbbb0 100644 --- a/examples/llava/README-minicpmo2.6.md +++ b/examples/llava/README-minicpmo2.6.md @@ -26,7 +26,7 @@ python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model ``` Build llama.cpp using `CMake`: -https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md +https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md ```bash cmake -B build diff --git a/examples/llava/README-minicpmv2.5.md b/examples/llava/README-minicpmv2.5.md index 1c8498ff9e151..b0e72a0fa7a78 100644 --- a/examples/llava/README-minicpmv2.5.md +++ b/examples/llava/README-minicpmv2.5.md @@ -6,7 +6,7 @@ Download [MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V- Clone llama.cpp: ```bash -git clone https://github.com/ggerganov/llama.cpp +git clone https://github.com/ggml-org/llama.cpp cd llama.cpp ``` diff --git a/examples/lookahead/README.md b/examples/lookahead/README.md index a69a471b47d39..aab3cd0ca49b9 100644 --- a/examples/lookahead/README.md +++ b/examples/lookahead/README.md @@ -4,4 +4,4 @@ Demonstration of lookahead decoding technique: https://lmsys.org/blog/2023-11-21-lookahead-decoding/ -More info: https://github.com/ggerganov/llama.cpp/pull/4207 +More info: https://github.com/ggml-org/llama.cpp/pull/4207 diff --git a/examples/lookup/README.md b/examples/lookup/README.md index 71c345c037a2f..07d73849b0601 100644 --- a/examples/lookup/README.md +++ b/examples/lookup/README.md @@ -8,5 +8,5 @@ The key parameters for lookup decoding are `ngram_min`, `ngram_max` and `n_draft More info: -https://github.com/ggerganov/llama.cpp/pull/4484 -https://github.com/ggerganov/llama.cpp/issues/4226 +https://github.com/ggml-org/llama.cpp/pull/4484 +https://github.com/ggml-org/llama.cpp/issues/4226 diff --git a/examples/main/README.md b/examples/main/README.md index ceaed42f649df..f7c2497294ab5 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -1,6 +1,6 @@ # llama.cpp/examples/main -This example program allows you to use various LLaMA language models easily and efficiently. It is specifically designed to work with the [llama.cpp](https://github.com/ggerganov/llama.cpp) project, which provides a plain C/C++ implementation with optional 4-bit quantization support for faster, lower memory inference, and is optimized for desktop CPUs. This program can be used to perform various inference tasks with LLaMA models, including generating text based on user-provided prompts and chat-like interactions with reverse prompts. +This example program allows you to use various LLaMA language models easily and efficiently. It is specifically designed to work with the [llama.cpp](https://github.com/ggml-org/llama.cpp) project, which provides a plain C/C++ implementation with optional 4-bit quantization support for faster, lower memory inference, and is optimized for desktop CPUs. This program can be used to perform various inference tasks with LLaMA models, including generating text based on user-provided prompts and chat-like interactions with reverse prompts. ## Table of Contents @@ -121,7 +121,7 @@ When --in-prefix or --in-suffix options are enabled the chat template ( --chat-t ### Chat templates - `--chat-template JINJA_TEMPLATE`: This option sets a custom jinja chat template. It accepts a string, not a file name. Default: template taken from model's metadata. Llama.cpp only supports [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template). These include llama2, llama3, gemma, monarch, chatml, orion, vicuna, vicuna-orca, deepseek, command-r, zephyr. When --in-prefix or --in-suffix options are enabled the chat template ( --chat-template ) is disabled. + `--chat-template JINJA_TEMPLATE`: This option sets a custom jinja chat template. It accepts a string, not a file name. Default: template taken from model's metadata. Llama.cpp only supports [some pre-defined templates](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template). These include llama2, llama3, gemma, monarch, chatml, orion, vicuna, vicuna-orca, deepseek, command-r, zephyr. When --in-prefix or --in-suffix options are enabled the chat template ( --chat-template ) is disabled. Example usage: `--chat-template gemma` @@ -265,6 +265,14 @@ Being experimental and unique, XTC is disabled by default. The recommended combi Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1` +### Top-nσ Sampling + +- `--top-nsigma N`: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1, -1 = disabled). + +Top-nσ sampling is a text generation method that selects tokens based on a statistical threshold in pre-softmax logits. It works by only sampling from tokens with logits that are within n * σ of the maximum logit. This method helps maintain a stable sampling space regardless of temperature scaling, allowing it to perform well on reasoning tasks even in high temperatures. Without complex probability manipulation, it efficiently filters tokens directly on the pre-softmax logits. A higher value for top-nsigma (e.g., 5) will take more noisy tokens into consideration, while a lower value (e.g., 1) will focous on the more informative region of the sampling space. + +Example usage: `--top-nsigma 1` + ### Logit Bias - `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion. diff --git a/examples/passkey/README.md b/examples/passkey/README.md index 2b8e910f9658d..2f19597c48d7f 100644 --- a/examples/passkey/README.md +++ b/examples/passkey/README.md @@ -5,8 +5,8 @@ models ability to recall information from long contexts. See the following PRs for more info: -- https://github.com/ggerganov/llama.cpp/pull/3856 -- https://github.com/ggerganov/llama.cpp/pull/4810 +- https://github.com/ggml-org/llama.cpp/pull/3856 +- https://github.com/ggml-org/llama.cpp/pull/4810 ### Usage diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 31c436f13976b..8c413f7d66e6d 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -3,6 +3,7 @@ #include "log.h" #include "llama.h" +#include #include #include #include diff --git a/examples/pydantic_models_to_grammar_examples.py b/examples/pydantic_models_to_grammar_examples.py index eb000d5ccba24..f94b82ca47570 100755 --- a/examples/pydantic_models_to_grammar_examples.py +++ b/examples/pydantic_models_to_grammar_examples.py @@ -23,7 +23,7 @@ def create_completion(host, prompt, gbnf_grammar): """Calls the /completion API on llama-server. See - https://github.com/ggerganov/llama.cpp/tree/HEAD/examples/server#api-endpoints + https://github.com/ggml-org/llama.cpp/tree/HEAD/examples/server#api-endpoints """ print(f" Request:\n Grammar:\n{textwrap.indent(gbnf_grammar, ' ')}\n Prompt:\n{textwrap.indent(prompt.rstrip(), ' ')}") headers = {"Content-Type": "application/json"} diff --git a/examples/quantize/README.md b/examples/quantize/README.md index f9cce7b213334..992d00e21b4fe 100644 --- a/examples/quantize/README.md +++ b/examples/quantize/README.md @@ -69,22 +69,22 @@ Several quantization methods are supported. They differ in the resulting model d | 13B | ms/tok @ 8th | - | 73 | 82 | 98 | 105 | 128 | | 13B | bits/weight | 16.0 | 4.5 | 5.0 | 5.5 | 6.0 | 8.5 | -- [k-quants](https://github.com/ggerganov/llama.cpp/pull/1684) +- [k-quants](https://github.com/ggml-org/llama.cpp/pull/1684) - recent k-quants improvements and new i-quants - - [#2707](https://github.com/ggerganov/llama.cpp/pull/2707) - - [#2807](https://github.com/ggerganov/llama.cpp/pull/2807) - - [#4773 - 2-bit i-quants (inference)](https://github.com/ggerganov/llama.cpp/pull/4773) - - [#4856 - 2-bit i-quants (inference)](https://github.com/ggerganov/llama.cpp/pull/4856) - - [#4861 - importance matrix](https://github.com/ggerganov/llama.cpp/pull/4861) - - [#4872 - MoE models](https://github.com/ggerganov/llama.cpp/pull/4872) - - [#4897 - 2-bit quantization](https://github.com/ggerganov/llama.cpp/pull/4897) - - [#4930 - imatrix for all k-quants](https://github.com/ggerganov/llama.cpp/pull/4930) - - [#4951 - imatrix on the GPU](https://github.com/ggerganov/llama.cpp/pull/4957) - - [#4969 - imatrix for legacy quants](https://github.com/ggerganov/llama.cpp/pull/4969) - - [#4996 - k-quants tuning](https://github.com/ggerganov/llama.cpp/pull/4996) - - [#5060 - Q3_K_XS](https://github.com/ggerganov/llama.cpp/pull/5060) - - [#5196 - 3-bit i-quants](https://github.com/ggerganov/llama.cpp/pull/5196) - - [quantization tuning](https://github.com/ggerganov/llama.cpp/pull/5320), [another one](https://github.com/ggerganov/llama.cpp/pull/5334), and [another one](https://github.com/ggerganov/llama.cpp/pull/5361) + - [#2707](https://github.com/ggml-org/llama.cpp/pull/2707) + - [#2807](https://github.com/ggml-org/llama.cpp/pull/2807) + - [#4773 - 2-bit i-quants (inference)](https://github.com/ggml-org/llama.cpp/pull/4773) + - [#4856 - 2-bit i-quants (inference)](https://github.com/ggml-org/llama.cpp/pull/4856) + - [#4861 - importance matrix](https://github.com/ggml-org/llama.cpp/pull/4861) + - [#4872 - MoE models](https://github.com/ggml-org/llama.cpp/pull/4872) + - [#4897 - 2-bit quantization](https://github.com/ggml-org/llama.cpp/pull/4897) + - [#4930 - imatrix for all k-quants](https://github.com/ggml-org/llama.cpp/pull/4930) + - [#4951 - imatrix on the GPU](https://github.com/ggml-org/llama.cpp/pull/4957) + - [#4969 - imatrix for legacy quants](https://github.com/ggml-org/llama.cpp/pull/4969) + - [#4996 - k-quants tuning](https://github.com/ggml-org/llama.cpp/pull/4996) + - [#5060 - Q3_K_XS](https://github.com/ggml-org/llama.cpp/pull/5060) + - [#5196 - 3-bit i-quants](https://github.com/ggml-org/llama.cpp/pull/5196) + - [quantization tuning](https://github.com/ggml-org/llama.cpp/pull/5320), [another one](https://github.com/ggml-org/llama.cpp/pull/5334), and [another one](https://github.com/ggml-org/llama.cpp/pull/5361) **Llama 2 7B** diff --git a/examples/retrieval/README.md b/examples/retrieval/README.md index bc5f22e2ff156..6938a1e96ee35 100644 --- a/examples/retrieval/README.md +++ b/examples/retrieval/README.md @@ -3,7 +3,7 @@ Demonstration of simple retrieval technique based on cosine similarity More info: -https://github.com/ggerganov/llama.cpp/pull/6193 +https://github.com/ggml-org/llama.cpp/pull/6193 ### How to use diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index 1b7cc8c1328e4..aee90388e4fb3 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -5,7 +5,7 @@ option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF) include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) if (MINGW) - # fix: https://github.com/ggerganov/llama.cpp/actions/runs/9651004652/job/26617901362?pr=8006 + # fix: https://github.com/ggml-org/llama.cpp/actions/runs/9651004652/job/26617901362?pr=8006 add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) endif() diff --git a/examples/server/README.md b/examples/server/README.md index d0b262f0e1f75..a2ae614d7c489 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -7,14 +7,14 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes - * Reranking endoint (WIP: https://github.com/ggerganov/llama.cpp/pull/9510) + * Reranking endoint (WIP: https://github.com/ggml-org/llama.cpp/pull/9510) * Parallel decoding with multi-user support * Continuous batching * Multimodal (wip) * Monitoring endpoints * Schema-constrained JSON response format -The project is under active development, and we are [looking for feedback and contributors](https://github.com/ggerganov/llama.cpp/issues/4216). +The project is under active development, and we are [looking for feedback and contributors](https://github.com/ggml-org/llama.cpp/issues/4216). ## Usage @@ -65,7 +65,7 @@ The project is under active development, and we are [looking for feedback and co | `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | | `--no-mmap` | do not memory-map model (slower load but may reduce pageouts if not using mlock)
(env: LLAMA_ARG_NO_MMAP) | -| `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggerganov/llama.cpp/issues/1437
(env: LLAMA_ARG_NUMA) | +| `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggml-org/llama.cpp/issues/1437
(env: LLAMA_ARG_NUMA) | | `-dev, --device ` | comma-separated list of devices to use for offloading (none = don't offload)
use --list-devices to see a list of available devices
(env: LLAMA_ARG_DEVICE) | | `--list-devices` | print list of available devices and exit | | `-ngl, --gpu-layers, --n-gpu-layers N` | number of layers to store in VRAM
(env: LLAMA_ARG_N_GPU_LAYERS) | @@ -127,6 +127,7 @@ The project is under active development, and we are [looking for feedback and co | `--grammar-file FNAME` | file to read grammar from | | `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | | `--jinja` | Enable experimental Jinja templating engine (required for tool use) | +| `--reasoning-format FORMAT` | Controls extraction of model thinking traces and the format / field in which they are returned (default: `deepseek`; allowed values: `deepseek`, `none`; requires `--jinja`). `none` will leave thinking traces inline in `message.content` in a model-specific format, while `deepseek` will return them separately under `message.reasoning_content` | **Example-specific params** @@ -177,7 +178,7 @@ Example usage of docker compose with environment variables: ```yml services: llamacpp-server: - image: ghcr.io/ggerganov/llama.cpp:server + image: ghcr.io/ggml-org/llama.cpp:server ports: - 8080:8080 volumes: @@ -272,10 +273,10 @@ You can consume the endpoints with Postman or NodeJS with axios library. You can ### Docker ```bash -docker run -p 8080:8080 -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:server -m models/7B/ggml-model.gguf -c 512 --host 0.0.0.0 --port 8080 +docker run -p 8080:8080 -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:server -m models/7B/ggml-model.gguf -c 512 --host 0.0.0.0 --port 8080 # or, with CUDA: -docker run -p 8080:8080 -v /path/to/models:/models --gpus all ghcr.io/ggerganov/llama.cpp:server-cuda -m models/7B/ggml-model.gguf -c 512 --host 0.0.0.0 --port 8080 --n-gpu-layers 99 +docker run -p 8080:8080 -v /path/to/models:/models --gpus all ghcr.io/ggml-org/llama.cpp:server-cuda -m models/7B/ggml-model.gguf -c 512 --host 0.0.0.0 --port 8080 --n-gpu-layers 99 ``` ## Testing with CURL @@ -1065,7 +1066,7 @@ print(completion.choices[0].text) ### POST `/v1/chat/completions`: OpenAI-compatible Chat Completions API -Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used. +Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used. *Options:* @@ -1119,7 +1120,7 @@ curl http://localhost:8080/v1/chat/completions \ *Tool call support* -[Function calling](https://platform.openai.com/docs/guides/function-calling) is supported for all models (see https://github.com/ggerganov/llama.cpp/pull/9639): +[Function calling](https://platform.openai.com/docs/guides/function-calling) is supported for all models (see https://github.com/ggml-org/llama.cpp/pull/9639): - Requires `--jinja` flag - Native tool call formats supported: @@ -1136,61 +1137,252 @@ curl http://localhost:8080/v1/chat/completions \ | Template | Format | |----------|--------| - | CohereForAI-c4ai-command-r-plus-default.jinja | generic tool calls | - | CohereForAI-c4ai-command-r-plus-rag.jinja | generic tool calls | - | CohereForAI-c4ai-command-r-plus-tool_use.jinja | generic tool calls | - | MiniMaxAI-MiniMax-Text-01.jinja | generic tool calls | - | NexaAIDev-Octopus-v2.jinja | generic tool calls | - | NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | generic tool calls | - | NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | hermes 2 pro tool calls | - | NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | generic tool calls | - | NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | hermes 2 pro tool calls | - | NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | generic tool calls | - | NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | hermes 2 pro tool calls | - | OrionStarAI-Orion-14B-Chat.jinja | generic tool calls | - | Qwen-QwQ-32B-Preview.jinja | hermes 2 pro tool calls | - | Qwen-Qwen2-7B-Instruct.jinja | generic tool calls | - | Qwen-Qwen2-VL-7B-Instruct.jinja | generic tool calls | - | Qwen-Qwen2.5-7B-Instruct.jinja | hermes 2 pro tool calls | - | Qwen-Qwen2.5-Math-7B-Instruct.jinja | hermes 2 pro tool calls | - | TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | generic tool calls | - | abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | generic tool calls | - | bofenghuang-vigogne-2-70b-chat.jinja | generic tool calls | - | databricks-dbrx-instruct.jinja | generic tool calls | - | deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | generic tool calls | - | deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | deepseek r1 tool calls | - | deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | deepseek r1 tool calls | - | deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | deepseek r1 tool calls | - | deepseek-ai-DeepSeek-V2.5.jinja | deepseek r1 tool calls | - | deepseek-ai-deepseek-coder-33b-instruct.jinja | generic tool calls | - | google-gemma-2-2b-it.jinja | generic tool calls | - | google-gemma-7b-it.jinja | generic tool calls | - | indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | generic tool calls | - | mattshumer-Reflection-Llama-3.1-70B.jinja | generic tool calls | - | meetkai-functionary-medium-v3.2.jinja | functionary v3.2 tool calls | - | meta-llama-Llama-3.1-8B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) | - | meta-llama-Llama-3.2-3B-Instruct.jinja | llama 3.x tool calls | - | meta-llama-Llama-3.3-70B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) | - | meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | llama 3.x tool calls (w/ builtin tools) | - | microsoft-Phi-3-medium-4k-instruct.jinja | generic tool calls | - | microsoft-Phi-3-mini-4k-instruct.jinja | generic tool calls | - | microsoft-Phi-3-small-8k-instruct.jinja | generic tool calls | - | microsoft-Phi-3.5-mini-instruct.jinja | generic tool calls | - | microsoft-Phi-3.5-vision-instruct.jinja | generic tool calls | - | mistralai-Mistral-7B-Instruct-v0.2.jinja | generic tool calls | - | mistralai-Mistral-Large-Instruct-2407.jinja | mistral nemo tool calls | - | mistralai-Mistral-Large-Instruct-2411.jinja | generic tool calls | - | mistralai-Mistral-Nemo-Instruct-2407.jinja | mistral nemo tool calls | - | mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | generic tool calls | - | mlabonne-AlphaMonarch-7B.jinja | generic tool calls | - | nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | llama 3.x tool calls (w/ builtin tools) | - | openchat-openchat-3.5-0106.jinja | generic tool calls | - | teknium-OpenHermes-2.5-Mistral-7B.jinja | generic tool calls | + | Almawave-Velvet-14B.jinja | Hermes 2 Pro | + | AtlaAI-Selene-1-Mini-Llama-3.1-8B.jinja | Llama 3.x | + | CohereForAI-aya-expanse-8b.jinja | Generic | + | CohereForAI-c4ai-command-r-plus-default.jinja | Generic | + | CohereForAI-c4ai-command-r-plus-rag.jinja | Generic | + | CohereForAI-c4ai-command-r-plus-tool_use.jinja | Generic | + | CohereForAI-c4ai-command-r7b-12-2024-default.jinja | Command R7B (extract reasoning) | + | CohereForAI-c4ai-command-r7b-12-2024-rag.jinja | Command R7B (extract reasoning) | + | CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja | Command R7B (extract reasoning) | + | CohereForAI-c4ai-command-r7b-12-2024.jinja | Generic | + | DavieLion-Llama-3.2-1B-SPIN-iter3.jinja | Generic | + | Delta-Vector-Rei-12B.jinja | Mistral Nemo | + | EpistemeAI-Mistral-Nemo-Instruct-12B-Philosophy-Math.jinja | Mistral Nemo | + | FlofloB-83k_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit.jinja | Hermes 2 Pro | + | FlofloB-test_continued_pretraining_Phi-3-mini-4k-instruct_Unsloth_merged_16bit.jinja | Generic | + | HelpingAI-HAI-SER.jinja | Generic | + | HuggingFaceTB-SmolLM2-1.7B-Instruct.jinja | Generic | + | HuggingFaceTB-SmolLM2-135M-Instruct.jinja | Generic | + | HuggingFaceTB-SmolLM2-360M-Instruct.jinja | Generic | + | INSAIT-Institute-BgGPT-Gemma-2-27B-IT-v1.0.jinja | Generic | + | Ihor-Text2Graph-R1-Qwen2.5-0.5b.jinja | Hermes 2 Pro | + | Infinigence-Megrez-3B-Instruct.jinja | Generic | + | Josephgflowers-TinyLlama_v1.1_math_code-world-test-1.jinja | Generic | + | LGAI-EXAONE-EXAONE-3.5-2.4B-Instruct.jinja | Generic | + | LGAI-EXAONE-EXAONE-3.5-7.8B-Instruct.jinja | Generic | + | LatitudeGames-Wayfarer-12B.jinja | Generic | + | Magpie-Align-Llama-3-8B-Magpie-Align-v0.1.jinja | Generic | + | Magpie-Align-Llama-3.1-8B-Magpie-Align-v0.1.jinja | Generic | + | MaziyarPanahi-calme-3.2-instruct-78b.jinja | Generic | + | MiniMaxAI-MiniMax-Text-01.jinja | Generic | + | MiniMaxAI-MiniMax-VL-01.jinja | Generic | + | NaniDAO-deepseek-r1-qwen-2.5-32B-ablated.jinja | DeepSeek R1 (extract reasoning) | + | NexaAIDev-Octopus-v2.jinja | Generic | + | NousResearch-Hermes-2-Pro-Llama-3-8B-default.jinja | Generic | + | NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | Hermes 2 Pro | + | NousResearch-Hermes-2-Pro-Mistral-7B-default.jinja | Generic | + | NousResearch-Hermes-2-Pro-Mistral-7B-tool_use.jinja | Hermes 2 Pro | + | NousResearch-Hermes-3-Llama-3.1-70B-default.jinja | Generic | + | NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja | Hermes 2 Pro | + | NovaSky-AI-Sky-T1-32B-Flash.jinja | Hermes 2 Pro | + | NovaSky-AI-Sky-T1-32B-Preview.jinja | Hermes 2 Pro | + | OnlyCheeini-greesychat-turbo.jinja | Generic | + | Orenguteng-Llama-3.1-8B-Lexi-Uncensored-V2.jinja | Llama 3.x | + | OrionStarAI-Orion-14B-Chat.jinja | Generic | + | PowerInfer-SmallThinker-3B-Preview.jinja | Generic | + | PrimeIntellect-INTELLECT-1-Instruct.jinja | Generic | + | Qwen-QVQ-72B-Preview.jinja | Generic | + | Qwen-QwQ-32B-Preview.jinja | Hermes 2 Pro | + | Qwen-Qwen1.5-7B-Chat.jinja | Generic | + | Qwen-Qwen2-7B-Instruct.jinja | Generic | + | Qwen-Qwen2-VL-72B-Instruct.jinja | Generic | + | Qwen-Qwen2-VL-7B-Instruct.jinja | Generic | + | Qwen-Qwen2.5-0.5B.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-1.5B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-14B-Instruct-1M.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-14B.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-32B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-32B.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-3B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-72B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-7B-Instruct-1M.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-7B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-7B.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-Coder-32B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-Coder-7B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-Math-1.5B.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-Math-7B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-VL-3B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-VL-72B-Instruct.jinja | Hermes 2 Pro | + | Qwen-Qwen2.5-VL-7B-Instruct.jinja | Hermes 2 Pro | + | RWKV-Red-Team-ARWKV-7B-Preview-0.1.jinja | Hermes 2 Pro | + | SakanaAI-TinySwallow-1.5B-Instruct.jinja | Hermes 2 Pro | + | SakanaAI-TinySwallow-1.5B.jinja | Hermes 2 Pro | + | Sao10K-70B-L3.3-Cirrus-x1.jinja | Llama 3.x | + | SentientAGI-Dobby-Mini-Leashed-Llama-3.1-8B.jinja | Llama 3.x | + | SentientAGI-Dobby-Mini-Unhinged-Llama-3.1-8B.jinja | Llama 3.x | + | Steelskull-L3.3-Damascus-R1.jinja | Llama 3.x | + | Steelskull-L3.3-MS-Nevoria-70b.jinja | Llama 3.x | + | Steelskull-L3.3-Nevoria-R1-70b.jinja | Llama 3.x | + | THUDM-glm-4-9b-chat.jinja | Generic | + | THUDM-glm-edge-1.5b-chat.jinja | Generic | + | Tarek07-Progenitor-V1.1-LLaMa-70B.jinja | Llama 3.x | + | TheBloke-FusionNet_34Bx2_MoE-AWQ.jinja | Generic | + | TinyLlama-TinyLlama-1.1B-Chat-v1.0.jinja | Generic | + | UCLA-AGI-Mistral7B-PairRM-SPPO-Iter3.jinja | Generic | + | ValiantLabs-Llama3.1-8B-Enigma.jinja | Llama 3.x | + | abacusai-Fewshot-Metamath-OrcaVicuna-Mistral.jinja | Generic | + | ai21labs-AI21-Jamba-1.5-Large.jinja | Generic | + | allenai-Llama-3.1-Tulu-3-405B-SFT.jinja | Generic | + | allenai-Llama-3.1-Tulu-3-405B.jinja | Generic | + | allenai-Llama-3.1-Tulu-3-8B.jinja | Generic | + | arcee-ai-Virtuoso-Lite.jinja | Hermes 2 Pro | + | arcee-ai-Virtuoso-Medium-v2.jinja | Hermes 2 Pro | + | arcee-ai-Virtuoso-Small-v2.jinja | Hermes 2 Pro | + | avemio-GRAG-NEMO-12B-ORPO-HESSIAN-AI.jinja | Generic | + | bespokelabs-Bespoke-Stratos-7B.jinja | Hermes 2 Pro | + | bfuzzy1-acheron-m1a-llama.jinja | Generic | + | bofenghuang-vigogne-2-70b-chat.jinja | Generic | + | bytedance-research-UI-TARS-72B-DPO.jinja | Generic | + | bytedance-research-UI-TARS-7B-DPO.jinja | Generic | + | bytedance-research-UI-TARS-7B-SFT.jinja | Generic | + | carsenk-phi3.5_mini_exp_825_uncensored.jinja | Generic | + | cyberagent-DeepSeek-R1-Distill-Qwen-14B-Japanese.jinja | DeepSeek R1 (extract reasoning) | + | cyberagent-DeepSeek-R1-Distill-Qwen-32B-Japanese.jinja | DeepSeek R1 (extract reasoning) | + | databricks-dbrx-instruct.jinja | Generic | + | deepseek-ai-DeepSeek-Coder-V2-Instruct.jinja | Generic | + | deepseek-ai-DeepSeek-Coder-V2-Lite-Base.jinja | Generic | + | deepseek-ai-DeepSeek-Coder-V2-Lite-Instruct.jinja | Generic | + | deepseek-ai-DeepSeek-R1-Distill-Llama-70B.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1-Distill-Qwen-1.5B.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1-Distill-Qwen-14B.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1-Distill-Qwen-7B.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1-Zero.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-V2-Lite.jinja | Generic | + | deepseek-ai-DeepSeek-V2.5.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-DeepSeek-V3.jinja | DeepSeek R1 (extract reasoning) | + | deepseek-ai-deepseek-coder-33b-instruct.jinja | Generic | + | deepseek-ai-deepseek-coder-6.7b-instruct.jinja | Generic | + | deepseek-ai-deepseek-coder-7b-instruct-v1.5.jinja | Generic | + | deepseek-ai-deepseek-llm-67b-chat.jinja | Generic | + | deepseek-ai-deepseek-llm-7b-chat.jinja | Generic | + | dicta-il-dictalm2.0-instruct.jinja | Generic | + | ehristoforu-Falcon3-8B-Franken-Basestruct.jinja | Hermes 2 Pro | + | fireworks-ai-llama-3-firefunction-v2.jinja | FireFunction v2 | + | godlikehhd-alpaca_data_sampled_ifd_new_5200.jinja | Hermes 2 Pro | + | godlikehhd-alpaca_data_score_max_0.7_2600.jinja | Hermes 2 Pro | + | google-gemma-2-27b-it.jinja | Generic | + | google-gemma-2-2b-it.jinja | Generic | + | google-gemma-2-2b-jpn-it.jinja | Generic | + | google-gemma-7b-it.jinja | Generic | + | huihui-ai-DeepSeek-R1-Distill-Llama-70B-abliterated.jinja | DeepSeek R1 (extract reasoning) | + | huihui-ai-DeepSeek-R1-Distill-Llama-8B-abliterated.jinja | DeepSeek R1 (extract reasoning) | + | huihui-ai-DeepSeek-R1-Distill-Qwen-14B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) | + | huihui-ai-DeepSeek-R1-Distill-Qwen-32B-abliterated.jinja | DeepSeek R1 (extract reasoning) | + | huihui-ai-DeepSeek-R1-Distill-Qwen-7B-abliterated-v2.jinja | DeepSeek R1 (extract reasoning) | + | huihui-ai-Qwen2.5-14B-Instruct-1M-abliterated.jinja | Hermes 2 Pro | + | ibm-granite-granite-3.1-8b-instruct.jinja | Generic | + | indischepartij-MiniCPM-3B-OpenHermes-2.5-v2.jinja | Generic | + | inflatebot-MN-12B-Mag-Mell-R1.jinja | Generic | + | jinaai-ReaderLM-v2.jinja | Generic | + | kms7530-chemeng_qwen-math-7b_24_1_100_1_nonmath.jinja | Hermes 2 Pro | + | knifeayumu-Cydonia-v1.3-Magnum-v4-22B.jinja | Mistral Nemo | + | langgptai-qwen1.5-7b-chat-sa-v0.1.jinja | Generic | + | lightblue-DeepSeek-R1-Distill-Qwen-7B-Japanese.jinja | DeepSeek R1 (extract reasoning) | + | mattshumer-Reflection-Llama-3.1-70B.jinja | Generic | + | meetkai-functionary-medium-v3.1.jinja | Functionary v3.1 Llama 3.1 | + | meetkai-functionary-medium-v3.2.jinja | Functionary v3.2 | + | meta-llama-Llama-2-7b-chat-hf.jinja | Generic | + | meta-llama-Llama-3.1-8B-Instruct.jinja | Llama 3.x | + | meta-llama-Llama-3.2-11B-Vision-Instruct.jinja | Llama 3.x | + | meta-llama-Llama-3.2-1B-Instruct.jinja | Llama 3.x | + | meta-llama-Llama-3.2-3B-Instruct.jinja | Llama 3.x | + | meta-llama-Llama-3.3-70B-Instruct.jinja | Llama 3.x | + | meta-llama-Meta-Llama-3-8B-Instruct.jinja | Generic | + | meta-llama-Meta-Llama-3.1-8B-Instruct.jinja | Llama 3.x | + | microsoft-Phi-3-medium-4k-instruct.jinja | Generic | + | microsoft-Phi-3-mini-4k-instruct.jinja | Generic | + | microsoft-Phi-3-small-8k-instruct.jinja | Generic | + | microsoft-Phi-3.5-mini-instruct.jinja | Generic | + | microsoft-Phi-3.5-vision-instruct.jinja | Generic | + | microsoft-phi-4.jinja | Generic | + | migtissera-Tess-3-Mistral-Nemo-12B.jinja | Generic | + | ministral-Ministral-3b-instruct.jinja | Generic | + | mistralai-Codestral-22B-v0.1.jinja | Generic | + | mistralai-Mistral-7B-Instruct-v0.1.jinja | Generic | + | mistralai-Mistral-7B-Instruct-v0.2.jinja | Generic | + | mistralai-Mistral-7B-Instruct-v0.3.jinja | Mistral Nemo | + | mistralai-Mistral-Large-Instruct-2407.jinja | Mistral Nemo | + | mistralai-Mistral-Large-Instruct-2411.jinja | Generic | + | mistralai-Mistral-Nemo-Instruct-2407.jinja | Mistral Nemo | + | mistralai-Mistral-Small-24B-Instruct-2501.jinja | Generic | + | mistralai-Mixtral-8x7B-Instruct-v0.1.jinja | Generic | + | mkurman-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | + | mlabonne-AlphaMonarch-7B.jinja | Generic | + | mlx-community-Josiefied-Qwen2.5-0.5B-Instruct-abliterated-v1-float32.jinja | Hermes 2 Pro | + | mlx-community-Qwen2.5-VL-7B-Instruct-8bit.jinja | Hermes 2 Pro | + | mobiuslabsgmbh-DeepSeek-R1-ReDistill-Qwen-1.5B-v1.1.jinja | DeepSeek R1 (extract reasoning) | + | netcat420-MFANNv0.20.jinja | Generic | + | netcat420-MFANNv0.24.jinja | Generic | + | netease-youdao-Confucius-o1-14B.jinja | Hermes 2 Pro | + | nvidia-AceMath-7B-RM.jinja | Hermes 2 Pro | + | nvidia-Eagle2-1B.jinja | Hermes 2 Pro | + | nvidia-Eagle2-9B.jinja | Hermes 2 Pro | + | nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | Llama 3.x | + | onnx-community-DeepSeek-R1-Distill-Qwen-1.5B-ONNX.jinja | DeepSeek R1 (extract reasoning) | + | open-thoughts-OpenThinker-7B.jinja | Hermes 2 Pro | + | openchat-openchat-3.5-0106.jinja | Generic | + | pankajmathur-orca_mini_v6_8b.jinja | Generic | + | princeton-nlp-Mistral-7B-Base-SFT-RDPO.jinja | Generic | + | princeton-nlp-Mistral-7B-Instruct-DPO.jinja | Generic | + | princeton-nlp-Mistral-7B-Instruct-RDPO.jinja | Generic | + | prithivMLmods-Bellatrix-Tiny-1.5B-R1.jinja | Hermes 2 Pro | + | prithivMLmods-Bellatrix-Tiny-1B-R1.jinja | Llama 3.x | + | prithivMLmods-Bellatrix-Tiny-1B-v3.jinja | Generic | + | prithivMLmods-Bellatrix-Tiny-3B-R1.jinja | Llama 3.x | + | prithivMLmods-Blaze-14B-xElite.jinja | Generic | + | prithivMLmods-Calcium-Opus-14B-Elite2-R1.jinja | Hermes 2 Pro | + | prithivMLmods-Calme-Ties-78B.jinja | Generic | + | prithivMLmods-Calme-Ties2-78B.jinja | Generic | + | prithivMLmods-Calme-Ties3-78B.jinja | Generic | + | prithivMLmods-ChemQwen2-vL.jinja | Generic | + | prithivMLmods-GWQ2b.jinja | Generic | + | prithivMLmods-LatexMind-2B-Codec.jinja | Generic | + | prithivMLmods-Llama-3.2-6B-AlgoCode.jinja | Llama 3.x | + | prithivMLmods-Megatron-Opus-14B-Exp.jinja | Hermes 2 Pro | + | prithivMLmods-Megatron-Opus-14B-Stock.jinja | Hermes 2 Pro | + | prithivMLmods-Megatron-Opus-7B-Exp.jinja | Hermes 2 Pro | + | prithivMLmods-Omni-Reasoner-Merged.jinja | Hermes 2 Pro | + | prithivMLmods-Omni-Reasoner4-Merged.jinja | Hermes 2 Pro | + | prithivMLmods-Primal-Opus-14B-Optimus-v1.jinja | Hermes 2 Pro | + | prithivMLmods-QwQ-Math-IO-500M.jinja | Hermes 2 Pro | + | prithivMLmods-Qwen-7B-Distill-Reasoner.jinja | DeepSeek R1 (extract reasoning) | + | prithivMLmods-Qwen2.5-1.5B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro | + | prithivMLmods-Qwen2.5-14B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | + | prithivMLmods-Qwen2.5-32B-DeepSeek-R1-Instruct.jinja | Hermes 2 Pro | + | prithivMLmods-Qwen2.5-7B-DeepSeek-R1-1M.jinja | Hermes 2 Pro | + | prithivMLmods-Triangulum-v2-10B.jinja | Hermes 2 Pro | + | qingy2024-Falcon3-2x10B-MoE-Instruct.jinja | Hermes 2 Pro | + | rubenroy-Zurich-14B-GCv2-5m.jinja | Hermes 2 Pro | + | rubenroy-Zurich-7B-GCv2-5m.jinja | Hermes 2 Pro | + | silma-ai-SILMA-Kashif-2B-Instruct-v1.0.jinja | Generic | + | simplescaling-s1-32B.jinja | Hermes 2 Pro | + | sometimesanotion-Lamarck-14B-v0.7.jinja | Hermes 2 Pro | + | sonthenguyen-zephyr-sft-bnb-4bit-DPO-mtbr-180steps.jinja | Generic | + | sthenno-tempesthenno-icy-0130.jinja | Generic | + | sumink-qwft.jinja | Hermes 2 Pro | + | teknium-OpenHermes-2.5-Mistral-7B.jinja | Generic | + | thirdeyeai-elevate360m.jinja | Generic | + | tiiuae-Falcon3-10B-Instruct.jinja | Hermes 2 Pro | + | unsloth-DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit.jinja | DeepSeek R1 (extract reasoning) | + | unsloth-DeepSeek-R1-Distill-Llama-8B.jinja | DeepSeek R1 (extract reasoning) | + | unsloth-DeepSeek-R1.jinja | DeepSeek R1 (extract reasoning) | + | unsloth-Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit.jinja | Generic | + | upstage-solar-pro-preview-instruct.jinja | Generic | + | whyhow-ai-PatientSeek.jinja | Generic | + | xwen-team-Xwen-72B-Chat.jinja | Hermes 2 Pro | + | xwen-team-Xwen-7B-Chat.jinja | Hermes 2 Pro | This table can be generated with: ```bash ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null + ```
@@ -1202,11 +1394,20 @@ curl http://localhost:8080/v1/chat/completions \ ```shell # Native support: + llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M + # Native support for DeepSeek R1 works best w/ our own template (official template buggy) + + llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \ + --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja + + llama-server --jinja -fa -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q4_K_M \ + --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja + # Native support requires the right template for these GGUFs: llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \ @@ -1236,17 +1437,17 @@ curl http://localhost:8080/v1/chat/completions \ { "type":"function", "function":{ - "name":"get_current_weather", - "description":"Get the current weather in a given location", + "name":"python", + "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", "parameters":{ "type":"object", "properties":{ - "location":{ + "code":{ "type":"string", - "description":"The city and state, e.g. San Francisco, CA" + "description":"The code to run in the ipython interpreter." } }, - "required":["location"] + "required":["code"] } } } @@ -1254,7 +1455,7 @@ curl http://localhost:8080/v1/chat/completions \ "messages": [ { "role": "user", - "content": "What is the weather like in Istanbul?." + "content": "Print a hello world message with python." } ] }' @@ -1398,7 +1599,7 @@ Apart from error types supported by OAI, we also have custom types that are spec ### Legacy completion web UI -A new chat-based UI has replaced the old completion-based since [this PR](https://github.com/ggerganov/llama.cpp/pull/10175). If you want to use the old completion, start the server with `--path ./examples/server/public_legacy` +A new chat-based UI has replaced the old completion-based since [this PR](https://github.com/ggml-org/llama.cpp/pull/10175). If you want to use the old completion, start the server with `--path ./examples/server/public_legacy` For example: diff --git a/examples/server/httplib.h b/examples/server/httplib.h index c2f12dd2ac8d6..593beb50150b1 100644 --- a/examples/server/httplib.h +++ b/examples/server/httplib.h @@ -1,14 +1,14 @@ // // httplib.h // -// Copyright (c) 2024 Yuji Hirose. All rights reserved. +// Copyright (c) 2025 Yuji Hirose. All rights reserved. // MIT License // #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.18.5" +#define CPPHTTPLIB_VERSION "0.19.0" /* * Configuration @@ -66,6 +66,10 @@ #define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND 0 #endif +#ifndef CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND +#define CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND 0 +#endif + #ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND #define CPPHTTPLIB_IDLE_INTERVAL_SECOND 0 #endif @@ -189,6 +193,7 @@ using ssize_t = long; #endif using socket_t = SOCKET; +using socklen_t = int; #ifdef CPPHTTPLIB_USE_POLL #define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) #endif @@ -218,7 +223,9 @@ using socket_t = SOCKET; #include #include #include +#ifndef __VMS #include +#endif #include #include #include @@ -662,6 +669,8 @@ struct Request { ContentProvider content_provider_; bool is_chunked_content_provider_ = false; size_t authorization_count_ = 0; + std::chrono::time_point start_time_ = + std::chrono::steady_clock::time_point::min(); }; struct Response { @@ -735,6 +744,8 @@ class Stream { virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; virtual socket_t socket() const = 0; + virtual time_t duration() const = 0; + ssize_t write(const char *ptr); ssize_t write(const std::string &s); }; @@ -838,6 +849,16 @@ using Logger = std::function; using SocketOptions = std::function; +namespace detail { + +bool set_socket_opt_impl(socket_t sock, int level, int optname, + const void *optval, socklen_t optlen); +bool set_socket_opt(socket_t sock, int level, int optname, int opt); +bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, + time_t usec); + +} // namespace detail + void default_socket_options(socket_t sock); const char *status_message(int status); @@ -1421,6 +1442,10 @@ class ClientImpl { template void set_write_timeout(const std::chrono::duration &duration); + void set_max_timeout(time_t msec); + template + void set_max_timeout(const std::chrono::duration &duration); + void set_basic_auth(const std::string &username, const std::string &password); void set_bearer_token_auth(const std::string &token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -1529,6 +1554,7 @@ class ClientImpl { time_t read_timeout_usec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND; time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; + time_t max_timeout_msec_ = CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND; std::string basic_auth_username_; std::string basic_auth_password_; @@ -1583,9 +1609,6 @@ class ClientImpl { bool send_(Request &req, Response &res, Error &error); Result send_(Request &&req); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - bool is_ssl_peer_could_be_closed(SSL *ssl) const; -#endif socket_t create_client_socket(Error &error) const; bool read_response_line(Stream &strm, const Request &req, Response &res) const; @@ -1611,8 +1634,10 @@ class ClientImpl { std::string adjust_host_string(const std::string &host) const; - virtual bool process_socket(const Socket &socket, - std::function callback); + virtual bool + process_socket(const Socket &socket, + std::chrono::time_point start_time, + std::function callback); virtual bool is_ssl() const; }; @@ -1854,6 +1879,10 @@ class Client { template void set_write_timeout(const std::chrono::duration &duration); + void set_max_timeout(time_t msec); + template + void set_max_timeout(const std::chrono::duration &duration); + void set_basic_auth(const std::string &username, const std::string &password); void set_bearer_token_auth(const std::string &token); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -1971,12 +2000,16 @@ class SSLClient final : public ClientImpl { void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); - bool process_socket(const Socket &socket, - std::function callback) override; + bool + process_socket(const Socket &socket, + std::chrono::time_point start_time, + std::function callback) override; bool is_ssl() const override; - bool connect_with_proxy(Socket &sock, Response &res, bool &success, - Error &error); + bool connect_with_proxy( + Socket &sock, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error); bool initialize_ssl(Socket &socket, Error &error); bool load_certs(); @@ -2053,20 +2086,45 @@ inline uint64_t Response::get_header_value_u64(const std::string &key, return detail::get_header_value_u64(headers, key, def, id); } -inline void default_socket_options(socket_t sock) { - int opt = 1; +namespace detail { + +inline bool set_socket_opt_impl(socket_t sock, int level, int optname, + const void *optval, socklen_t optlen) { + return setsockopt(sock, level, optname, #ifdef _WIN32 - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, - reinterpret_cast(&opt), sizeof(opt)); + reinterpret_cast(optval), #else -#ifdef SO_REUSEPORT - setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, - reinterpret_cast(&opt), sizeof(opt)); + optval, +#endif + optlen) == 0; +} + +inline bool set_socket_opt(socket_t sock, int level, int optname, int optval) { + return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval)); +} + +inline bool set_socket_opt_time(socket_t sock, int level, int optname, + time_t sec, time_t usec) { +#ifdef _WIN32 + auto timeout = static_cast(sec * 1000 + usec / 1000); #else - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, - reinterpret_cast(&opt), sizeof(opt)); + timeval timeout; + timeout.tv_sec = static_cast(sec); + timeout.tv_usec = static_cast(usec); #endif + return set_socket_opt_impl(sock, level, optname, &timeout, sizeof(timeout)); +} + +} // namespace detail + +inline void default_socket_options(socket_t sock) { + detail::set_socket_opt(sock, SOL_SOCKET, +#ifdef SO_REUSEPORT + SO_REUSEPORT, +#else + SO_REUSEADDR, #endif + 1); } inline const char *status_message(int status) { @@ -2238,6 +2296,14 @@ inline void ClientImpl::set_write_timeout( duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); } +template +inline void ClientImpl::set_max_timeout( + const std::chrono::duration &duration) { + auto msec = + std::chrono::duration_cast(duration).count(); + set_max_timeout(msec); +} + template inline void Client::set_connection_timeout( const std::chrono::duration &duration) { @@ -2256,6 +2322,12 @@ Client::set_write_timeout(const std::chrono::duration &duration) { cli_->set_write_timeout(duration); } +template +inline void +Client::set_max_timeout(const std::chrono::duration &duration) { + cli_->set_max_timeout(duration); +} + /* * Forward declarations and types that will be part of the .h file if split into * .h + .cc. @@ -2330,10 +2402,12 @@ void split(const char *b, const char *e, char d, void split(const char *b, const char *e, char d, size_t m, std::function fn); -bool process_client_socket(socket_t sock, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, - std::function callback); +bool process_client_socket( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time, + std::function callback); socket_t create_client_socket(const std::string &host, const std::string &ip, int port, int address_family, bool tcp_nodelay, @@ -2381,6 +2455,7 @@ class BufferStream final : public Stream { void get_remote_ip_and_port(std::string &ip, int &port) const override; void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; + time_t duration() const override; const std::string &get_buffer() const; @@ -2547,7 +2622,7 @@ inline bool is_obs_text(char c) { return 128 <= static_cast(c); } inline bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } inline bool is_field_content(const std::string &s) { - if (s.empty()) { return false; } + if (s.empty()) { return true; } if (s.size() == 1) { return is_field_vchar(s[0]); @@ -3171,60 +3246,43 @@ inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, }); } -inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { +template +inline ssize_t select_impl(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLIN; + struct pollfd pfd; + pfd.fd = sock; + pfd.events = (Read ? POLLIN : POLLOUT); auto timeout = static_cast(sec * 1000 + usec / 1000); - return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + return handle_EINTR([&]() { return poll(&pfd, 1, timeout); }); #else #ifndef _WIN32 if (sock >= FD_SETSIZE) { return -1; } #endif - fd_set fds; + fd_set fds, *rfds, *wfds; FD_ZERO(&fds); FD_SET(sock, &fds); + rfds = (Read ? &fds : nullptr); + wfds = (Read ? nullptr : &fds); timeval tv; tv.tv_sec = static_cast(sec); tv.tv_usec = static_cast(usec); return handle_EINTR([&]() { - return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); + return select(static_cast(sock + 1), rfds, wfds, nullptr, &tv); }); #endif } -inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { -#ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLOUT; - - auto timeout = static_cast(sec * 1000 + usec / 1000); - - return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); -#else -#ifndef _WIN32 - if (sock >= FD_SETSIZE) { return -1; } -#endif - - fd_set fds; - FD_ZERO(&fds); - FD_SET(sock, &fds); - - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); +inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { + return select_impl(sock, sec, usec); +} - return handle_EINTR([&]() { - return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); - }); -#endif +inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { + return select_impl(sock, sec, usec); } inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, @@ -3298,7 +3356,10 @@ inline bool is_socket_alive(socket_t sock) { class SocketStream final : public Stream { public: SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec); + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec = 0, + std::chrono::time_point start_time = + std::chrono::steady_clock::time_point::min()); ~SocketStream() override; bool is_readable() const override; @@ -3308,6 +3369,7 @@ class SocketStream final : public Stream { void get_remote_ip_and_port(std::string &ip, int &port) const override; void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; + time_t duration() const override; private: socket_t sock_; @@ -3315,6 +3377,8 @@ class SocketStream final : public Stream { time_t read_timeout_usec_; time_t write_timeout_sec_; time_t write_timeout_usec_; + time_t max_timeout_msec_; + const std::chrono::time_point start_time; std::vector read_buff_; size_t read_buff_off_ = 0; @@ -3326,9 +3390,12 @@ class SocketStream final : public Stream { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLSocketStream final : public Stream { public: - SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec); + SSLSocketStream( + socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec = 0, + std::chrono::time_point start_time = + std::chrono::steady_clock::time_point::min()); ~SSLSocketStream() override; bool is_readable() const override; @@ -3338,6 +3405,7 @@ class SSLSocketStream final : public Stream { void get_remote_ip_and_port(std::string &ip, int &port) const override; void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; + time_t duration() const override; private: socket_t sock_; @@ -3346,6 +3414,8 @@ class SSLSocketStream final : public Stream { time_t read_timeout_usec_; time_t write_timeout_sec_; time_t write_timeout_usec_; + time_t max_timeout_msec_; + const std::chrono::time_point start_time; }; #endif @@ -3416,13 +3486,15 @@ process_server_socket(const std::atomic &svr_sock, socket_t sock, }); } -inline bool process_client_socket(socket_t sock, time_t read_timeout_sec, - time_t read_timeout_usec, - time_t write_timeout_sec, - time_t write_timeout_usec, - std::function callback) { +inline bool process_client_socket( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time, + std::function callback) { SocketStream strm(sock, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); return callback(strm); } @@ -3569,26 +3641,10 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, } #endif - if (tcp_nodelay) { - auto opt = 1; -#ifdef _WIN32 - setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, - reinterpret_cast(&opt), sizeof(opt)); -#else - setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, - reinterpret_cast(&opt), sizeof(opt)); -#endif - } + if (tcp_nodelay) { set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1); } if (rp->ai_family == AF_INET6) { - auto opt = ipv6_v6only ? 1 : 0; -#ifdef _WIN32 - setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, - reinterpret_cast(&opt), sizeof(opt)); -#else - setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, - reinterpret_cast(&opt), sizeof(opt)); -#endif + set_socket_opt(sock, IPPROTO_IPV6, IPV6_V6ONLY, ipv6_v6only ? 1 : 0); } if (socket_options) { socket_options(sock); } @@ -3731,36 +3787,10 @@ inline socket_t create_client_socket( } set_nonblocking(sock2, false); - - { -#ifdef _WIN32 - auto timeout = static_cast(read_timeout_sec * 1000 + - read_timeout_usec / 1000); - setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, - reinterpret_cast(&timeout), sizeof(timeout)); -#else - timeval tv; - tv.tv_sec = static_cast(read_timeout_sec); - tv.tv_usec = static_cast(read_timeout_usec); - setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, - reinterpret_cast(&tv), sizeof(tv)); -#endif - } - { - -#ifdef _WIN32 - auto timeout = static_cast(write_timeout_sec * 1000 + - write_timeout_usec / 1000); - setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, - reinterpret_cast(&timeout), sizeof(timeout)); -#else - timeval tv; - tv.tv_sec = static_cast(write_timeout_sec); - tv.tv_usec = static_cast(write_timeout_usec); - setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, - reinterpret_cast(&tv), sizeof(tv)); -#endif - } + set_socket_opt_time(sock2, SOL_SOCKET, SO_RCVTIMEO, read_timeout_sec, + read_timeout_usec); + set_socket_opt_time(sock2, SOL_SOCKET, SO_SNDTIMEO, write_timeout_sec, + write_timeout_usec); error = Error::Success; return true; @@ -4212,22 +4242,21 @@ inline bool parse_header(const char *beg, const char *end, T fn) { if (!key_len) { return false; } auto key = std::string(beg, key_end); - auto val = case_ignore::equal(key, "Location") - ? std::string(p, end) - : decode_url(std::string(p, end), false); - - // NOTE: From RFC 9110: - // Field values containing CR, LF, or NUL characters are - // invalid and dangerous, due to the varying ways that - // implementations might parse and interpret those - // characters; a recipient of CR, LF, or NUL within a field - // value MUST either reject the message or replace each of - // those characters with SP before further processing or - // forwarding of that message. - static const std::string CR_LF_NUL("\r\n\0", 3); - if (val.find_first_of(CR_LF_NUL) != std::string::npos) { return false; } - - fn(key, val); + // auto val = (case_ignore::equal(key, "Location") || + // case_ignore::equal(key, "Referer")) + // ? std::string(p, end) + // : decode_url(std::string(p, end), false); + auto val = std::string(p, end); + + if (!detail::fields::is_field_value(val)) { return false; } + + if (case_ignore::equal(key, "Location") || + case_ignore::equal(key, "Referer")) { + fn(key, val); + } else { + fn(key, decode_url(val, false)); + } + return true; } @@ -4263,7 +4292,7 @@ inline bool read_headers(Stream &strm, Headers &headers) { auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; if (!parse_header(line_reader.ptr(), end, - [&](const std::string &key, std::string &val) { + [&](const std::string &key, const std::string &val) { headers.emplace(key, val); })) { return false; @@ -4312,7 +4341,7 @@ inline bool read_content_without_length(Stream &strm, uint64_t r = 0; for (;;) { auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); - if (n <= 0) { return true; } + if (n <= 0) { return false; } if (!out(buf, static_cast(n), r, 0)) { return false; } r += static_cast(n); @@ -4456,9 +4485,9 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, ret = read_content_without_length(strm, out); } else { auto is_invalid_value = false; - auto len = get_header_value_u64(x.headers, "Content-Length", - std::numeric_limits::max(), - 0, is_invalid_value); + auto len = get_header_value_u64( + x.headers, "Content-Length", + (std::numeric_limits::max)(), 0, is_invalid_value); if (is_invalid_value) { ret = false; @@ -5380,10 +5409,14 @@ write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, inline bool expect_content(const Request &req) { if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || - req.method == "PRI" || req.method == "DELETE") { + req.method == "DELETE") { + return true; + } + if (req.has_header("Content-Length") && + req.get_header_value_u64("Content-Length") > 0) { return true; } - // TODO: check if Content-Length is set + if (is_chunked_transfer_encoding(req.headers)) { return true; } return false; } @@ -5428,9 +5461,76 @@ inline std::string SHA_256(const std::string &s) { inline std::string SHA_512(const std::string &s) { return message_digest(s, EVP_sha512()); } -#endif -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + std::string nc; + { + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; + nc = ss.str(); + } + + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + std::string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 + : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + } + + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + + auto field = "Digest username=\"" + username + "\", realm=\"" + + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + (qop.empty() ? ", response=\"" + : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + + cnonce + "\", response=\"") + + response + "\"" + + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} + +inline bool is_ssl_peer_could_be_closed(SSL *ssl, socket_t sock) { + detail::set_nonblocking(sock, true); + auto se = detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + char buf[1]; + return !SSL_peek(ssl, buf, 1) && + SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; +} + #ifdef _WIN32 // NOTE: This code came up with the following stackoverflow post: // https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store @@ -5569,68 +5669,6 @@ class WSInit { static WSInit wsinit_; #endif -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -inline std::pair make_digest_authentication_header( - const Request &req, const std::map &auth, - size_t cnonce_count, const std::string &cnonce, const std::string &username, - const std::string &password, bool is_proxy = false) { - std::string nc; - { - std::stringstream ss; - ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; - nc = ss.str(); - } - - std::string qop; - if (auth.find("qop") != auth.end()) { - qop = auth.at("qop"); - if (qop.find("auth-int") != std::string::npos) { - qop = "auth-int"; - } else if (qop.find("auth") != std::string::npos) { - qop = "auth"; - } else { - qop.clear(); - } - } - - std::string algo = "MD5"; - if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } - - std::string response; - { - auto H = algo == "SHA-256" ? detail::SHA_256 - : algo == "SHA-512" ? detail::SHA_512 - : detail::MD5; - - auto A1 = username + ":" + auth.at("realm") + ":" + password; - - auto A2 = req.method + ":" + req.path; - if (qop == "auth-int") { A2 += ":" + H(req.body); } - - if (qop.empty()) { - response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); - } else { - response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + - ":" + qop + ":" + H(A2)); - } - } - - auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; - - auto field = "Digest username=\"" + username + "\", realm=\"" + - auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + - "\", uri=\"" + req.path + "\", algorithm=" + algo + - (qop.empty() ? ", response=\"" - : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + - cnonce + "\", response=\"") + - response + "\"" + - (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); - - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); -} -#endif - inline bool parse_www_authenticate(const Response &res, std::map &auth, bool is_proxy) { @@ -5949,20 +5987,45 @@ inline ssize_t Stream::write(const std::string &s) { namespace detail { +inline void calc_actual_timeout(time_t max_timeout_msec, + time_t duration_msec, time_t timeout_sec, + time_t timeout_usec, time_t &actual_timeout_sec, + time_t &actual_timeout_usec) { + auto timeout_msec = (timeout_sec * 1000) + (timeout_usec / 1000); + + auto actual_timeout_msec = + std::min(max_timeout_msec - duration_msec, timeout_msec); + + actual_timeout_sec = actual_timeout_msec / 1000; + actual_timeout_usec = (actual_timeout_msec % 1000) * 1000; +} + // Socket stream implementation -inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, - time_t read_timeout_usec, - time_t write_timeout_sec, - time_t write_timeout_usec) +inline SocketStream::SocketStream( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time) : sock_(sock), read_timeout_sec_(read_timeout_sec), read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), - write_timeout_usec_(write_timeout_usec), read_buff_(read_buff_size_, 0) {} + write_timeout_usec_(write_timeout_usec), + max_timeout_msec_(max_timeout_msec), start_time(start_time), + read_buff_(read_buff_size_, 0) {} inline SocketStream::~SocketStream() = default; inline bool SocketStream::is_readable() const { - return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + if (max_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; } inline bool SocketStream::is_writable() const { @@ -6039,6 +6102,12 @@ inline void SocketStream::get_local_ip_and_port(std::string &ip, inline socket_t SocketStream::socket() const { return sock_; } +inline time_t SocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time) + .count(); +} + // Buffer stream implementation inline bool BufferStream::is_readable() const { return true; } @@ -6067,6 +6136,8 @@ inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/, inline socket_t BufferStream::socket() const { return 0; } +inline time_t BufferStream::duration() const { return 0; } + inline const std::string &BufferStream::get_buffer() const { return buffer; } inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { @@ -6869,35 +6940,10 @@ inline bool Server::listen_internal() { break; } - { -#ifdef _WIN32 - auto timeout = static_cast(read_timeout_sec_ * 1000 + - read_timeout_usec_ / 1000); - setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, - reinterpret_cast(&timeout), sizeof(timeout)); -#else - timeval tv; - tv.tv_sec = static_cast(read_timeout_sec_); - tv.tv_usec = static_cast(read_timeout_usec_); - setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, - reinterpret_cast(&tv), sizeof(tv)); -#endif - } - { - -#ifdef _WIN32 - auto timeout = static_cast(write_timeout_sec_ * 1000 + - write_timeout_usec_ / 1000); - setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, - reinterpret_cast(&timeout), sizeof(timeout)); -#else - timeval tv; - tv.tv_sec = static_cast(write_timeout_sec_); - tv.tv_usec = static_cast(write_timeout_usec_); - setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, - reinterpret_cast(&tv), sizeof(tv)); -#endif - } + detail::set_socket_opt_time(sock, SOL_SOCKET, SO_RCVTIMEO, + read_timeout_sec_, read_timeout_usec_); + detail::set_socket_opt_time(sock, SOL_SOCKET, SO_SNDTIMEO, + write_timeout_sec_, write_timeout_usec_); if (!task_queue->enqueue( [this, sock]() { process_and_close_socket(sock); })) { @@ -7157,14 +7203,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr, #endif #endif - // Check if the request URI doesn't exceed the limit - if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { - Headers dummy; - detail::read_headers(strm, dummy); - res.status = StatusCode::UriTooLong_414; - return write_response(strm, close_connection, req, res); - } - // Request line and headers if (!parse_request_line(line_reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { @@ -7172,6 +7210,14 @@ Server::process_request(Stream &strm, const std::string &remote_addr, return write_response(strm, close_connection, req, res); } + // Check if the request URI doesn't exceed the limit + if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::UriTooLong_414; + return write_response(strm, close_connection, req, res); + } + if (req.get_header_value("Connection") == "close") { connection_closed = true; } @@ -7363,6 +7409,7 @@ inline void ClientImpl::copy_settings(const ClientImpl &rhs) { read_timeout_usec_ = rhs.read_timeout_usec_; write_timeout_sec_ = rhs.write_timeout_sec_; write_timeout_usec_ = rhs.write_timeout_usec_; + max_timeout_msec_ = rhs.max_timeout_msec_; basic_auth_username_ = rhs.basic_auth_username_; basic_auth_password_ = rhs.basic_auth_password_; bearer_token_auth_token_ = rhs.bearer_token_auth_token_; @@ -7509,18 +7556,6 @@ inline bool ClientImpl::send(Request &req, Response &res, Error &error) { return ret; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -inline bool ClientImpl::is_ssl_peer_could_be_closed(SSL *ssl) const { - detail::set_nonblocking(socket_.sock, true); - auto se = detail::scope_exit( - [&]() { detail::set_nonblocking(socket_.sock, false); }); - - char buf[1]; - return !SSL_peek(ssl, buf, 1) && - SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; -} -#endif - inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { { std::lock_guard guard(socket_mutex_); @@ -7535,7 +7570,9 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT if (is_alive && is_ssl()) { - if (is_ssl_peer_could_be_closed(socket_.ssl)) { is_alive = false; } + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + is_alive = false; + } } #endif @@ -7560,7 +7597,8 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { auto &scli = static_cast(*this); if (!proxy_host_.empty() && proxy_port_ != -1) { auto success = false; - if (!scli.connect_with_proxy(socket_, res, success, error)) { + if (!scli.connect_with_proxy(socket_, req.start_time_, res, success, + error)) { return success; } } @@ -7606,7 +7644,7 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { } }); - ret = process_socket(socket_, [&](Stream &strm) { + ret = process_socket(socket_, req.start_time_, [&](Stream &strm) { return handle_request(strm, req, res, close_connection, error); }); @@ -8015,6 +8053,9 @@ inline Result ClientImpl::send_with_content_provider( req.headers = headers; req.path = path; req.progress = progress; + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } auto error = Error::Success; @@ -8041,7 +8082,7 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req, if (is_ssl()) { auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; if (!is_proxy_enabled) { - if (is_ssl_peer_could_be_closed(socket_.ssl)) { + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { error = Error::SSLPeerCouldBeClosed_; return false; } @@ -8166,12 +8207,14 @@ inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider( }; } -inline bool -ClientImpl::process_socket(const Socket &socket, - std::function callback) { +inline bool ClientImpl::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { return detail::process_client_socket( socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, std::move(callback)); + write_timeout_usec_, max_timeout_msec_, start_time, + std::move(callback)); } inline bool ClientImpl::is_ssl() const { return false; } @@ -8195,6 +8238,9 @@ inline Result ClientImpl::Get(const std::string &path, const Headers &headers, req.path = path; req.headers = headers; req.progress = std::move(progress); + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } return send_(std::move(req)); } @@ -8260,6 +8306,9 @@ inline Result ClientImpl::Get(const std::string &path, const Headers &headers, return content_receiver(data, data_length); }; req.progress = std::move(progress); + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } return send_(std::move(req)); } @@ -8305,6 +8354,9 @@ inline Result ClientImpl::Head(const std::string &path, req.method = "HEAD"; req.headers = headers; req.path = path; + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } return send_(std::move(req)); } @@ -8722,6 +8774,9 @@ inline Result ClientImpl::Delete(const std::string &path, req.headers = headers; req.path = path; req.progress = progress; + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } if (!content_type.empty()) { req.set_header("Content-Type", content_type); } req.body.assign(body, content_length); @@ -8769,6 +8824,9 @@ inline Result ClientImpl::Options(const std::string &path, req.method = "OPTIONS"; req.headers = headers; req.path = path; + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } return send_(std::move(req)); } @@ -8822,6 +8880,10 @@ inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { write_timeout_usec_ = usec; } +inline void ClientImpl::set_max_timeout(time_t msec) { + max_timeout_msec_ = msec; +} + inline void ClientImpl::set_basic_auth(const std::string &username, const std::string &password) { basic_auth_username_ = username; @@ -9004,11 +9066,7 @@ inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, (void)(sock); SSL_shutdown(ssl); #else - timeval tv; - tv.tv_sec = 1; - tv.tv_usec = 0; - setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, - reinterpret_cast(&tv), sizeof(tv)); + detail::set_socket_opt_time(sock, SOL_SOCKET, SO_RCVTIMEO, 1, 0); auto ret = SSL_shutdown(ssl); while (ret == 0) { @@ -9060,12 +9118,14 @@ inline bool process_server_socket_ssl( } template -inline bool -process_client_socket_ssl(SSL *ssl, socket_t sock, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, T callback) { +inline bool process_client_socket_ssl( + SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time, T callback) { SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); + write_timeout_sec, write_timeout_usec, + max_timeout_msec, start_time); return callback(strm); } @@ -9078,27 +9138,37 @@ class SSLInit { }; // SSL socket stream implementation -inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, - time_t read_timeout_sec, - time_t read_timeout_usec, - time_t write_timeout_sec, - time_t write_timeout_usec) +inline SSLSocketStream::SSLSocketStream( + socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time) : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), - write_timeout_usec_(write_timeout_usec) { + write_timeout_usec_(write_timeout_usec), + max_timeout_msec_(max_timeout_msec), start_time(start_time) { SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); } inline SSLSocketStream::~SSLSocketStream() = default; inline bool SSLSocketStream::is_readable() const { - return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + if (max_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; } inline bool SSLSocketStream::is_writable() const { return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && - is_socket_alive(sock_); + is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); } inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { @@ -9129,8 +9199,9 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { } } return ret; + } else { + return -1; } - return -1; } inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { @@ -9176,6 +9247,12 @@ inline void SSLSocketStream::get_local_ip_and_port(std::string &ip, inline socket_t SSLSocketStream::socket() const { return sock_; } +inline time_t SSLSocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time) + .count(); +} + static SSLInit sslinit_; } // namespace detail @@ -9416,16 +9493,22 @@ inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { } // Assumes that socket_mutex_ is locked and that there are no requests in flight -inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, - bool &success, Error &error) { +inline bool SSLClient::connect_with_proxy( + Socket &socket, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error) { success = true; Response proxy_res; if (!detail::process_client_socket( socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { Request req2; req2.method = "CONNECT"; req2.path = host_and_port_; + if (max_timeout_msec_ > 0) { + req2.start_time_ = std::chrono::steady_clock::now(); + } return process_request(strm, req2, proxy_res, false, error); })) { // Thread-safe to close everything because we are assuming there are no @@ -9445,7 +9528,8 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, proxy_res = Response(); if (!detail::process_client_socket( socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { Request req3; req3.method = "CONNECT"; req3.path = host_and_port_; @@ -9453,6 +9537,9 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, req3, auth, 1, detail::random_string(10), proxy_digest_auth_username_, proxy_digest_auth_password_, true)); + if (max_timeout_msec_ > 0) { + req3.start_time_ = std::chrono::steady_clock::now(); + } return process_request(strm, req3, proxy_res, false, error); })) { // Thread-safe to close everything because we are assuming there are @@ -9608,13 +9695,15 @@ inline void SSLClient::shutdown_ssl_impl(Socket &socket, assert(socket.ssl == nullptr); } -inline bool -SSLClient::process_socket(const Socket &socket, - std::function callback) { +inline bool SSLClient::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { assert(socket.ssl); return detail::process_client_socket_ssl( socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, std::move(callback)); + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, start_time, + std::move(callback)); } inline bool SSLClient::is_ssl() const { return true; } diff --git a/examples/server/public/index.html.gz b/examples/server/public/index.html.gz index 141e8092057ac..1925b334b463f 100644 Binary files a/examples/server/public/index.html.gz and b/examples/server/public/index.html.gz differ diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 38d96ddd6b213..c287eab23f8b6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -42,7 +42,7 @@ enum stop_type { STOP_TYPE_LIMIT, }; -// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283 +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future @@ -173,6 +173,7 @@ struct slot_params { {"grammar_trigger_words", grammar_trigger_words}, {"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, {"preserved_tokens", sampling.preserved_tokens}, + {"chat_format", common_chat_format_name(oaicompat_chat_format)}, {"samplers", samplers}, {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, @@ -724,9 +725,19 @@ struct server_task_result_cmpl_final : server_task_result { msg.content = content; } - json tool_calls; + json message { + {"role", "assistant"}, + }; + if (!msg.reasoning_content.empty()) { + message["reasoning_content"] = msg.reasoning_content; + } + if (msg.content.empty() && !msg.tool_calls.empty()) { + message["content"] = json(); + } else { + message["content"] = msg.content; + } if (!msg.tool_calls.empty()) { - tool_calls = json::array(); + auto tool_calls = json::array(); for (const auto & tc : msg.tool_calls) { tool_calls.push_back({ {"type", "function"}, @@ -737,15 +748,7 @@ struct server_task_result_cmpl_final : server_task_result { {"id", tc.id}, }); } - } - - json message { - {"content", msg.content}, - {"tool_calls", tool_calls}, - {"role", "assistant"}, - }; - if (!msg.tool_plan.empty()) { - message["tool_plan"] = msg.tool_plan; + message["tool_calls"] = tool_calls; } json choice { @@ -1600,6 +1603,10 @@ struct server_queue { while (true) { std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } if (queue_tasks.empty()) { lock.unlock(); break; @@ -1620,11 +1627,11 @@ struct server_queue { QUE_DBG("%s", "waiting for new tasks\n"); { std::unique_lock lock(mutex_tasks); + if (!running) { + QUE_DBG("%s", "terminate\n"); + return; + } if (queue_tasks.empty()) { - if (!running) { - QUE_DBG("%s", "terminate\n"); - return; - } condition_tasks.wait(lock, [&]{ return (!queue_tasks.empty() || !running); }); @@ -2069,8 +2076,8 @@ struct server_context { if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { // Might be better to reject the request with a 400 ? + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict); slot.params.n_predict = slot.n_predict; - SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict); } if (slot.params.ignore_eos && has_eos_token) { @@ -2275,7 +2282,7 @@ struct server_context { for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { result.probs.push_back({ cur_p->data[i].id, - common_detokenize(ctx, {cur_p->data[i].id}, special), + common_token_to_piece(ctx, cur_p->data[i].id, special), cur_p->data[i].p }); } @@ -2297,7 +2304,7 @@ struct server_context { for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { result.probs.push_back({ cur[i].id, - common_detokenize(ctx, {cur[i].id}, special), + common_token_to_piece(ctx, cur[i].id, special), cur[i].p }); } @@ -3649,7 +3656,7 @@ int main(int argc, char ** argv) { }, { {"name", "n_busy_slots_per_decode"}, {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) res_metrics->n_busy_slots_total / (float) res_metrics->n_decode_total} + {"value", (float) res_metrics->n_busy_slots_total / std::max((float) res_metrics->n_decode_total, 1.f)} }}}, {"gauge", {{ {"name", "prompt_tokens_seconds"}, @@ -4056,7 +4063,7 @@ int main(int argc, char ** argv) { } auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates); + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, @@ -4069,7 +4076,7 @@ int main(int argc, char ** argv) { // same with handle_chat_completions, but without inference part const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates); + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; @@ -4430,6 +4437,7 @@ int main(int argc, char ** argv) { // clean up function, to be called before exit auto clean_up = [&svr]() { + SRV_INF("%s: cleaning up before exit...\n", __func__); svr->stop(); llama_backend_free(); }; @@ -4446,10 +4454,6 @@ int main(int argc, char ** argv) { } if (!was_bound) { - //LOG_ERROR("couldn't bind HTTP server socket", { - // {"hostname", params.hostname}, - // {"port", params.port}, - //}); LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port); clean_up(); return 1; @@ -4466,7 +4470,7 @@ int main(int argc, char ** argv) { if (!ctx_server.load_model(params)) { clean_up(); - t.join(); + // t.join(); // FIXME: see below LOG_ERR("%s: exiting due to model loading error\n", __func__); return 1; } @@ -4490,13 +4494,10 @@ int main(int argc, char ** argv) { }); shutdown_handler = [&](int) { + // this will unblock start_loop() ctx_server.queue_tasks.terminate(); }; - LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port); - - ctx_server.queue_tasks.start_loop(); - #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; sigint_action.sa_handler = signal_handler; @@ -4511,8 +4512,13 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif + LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port); + + // this call blocks the main thread until queue_tasks.terminate() is called + ctx_server.queue_tasks.start_loop(); + clean_up(); - t.join(); + // t.join(); // FIXME: http thread may stuck if there is an on-going request. we don't need to care about this for now as the HTTP connection will already be closed at this point, but it's better to fix this return 0; } diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index 4a551404f22a9..ba3367b4f332d 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -92,6 +92,7 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] + assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] assert expected_function_name == tool_call["function"]["name"] actual_arguments = tool_call["function"]["arguments"] @@ -155,11 +156,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + # (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), - (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + # (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), @@ -175,7 +176,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), - (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), + # (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), # TODO: fix these # (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), # (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), @@ -214,6 +215,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] + assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] assert expected_function_name == tool_call["function"]["name"] actual_arguments = tool_call["function"]["arguments"] @@ -273,7 +275,6 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t @pytest.mark.slow @pytest.mark.parametrize("hf_repo,template_override", [ - ("bartowski/c4ai-command-r7b-12-2024-GGUF:Q4_K_M", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")), ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), @@ -298,13 +299,16 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")), + + ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), - # ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ]) -def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None): +def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server n_predict = 512 server.n_slots = 1 @@ -323,6 +327,7 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None) res = server.make_request("POST", "/chat/completions", data={ "max_tokens": n_predict, "messages": [ + {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."}, {"role": "user", "content": "What is the weather in Istanbul?"}, ], "tools": [WEATHER_TOOL], @@ -332,6 +337,7 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None) tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] + assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"] actual_arguments = json.loads(tool_call["function"]["arguments"]) assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" @@ -340,22 +346,166 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None) assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' +@pytest.mark.slow +@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [ + (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + ("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + + # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value) + ("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + ("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), +]) +def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None): + global server + # n_predict = 512 + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 * 2 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if isinstance(template_override, tuple): + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things, and provide very concise answers. Do not explain your reasoning to the user. Provide any numerical values back to the user with at most two decimals."}, + {"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_6789", + "type": "function", + "function": { + "name": "calculate", + "arguments": "{\"expression\":\"sin(30 * pi / 180)\"}" + } + } + ] + }, + { + "role": "tool", + "name": "calculate", + "content": 0.55644242476, + "tool_call_id": "call_6789" + } + ], + "tools": [ + { + "type":"function", + "function":{ + "name":"calculate", + "description":"A calculator function that computes values of arithmetic expressions in the Python syntax", + "parameters":{ + "type":"object", + "properties":{ + "expression":{ + "type":"string", + "description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)" + } + }, + "required":["expression"] + } + } + } + ] + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls is None, f'Expected no tool call in {choice["message"]}' + content = choice["message"].get("content") + assert content is not None, f'Expected content in {choice["message"]}' + if result_override is not None: + assert re.match(result_override, content), f'Expected {result_override}, got {content}' + else: + assert re.match('^[\\s\\S]*?The (y[ -])?coordinate [\\s\\S]*?is (approximately )?0\\.56\\b|^0\\.56$', content), \ + f'Expected something like "The y coordinate is 0.56.", got {content}' + + +@pytest.mark.slow +@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [ + (128, 'deepseek', "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + + (1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (1024, 'none', "\n?I need[\\s\\S]*?\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + + (1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), +]) +def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): + global server + server.n_slots = 1 + server.reasoning_format = reasoning_format + server.jinja = True + server.n_ctx = 8192 * 2 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if isinstance(template_override, tuple): + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "user", "content": "What's the sum of 102 and 7?"}, + ] + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' + + content = choice["message"].get("content") + if expect_content is None: + assert content is None, f'Expected no content in {choice["message"]}' + else: + assert re.match(expect_content, content), f'Expected {expect_content}, got {content}' + + reasoning_content = choice["message"].get("reasoning_content") + if expect_reasoning_content is None: + assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}' + else: + assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}' + + @pytest.mark.slow @pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ + (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"), + (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), - (None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), - ('{"code":"print("}', "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), - ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + (None, "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), @@ -371,15 +521,13 @@ def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None) # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. (None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), - - # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ]) -def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): +def test_hello_world(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server server.n_slots = 1 server.jinja = True server.n_ctx = 8192 - server.n_predict = 128 + server.n_predict = 512 # High because of DeepSeek R1 server.model_hf_repo = hf_repo server.model_hf_file = None if isinstance(template_override, tuple): @@ -406,6 +554,7 @@ def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] + assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] actual_arguments = tool_call["function"]["arguments"] if expected_arguments_override is not None: diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 97d650a9c0cd0..c247ed30139b7 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -78,6 +78,7 @@ class ServerProcess: draft_max: int | None = None no_webui: bool | None = None jinja: bool | None = None + reasoning_format: Literal['deepseek', 'none'] | None = None chat_template: str | None = None chat_template_file: str | None = None @@ -172,6 +173,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.append("--no-webui") if self.jinja: server_args.append("--jinja") + if self.reasoning_format is not None: + server_args.extend(("--reasoning-format", self.reasoning_format)) if self.chat_template: server_args.extend(["--chat-template", self.chat_template]) if self.chat_template_file: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 5f97df5fde639..b5aebebba4ac7 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -367,10 +367,10 @@ inline std::string format_chat(const common_chat_template & tmpl, const std::vec } } } else { - throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); } } else { - throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + throw std::runtime_error("Missing 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); } chat.push_back({role, content, /* tool_calls= */ {}}); @@ -578,6 +578,7 @@ static json oaicompat_completion_params_parse(const json & body) { static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ bool use_jinja, + common_reasoning_format reasoning_format, const common_chat_templates & chat_templates) { json llama_params; @@ -633,9 +634,10 @@ static json oaicompat_completion_params_parse( throw std::runtime_error("Cannot use custom grammar constraints with tools."); } common_chat_inputs inputs; - inputs.messages = body.at("messages"); - inputs.tools = tools; - inputs.tool_choice = tool_choice; + inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; + inputs.messages = body.at("messages"); + inputs.tools = tools; + inputs.tool_choice = tool_choice; inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); diff --git a/examples/server/webui/package-lock.json b/examples/server/webui/package-lock.json index c6c5de3c0c97e..056592dd4775d 100644 --- a/examples/server/webui/package-lock.json +++ b/examples/server/webui/package-lock.json @@ -13,6 +13,7 @@ "@vscode/markdown-it-katex": "^1.1.1", "autoprefixer": "^10.4.20", "daisyui": "^4.12.14", + "dexie": "^4.0.11", "highlight.js": "^11.10.0", "katex": "^0.16.15", "postcss": "^8.4.49", @@ -2338,6 +2339,12 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/dexie": { + "version": "4.0.11", + "resolved": "https://registry.npmjs.org/dexie/-/dexie-4.0.11.tgz", + "integrity": "sha512-SOKO002EqlvBYYKQSew3iymBoN2EQ4BDw/3yprjh7kAfFzjBYkaMNa/pZvcA7HSWlcKSQb9XhPe3wKyQ0x4A8A==", + "license": "Apache-2.0" + }, "node_modules/didyoumean": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz", diff --git a/examples/server/webui/package.json b/examples/server/webui/package.json index 3be2b14de084b..8c833d0241e1c 100644 --- a/examples/server/webui/package.json +++ b/examples/server/webui/package.json @@ -16,6 +16,7 @@ "@vscode/markdown-it-katex": "^1.1.1", "autoprefixer": "^10.4.20", "daisyui": "^4.12.14", + "dexie": "^4.0.11", "highlight.js": "^11.10.0", "katex": "^0.16.15", "postcss": "^8.4.49", diff --git a/examples/server/webui/src/components/ChatMessage.tsx b/examples/server/webui/src/components/ChatMessage.tsx index ec72196baf0a6..68be7c751db54 100644 --- a/examples/server/webui/src/components/ChatMessage.tsx +++ b/examples/server/webui/src/components/ChatMessage.tsx @@ -3,6 +3,7 @@ import { useAppContext } from '../utils/app.context'; import { Message, PendingMessage } from '../utils/types'; import { classNames } from '../utils/misc'; import MarkdownDisplay, { CopyButton } from './MarkdownDisplay'; +import { ChevronLeftIcon, ChevronRightIcon } from '@heroicons/react/24/outline'; interface SplitMessage { content: PendingMessage['content']; @@ -12,17 +13,24 @@ interface SplitMessage { export default function ChatMessage({ msg, + siblingLeafNodeIds, + siblingCurrIdx, id, - scrollToBottom, + onRegenerateMessage, + onEditMessage, + onChangeSibling, isPending, }: { msg: Message | PendingMessage; + siblingLeafNodeIds: Message['id'][]; + siblingCurrIdx: number; id?: string; - scrollToBottom: (requiresNearBottom: boolean) => void; + onRegenerateMessage(msg: Message): void; + onEditMessage(msg: Message, content: string): void; + onChangeSibling(sibling: Message['id']): void; isPending?: boolean; }) { - const { viewingConversation, replaceMessageAndGenerate, config } = - useAppContext(); + const { viewingChat, config } = useAppContext(); const [editingContent, setEditingContent] = useState(null); const timings = useMemo( () => @@ -37,6 +45,8 @@ export default function ChatMessage({ : null, [msg.timings] ); + const nextSibling = siblingLeafNodeIds[siblingCurrIdx + 1]; + const prevSibling = siblingLeafNodeIds[siblingCurrIdx - 1]; // for reasoning model, we split the message into content and thought // TODO: implement this as remark/rehype plugin in the future @@ -64,13 +74,7 @@ export default function ChatMessage({ return { content: actualContent, thought, isThinking }; }, [msg]); - if (!viewingConversation) return null; - - const regenerate = async () => { - replaceMessageAndGenerate(viewingConversation.id, msg.id, undefined, () => - scrollToBottom(true) - ); - }; + if (!viewingChat) return null; return (
@@ -105,13 +109,12 @@ export default function ChatMessage({ @@ -196,10 +199,35 @@ export default function ChatMessage({ {msg.content !== null && (
+ {siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && ( +
+ + + {siblingCurrIdx + 1} / {siblingLeafNodeIds.length} + + +
+ )} {/* user message */} {msg.role === 'user' && ( )} - )} +
)}
diff --git a/examples/server/webui/src/components/ChatScreen.tsx b/examples/server/webui/src/components/ChatScreen.tsx index dbc683ed15cd2..2636c1e7bca52 100644 --- a/examples/server/webui/src/components/ChatScreen.tsx +++ b/examples/server/webui/src/components/ChatScreen.tsx @@ -1,28 +1,59 @@ -import { useEffect, useState } from 'react'; -import { useAppContext } from '../utils/app.context'; -import StorageUtils from '../utils/storage'; -import { useNavigate } from 'react-router'; +import { useEffect, useMemo, useState } from 'react'; +import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context'; import ChatMessage from './ChatMessage'; -import { CanvasType, PendingMessage } from '../utils/types'; -import { classNames } from '../utils/misc'; +import { CanvasType, Message, PendingMessage } from '../utils/types'; +import { classNames, throttle } from '../utils/misc'; import CanvasPyInterpreter from './CanvasPyInterpreter'; +import StorageUtils from '../utils/storage'; -export default function ChatScreen() { - const { - viewingConversation, - sendMessage, - isGenerating, - stopGenerating, - pendingMessages, - canvasData, - } = useAppContext(); - const [inputMsg, setInputMsg] = useState(''); - const navigate = useNavigate(); +/** + * A message display is a message node with additional information for rendering. + * For example, siblings of the message node are stored as their last node (aka leaf node). + */ +export interface MessageDisplay { + msg: Message | PendingMessage; + siblingLeafNodeIds: Message['id'][]; + siblingCurrIdx: number; + isPending?: boolean; +} - const currConvId = viewingConversation?.id ?? ''; - const pendingMsg: PendingMessage | undefined = pendingMessages[currConvId]; +function getListMessageDisplay( + msgs: Readonly, + leafNodeId: Message['id'] +): MessageDisplay[] { + const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true); + const res: MessageDisplay[] = []; + const nodeMap = new Map(); + for (const msg of msgs) { + nodeMap.set(msg.id, msg); + } + // find leaf node from a message node + const findLeafNode = (msgId: Message['id']): Message['id'] => { + let currNode: Message | undefined = nodeMap.get(msgId); + while (currNode) { + if (currNode.children.length === 0) break; + currNode = nodeMap.get(currNode.children.at(-1) ?? -1); + } + return currNode?.id ?? -1; + }; + // traverse the current nodes + for (const msg of currNodes) { + const parentNode = nodeMap.get(msg.parent ?? -1); + if (!parentNode) continue; + const siblings = parentNode.children; + if (msg.type !== 'root') { + res.push({ + msg, + siblingLeafNodeIds: siblings.map(findLeafNode), + siblingCurrIdx: siblings.indexOf(msg.id), + }); + } + } + return res; +} - const scrollToBottom = (requiresNearBottom: boolean) => { +const scrollToBottom = throttle( + (requiresNearBottom: boolean, delay: number = 80) => { const mainScrollElem = document.getElementById('main-scroll'); if (!mainScrollElem) return; const spaceToBottom = @@ -32,36 +63,107 @@ export default function ChatScreen() { if (!requiresNearBottom || spaceToBottom < 50) { setTimeout( () => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }), - 1 + delay ); } - }; + }, + 80 +); + +export default function ChatScreen() { + const { + viewingChat, + sendMessage, + isGenerating, + stopGenerating, + pendingMessages, + canvasData, + replaceMessageAndGenerate, + } = useAppContext(); + const [inputMsg, setInputMsg] = useState(''); + + // keep track of leaf node for rendering + const [currNodeId, setCurrNodeId] = useState(-1); + const messages: MessageDisplay[] = useMemo(() => { + if (!viewingChat) return []; + else return getListMessageDisplay(viewingChat.messages, currNodeId); + }, [currNodeId, viewingChat]); + + const currConvId = viewingChat?.conv.id ?? null; + const pendingMsg: PendingMessage | undefined = + pendingMessages[currConvId ?? '']; - // scroll to bottom when conversation changes useEffect(() => { - scrollToBottom(false); - }, [viewingConversation?.id]); + // reset to latest node when conversation changes + setCurrNodeId(-1); + // scroll to bottom when conversation changes + scrollToBottom(false, 1); + }, [currConvId]); + + const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => { + if (currLeafNodeId) { + setCurrNodeId(currLeafNodeId); + } + scrollToBottom(true); + }; const sendNewMessage = async () => { - if (inputMsg.trim().length === 0 || isGenerating(currConvId)) return; - const convId = viewingConversation?.id ?? StorageUtils.getNewConvId(); + if (inputMsg.trim().length === 0 || isGenerating(currConvId ?? '')) return; const lastInpMsg = inputMsg; setInputMsg(''); - if (!viewingConversation) { - // if user is creating a new conversation, redirect to the new conversation - navigate(`/chat/${convId}`); - } scrollToBottom(false); - // auto scroll as message is being generated - const onChunk = () => scrollToBottom(true); - if (!(await sendMessage(convId, inputMsg, onChunk))) { + setCurrNodeId(-1); + // get the last message node + const lastMsgNodeId = messages.at(-1)?.msg.id ?? null; + if (!(await sendMessage(currConvId, lastMsgNodeId, inputMsg, onChunk))) { // restore the input message if failed setInputMsg(lastInpMsg); } }; + const handleEditMessage = async (msg: Message, content: string) => { + if (!viewingChat) return; + setCurrNodeId(msg.id); + scrollToBottom(false); + await replaceMessageAndGenerate( + viewingChat.conv.id, + msg.parent, + content, + onChunk + ); + setCurrNodeId(-1); + scrollToBottom(false); + }; + + const handleRegenerateMessage = async (msg: Message) => { + if (!viewingChat) return; + setCurrNodeId(msg.parent); + scrollToBottom(false); + await replaceMessageAndGenerate( + viewingChat.conv.id, + msg.parent, + null, + onChunk + ); + setCurrNodeId(-1); + scrollToBottom(false); + }; + const hasCanvas = !!canvasData; + // due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg) + const pendingMsgDisplay: MessageDisplay[] = + pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id + ? [ + { + msg: pendingMsg, + siblingLeafNodeIds: [], + siblingCurrIdx: 0, + isPending: true, + }, + ] + : []; + return (
{/* placeholder to shift the message to the bottom */} - {viewingConversation ? '' : 'Send a message to start'} + {viewingChat ? '' : 'Send a message to start'}
- {viewingConversation?.messages.map((msg) => ( + {[...messages, ...pendingMsgDisplay].map((msg) => ( ))} - - {pendingMsg && ( - - )}
{/* chat input */} @@ -118,10 +215,10 @@ export default function ChatScreen() { id="msg-input" dir="auto" > - {isGenerating(currConvId) ? ( + {isGenerating(currConvId ?? '') ? ( diff --git a/examples/server/webui/src/components/Header.tsx b/examples/server/webui/src/components/Header.tsx index 505350313a2fc..cbee394ba2e0c 100644 --- a/examples/server/webui/src/components/Header.tsx +++ b/examples/server/webui/src/components/Header.tsx @@ -25,12 +25,12 @@ export default function Header() { ); }, [selectedTheme]); - const { isGenerating, viewingConversation } = useAppContext(); - const isCurrConvGenerating = isGenerating(viewingConversation?.id ?? ''); + const { isGenerating, viewingChat } = useAppContext(); + const isCurrConvGenerating = isGenerating(viewingChat?.conv.id ?? ''); const removeConversation = () => { - if (isCurrConvGenerating || !viewingConversation) return; - const convId = viewingConversation.id; + if (isCurrConvGenerating || !viewingChat) return; + const convId = viewingChat?.conv.id; if (window.confirm('Are you sure to delete this conversation?')) { StorageUtils.remove(convId); navigate('/'); @@ -38,9 +38,9 @@ export default function Header() { }; const downloadConversation = () => { - if (isCurrConvGenerating || !viewingConversation) return; - const convId = viewingConversation.id; - const conversationJson = JSON.stringify(viewingConversation, null, 2); + if (isCurrConvGenerating || !viewingChat) return; + const convId = viewingChat?.conv.id; + const conversationJson = JSON.stringify(viewingChat, null, 2); const blob = new Blob([conversationJson], { type: 'application/json' }); const url = URL.createObjectURL(blob); const a = document.createElement('a'); @@ -75,38 +75,41 @@ export default function Header() { {/* action buttons (top right) */}
-
- {/* "..." button */} - - {/* dropdown menu */} - -
+ + + + + {/* dropdown menu */} + +
+ )} +
))}
- Conversations are saved to browser's localStorage + Conversations are saved to browser's IndexedDB
diff --git a/examples/server/webui/src/utils/app.context.tsx b/examples/server/webui/src/utils/app.context.tsx index af6bd885fc689..f2c935e1f4de6 100644 --- a/examples/server/webui/src/utils/app.context.tsx +++ b/examples/server/webui/src/utils/app.context.tsx @@ -5,6 +5,7 @@ import { Conversation, Message, PendingMessage, + ViewingChat, } from './types'; import StorageUtils from './storage'; import { @@ -13,24 +14,25 @@ import { getSSEStreamAsync, } from './misc'; import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config'; -import { matchPath, useLocation } from 'react-router'; +import { matchPath, useLocation, useNavigate } from 'react-router'; interface AppContextValue { // conversations and messages - viewingConversation: Conversation | null; + viewingChat: ViewingChat | null; pendingMessages: Record; isGenerating: (convId: string) => boolean; sendMessage: ( - convId: string, + convId: string | null, + leafNodeId: Message['id'] | null, content: string, - onChunk?: CallbackGeneratedChunk + onChunk: CallbackGeneratedChunk ) => Promise; stopGenerating: (convId: string) => void; replaceMessageAndGenerate: ( convId: string, - origMsgId: Message['id'], - content?: string, - onChunk?: CallbackGeneratedChunk + parentNodeId: Message['id'], // the parent node of the message to be replaced + content: string | null, + onChunk: CallbackGeneratedChunk ) => Promise; // canvas @@ -44,23 +46,33 @@ interface AppContextValue { setShowSettings: (show: boolean) => void; } -// for now, this callback is only used for scrolling to the bottom of the chat -type CallbackGeneratedChunk = () => void; +// this callback is used for scrolling to the bottom of the chat and switching to the last node +export type CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => void; // eslint-disable-next-line @typescript-eslint/no-explicit-any const AppContext = createContext({} as any); +const getViewingChat = async (convId: string): Promise => { + const conv = await StorageUtils.getOneConversation(convId); + if (!conv) return null; + return { + conv: conv, + // all messages from all branches, not filtered by last node + messages: await StorageUtils.getMessages(convId), + }; +}; + export const AppContextProvider = ({ children, }: { children: React.ReactElement; }) => { const { pathname } = useLocation(); + const navigate = useNavigate(); const params = matchPath('/chat/:convId', pathname); const convId = params?.params?.convId; - const [viewingConversation, setViewingConversation] = - useState(null); + const [viewingChat, setViewingChat] = useState(null); const [pendingMessages, setPendingMessages] = useState< Record >({}); @@ -75,12 +87,12 @@ export const AppContextProvider = ({ useEffect(() => { // also reset the canvas data setCanvasData(null); - const handleConversationChange = (changedConvId: string) => { + const handleConversationChange = async (changedConvId: string) => { if (changedConvId !== convId) return; - setViewingConversation(StorageUtils.getOneConversation(convId)); + setViewingChat(await getViewingChat(changedConvId)); }; StorageUtils.onConversationChanged(handleConversationChange); - setViewingConversation(StorageUtils.getOneConversation(convId ?? '')); + getViewingChat(convId ?? '').then(setViewingChat); return () => { StorageUtils.offConversationChanged(handleConversationChange); }; @@ -118,23 +130,39 @@ export const AppContextProvider = ({ const generateMessage = async ( convId: string, - onChunk?: CallbackGeneratedChunk + leafNodeId: Message['id'], + onChunk: CallbackGeneratedChunk ) => { if (isGenerating(convId)) return; const config = StorageUtils.getConfig(); - const currConversation = StorageUtils.getOneConversation(convId); + const currConversation = await StorageUtils.getOneConversation(convId); if (!currConversation) { throw new Error('Current conversation is not found'); } + const currMessages = StorageUtils.filterByLeafNodeId( + await StorageUtils.getMessages(convId), + leafNodeId, + false + ); const abortController = new AbortController(); setAbort(convId, abortController); + if (!currMessages) { + throw new Error('Current messages are not found'); + } + + const pendingId = Date.now() + 1; let pendingMsg: PendingMessage = { - id: Date.now() + 1, + id: pendingId, + convId, + type: 'text', + timestamp: pendingId, role: 'assistant', content: null, + parent: leafNodeId, + children: [], }; setPending(convId, pendingMsg); @@ -144,7 +172,7 @@ export const AppContextProvider = ({ ...(config.systemMessage.length === 0 ? [] : [{ role: 'system', content: config.systemMessage } as APIMessage]), - ...normalizeMsgsForAPI(currConversation?.messages ?? []), + ...normalizeMsgsForAPI(currMessages), ]; if (config.excludeThoughtOnReq) { messages = filterThoughtFromMsgs(messages); @@ -205,8 +233,7 @@ export const AppContextProvider = ({ const lastContent = pendingMsg.content || ''; if (addedContent) { pendingMsg = { - id: pendingMsg.id, - role: 'assistant', + ...pendingMsg, content: lastContent + addedContent, }; } @@ -221,7 +248,7 @@ export const AppContextProvider = ({ }; } setPending(convId, pendingMsg); - onChunk?.(); + onChunk(); // don't need to switch node for pending message } } catch (err) { setPending(convId, null); @@ -236,37 +263,53 @@ export const AppContextProvider = ({ } } - if (pendingMsg.content) { - StorageUtils.appendMsg(currConversation.id, { - id: pendingMsg.id, - content: pendingMsg.content, - role: pendingMsg.role, - timings: pendingMsg.timings, - }); + if (pendingMsg.content !== null) { + await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId); } setPending(convId, null); - onChunk?.(); // trigger scroll to bottom + onChunk(pendingId); // trigger scroll to bottom and switch to the last node }; const sendMessage = async ( - convId: string, + convId: string | null, + leafNodeId: Message['id'] | null, content: string, - onChunk?: CallbackGeneratedChunk + onChunk: CallbackGeneratedChunk ): Promise => { - if (isGenerating(convId) || content.trim().length === 0) return false; + if (isGenerating(convId ?? '') || content.trim().length === 0) return false; - StorageUtils.appendMsg(convId, { - id: Date.now(), - role: 'user', - content, - }); + if (convId === null || convId.length === 0 || leafNodeId === null) { + const conv = await StorageUtils.createConversation( + content.substring(0, 256) + ); + convId = conv.id; + leafNodeId = conv.currNode; + // if user is creating a new conversation, redirect to the new conversation + navigate(`/chat/${convId}`); + } + + const now = Date.now(); + const currMsgId = now; + StorageUtils.appendMsg( + { + id: currMsgId, + timestamp: now, + type: 'text', + convId, + role: 'user', + content, + parent: leafNodeId, + children: [], + }, + leafNodeId + ); + onChunk(currMsgId); try { - await generateMessage(convId, onChunk); + await generateMessage(convId, currMsgId, onChunk); return true; } catch (_) { - // rollback - StorageUtils.popMsg(convId); + // TODO: rollback } return false; }; @@ -279,22 +322,33 @@ export const AppContextProvider = ({ // if content is undefined, we remove last assistant message const replaceMessageAndGenerate = async ( convId: string, - origMsgId: Message['id'], - content?: string, - onChunk?: CallbackGeneratedChunk + parentNodeId: Message['id'], // the parent node of the message to be replaced + content: string | null, + onChunk: CallbackGeneratedChunk ) => { if (isGenerating(convId)) return; - StorageUtils.filterAndKeepMsgs(convId, (msg) => msg.id < origMsgId); - if (content) { - StorageUtils.appendMsg(convId, { - id: Date.now(), - role: 'user', - content, - }); + if (content !== null) { + const now = Date.now(); + const currMsgId = now; + StorageUtils.appendMsg( + { + id: currMsgId, + timestamp: now, + type: 'text', + convId, + role: 'user', + content, + parent: parentNodeId, + children: [], + }, + parentNodeId + ); + parentNodeId = currMsgId; } + onChunk(parentNodeId); - await generateMessage(convId, onChunk); + await generateMessage(convId, parentNodeId, onChunk); }; const saveConfig = (config: typeof CONFIG_DEFAULT) => { @@ -306,7 +360,7 @@ export const AppContextProvider = ({ !!x.toLowerCase; @@ -23,7 +22,7 @@ export async function* getSSEStreamAsync(fetchResponse: Response) { .pipeThrough(new TextLineStream()); // @ts-expect-error asyncIterator complains about type, but it should work for await (const line of asyncIterator(lines)) { - if (isDev) console.log({ line }); + //if (isDev) console.log({ line }); if (line.startsWith('data:') && !line.endsWith('[DONE]')) { const data = JSON.parse(line.slice(5)); yield data; @@ -55,7 +54,7 @@ export const copyStr = (textToCopy: string) => { /** * filter out redundant fields upon sending to API */ -export function normalizeMsgsForAPI(messages: Message[]) { +export function normalizeMsgsForAPI(messages: Readonly) { return messages.map((msg) => { return { role: msg.role, @@ -88,3 +87,23 @@ export function classNames(classes: Record): string { export const delay = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); + +export const throttle = ( + callback: (...args: T) => void, + delay: number +) => { + let isWaiting = false; + + return (...args: T) => { + if (isWaiting) { + return; + } + + callback(...args); + isWaiting = true; + + setTimeout(() => { + isWaiting = false; + }, delay); + }; +}; diff --git a/examples/server/webui/src/utils/storage.ts b/examples/server/webui/src/utils/storage.ts index 8c03fa7815f89..1dfc9d9799311 100644 --- a/examples/server/webui/src/utils/storage.ts +++ b/examples/server/webui/src/utils/storage.ts @@ -2,7 +2,8 @@ // format: { [convId]: { id: string, lastModified: number, messages: [...] } } import { CONFIG_DEFAULT } from '../Config'; -import { Conversation, Message } from './types'; +import { Conversation, Message, TimingReport } from './types'; +import Dexie, { Table } from 'dexie'; const event = new EventTarget(); @@ -17,84 +18,153 @@ const dispatchConversationChange = (convId: string) => { ); }; +const db = new Dexie('LlamacppWebui') as Dexie & { + conversations: Table; + messages: Table; +}; + +// https://dexie.org/docs/Version/Version.stores() +db.version(1).stores({ + // Unlike SQL, you don’t need to specify all properties but only the one you wish to index. + conversations: '&id, lastModified', + messages: '&id, convId, [convId+id], timestamp', +}); + // convId is a string prefixed with 'conv-' const StorageUtils = { /** * manage conversations */ - getAllConversations(): Conversation[] { - const res = []; - for (const key in localStorage) { - if (key.startsWith('conv-')) { - res.push(JSON.parse(localStorage.getItem(key) ?? '{}')); - } - } - res.sort((a, b) => b.lastModified - a.lastModified); - return res; + async getAllConversations(): Promise { + await migrationLStoIDB().catch(console.error); // noop if already migrated + return (await db.conversations.toArray()).sort( + (a, b) => b.lastModified - a.lastModified + ); }, /** * can return null if convId does not exist */ - getOneConversation(convId: string): Conversation | null { - return JSON.parse(localStorage.getItem(convId) || 'null'); + async getOneConversation(convId: string): Promise { + return (await db.conversations.where('id').equals(convId).first()) ?? null; }, /** - * if convId does not exist, create one + * get all message nodes in a conversation */ - appendMsg(convId: string, msg: Message): void { - if (msg.content === null) return; - const conv = StorageUtils.getOneConversation(convId) || { - id: convId, - lastModified: Date.now(), - messages: [], - }; - conv.messages.push(msg); - conv.lastModified = Date.now(); - localStorage.setItem(convId, JSON.stringify(conv)); - dispatchConversationChange(convId); + async getMessages(convId: string): Promise { + return await db.messages.where({ convId }).toArray(); }, /** - * Get new conversation id + * use in conjunction with getMessages to filter messages by leafNodeId + * includeRoot: whether to include the root node in the result + * if node with leafNodeId does not exist, return the path with the latest timestamp */ - getNewConvId(): string { - return `conv-${Date.now()}`; + filterByLeafNodeId( + msgs: Readonly, + leafNodeId: Message['id'], + includeRoot: boolean + ): Readonly { + const res: Message[] = []; + const nodeMap = new Map(); + for (const msg of msgs) { + nodeMap.set(msg.id, msg); + } + let startNode: Message | undefined = nodeMap.get(leafNodeId); + if (!startNode) { + // if not found, we return the path with the latest timestamp + let latestTime = -1; + for (const msg of msgs) { + if (msg.timestamp > latestTime) { + startNode = msg; + latestTime = msg.timestamp; + } + } + } + // traverse the path from leafNodeId to root + // startNode can never be undefined here + let currNode: Message | undefined = startNode; + while (currNode) { + if (currNode.type !== 'root' || (currNode.type === 'root' && includeRoot)) + res.push(currNode); + currNode = nodeMap.get(currNode.parent ?? -1); + } + res.sort((a, b) => a.timestamp - b.timestamp); + return res; }, /** - * remove conversation by id + * create a new conversation with a default root node */ - remove(convId: string): void { - localStorage.removeItem(convId); - dispatchConversationChange(convId); + async createConversation(name: string): Promise { + const now = Date.now(); + const msgId = now; + const conv: Conversation = { + id: `conv-${now}`, + lastModified: now, + currNode: msgId, + name, + }; + await db.conversations.add(conv); + // create a root node + await db.messages.add({ + id: msgId, + convId: conv.id, + type: 'root', + timestamp: now, + role: 'system', + content: '', + parent: -1, + children: [], + }); + return conv; }, /** - * remove all conversations + * if convId does not exist, throw an error */ - filterAndKeepMsgs( - convId: string, - predicate: (msg: Message) => boolean - ): void { - const conv = StorageUtils.getOneConversation(convId); - if (!conv) return; - conv.messages = conv.messages.filter(predicate); - conv.lastModified = Date.now(); - localStorage.setItem(convId, JSON.stringify(conv)); + async appendMsg( + msg: Exclude, + parentNodeId: Message['id'] + ): Promise { + if (msg.content === null) return; + const { convId } = msg; + await db.transaction('rw', db.conversations, db.messages, async () => { + const conv = await StorageUtils.getOneConversation(convId); + const parentMsg = await db.messages + .where({ convId, id: parentNodeId }) + .first(); + // update the currNode of conversation + if (!conv) { + throw new Error(`Conversation ${convId} does not exist`); + } + if (!parentMsg) { + throw new Error( + `Parent message ID ${parentNodeId} does not exist in conversation ${convId}` + ); + } + await db.conversations.update(convId, { + lastModified: Date.now(), + currNode: msg.id, + }); + // update parent + await db.messages.update(parentNodeId, { + children: [...parentMsg.children, msg.id], + }); + // create message + await db.messages.add({ + ...msg, + parent: parentNodeId, + children: [], + }); + }); dispatchConversationChange(convId); }, /** - * remove last message from conversation + * remove conversation by id */ - popMsg(convId: string): Message | undefined { - const conv = StorageUtils.getOneConversation(convId); - if (!conv) return; - const msg = conv.messages.pop(); - conv.lastModified = Date.now(); - if (conv.messages.length === 0) { - StorageUtils.remove(convId); - } else { - localStorage.setItem(convId, JSON.stringify(conv)); - } + async remove(convId: string): Promise { + await db.transaction('rw', db.conversations, db.messages, async () => { + await db.conversations.delete(convId); + await db.messages.where({ convId }).delete(); + }); dispatchConversationChange(convId); - return msg; }, // event listeners @@ -136,3 +206,79 @@ const StorageUtils = { }; export default StorageUtils; + +// Migration from localStorage to IndexedDB + +// these are old types, LS prefix stands for LocalStorage +interface LSConversation { + id: string; // format: `conv-{timestamp}` + lastModified: number; // timestamp from Date.now() + messages: LSMessage[]; +} +interface LSMessage { + id: number; + role: 'user' | 'assistant' | 'system'; + content: string; + timings?: TimingReport; +} +async function migrationLStoIDB() { + if (localStorage.getItem('migratedToIDB')) return; + const res: LSConversation[] = []; + for (const key in localStorage) { + if (key.startsWith('conv-')) { + res.push(JSON.parse(localStorage.getItem(key) ?? '{}')); + } + } + if (res.length === 0) return; + await db.transaction('rw', db.conversations, db.messages, async () => { + let migratedCount = 0; + for (const conv of res) { + const { id: convId, lastModified, messages } = conv; + const firstMsg = messages[0]; + const lastMsg = messages.at(-1); + if (messages.length < 2 || !firstMsg || !lastMsg) { + console.log( + `Skipping conversation ${convId} with ${messages.length} messages` + ); + continue; + } + const name = firstMsg.content ?? '(no messages)'; + await db.conversations.add({ + id: convId, + lastModified, + currNode: lastMsg.id, + name, + }); + const rootId = messages[0].id - 2; + await db.messages.add({ + id: rootId, + convId: convId, + type: 'root', + timestamp: rootId, + role: 'system', + content: '', + parent: -1, + children: [firstMsg.id], + }); + for (let i = 0; i < messages.length; i++) { + const msg = messages[i]; + await db.messages.add({ + ...msg, + type: 'text', + convId: convId, + timestamp: msg.id, + parent: i === 0 ? rootId : messages[i - 1].id, + children: i === messages.length - 1 ? [] : [messages[i + 1].id], + }); + } + migratedCount++; + console.log( + `Migrated conversation ${convId} with ${messages.length} messages` + ); + } + console.log( + `Migrated ${migratedCount} conversations from localStorage to IndexedDB` + ); + localStorage.setItem('migratedToIDB', '1'); + }); +} diff --git a/examples/server/webui/src/utils/types.ts b/examples/server/webui/src/utils/types.ts index 7cd12b40aea1d..e85049f201cc8 100644 --- a/examples/server/webui/src/utils/types.ts +++ b/examples/server/webui/src/utils/types.ts @@ -5,11 +5,46 @@ export interface TimingReport { predicted_ms: number; } +/** + * What is conversation "branching"? It is a feature that allows the user to edit an old message in the history, while still keeping the conversation flow. + * Inspired by ChatGPT / Claude / Hugging Chat where you edit a message, a new branch of the conversation is created, and the old message is still visible. + * + * We use the same node-based structure like other chat UIs, where each message has a parent and children. A "root" message is the first message in a conversation, which will not be displayed in the UI. + * + * root + * ├── message 1 + * │ └── message 2 + * │ └── message 3 + * └── message 4 + * └── message 5 + * + * In the above example, assuming that user wants to edit message 2, a new branch will be created: + * + * ├── message 2 + * │ └── message 3 + * └── message 6 + * + * Message 2 and 6 are siblings, and message 6 is the new branch. + * + * We only need to know the last node (aka leaf) to get the current branch. In the above example, message 5 is the leaf of branch containing message 4 and 5. + * + * For the implementation: + * - StorageUtils.getMessages() returns list of all nodes + * - StorageUtils.filterByLeafNodeId() filters the list of nodes from a given leaf node + */ + +// Note: the term "message" and "node" are used interchangeably in this context export interface Message { id: number; + convId: string; + type: 'text' | 'root'; + timestamp: number; // timestamp from Date.now() role: 'user' | 'assistant' | 'system'; content: string; timings?: TimingReport; + // node based system for branching + parent: Message['id']; + children: Message['id'][]; } export type APIMessage = Pick; @@ -17,7 +52,13 @@ export type APIMessage = Pick; export interface Conversation { id: string; // format: `conv-{timestamp}` lastModified: number; // timestamp from Date.now() - messages: Message[]; + currNode: Message['id']; // the current message node being viewed + name: string; +} + +export interface ViewingChat { + conv: Readonly; + messages: Readonly; } export type PendingMessage = Omit & { diff --git a/examples/simple-cmake-pkg/README.md b/examples/simple-cmake-pkg/README.md index 8b30049e247cc..d7430cc9c2083 100644 --- a/examples/simple-cmake-pkg/README.md +++ b/examples/simple-cmake-pkg/README.md @@ -1,6 +1,6 @@ # llama.cpp/example/simple-cmake-pkg -This program builds [simple](../simple) using a relocatable CMake package. It serves as an example of using the `find_package()` CMake command to conveniently include [llama.cpp](https://github.com/ggerganov/llama.cpp) in projects which live outside of the source tree. +This program builds [simple](../simple) using a relocatable CMake package. It serves as an example of using the `find_package()` CMake command to conveniently include [llama.cpp](https://github.com/ggml-org/llama.cpp) in projects which live outside of the source tree. ## Building @@ -13,7 +13,7 @@ When hardware acceleration libraries are used (e.g. CUDA, Metal, Vulkan, etc.), ### Build llama.cpp and install to llama.cpp/inst ```sh -git clone https://github.com/ggerganov/llama.cpp +git clone https://github.com/ggml-org/llama.cpp cd llama.cpp cmake -S . -B build cmake --build build diff --git a/examples/speculative/README.md b/examples/speculative/README.md index a6608c5fe8e3a..36ab3708629d2 100644 --- a/examples/speculative/README.md +++ b/examples/speculative/README.md @@ -4,6 +4,6 @@ Demonstration of speculative decoding and tree-based speculative decoding techni More info: -- https://github.com/ggerganov/llama.cpp/pull/2926 -- https://github.com/ggerganov/llama.cpp/pull/3624 -- https://github.com/ggerganov/llama.cpp/pull/5625 +- https://github.com/ggml-org/llama.cpp/pull/2926 +- https://github.com/ggml-org/llama.cpp/pull/3624 +- https://github.com/ggml-org/llama.cpp/pull/5625 diff --git a/flake.nix b/flake.nix index 26a2588169101..0b5edf911fd06 100644 --- a/flake.nix +++ b/flake.nix @@ -36,7 +36,7 @@ # ``` # nixConfig = { # extra-substituters = [ - # # Populated by the CI in ggerganov/llama.cpp + # # Populated by the CI in ggml-org/llama.cpp # "https://llama-cpp.cachix.org" # # # A development cache for nixpkgs imported with `config.cudaSupport = true`. @@ -56,11 +56,11 @@ # }; # ``` - # For inspection, use `nix flake show github:ggerganov/llama.cpp` or the nix repl: + # For inspection, use `nix flake show github:ggml-org/llama.cpp` or the nix repl: # # ```bash # ❯ nix repl - # nix-repl> :lf github:ggerganov/llama.cpp + # nix-repl> :lf github:ggml-org/llama.cpp # Added 13 variables. # nix-repl> outputs.apps.x86_64-linux.quantize # { program = "/nix/store/00000000000000000000000000000000-llama.cpp/bin/llama-quantize"; type = "app"; } @@ -176,7 +176,7 @@ # # We could test all outputs e.g. as `checks = confg.packages`. # - # TODO: Build more once https://github.com/ggerganov/llama.cpp/issues/6346 has been addressed + # TODO: Build more once https://github.com/ggml-org/llama.cpp/issues/6346 has been addressed checks = { inherit (config.packages) default vulkan; }; diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index 3aa71badb5fb0..d23c6b262e202 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -8,7 +8,7 @@ extern "C" { #endif // the compute plan that needs to be prepared for ggml_graph_compute() - // since https://github.com/ggerganov/ggml/issues/287 + // since https://github.com/ggml-org/ggml/issues/287 struct ggml_cplan { size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` diff --git a/ggml/include/ggml-metal.h b/ggml/include/ggml-metal.h index 669c1f84ae6e3..a610694423483 100644 --- a/ggml/include/ggml-metal.h +++ b/ggml/include/ggml-metal.h @@ -45,7 +45,7 @@ GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend); GGML_DEPRECATED( GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size), - "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713"); + "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713"); GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5bd8d9c8b5023..dd0c6a96eaee8 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -198,7 +198,7 @@ #ifndef __GNUC__ # define GGML_ATTRIBUTE_FORMAT(...) -#elif defined(__MINGW32__) +#elif defined(__MINGW32__) && !defined(__clang__) # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) #else # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index f13fd4dea6f41..6c02b69ea239a 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -473,7 +473,6 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128) 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, GGML_TABLE_END() -//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics GGML_TABLE_BEGIN(uint64_t, ksigns64, 128) 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff, 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff, @@ -508,7 +507,6 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128) 0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff, 0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff, GGML_TABLE_END() -//#endif GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 27ec149356543..0315dc2575e3a 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -562,6 +562,41 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) { return __lasx_xvpickev_b(tmp1, tmp); } +static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_h_b(a, b); + tmp2 = __lasx_xvmulwod_h_b(a, b); + return __lasx_xvadd_h(tmp1, tmp2); +} + +static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) { + switch (b) { + case 0: return __lasx_xvrepl128vei_h(a, 0); + case 1: return __lasx_xvrepl128vei_h(a, 1); + case 2: return __lasx_xvrepl128vei_h(a, 2); + case 3: return __lasx_xvrepl128vei_h(a, 3); + case 4: return __lasx_xvrepl128vei_h(a, 4); + case 5: return __lasx_xvrepl128vei_h(a, 5); + case 6: return __lasx_xvrepl128vei_h(a, 6); + case 7: return __lasx_xvrepl128vei_h(a, 7); + default: __builtin_unreachable(); + } +} + +static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) { + switch (b) { + case 0: return __lasx_xvandi_b(a, 1 << 0); + case 1: return __lasx_xvandi_b(a, 1 << 1); + case 2: return __lasx_xvandi_b(a, 1 << 2); + case 3: return __lasx_xvandi_b(a, 1 << 3); + case 4: return __lasx_xvandi_b(a, 1 << 4); + case 5: return __lasx_xvandi_b(a, 1 << 5); + case 6: return __lasx_xvandi_b(a, 1 << 6); + case 7: return __lasx_xvandi_b(a, 1 << 7); + default: __builtin_unreachable(); + } +} + // multiply int8_t, add results pairwise twice static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { // Get absolute values of x vectors @@ -656,13 +691,8 @@ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) // multiply int8_t, add results pairwise twice and return as float vector static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - - // Get absolute values of x vectors - const __m256i ax = __lasx_xvsigncov_b(x, x); - // Sign the values of the y vectors - const __m256i sy = __lasx_xvsigncov_b(x, y); - - return mul_sum_us8_pairs_float(ax, sy); + const __m256i dot = lasx_madd_h_b(x, y); + return sum_i16_pairs_float(dot); } static inline __m128i packNibbles( __m256i bytes ) { @@ -742,7 +772,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); } } -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ for (int i = 0; i < nb; i++) { v128_t srcv [8]; v128_t asrcv[8]; @@ -1030,7 +1060,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); } -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ for (int i = 0; i < nb; i++) { v128_t srcv [8]; v128_t asrcv[8]; @@ -1644,7 +1674,87 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1 //===================================== Q8_K ============================================== void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { +#ifdef __wasm_simd128__ + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + block_q8_K * restrict yc = y; // Cast to proper type + + for (int i = 0; i < nb; i++) { + const float * x_block = x + i * QK_K; + + v128_t min_vec = wasm_v128_load(x_block); + v128_t max_vec = min_vec; + + for (int j = 4; j < QK_K; j += 4) { + v128_t x_vec = wasm_v128_load(x_block + j); + max_vec = wasm_f32x4_pmax(max_vec, x_vec); + min_vec = wasm_f32x4_pmin(min_vec, x_vec); + } + max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1)); + max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2)); + min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1)); + min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2)); + float max = wasm_f32x4_extract_lane(max_vec, 0); + float min = wasm_f32x4_extract_lane(min_vec, 0); + float amax = -min > max ? min : max; + + if (amax == 0.0f) { + yc[i].d = 0.0f; + const v128_t zero = wasm_i8x16_splat(0); + for (int j = 0; j < QK_K; j += 16) { + wasm_v128_store(yc[i].qs + j, zero); + } + continue; + } + + const float iscale = -127.0f / amax; + const v128_t scale_vec = wasm_f32x4_splat(iscale); + + // Process 16 elements per iteration + for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) { + // Load and quantize 16 floats + v128_t x0 = wasm_v128_load(x_block + j); + v128_t x1 = wasm_v128_load(x_block + j + 4); + v128_t x2 = wasm_v128_load(x_block + j + 8); + v128_t x3 = wasm_v128_load(x_block + j + 12); + + v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec)); + v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec)); + v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec)); + v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec)); + + // Convert to i32 with saturation + v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0); + v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1); + v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2); + v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3); + + // Pack into 16 i8 values + v128_t i8 = wasm_i8x16_narrow_i16x8( + wasm_i16x8_narrow_i32x4(i0, i1), + wasm_i16x8_narrow_i32x4(i2, i3) + ); + wasm_v128_store(yc[i].qs + j, i8); + + // Calculate bsums using SIMD + v128_t sum16 = wasm_i16x8_add( + wasm_i16x8_extend_low_i8x16(i8), + wasm_i16x8_extend_high_i8x16(i8) + ); + v128_t sum32 = wasm_i32x4_add( + wasm_i32x4_extend_low_i16x8(sum16), + wasm_i32x4_extend_high_i16x8(sum16) + ); + sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1)); + sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2)); + yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0); + } + + yc[i].d = 1.0f / iscale; + } +#else quantize_row_q8_K_ref(x, y, k); +#endif } //===================================== Dot products ================================= @@ -2002,6 +2112,94 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + + const v128_t m4b = wasm_i8x16_splat(0x0F); + const v128_t s8b = wasm_i8x16_splat(0x8); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // Load and process x0 + v128_t v0_0 = wasm_v128_load(x0->qs); + v128_t v0_0l = wasm_v128_and(v0_0, m4b); + v128_t v0_0h = wasm_u8x16_shr(v0_0, 4); + v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b); + v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b); + + // Load y0 vectors + v128_t y0_l = wasm_v128_load(y0->qs); + v128_t y0_h = wasm_v128_load(y0->qs + 16); + + // Extend to i16x8 and compute dot products + v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls); + v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls); + v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs); + v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs); + + v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l); + v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l); + v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h); + v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h); + + v128_t dp0 = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx0l, dy0ll), + wasm_i32x4_dot_i16x8(dx0h, dy0lh) + ), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx0hl, dy0hl), + wasm_i32x4_dot_i16x8(dx0hh, dy0hh) + ) + ); + + // Load and process x1 + v128_t v0_1 = wasm_v128_load(x1->qs); + v128_t v0_1l = wasm_v128_and(v0_1, m4b); + v128_t v0_1h = wasm_u8x16_shr(v0_1, 4); + v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b); + v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b); + + // Load y1 vectors + v128_t y1_l = wasm_v128_load(y1->qs); + v128_t y1_h = wasm_v128_load(y1->qs + 16); + + // Extend to i16x8 and compute dot products + v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls); + v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls); + v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs); + v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs); + + v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l); + v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l); + v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h); + v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h); + + v128_t dp1 = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx1l, dy1ll), + wasm_i32x4_dot_i16x8(dx1h, dy1lh) + ), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx1hl, dy1hl), + wasm_i32x4_dot_i16x8(dx1hh, dy1hh) + ) + ); + + // Accumulate results with scaling + float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); + float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d); + + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0))); + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2688,10 +2886,10 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ v128_t sumv = wasm_f32x4_splat(0.0f); - uint32_t qh; + uint32_t qh_; uint64_t tmp[4]; // TODO: check if unrolling this is better @@ -2702,12 +2900,12 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r const v128_t m4b = wasm_i8x16_splat(0x0F); // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); + memcpy(&qh_, x0->qh, sizeof(qh_)); - tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_1[(qh >> 24) ]; + tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh_ >> 24) ]; const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhh = wasm_v128_load(tmp + 2); @@ -3049,12 +3247,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ v128_t sumv = wasm_f32x4_splat(0.0f); float summs = 0.0f; - uint32_t qh; + uint32_t qh_; uint64_t tmp[4]; // TODO: check if unrolling this is better @@ -3067,12 +3265,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r const v128_t m4b = wasm_i8x16_splat(0x0F); // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); + memcpy(&qh_, x0->qh, sizeof(qh_)); - tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_0[(qh >> 24) ]; + tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh_ >> 24) ]; const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhh = wasm_v128_load(tmp + 2); @@ -3565,6 +3763,45 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + + for (; ib < nb; ++ib) { + const block_q8_0 * restrict x0 = &x[ib]; + const block_q8_0 * restrict y0 = &y[ib]; + + const v128_t x0_0 = wasm_v128_load(x0->qs); + const v128_t x0_1 = wasm_v128_load(x0->qs + 16); + const v128_t y0_0 = wasm_v128_load(y0->qs); + const v128_t y0_1 = wasm_v128_load(y0->qs + 16); + + // Extend 8-bit to 16-bit + const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0); + const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0); + const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1); + const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1); + + const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0); + const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0); + const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1); + const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1); + + // Compute dot products + const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l); + const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h); + const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l); + const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h); + + // Sum all dot products + const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1)); + + // Convert to float and accumulate + const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -4439,6 +4676,106 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc); +#elif defined __wasm_simd128__ + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + // Vectorized summs calculation + v128_t summs_vec = wasm_i32x4_splat(0); + { + v128_t sc_vec = wasm_v128_load(sc); + v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4); + + v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper); + v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper); + + v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]); + v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]); + + summs_vec = wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1), + wasm_i32x4_dot_i16x8(sc_high, bsums2)), + summs_vec + ); + + summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1)); + summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2)); + } + int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0); + + // Vectorized isum calculation + int32_t isum = 0; + const uint8_t * sc_ptr = sc; + const int k_iters = QK_K/128; + + for (int k = 0; k < k_iters; ++k) { + v128_t isum_vec = wasm_i32x4_splat(0); + int shift = 0; + + for (int j = 0; j < 4; ++j) { + const int d0 = (sc_ptr[0] & 0xF); + const int d1 = (sc_ptr[1] & 0xF); + sc_ptr += 2; + + // Process first 16 elements + v128_t q2_0 = wasm_v128_load(q2); + v128_t q8_0 = wasm_v128_load(q8); + v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift); + v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03)); + + // Process next 16 elements + v128_t q2_1 = wasm_v128_load(q2 + 16); + v128_t q8_1 = wasm_v128_load(q8 + 16); + v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift); + v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03)); + + // Calculate dot products + v128_t p0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q8_0), + wasm_i16x8_extend_low_i8x16(q2_bits_0) + ); + v128_t p1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q8_0), + wasm_i16x8_extend_high_i8x16(q2_bits_0) + ); + v128_t p2 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q8_1), + wasm_i16x8_extend_low_i8x16(q2_bits_1) + ); + v128_t p3 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q8_1), + wasm_i16x8_extend_high_i8x16(q2_bits_1) + ); + + // Accumulate scaled results + v128_t scaled = wasm_i32x4_add( + wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)), + wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1)) + ); + + isum_vec = wasm_i32x4_add(isum_vec, scaled); + q8 += 32; + shift += 2; + } + q2 += 32; + + // Horizontal sum of isum_vec + isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1)); + isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2)); + isum += wasm_i32x4_extract_lane(isum_vec, 0); + } + + const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf += dall * isum - dmin * summs; + } + + *s = sumf; + #elif defined __riscv_v_intrinsic float sumf = 0; @@ -4658,9 +4995,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r #elif defined __loongarch_asx - const __m256i m3 = __lasx_xvreplgr2vr_b(3); - const __m128i m4 = __lsx_vreplgr2vr_b(0xF); - __m256 acc = (__m256)__lasx_xvldi(0); for (int i = 0; i < nb; ++i) { @@ -4671,18 +5005,15 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0); - const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4); - const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4); - const __m256i mins = lasx_ext8_16(mins8); + const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0); + const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf); + const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4)); const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0)); acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc); - const __m256i all_scales = lasx_ext8_16(scales8); - const __m128i l_scales = lasx_extracti128(all_scales, 0); - const __m128i h_scales = lasx_extracti128(all_scales, 1); - const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; + const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); __m256i sumi = __lasx_xvldi(0); @@ -4695,20 +5026,20 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q2_0 = __lasx_xvand_v(q2bits, m3); - const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3); - const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3); - const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3); + const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3); + const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3); + const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3); + const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6); - __m256i p0 = lasx_maddubs_h(q2_0, q8_0); - __m256i p1 = lasx_maddubs_h(q2_1, q8_1); - __m256i p2 = lasx_maddubs_h(q2_2, q8_2); - __m256i p3 = lasx_maddubs_h(q2_3, q8_3); + __m256i p0 = lasx_madd_h_b(q2_0, q8_0); + __m256i p1 = lasx_madd_h_b(q2_1, q8_1); + __m256i p2 = lasx_madd_h_b(q2_2, q8_2); + __m256i p3 = lasx_madd_h_b(q2_3, q8_3); - p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0); - p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1); - p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2); - p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3); + p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0); + p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1); + p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2); + p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3); p0 = __lasx_xvadd_w(p0, p1); p2 = __lasx_xvadd_w(p2, p3); @@ -5121,6 +5452,94 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc); +#elif defined __wasm_simd128__ + int8_t aux8[QK_K]; + float sums[8] = {0}; + uint32_t auxs[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + // Process blocks with SIMD + int8_t * a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int shift = 0; shift <= 6; shift += 2) { + v128_t v_m = wasm_i8x16_splat(m); + for (int l = 0; l < 32; l += 16) { + v128_t v_q3 = wasm_v128_load(q3 + l); + v128_t v_shift = wasm_i8x16_shr(v_q3, shift); + v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03)); + + v128_t v_hm = wasm_v128_load(hm + l); + v128_t v_mask = wasm_v128_and(v_hm, v_m); + v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0)); + + v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask))); + wasm_v128_store(a + l, v_low2); + } + a += 32; + m <<= 1; + } + q3 += 32; + } + + // Extract scales + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + const int8_t * scales = (const int8_t *)auxs; + + // SIMD dot product with register accumulators + v128_t v_acc0 = wasm_i32x4_splat(0); + v128_t v_acc1 = wasm_i32x4_splat(0); + a = aux8; + for (int j = 0; j < QK_K/16; ++j) { + const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32); + + // Process 16 elements per iteration + for (int k = 0; k < 2; ++k) { + const v128_t v_q8 = wasm_i16x8_load8x8(q8); + const v128_t v_a = wasm_i16x8_load8x8(a); + + v128_t v_prod = wasm_i16x8_mul(v_q8, v_a); + v_prod = wasm_i16x8_mul(v_prod, v_scale); + + v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod)); + v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod)); + + q8 += 8; + a += 8; + } + } + + // Accumulate results + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const v128_t v_d = wasm_f32x4_splat(d); + v128_t v_sum = wasm_f32x4_add( + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d), + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d) + ); + + // Accumulate into sums vector + wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum)); + } + + // Horizontal sum + v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4)); + sumf = wasm_f32x4_extract_lane(v_sum, 0) + + wasm_f32x4_extract_lane(v_sum, 1) + + wasm_f32x4_extract_lane(v_sum, 2) + + wasm_f32x4_extract_lane(v_sum, 3); + + *s = sumf; + #elif defined __riscv_v_intrinsic uint32_t aux[3]; @@ -5376,8 +5795,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r #elif defined __loongarch_asx - const __m256i m3 = __lasx_xvreplgr2vr_b(3); - const __m256i mone = __lasx_xvreplgr2vr_b(1); const __m128i m32 = __lsx_vreplgr2vr_b(32); __m256 acc = (__m256)__lasx_xvldi(0); @@ -5397,10 +5814,9 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); scales128 = __lsx_vsub_b(scales128, m32); - const __m256i all_scales = lasx_ext8_16(scales128); - const __m128i l_scales = lasx_extracti128(all_scales, 0); - const __m128i h_scales = lasx_extracti128(all_scales, 1); - const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; + + const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); // high bit const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0); @@ -5408,35 +5824,23 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r // integer accumulator __m256i sumi = __lasx_xvldi(0); - int bit = 0; - int is = 0; - __m256i xvbit; - - for (int j = 0; j < QK_K/128; ++j) { // load low 2 bits const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32; - xvbit = __lasx_xvreplgr2vr_h(bit); // prepare low and high bits - const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3); - const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - xvbit = __lasx_xvreplgr2vr_h(bit); - const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3); - const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - xvbit = __lasx_xvreplgr2vr_h(bit); - const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3); - const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - xvbit = __lasx_xvreplgr2vr_h(bit); - const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3); - const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; + const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3); + const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3); + const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3); + const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6); + const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2); + const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2); + const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2); + const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2); + const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0); + const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1); + const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2); + const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3); // load Q8 quants const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; @@ -5444,29 +5848,16 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0); - __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1); - __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2); - __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3); - - __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0); - __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1); - __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2); - __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3); - - p16_0 = __lasx_xvsub_h(p16_0, q8s_0); - p16_1 = __lasx_xvsub_h(p16_1, q8s_1); - p16_2 = __lasx_xvsub_h(p16_2, q8s_2); - p16_3 = __lasx_xvsub_h(p16_3, q8s_3); + __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0); + __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1); + __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2); + __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3); // multiply with scales - p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); - p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); - p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); - p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0); + p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1); + p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2); + p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3); // accumulate p16_0 = __lasx_xvadd_w(p16_0, p16_1); @@ -5474,7 +5865,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2)); } // multiply with block scale and accumulate - acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); } *s = hsum_float_8(acc); @@ -5646,7 +6037,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r } } *s = sumf; -#elif __ARM_NEON +#elif defined __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); const int32x4_t mzero = vdupq_n_s32(0); @@ -5709,6 +6100,107 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined __wasm_simd128__ + const uint8_t * scales = (const uint8_t*)&utmp[0]; + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Process scales and mins + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + // Sum mins * q8sums + int32_t sumi = 0; + const int16_t * restrict q8sums = y[i].bsums; + const uint8_t * m = (const uint8_t *)&utmp[2]; + for (int j = 0; j < 16; j += 2) { + sumi += (q8sums[j] + q8sums[j+1]) * m[j/2]; + } + sumf -= dmin * sumi; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + // Load 64 4-bit weights (32 bytes) + const v128_t q4x0 = wasm_v128_load(q4); + const v128_t q4x1 = wasm_v128_load(q4 + 16); + q4 += 32; + + // Split into low/high nibbles + const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F)); + const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4); + const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F)); + const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4); + + // Load 64 8-bit values (64 bytes) + const v128_t q8x0 = wasm_v128_load(q8); + const v128_t q8x1 = wasm_v128_load(q8 + 16); + const v128_t q8x2 = wasm_v128_load(q8 + 32); + const v128_t q8x3 = wasm_v128_load(q8 + 48); + q8 += 64; + + // Low nibble products + v128_t vacc1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4l0), + wasm_i16x8_extend_low_i8x16(q8x0) + ); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4l0), + wasm_i16x8_extend_high_i8x16(q8x0) + )); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4l1), + wasm_i16x8_extend_low_i8x16(q8x1) + )); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4l1), + wasm_i16x8_extend_high_i8x16(q8x1) + )); + + // High nibble products + v128_t vacc2 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4h0), + wasm_i16x8_extend_low_i8x16(q8x2) + ); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4h0), + wasm_i16x8_extend_high_i8x16(q8x2) + )); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4h1), + wasm_i16x8_extend_low_i8x16(q8x3) + )); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4h1), + wasm_i16x8_extend_high_i8x16(q8x3) + )); + + // Accumulate scaled results + int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) + + wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3); + sumi1 += vacc1_sum * scales[2*j]; + + int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) + + wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3); + sumi2 += vacc2_sum * scales[2*j+1]; + } + + sumf += d * (sumi1 + sumi2); + } + + *s = sumf; + #elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); @@ -6066,11 +6558,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = vec_extract(vsumf0, 0); #elif defined __loongarch_asx - GGML_UNUSED(kmask1); - GGML_UNUSED(kmask2); - GGML_UNUSED(kmask3); - - const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); __m256 acc = (__m256)__lasx_xvldi(0); __m128 acc_m = (__m128)__lsx_vldi(0); @@ -6090,33 +6577,34 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); + const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128); + const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0); const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); - const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); + const __m128i prod = lsx_madd_h(mins128, q8s); acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); - const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); - const __m256i scales = lasx_insertf128(sc128, sc128); + const __m256i scales = lasx_insertf128(scales128, scales128); __m256i sumi = __lasx_xvldi(0); for (int j = 0; j < QK_K/64; ++j) { - const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); + const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0); + const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1); const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; - const __m256i q4l = __lasx_xvand_v(q4bits, m4); - const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4); + const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf); + const __m256i q4h = __lasx_xvsrli_b(q4bits, 4); const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i p16l = lasx_maddubs_h(q4l, q8l); + __m256i p16l = lasx_madd_h_b(q4l, q8l); p16l = lasx_madd_h(scale_l, p16l); const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i p16h = lasx_maddubs_h(q4h, q8h); + __m256i p16h = lasx_madd_h_b(q4h, q8h); p16h = lasx_madd_h(scale_h, p16h); const __m256i sumj = __lasx_xvadd_w(p16l, p16h); @@ -6459,6 +6947,118 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc) + summs; +#elif defined __wasm_simd128__ + //const uint8_t * scales = (const uint8_t*)&utmp[0]; + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + // Process scales and mins + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + // Sum mins * q8sums + int32_t sumi_mins = 0; + const int16_t * restrict q8sums = y[i].bsums; + const uint8_t * m = (const uint8_t *)&utmp[2]; + for (int j = 0; j < 16; j += 2) { + sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2]; + } + sumf -= dmin * sumi_mins; // Correct subtraction + + v128_t qh0 = wasm_v128_load(qh); + v128_t qh1 = wasm_v128_load(qh + 16); + const uint8_t * sc = (const uint8_t *)utmp; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + const int shift = j * 2; + v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift); + v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift); + + v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4); + v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3); + v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4); + v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3); + + v128_t q5_0 = wasm_v128_load(q5); + v128_t q5_1 = wasm_v128_load(q5 + 16); + q5 += 32; + + v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0); + v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0); + v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1); + v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1); + + v128_t q8_0 = wasm_v128_load(q8); + v128_t q8_1 = wasm_v128_load(q8 + 16); + v128_t q8_2 = wasm_v128_load(q8 + 32); + v128_t q8_3 = wasm_v128_load(q8 + 48); + q8 += 64; + + // Process low quants + v128_t pl0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5l_0), + wasm_i16x8_extend_low_i8x16(q8_0) + ); + pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5l_0), + wasm_i16x8_extend_high_i8x16(q8_0) + )); + v128_t pl1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5l_1), + wasm_i16x8_extend_low_i8x16(q8_1) + ); + pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5l_1), + wasm_i16x8_extend_high_i8x16(q8_1) + )); + v128_t sum_low = wasm_i32x4_add(pl0, pl1); + + // Process high quants + v128_t ph0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5h_0), + wasm_i16x8_extend_low_i8x16(q8_2) + ); + ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5h_0), + wasm_i16x8_extend_high_i8x16(q8_2) + )); + v128_t ph1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5h_1), + wasm_i16x8_extend_low_i8x16(q8_3) + ); + ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5h_1), + wasm_i16x8_extend_high_i8x16(q8_3) + )); + v128_t sum_high = wasm_i32x4_add(ph0, ph1); + + // Accumulate with scale factors + int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) + + wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3); + int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) + + wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3); + + sumi += sl * sc[2*j] + sh * sc[2*j+1]; + } + + sumf += d * sumi; + } + + *s = sumf; + #elif defined __riscv_v_intrinsic const uint8_t * scales = (const uint8_t*)&utmp[0]; @@ -6681,19 +7281,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = vec_extract(vsumf0, 0); #elif defined __loongarch_asx - GGML_UNUSED(kmask1); - GGML_UNUSED(kmask2); - GGML_UNUSED(kmask3); - - const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); - const __m128i mzero = __lsx_vldi(0); - const __m256i mone = __lasx_xvreplgr2vr_b(1); __m256 acc = (__m256)__lasx_xvldi(0); + __m128 acc_m = (__m128)__lsx_vldi(0); - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb; ++i) { const uint8_t * restrict q5 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -6708,49 +7300,40 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r utmp[2] = uaux; utmp[0] &= kmask1; - const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); + const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128); + const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0); const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); - const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); - const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero); - summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check + const __m128i prod = lsx_madd_h(mins128, q8s); + acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); - const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); - const __m256i scales = lasx_insertf128(sc128, sc128); + const __m256i scales = lasx_insertf128(scales128, scales128); const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0); - __m256i hmask = mone; __m256i sumi = __lasx_xvldi(0); - int bit = 0; - __m256i xvbit; - for (int j = 0; j < QK_K/64; ++j) { - const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); + const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0); + const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1); const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32; - xvbit = __lasx_xvreplgr2vr_h(bit++); - const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4); - const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); - const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0); - hmask = __lasx_xvslli_h(hmask, 1); - - xvbit = __lasx_xvreplgr2vr_h(bit++); - const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4); - const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); - const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1); - hmask = __lasx_xvslli_h(hmask, 1); + const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf); + const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4); + const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef); + const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef); + const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0); + const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1); const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0); - __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1); + __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0); + __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1); p16_0 = lasx_madd_h(scale_0, p16_0); p16_1 = lasx_madd_h(scale_1, p16_1); @@ -6764,7 +7347,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r } - *s = hsum_float_8(acc) + summs; + acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8)); + acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4)); + + *s = hsum_float_8(acc) + ((v4f32)acc_m)[0]; #else @@ -7122,6 +7708,85 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc); +#elif defined __wasm_simd128__ + int8_t aux8[QK_K] __attribute__((aligned(16))); + int32_t aux32[8] __attribute__((aligned(16))) = {0}; + float sums[8] __attribute__((aligned(16))) = {0}; + + for (int i = 0; i < nb; ++i) { + // Unpack 6-bit quantized data into aux8 (unchanged) + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + int8_t * a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + + const int8_t * restrict a_ptr = aux8; + const int8_t * restrict q8 = y[i].qs; + v128_t acc0 = wasm_i32x4_splat(0); + v128_t acc1 = wasm_i32x4_splat(0); + + for (int j = 0; j < QK_K/16; ++j) { + const int scale = x[i].scales[j]; + const v128_t vscale = wasm_i32x4_splat(scale); + + // Load 16 elements from a and q8 + const v128_t a_vec = wasm_v128_load(a_ptr); + const v128_t q8_vec = wasm_v128_load(q8); + + // Process low 8 elements + v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec); + v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec); + v128_t prod_low = wasm_i16x8_mul(a_low, q8_low); + v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low); + v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low); + + // Process high 8 elements + v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec); + v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec); + v128_t prod_high = wasm_i16x8_mul(a_high, q8_high); + v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high); + v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high); + + // Scale and accumulate + prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale); + prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale); + prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale); + prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale); + + acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo)); + acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi)); + + a_ptr += 16; + q8 += 16; + } + + // Store accumulated results + wasm_v128_store(&aux32[0], acc0); + wasm_v128_store(&aux32[4], acc1); + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) { + sums[l] += d * aux32[l]; + } + } + + // Sum final results + float sumf = 0; + for (int l = 0; l < 8; ++l) { + sumf += sums[l]; + } + *s = sumf; + #elif defined __riscv_v_intrinsic float sumf = 0; @@ -7346,8 +8011,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r #elif defined __loongarch_asx - const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); - const __m256i m2 = __lasx_xvreplgr2vr_b(3); const __m256i m32s = __lasx_xvreplgr2vr_b(32); __m256 acc = (__m256)__lasx_xvldi(0); @@ -7360,58 +8023,42 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; - const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0); + const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0); + const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); __m256i sumi = __lasx_xvldi(0); - int is = 0; - for (int j = 0; j < QK_K/128; ++j) { - const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3)); - is += 4; - const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32; - const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4); - const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4); + const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4); + const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2); + const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4); + const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2); - const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0); - const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1); - const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3); + const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0); + const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1); + const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2); + const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3); const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0); - __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1); - __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2); - __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3); - - __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0); - __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1); - __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2); - __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3); - - p16_0 = __lasx_xvsub_h(p16_0, q8s_0); - p16_1 = __lasx_xvsub_h(p16_1, q8s_1); - p16_2 = __lasx_xvsub_h(p16_2, q8s_2); - p16_3 = __lasx_xvsub_h(p16_3, q8s_3); + __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0); + __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1); + __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2); + __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3); - p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0); - p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1); - p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2); - p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3); + p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0); + p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1); + p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2); + p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3); sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3)); @@ -9736,13 +10383,9 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { } #elif defined(__loongarch_asx) static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { - const __m256i ax = __lasx_xvsigncov_b(x, x); - const __m256i sy = __lasx_xvsigncov_b(x, y); - __m256i tmp1, tmp2, tmp3; - tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy); - tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy); - tmp3 = __lasx_xvadd_h(tmp1, tmp2); - return __lasx_xvsat_h(tmp3, 15); + const __m256i a = __lasx_xvmulwev_h_b(x, y); + const __m256i b = __lasx_xvmulwod_h_b(x, y); + return __lasx_xvadd_h(a, b); } #endif @@ -10792,67 +11435,31 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * #elif defined(__loongarch_asx) const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); - const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); __m256 accum = (__m256)__lasx_xvldi(0); - __m256i tmp1; - __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask; - mask_8f = __lsx_vreplgr2vr_b(0x8f); for (int ibl = 0; ibl < nb; ++ibl) { const uint8_t * qs = x[ibl].qs; const int8_t * q8 = y[ibl].qs; uint16_t sh = x[ibl].scales_h; __m256i sumi1 = __lasx_xvldi(0); __m256i sumi2 = __lasx_xvldi(0); - __m128i zero = __lsx_vldi(0); for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16; - const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16; const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp3 = __lsx_vand_v(tmp0, mask); - tmp3 = __lsx_vshuf_b(values128, zero, tmp3); - - tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp4 = __lsx_vand_v(tmp0, mask); - tmp4 = __lsx_vshuf_b(values128, zero, tmp4); - - const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4); - - tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp3 = __lsx_vand_v(tmp0, mask); - tmp3 = __lsx_vshuf_b(values128, zero, tmp3); - - tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp4 = __lsx_vand_v(tmp0, mask); - tmp4 = __lsx_vshuf_b(values128, zero, tmp4); - - const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4); - + const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)), + __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf))); + const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)), + __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf))); const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; sh >>= 4; - __m256i tmp5, tmp6; - tmp1 = __lasx_xvreplgr2vr_h(ls1); - tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1); - tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1); - const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6); - tmp1 = __lasx_xvreplgr2vr_h(ls2); - tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1); - tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1); - const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6); + const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1)); + const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2)); sumi1 = __lasx_xvadd_w(p_1, sumi1); sumi2 = __lasx_xvadd_w(p_2, sumi2); } diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index fdb430a43cc21..dbef5df2111c6 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -7,10 +7,8 @@ #include "ggml-cpu-impl.h" #include "ggml-cpu.h" #include "ggml-impl.h" -#include "ggml-quants.h" #include "ggml-cpu-quants.h" #include "ggml-threading.h" -#include "amx/amx.h" #include "ggml.h" #if defined(_MSC_VER) || defined(__MINGW32__) @@ -1291,7 +1289,7 @@ struct ggml_threadpool { atomic_int n_graph; // incremented when there is work to be done (i.e each graph) atomic_int GGML_CACHE_ALIGN n_barrier; atomic_int GGML_CACHE_ALIGN n_barrier_passed; - atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. + atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. // these are atomic as an annotation for thread-sanitizer atomic_bool stop; // Used for stopping the threadpool altogether @@ -1818,7 +1816,7 @@ inline static float ggml_silu_f32(float x) { #if __FINITE_MATH_ONLY__ #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix" -#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461" +#error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461" #endif #if defined(__ARM_NEON) && defined(__aarch64__) @@ -7490,6 +7488,7 @@ UseGgmlGemm1:; if (src1->type != vec_dot_type) { char * wdata = params->wdata; + const size_t nbw0 = ggml_type_size(vec_dot_type); const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw2 = nbw1*ne11; const size_t nbw3 = nbw2*ne12; @@ -7497,6 +7496,7 @@ UseGgmlGemm1:; assert(params->wsize >= ne13*nbw3); GGML_ASSERT(src1->type == GGML_TYPE_F32); + #if 0 for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = ith; i11 < ne11; i11 += nth) { @@ -7506,6 +7506,20 @@ UseGgmlGemm1:; } } } + #else + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + size_t bs = ggml_blck_size(vec_dot_type); + int64_t ne10_block_start = (ith * ne10/bs) / nth; + int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth; + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0), + (ne10_block_end - ne10_block_start) * bs); + } + } + } + #endif } if (ith == 0) { @@ -7560,7 +7574,7 @@ UseGgmlGemm2:; int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. - // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915 + // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggml-org/llama.cpp/pull/6915 // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that. if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) { // distribute the thread work across the inner or outer loop based on which one is larger @@ -7593,7 +7607,6 @@ UseGgmlGemm2:; if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) { num_rows_per_vec_dot = 1; } - ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); if (nth >= nchunk0 * nchunk1) { @@ -7606,6 +7619,84 @@ UseGgmlGemm2:; // ggml_compute_forward_mul_mat_id +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)] + +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +static void ggml_compute_forward_mul_mat_id_one_chunk( + struct ggml_tensor * dst, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * ids, + const int64_t cur_a, + const int64_t ir0_start, + const int64_t ir0_end, + const int64_t ir1_start, + const int64_t ir1_end, + const char * src0_cur, + const struct mmid_row_mapping * matrix_rows, + const size_t row_size, + const bool src1_cont, + const void * wdata) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const enum ggml_type type = src0->type; + + ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; + + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + float tmp[16]; + + for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) { + const int64_t _i12 = ir1; // logical row index for this expert + + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12*ne11)*row_size + : (i11*nb11 + i12*nb12)); + + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { + vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); + } + + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float)); + } + } + } +} + +static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { + + void * ptr = *p; + ptr = (void *) GGML_PAD((uintptr_t) ptr, align); + *p = (void *) ((char *) ptr + size); + return ptr; +} + static void ggml_compute_forward_mul_mat_id( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -7623,7 +7714,6 @@ static void ggml_compute_forward_mul_mat_id( const bool src1_cont = ggml_is_contiguous(src1); - ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; @@ -7641,21 +7731,27 @@ static void ggml_compute_forward_mul_mat_id( const int n_ids = ids->ne[0]; // n_expert_used const int n_as = ne02; // n_expert - char * wdata_src1_end = (src1->type == vec_dot_type) ? - (char *) params->wdata : - (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + void * wdata_cur = params->wdata; - struct mmid_row_mapping { - int32_t i1; - int32_t i2; - }; + if (src1->type != vec_dot_type) { + incr_ptr_aligned(&wdata_cur, ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + } + + int64_t * matrix_row_counts = // [n_as] + incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t)); + + struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]] + incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t)); - int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] - struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] + char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as] + incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE); + + GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata)); if (src1->type != vec_dot_type) { char * wdata = params->wdata; + const size_t nbw0 = ggml_type_size(vec_dot_type); const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw2 = nbw1*ne11; const size_t nbw3 = nbw2*ne12; @@ -7663,19 +7759,32 @@ static void ggml_compute_forward_mul_mat_id( assert(params->wsize >= ne13*nbw3); GGML_ASSERT(src1->type == GGML_TYPE_F32); +#if 0 for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + for (int64_t i12 = ith; i12 < ne12; i12 += nth) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), ne10); } } } +#else + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + size_t bs = ggml_blck_size(vec_dot_type); + int64_t ne10_block_start = (ith * ne10/bs) / nth; + int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth; + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0), + (ne10_block_end - ne10_block_start) * bs); + } + } + } +#endif } -#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] - if (ith == 0) { // initialize matrix_row_counts memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); @@ -7693,9 +7802,14 @@ static void ggml_compute_forward_mul_mat_id( } } + // reset current_chunk + for (int cur_a = ith; cur_a < n_as; cur_a += nth) { + atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a); + *current_chunk_ctr = nth; + } + ggml_barrier(params->threadpool); - // compute each matrix multiplication in sequence for (int cur_a = 0; cur_a < n_as; ++cur_a) { const int64_t cne1 = matrix_row_counts[cur_a]; @@ -7703,84 +7817,64 @@ static void ggml_compute_forward_mul_mat_id( continue; } - const char * src0_cur = (const char *) src0->data + cur_a*nb02; - - const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const char * src0_cur = (const char *) src0->data + cur_a * nb02; + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = cne1; // src1 rows - - // distribute the thread work across the inner or outer loop based on which one is larger - - const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - - const int64_t ith0 = ith % nth0; - const int64_t ith1 = ith / nth0; - - const int64_t dr0 = (nr0 + nth0 - 1)/nth0; - const int64_t dr1 = (nr1 + nth1 - 1)/nth1; + const int64_t nr0 = ne01; + const int64_t nr1 = cne1; - const int64_t ir010 = dr0*ith0; - const int64_t ir011 = MIN(ir010 + dr0, nr0); - - const int64_t ir110 = dr1*ith1; - const int64_t ir111 = MIN(ir110 + dr1, nr1); + int chunk_size = 16; + if (nr0 == 1 || nr1 == 1) { + chunk_size = 64; + } - // threads with no work simply yield (not sure if it helps) - //if (ir010 >= ir011 || ir110 >= ir111) { - // sched_yield(); - // continue; - //} +#if defined(__aarch64__) + // disable for ARM + const bool disable_chunking = true; +#else + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); +#endif // defined(__aarch64__) - // block-tiling attempt - const int64_t blck_0 = 16; - const int64_t blck_1 = 16; + int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; + int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; - // attempt to reduce false-sharing (does not seem to make a difference) - float tmp[16]; + if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) { + nchunk0 = nr0 > nr1 ? nth : 1; + nchunk1 = nr0 > nr1 ? 1 : nth; + } - for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { - for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { - for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { - const int64_t _i12 = ir1; // logical row index for this expert + const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); - const int id = row_mapping.i1; // selected expert index + int current_chunk = ith; - const int64_t i11 = id % ne11; - const int64_t i12 = row_mapping.i2; // row index in src1 + atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a); - const int64_t i1 = id; // selected expert index - const int64_t i2 = i12; // row + while (current_chunk < nchunk0 * nchunk1) { + const int64_t ith0 = current_chunk % nchunk0; + const int64_t ith1 = current_chunk / nchunk0; - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char *) wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12*ne11)*row_size - : (i11*nb11 + i12*nb12)); + const int64_t ir0_start = dr0 * ith0; + const int64_t ir0_end = MIN(ir0_start + dr0, nr0); - float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); + const int64_t ir1_start = dr1 * ith1; + const int64_t ir1_end = MIN(ir1_start + dr1, nr1); - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} + ggml_compute_forward_mul_mat_id_one_chunk( + dst, src0, src1, ids, cur_a, + ir0_start, ir0_end, ir1_start, ir1_end, + src0_cur, matrix_rows, row_size, src1_cont, wdata + ); - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); - } - - memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); - } + if (nth >= nchunk0 * nchunk1) { + break; } + + current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed); } } - -#undef MMID_MATRIX_ROW } // ggml_compute_forward_out_prod @@ -9074,10 +9168,6 @@ static void ggml_compute_forward_clamp_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->ith != 0) { - return; - } - float min; float max; memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); @@ -13717,14 +13807,19 @@ struct ggml_cplan ggml_graph_plan( cur = 0; const struct ggml_tensor * src0 = node->src[0]; const struct ggml_tensor * src1 = node->src[1]; + const struct ggml_tensor * ids = node->src[2]; const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type; + const int n_as = src0->ne[2]; + // src1 if (src1->type != vec_dot_type) { - cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); + cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t); } - const int n_as = src0->ne[2]; - cur += GGML_PAD(cur, sizeof(int64_t)); // align - cur += n_as * sizeof(int64_t); // matrix_row_counts - cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows + // matrix_row_counts + cur += n_as * sizeof(int64_t) + sizeof(int64_t); + // matrix_rows + cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t); + // atomic_current_chunk + cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE; } break; case GGML_OP_OUT_PROD: { diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 2ccb4b472d612..be4eadcd021f6 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -284,14 +284,14 @@ struct ggml_backend_cpu_device_context { &hKey) == ERROR_SUCCESS) { DWORD cpu_brand_size = 0; if (RegQueryValueExA(hKey, - TEXT("ProcessorNameString"), + "ProcessorNameString", NULL, NULL, NULL, &cpu_brand_size) == ERROR_SUCCESS) { description.resize(cpu_brand_size); if (RegQueryValueExA(hKey, - TEXT("ProcessorNameString"), + "ProcessorNameString", NULL, NULL, (LPBYTE)&description[0], // NOLINT @@ -534,9 +534,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r if (ggml_cpu_has_dotprod()) { features.push_back({ "DOTPROD", "1" }); } - if (ggml_cpu_has_matmul_int8()) { - features.push_back({ "MATMUL_INT8", "1" }); - } if (ggml_cpu_get_sve_cnt() > 0) { static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt()); features.push_back({ "SVE_CNT", sve_cnt.c_str() }); diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index c22a662876c4a..e0482c59377fd 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -280,14 +280,6 @@ template <> inline __m256bh load(const float *p) { } #endif -//////////////////////////////////////////////////////////////////////////////////////////////////// -// CONSTANTS - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; -static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl); -#endif - //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION @@ -614,6 +606,14 @@ class tinyBLAS_Q0_AVX { TC *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + const int8_t kvalues_iq4nl[16] = { + -127, -104, -83, -65, + -49, -35, -22, -10, + 1, 13, 25, 38, + 53, 69, 89, 113 + }; + + iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); } void matmul(int64_t m, int64_t n) { @@ -1038,6 +1038,7 @@ class tinyBLAS_Q0_AVX { const int64_t ldc; const int ith; const int nth; + __m128i iq4nlt; }; #endif // __AVX__ diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 119fd39b8e437..682640b5208b9 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -15,9 +15,9 @@ if (CUDAToolkit_FOUND) if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") set(CMAKE_CUDA_ARCHITECTURES "native") elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") + set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80") else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") + set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75;80") endif() endif() message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 174916bc970d7..4a92d35f9f40c 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -41,12 +41,13 @@ #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons -#define GGML_CUDA_CC_PASCAL 600 -#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products -#define GGML_CUDA_CC_VOLTA 700 -#define GGML_CUDA_CC_TURING 750 -#define GGML_CUDA_CC_AMPERE 800 -#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 +#define GGML_CUDA_CC_PASCAL 600 +#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define GGML_CUDA_CC_VOLTA 700 +#define GGML_CUDA_CC_TURING 750 +#define GGML_CUDA_CC_AMPERE 800 +#define GGML_CUDA_CC_ADA_LOVELACE 890 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 // GCN/CNDA, wave size is 64 #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 @@ -71,6 +72,47 @@ #define GGML_CUDA_CC_QY1 210 #define GGML_CUDA_CC_QY2 220 +#ifdef __CUDA_ARCH_LIST__ +constexpr bool ggml_cuda_has_arch_impl(int) { + return false; +} + +template +constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) { + return arch == first || ggml_cuda_has_arch_impl(arch, rest...); +} + +constexpr bool ggml_cuda_has_arch(const int arch) { + return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__); +} + +constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) { + if (cur == 0) { + GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch); + } + return cur; +} + +template +constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) { + if (first <= arch && first > cur) { + return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...); + } else { + return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...); + } +} + +constexpr int ggml_cuda_highest_compiled_arch(const int arch) { + return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__); +} +#else +static int ggml_cuda_highest_compiled_arch(const int arch) { + return arch; +} +#endif // __CUDA_ARCH_LIST__ + +// --------------------------------------------------------------------------------------------------------- + #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #if defined(_MSC_VER) @@ -124,11 +166,11 @@ static const char * cu_get_error_str(CUresult err) { #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #endif -#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) #define GGML_CUDA_ASSUME(x) __builtin_assume(x) #else #define GGML_CUDA_ASSUME(x) -#endif // CUDART_VERSION >= 11100 +#endif // CUDART_VERSION >= 11010 #ifdef GGML_CUDA_F16 typedef half dfloat; // dequantize float @@ -158,22 +200,44 @@ typedef float2 dfloat2; #define NEW_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define CP_ASYNC_AVAILABLE +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) #define FLASH_ATTN_AVAILABLE #endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) -static constexpr bool fast_fp16_available(const int cc) { +static bool fp16_available(const int cc) { + return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; +} + +static bool fast_fp16_available(const int cc) { + return fp16_available(cc) && cc != 610; +} + +// To be used for feature selection of external libraries, e.g. cuBLAS. +static bool fast_fp16_hardware_available(const int cc) { return cc >= GGML_CUDA_CC_PASCAL && cc != 610; } -// Any FP16 tensor cores are available. -static constexpr bool fp16_mma_available(const int cc) { +// Any FP16 tensor core instructions are available for ggml code. +static bool fp16_mma_available(const int cc) { + return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA; +} + +// To be used for feature selection of external libraries, e.g. cuBLAS. +static bool fp16_mma_hardware_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA; } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. -static constexpr bool new_mma_available(const int cc) { - return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING; +static bool new_mma_available(const int cc) { + return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; +} + +static bool cp_async_available(const int cc) { + return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; } static constexpr __device__ int ggml_cuda_get_physical_warp_size() { diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 5b0dfacefc9da..795b720d60bcb 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { case GGML_TYPE_Q5_1: return dequantize_block_cuda; case GGML_TYPE_Q8_0: - if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) { + if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) { return dequantize_block_q8_0_f16_cuda; } return dequantize_block_cuda; diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh new file mode 100644 index 0000000000000..51aa41e7e60ef --- /dev/null +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -0,0 +1,46 @@ +// Simplified API for asynchronous data loading. + +#include "common.cuh" + +// Copies data from global to shared memory, cg == cache global. +// Both the src and dst pointers must be aligned to 16 bit. +// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. +// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared. +// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements. +template +static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) { + static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload"); +#ifdef CP_ASYNC_AVAILABLE +#if CUDART_VERSION >= 11040 + if (preload == 256) { + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 128) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 64) { + asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else +#endif // CUDART_VERSION >= 11040 + { + asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } +#else + GGML_UNUSED(dst); + GGML_UNUSED(src); + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} + +// Makes each thread wait until its asynchronous data copies are done. +// This does NOT provide any additional synchronization. +// In particular, when copying data with multiple warps a call to __syncthreads will be needed. +static __device__ __forceinline__ void cp_async_wait_all() { +#ifdef CP_ASYNC_AVAILABLE + asm volatile("cp.async.wait_all;"); +#else + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index d40ee2da41887..fefbd319baf76 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -716,7 +716,9 @@ void launch_fattn( ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); @@ -768,13 +770,14 @@ void launch_fattn( dim3 blocks_num; if (parallel_blocks == 0) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. - const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm; - const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total; - const bool short_context = K->ne[1] < 4096; + const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm); + const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves); const int nblocks_stream_k = 2*nsm; - blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k; + const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE; + + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; blocks_num.y = 1; blocks_num.z = 1; @@ -827,7 +830,7 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); if constexpr (parallel_blocks == 0) { - if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. const dim3 block_dim_combine(D, 1, 1); const dim3 blocks_num_combine = blocks_num; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 05bc91a3b8d6f..d777f5413ed97 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1,242 +1,195 @@ #include "common.cuh" +#include "cp-async.cuh" #include "mma.cuh" #include "fattn-common.cuh" -template -static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( - const float2 * const __restrict__ Q_f2, - const half2 * const __restrict__ K_h2, - const half2 * const __restrict__ V_h2, - const half * const __restrict__ maskh, - float2 * const __restrict__ dstk, - float2 * const __restrict__ dstk_fixup, - const float scale, - const float slope, - const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3, - const int jt, - const int kb0_start, - const int kb0_stop) { -#ifdef NEW_MMA_AVAILABLE - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. +using namespace ggml_cuda_mma; - typedef mma_A_I16K8 mma_A; - typedef mma_B_J8K8 mma_B; - typedef mma_C_I16J8 mma_C_KQ; - typedef mma_C_I16J8 mma_C_VKQ; - - static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps"); - constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column. - - static_assert(D % nwarps == 0, "bad D"); - static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); +typedef tile<16, 8, half2> tile_A; +typedef tile< 8, 8, half2> tile_B; +typedef tile<16, 8, float> tile_C_KQ; +typedef tile<16, 4, half2> tile_C_VKQ; +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. - extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements. - const int stride_Q = nb01 / sizeof(float2); - const int stride_KV = nb11 / sizeof(half2); - const int stride_mask = nb31 / sizeof(half); + // If cp.async is available, load up to the highest power of 2 in D asynchronously: +#ifdef CP_ASYNC_AVAILABLE + static_assert(D >= 64 && D < 512, "bad D"); + constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128); - mma_B Q_B[D/(2*mma_B::K)]; - mma_C_VKQ VKQ_C[D/mma_C_VKQ::I]; + const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); - float2 KQ_rowsum = {0.0f, 0.0f}; - float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; - float2 KQ_max_scale = {0.0f, 0.0f}; + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; + constexpr int stride_i = WARP_SIZE / chunks_per_row; +#pragma unroll + for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); + const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; - // Temporarily load Q data into tile_KV, will be loaded into registers afterwards. - // The loading is done with decreasing granularity for D for better memory bandwidth. - const half2 scale_h2 = make_half2(scale, scale); + cp_async_cg_16(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k); + } +#else + constexpr int k0_sync_start = 0; +#endif // CP_ASYNC_AVAILABLE + static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start"); + + // If D is not a power of 2, the rest is loaded synchronously. + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); #pragma unroll for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_j = WARP_SIZE / stride_k; + const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; - if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { - break; + if (k0_start == k0_stop || k0_stop <= k0_sync_start) { + continue; } #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) { - const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - if (jt*ncols + j < ne01) { -#pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - - const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; - tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); - } - } else { #pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f); - } + tile_KV[i*D2_padded + k] = KV[i*stride_KV + k]; } } } +} - __syncthreads(); - - { - const int j0 = (threadIdx.y / np) * mma_B::J; +template +static __device__ __forceinline__ void flash_attn_ext_f16_iter( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ maskh, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_Q, + const int stride_KV, + const int stride_mask, + const int jt, + half2 * const __restrict__ tile_K, + half2 * const __restrict__ tile_V, + const tile_B * const __restrict__ Q_B, + tile_C_VKQ * const __restrict__ VKQ_C, + float2 & KQ_max, + float2 & KQ_rowsum, + const int kb0) { +#ifdef NEW_MMA_AVAILABLE + constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column. + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. -#pragma unroll - for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { - Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded); - } - } + const int k_VKQ_0 = kb0*KQ_stride; + tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)]; +#ifdef CP_ASYNC_AVAILABLE + cp_async_wait_all(); __syncthreads(); + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); +#else + flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); + __syncthreads(); +#endif // CP_ASYNC_AVAILABLE - // Iterate over ne11 == previous tokens: - for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) { - const int k_VKQ_0 = kb0*KQ_stride; - mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)]; - - // Load K data into tile with decreasing granularity for D for better memory bandwidth: - static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); -#pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) { - const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - -#pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) { - const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - - tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ]; - } - } - } - - __syncthreads(); - - // Calculate tile of KQ: + // Calculate tile of KQ: #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I; + for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) { - mma_A K_A; - K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded); - KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]); - } + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]); } + } - __syncthreads(); +#ifndef CP_ASYNC_AVAILABLE + __syncthreads(); // Only needed if tile_K == tile_V. +#endif // CP_ASYNC_AVAILABLE - if (use_logit_softcap) { - static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); + if (use_logit_softcap) { + static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) { + for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) { #pragma unroll - for (int l = 0; l < mma_C_KQ::ne; ++l) { - KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); - } + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); } } + } - if (maskh) { - static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size"); - static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size"); + if (maskh) { + static_assert(KQ_stride % (np *tile_C_KQ::I) == 0, "bad loop size"); + static_assert(ncols % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size"); #pragma unroll - for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) { - const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I; + for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; #pragma unroll - for (int l = 0; l < mma_C_KQ::ne; ++l) { - const int i = i0 + mma_C_KQ::get_i(l); - const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l); + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const int i = i0 + tile_C_KQ::get_i(l); + const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l); - KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); - } + KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); } } + } - // Calculate softmax for each KQ column using the current max. value. - // The divisor is stored in KQ_rowsum and will be applied at the end. - float2 KQ_max_new = KQ_max; - static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + float2 KQ_max_new = KQ_max; + static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { + for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) { #pragma unroll - for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) { - KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); - KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); - } + for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) { + KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); + KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); } + } - // Values per KQ column are spread across 8 threads, does not need full warp reduce: + // Values per KQ column are spread across 8 threads, does not need full warp reduce: #pragma unroll - for (int offset = 16; offset > 2; offset >>= 1) { - KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); - KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); - } - - { - const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); - KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); - if (diff.x <= SOFTMAX_FTZ_THRESHOLD) { - KQ_max_scale.x = 0.0f; - } - if (diff.y <= SOFTMAX_FTZ_THRESHOLD) { - KQ_max_scale.y = 0.0f; - } - KQ_max = KQ_max_new; - } + for (int offset = 16; offset > 2; offset >>= 1) { + KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); + KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); + } - float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); - static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); + float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); + static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { + for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) { #pragma unroll - for (int l = 0; l < mma_C_KQ::ne; ++l) { - const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y; - const float diff = KQ_C[k].x[l] - KQ_max_l; - KQ_C[k].x[l] = expf(diff); - if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_C[k].x[l] = 0.0f; - } + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y; + const float diff = KQ_C[k].x[l] - KQ_max_l; + KQ_C[k].x[l] = expf(diff); - if (l % 2 == 0) { - KQ_rowsum_add.x += KQ_C[k].x[l]; - } else { - KQ_rowsum_add.y += KQ_C[k].x[l]; - } + if (l % 2 == 0) { + KQ_rowsum_add.x += KQ_C[k].x[l]; + } else { + KQ_rowsum_add.y += KQ_C[k].x[l]; } } + } + + { + const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); + const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); + KQ_max = KQ_max_new; // Scale previous KQ_rowsum to account for a potential increase in KQ_max: KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x; @@ -244,60 +197,179 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y); #pragma unroll - for (int i = 0; i < D/mma_C_VKQ::I; ++i) { + for (int i = 0; i < D/tile_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < mma_C_VKQ::ne; ++l) { + for (int l = 0; l < tile_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; } } + } + + // Convert KQ C tiles into B tiles for VKQ calculation: + tile_B B[KQ_stride/(np*2*tile_B::J)]; + static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) { + B[k] = get_transposed(get_half2(KQ_C[k])); + } - // Convert KQ C tiles into B tiles for VKQ calculation: - mma_B B[KQ_stride/(np*2*mma_B::K)]; - static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size"); +#ifdef CP_ASYNC_AVAILABLE + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV); + } +#else + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); + __syncthreads(); +#endif // CP_ASYNC_AVAILABLE + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { + static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) { - B[k] = KQ_C[k].to_mma_B(); + for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) { + const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + + tile_A A; + load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } + } + +#ifndef CP_ASYNC_AVAILABLE + __syncthreads(); // Only needed if tile_K == tile_V. +#endif // CP_ASYNC_AVAILABLE + +#else + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ maskh, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_Q, + const int stride_KV, + const int stride_mask, + const int jt, + const int kb0_start, + const int kb0_stop) { +#ifdef NEW_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps"); + constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column. + + static_assert(D % nwarps == 0, "bad D"); + static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); + + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements: + extern __shared__ half2 tile_K[]; +#ifdef CP_ASYNC_AVAILABLE + half2 * tile_V = tile_K + KQ_stride*D2_padded; +#else + half2 * tile_V = tile_K; +#endif // CP_ASYNC_AVAILABLE + + tile_B Q_B[D/(2*tile_B::J)]; + tile_C_VKQ VKQ_C[D/tile_C_VKQ::I]; + + float2 KQ_rowsum = {0.0f, 0.0f}; + float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; + + // Temporarily load Q data into tile_K, will be loaded into registers afterwards. + // The loading is done with decreasing granularity for D for better memory bandwidth. + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_j = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + + if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { + break; } - // Load V data into tile with decreasing granularity for D for better memory bandwidth: - static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); #pragma unroll - for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i); - const int i0_stop = D/2 - (D/2) % (1*stride_i); - const int stride_k = WARP_SIZE / stride_i; + for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) { + const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + if (jt*ncols + j < ne01) { #pragma unroll - for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) { - const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i); + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; + tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + } + } else { #pragma unroll - for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) { - const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i); + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V]; + tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f); } } } + } - __syncthreads(); + __syncthreads(); - // Calculate VKQ tile: -#pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) { - static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size"); -#pragma unroll - for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) { - const int k0 = k00 + (threadIdx.y % np)*mma_A::K; + { + const int j0 = (threadIdx.y / np) * tile_B::I; - mma_A A; - A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); - VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]); - } +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); } + } + + __syncthreads(); + // Preload K data for first iteration when using cp_async: +#ifdef CP_ASYNC_AVAILABLE + flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV); +#endif // CP_ASYNC_AVAILABLE + + // Iterate over ne11 == previous tokens: + for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + } + { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + constexpr bool last_iter = true; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + } + + // With cp_async there is no __syncthreads at the end of the iter, + // there can be a race condition on shared memory access for combining/writing back results. +#ifdef CP_ASYNC_AVAILABLE + if (nwarps*tile_B::I > KQ_stride) { __syncthreads(); } +#endif // CP_ASYNC_AVAILABLE // Finally, sum up partial KQ rowsums. // The partial sums are spread across 8 threads each, does not need full reduce. @@ -310,26 +382,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Write VKQ accumulators to shared memory in column-major format. // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // Also for np > 1 the combination is done via these values in shared memory. - const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data + const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { - const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format. + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int k = k0 + mma_B::get_k(l); + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); - tile_KV[j_cwd*D2_padded + k] = B.x[l]; + tile_K[j_cwd*D2_padded + k] = B.x[l]; } } - const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset - const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta + const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset + const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum - if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) { + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; + ((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; } __syncthreads(); @@ -337,11 +409,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( static_assert(np == 1 || np == 2 || np == 4, "bad np"); if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && threadIdx.x < mma_B::J) { + if (needs_fixup && threadIdx.x < tile_B::I) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[j_cwm] = KQ_cmr; } - if (is_fixup && threadIdx.x < mma_B::J) { + if (is_fixup && threadIdx.x < tile_B::I) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[j_cwm] = KQ_cmr; } @@ -350,42 +422,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Warps with threadIdx.y % np != 0 must NOT return early. // All threads must return simultaneously to avoid race conditions with work on the next tile. - float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2; + float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2; float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp. - if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { KQ_cm = meta_j[0]; } float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps. #pragma unroll - for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { + for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) { KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp. float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps. - if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { KQ_crs = KQ_cms*meta_j[1]; } #pragma unroll - for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { + for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) { KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } // Write back combined meta data: - if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { - meta_j[0] = KQ_cmn; // Combined max. KQ values. - meta_j[1] = KQ_crs; // Combined KQ rowsums. - meta_j[2] = KQ_cms; // KQ max scales per parallel warp. + if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { + *((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum. } - if (needs_fixup && threadIdx.x < mma_B::J) { + if (needs_fixup && threadIdx.x < tile_B::I) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; - dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } - if (is_fixup && threadIdx.x < mma_B::J) { + if (is_fixup && threadIdx.x < tile_B::I) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; - dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } } @@ -404,6 +474,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int k0_stop = D/2 - (D/2) % (1*stride_k); const int stride_j = WARP_SIZE / stride_k; + if (k0_start == k0_stop) { + continue; + } + if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { break; } @@ -411,12 +485,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) { const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J; + const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I; if (!is_fixup && jt*ncols + j_dst >= ne01) { continue; } - const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2; + const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2; #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -424,8 +498,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 dstk_val = make_float2(0.0f, 0.0f); #pragma unroll for (int ip = 0; ip < np; ++ip) { - const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2]; - const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]); + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0]; + const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]); dstk_val.x += dstk_val_add.x*KQ_crs; dstk_val.y += dstk_val_add.y*KQ_crs; } @@ -450,7 +524,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); } #else - NO_DEVICE_CODE; + NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -494,6 +568,11 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { +#ifndef NEW_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // NEW_MMA_AVAILABLE + // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -504,6 +583,10 @@ static __global__ void flash_attn_ext_f16( const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const int stride_Q = nb01 / sizeof(float2); + const int stride_KV = nb11 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + const int iter_k = ne11 / KQ_stride; const int iter_j = (ne01 + (ncols - 1)) / ncols; @@ -535,14 +618,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, - jt, kb0_start, kb0_stop); + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, - jt, kb0_start, kb0_stop); + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); } kbc += iter_k; @@ -571,24 +652,27 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, - jt, kb0_start, kb0_stop); + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); } template void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - typedef mma_A_I16K8 mma_A; - typedef mma_B_J8K8 mma_B; + typedef tile<16, 8, half2> tile_A; + typedef tile< 8, 8, half2> tile_B; - static_assert(D % mma_B::K == 0, "bad D"); - static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block"); + static_assert(D % tile_B::J == 0, "bad D"); + static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block"); const ggml_tensor * KQV = dst; + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + constexpr int KQ_stride = D <= 128 ? 64 : 32; + constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? + cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8); - constexpr int KQ_stride = D <= 128 ? 64 : 32; - constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? - cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8); - constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half); + const int nrows_KQ = cp_async_available(cc) ? 2*KQ_stride : KQ_stride; + const int nrows_combine = nwarps*tile_B::J; + const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half); float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 4dbaefdbafdf4..093ad70991b5a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -178,11 +178,11 @@ static ggml_cuda_device_info ggml_cuda_init() { int major_version = 0; size_t version_length = 0; if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) { - std::string version(version_length, '\0'); + std::vector version(version_length+1, '\0'); if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) { - version.resize(::strlen(version.c_str())); + version.resize(::strlen(version.data())); int parsed_value = 0; - if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) { + if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) { major_version = parsed_value; } } @@ -1480,12 +1480,7 @@ static void ggml_cuda_op_mul_mat( const size_t nbytes_data = ggml_nbytes(src0); const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding); - // TODO: remove this for MUSA once the Guilty Lockup issue is resolved -#ifndef GGML_USE_MUSA CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream)); -#else // GGML_USE_MUSA - CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream)); -#endif // !GGML_USE_MUSA } // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared: @@ -1867,14 +1862,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor const int cc = ggml_cuda_info().devices[id].cc; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); - any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc); } } else { const int cc = ggml_cuda_info().devices[ctx.device].cc; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); - any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc); } // debug helpers @@ -2840,7 +2835,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { return false; } -#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error @@ -2852,8 +2847,10 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { } return true; #else + GGML_UNUSED(buffer); + GGML_UNUSED(size); return false; -#endif +#endif // CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) } void ggml_backend_cuda_unregister_host_buffer(void * buffer) { @@ -3205,8 +3202,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { return true; } - const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; - return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; + return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) && + op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; } case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index bbc0a35ae5616..0a5656e4cb310 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -4,11 +4,12 @@ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction // // Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C. -// A is a row-major matrix with shape I x K. -// B is a column-major matrix with shape K x J. -// C is a column-major matrix with shape I x J. -// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements. -// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile. +// A is a row-major matrix with shape M x K. +// B is a column-major matrix with shape K x N. +// C is a column-major matrix with shape M x N. +// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns. +// Note that J is measured in physical 32 bit elements instead of logical elements. +// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile. // All matrix tiles have ne physical 32 bit elements per warp. // // As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. @@ -23,7 +24,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { #ifdef NEW_MMA_AVAILABLE asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" - : "+r"(ret) : "r"(x)); + : "=r"(ret) : "r"(x)); #else NO_DEVICE_CODE; #endif // defined(NEW_MMA_AVAILABLE) @@ -52,407 +53,267 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { #endif // CUDART_VERSION >= 11080 +static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { + half2 ret; + *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x)); + return ret; +} -template -struct mma_A_I16K4 { - static_assert(sizeof(T) == 4, "bad type size"); - - static constexpr int I = 16; - static constexpr int K = 4; - static constexpr int ne = 2; - - T x[ne]; +namespace ggml_cuda_mma { + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + T x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && (J == 4 || J == 8)) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } - static __device__ __forceinline__ int get_i(const int l) { - const int ret = (l%2) * (I/2) + threadIdx.x / K; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 8 && J == 8) { + return 4 * l + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return 2 * (threadIdx.x % 4) + l % 2; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + half2 x[ne] = {{0.0f, 0.0f}}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 8) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 4) { + return l * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l % 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } - static __device__ __forceinline__ int get_k(const int /* l */) { - const int ret = threadIdx.x % K; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < K); - return ret; - } + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 8) { + return l * 4 + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 4 + threadIdx.x % 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; - __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ tile get_half2(const tile & tile_float) { + tile ret; #pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_i(l)*stride + get_k(l)]; + for (int l0 = 0; l0 < tile_float.ne; l0 += 2) { + ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); } - } - - __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { -#ifdef NEW_MMA_AVAILABLE - int * xi = (int *) x; - const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride; - asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" - : "+r"(xi[0]), "+r"(xi[1]) - : "l"(xs)); -#else - load_generic(xs0, stride); -#endif // NEW_MMA_AVAILABLE - } -}; - -template -struct mma_A_I16K8 { - static_assert(sizeof(T) == 4, "bad type size"); - - static constexpr int I = 16; - static constexpr int K = 8; - static constexpr int ne = 4; - - T x[ne]; - - static __device__ __forceinline__ int get_i(const int l) { - const int ret = (l%2) * (I/2) + threadIdx.x / (K/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); return ret; } - static __device__ __forceinline__ int get_k(const int l) { - const int ret = (l/2) * (K/2) + threadIdx.x % (K/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < K); + static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) { + tile<8, 8, half2> ret; + ret.x[0] = ggml_cuda_movmatrix(t.x[0]); + ret.x[1] = ggml_cuda_movmatrix(t.x[1]); + return ret; } - __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { #pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_i(l)*stride + get_k(l)]; + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } } - __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef NEW_MMA_AVAILABLE - int * xi = (int * ) x; - const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); - asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" - : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else - GGML_UNUSED(xs0); - GGML_UNUSED(stride); - NO_DEVICE_CODE; + load_generic(t, xs0, stride); #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef NEW_MMA_AVAILABLE - int * xi = (int * ) x; - const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); - asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" - : "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3]) + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else - GGML_UNUSED(xs0); - GGML_UNUSED(stride); - NO_DEVICE_CODE; + load_generic(xs0, stride); #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ void transpose() { - int * xi = (int *) x; - xi[0] = ggml_cuda_movmatrix(xi[0]); - - const int tmp = ggml_cuda_movmatrix(xi[1]); - xi[1] = ggml_cuda_movmatrix(xi[2]); - xi[2] = tmp; - - xi[3] = ggml_cuda_movmatrix(xi[3]); - } -}; - -template -struct mma_B_J8K4 { - static_assert(sizeof(T) == 4, "bad type size"); - - static constexpr int J = 8; - static constexpr int K = 4; - static constexpr int ne = 1; - - T x[ne]; - - static __device__ __forceinline__ int get_j(const int /* l */) { - const int ret = threadIdx.x / K; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - static __device__ __forceinline__ int get_k(const int /* l */) { - const int ret = threadIdx.x % K; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < K); - return ret; - } - - __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { -#pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_j(l)*stride + get_k(l)]; - } - } - - __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef NEW_MMA_AVAILABLE - int * xi = (int *) x; - const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride; - asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" - : "+r"(xi[0]) : "l"(xs)); + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) + : "l"(xs)); #else - load_generic(xs0, stride); + load_generic(t, xs0, stride); #endif // NEW_MMA_AVAILABLE } -}; - -template -struct mma_B_J8K8 { - static_assert(sizeof(T) == 4, "bad type size"); - - static constexpr int J = 8; - static constexpr int K = 8; - static constexpr int ne = 2; - T x[ne]; - - static __device__ __forceinline__ int get_j(const int /* l */) { - const int ret = threadIdx.x / (K/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - static __device__ __forceinline__ int get_k(const int l) { - const int ret = l * (K/2) + threadIdx.x % (K/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < K); - return ret; - } - - __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { -#pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_j(l)*stride + get_k(l)]; - } - } - - __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ void load_ldmatrix_trans( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef NEW_MMA_AVAILABLE - int * xi = (int *) x; - const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; - asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" - : "+r"(xi[0]), "+r"(xi[1]) + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); #else - load_generic(xs0, stride); + GGML_UNUSED(t); + GGML_UNUSED(xs0); + GGML_UNUSED(stride); + NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } -}; - -template -struct mma_C_I16J8 {}; - -template <> -struct mma_C_I16J8 { - static constexpr int I = 16; - static constexpr int J = 8; - static constexpr int ne = 4; - int x[ne] = {0}; - - static __device__ __forceinline__ int get_i(const int l) { - const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } - - static __device__ __forceinline__ int get_j(const int l) { - const int ret = 2 * (threadIdx.x % (J/2)) + l%2; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - __device__ __forceinline__ void mma(const mma_A_I16K4 & mma_A, const mma_B_J8K4 & mma_B) { + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) { #ifdef NEW_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" - : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0])); + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0])); #else // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead: asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[0]), "+r"(x[1]) - : "r"(mma_A.x[0]), "r"(mma_B.x[0])); + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[1]), "r"(mma_B.x[0])); + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(mma_A); - GGML_UNUSED(mma_B); + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) { #ifdef NEW_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" - : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1])); + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1])); #else // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead: asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[0]), "+r"(x[1]) - : "r"(mma_A.x[0]), "r"(mma_B.x[0])); + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[1]), "r"(mma_B.x[0])); + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[0]), "+r"(x[1]) - : "r"(mma_A.x[2]), "r"(mma_B.x[1])); + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[2]), "r"(B.x[1])); asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[3]), "r"(mma_B.x[1])); + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[3]), "r"(B.x[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(mma_A); - GGML_UNUSED(mma_B); + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } -}; - -template <> -struct mma_C_I16J8 { - static constexpr int I = 16; - static constexpr int J = 4; - static constexpr int ne = 2; - - half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}}; - - static __device__ __forceinline__ int get_i(const int l) { - const int ret = l * (I/2) + threadIdx.x / J; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } - static __device__ __forceinline__ int get_j(const int /* l */) { - const int ret = threadIdx.x % J; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { + static __device__ __forceinline__ void mma( + tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { #ifdef NEW_MMA_AVAILABLE - int * Axi = (int *) mma_A.x; - int * Bxi = (int *) mma_B.x; - int * xi = (int *) x; + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" - : "+r"(xi[0]), "+r"(xi[1]) + : "+r"(Dxi[0]), "+r"(Dxi[1]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); #else // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" - : "+r"(xi[0]), "+r"(xi[1]) + : "+r"(Dxi[0]), "+r"(Dxi[1]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" - : "+r"(xi[0]), "+r"(xi[1]) + : "+r"(Dxi[0]), "+r"(Dxi[1]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(mma_A); - GGML_UNUSED(mma_B); + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ mma_B_J8K8 to_mma_B() { - mma_B_J8K8 mma_B; - - int * xi = (int *) x; - int * Bxi = (int *) mma_B.x; - Bxi[0] = ggml_cuda_movmatrix(xi[0]); - Bxi[1] = ggml_cuda_movmatrix(xi[1]); - - return mma_B; - } -}; - -template <> -struct mma_C_I16J8 { - static constexpr int I = 16; - static constexpr int J = 8; - static constexpr int ne = 4; - - float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f}; - - static __device__ __forceinline__ int get_i(const int l) { - const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } - - static __device__ __forceinline__ int get_j(const int l) { - const int ret = 2 * (threadIdx.x % (J/2)) + l%2; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { + static __device__ __forceinline__ void mma( + tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { #ifdef NEW_MMA_AVAILABLE - int * Axi = (int *) mma_A.x; - int * Bxi = (int *) mma_B.x; - int * xi = (int *) x; + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" - : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); #else // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" - : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" - : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(mma_A); - GGML_UNUSED(mma_B); + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ mma_B_J8K8 to_mma_B() { - mma_B_J8K8 mma_B; - mma_B.x[0] = make_half2(x[0], x[1]); - mma_B.x[1] = make_half2(x[2], x[3]); - - int * Bxi = (int *) mma_B.x; - Bxi[0] = ggml_cuda_movmatrix(Bxi[0]); - Bxi[1] = ggml_cuda_movmatrix(Bxi[1]); - - return mma_B; - } - - __device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) { -#pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_j(l)*stride + get_i(l)]; - } - } -}; +} diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 45212f66c0098..10f2ebb1cb43d 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -18,7 +18,7 @@ void ggml_cuda_op_mul_mat_q( const int64_t stride00 = ne00 / ggml_blck_size(src0->type); int id = ggml_cuda_get_device(); - const int compute_capability = ggml_cuda_info().devices[id].cc; + const int cc = ggml_cuda_info().devices[id].cc; // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into @@ -27,7 +27,8 @@ void ggml_cuda_op_mul_mat_q( // The stream-k decomposition is only faster for recent NVIDIA GPUs. // Also its fixup needs to allocate a temporary buffer in the memory pool. // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. - const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11; + const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && + cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11; const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k}; switch (src0->type) { @@ -136,7 +137,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return true; } - if (cc < GGML_CUDA_CC_DP4A) { + if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) { return false; } @@ -145,8 +146,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { #endif //GGML_CUDA_FORCE_MMQ if (cc < GGML_CUDA_CC_OFFSET_AMD) { - return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } - return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 7a2c4d85b799f..0451c65f30277 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -7,6 +7,8 @@ #include #include +using namespace ggml_cuda_mma; + #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. #define MMQ_ITER_K 256 #define MMQ_NWARPS 8 @@ -86,12 +88,13 @@ struct tile_x_sizes { int sc; }; -static constexpr int get_mmq_x_max_host(const int cc) { +static int get_mmq_x_max_host(const int cc) { return new_mma_available(cc) ? 128 : + ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? #ifdef GGML_CUDA_FORCE_MMQ - cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64; + 128 : 64; #else - cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64; + MMQ_DP4A_MAX_BATCH_SIZE : 64; #endif // GGML_CUDA_FORCE_MMQ } @@ -119,8 +122,9 @@ static constexpr __device__ int get_mmq_x_max_device() { #endif // NEW_MMA_AVAILABLE } -static constexpr int get_mmq_y_host(const int cc) { - return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64); +static int get_mmq_y_host(const int cc) { + return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : + (ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64); } static constexpr __device__ int get_mmq_y_device() { @@ -645,15 +649,15 @@ template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - typedef mma_A_I16K8 mma_A; - typedef mma_B_J8K8 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + 2*WARP_SIZE; @@ -661,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const float * y_df = (const float *) y; const half2 * y_ds = (const half2 *) y; - mma_A A[ntx][WARP_SIZE/QI8_0]; - float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0]; + tile_A A[ntx][WARP_SIZE/QI8_0]; + float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; @@ -672,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { const int k0 = k00 + k01; - A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { @@ -689,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { - mma_B B; - float dB[mma_C::ne/2]; + tile_B B; + float dB[tile_C::ne/2]; - B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; @@ -710,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma(A[n][k01/QI8_0], B); + tile_C C; + mma(C, A[n][k01/QI8_0], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; } } } @@ -756,23 +760,23 @@ template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - typedef mma_A_I16K8 mma_A; - typedef mma_B_J8K8 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; const int * y_qs = (const int *) y + 4; const half2 * y_dm = (const half2 *) y; - mma_A A[ntx][WARP_SIZE/QI8_1]; - float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1]; + tile_A A[ntx][WARP_SIZE/QI8_1]; + float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; @@ -782,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { @@ -799,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { - mma_B B; - float2 dsB[mma_C::ne/2]; + tile_B B; + float2 dsB[tile_C::ne/2]; - B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); } #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma(A[n][k01/QI8_1], B); + tile_C C; + mma(C, A[n][k01/QI8_1], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; } } } @@ -866,26 +870,26 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef NEW_MMA_AVAILABLE - typedef mma_A_I16K4 mma_A; - typedef mma_A_I16K8 mma_A_K8; - typedef mma_B_J8K4 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + WARP_SIZE*2; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); - mma_A A[ntx][8]; - float dA[ntx][mma_C::ne/2][8]; + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; #pragma unroll for (int n = 0; n < ntx; ++n) { @@ -893,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { const int k0 = k00 + k01; - ((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) { @@ -910,32 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { - mma_B B[2]; - float dB[mma_C::ne/2]; + tile_B B[2]; + float dB[tile_C::ne/2]; // Here load_generic is faster than load_ldmatrix. - B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C C[2]; - C[0].mma(A[n][k01/4 + 0], B[0]); - C[1].mma(A[n][k01/4 + 1], B[1]); + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); } } } @@ -1054,27 +1058,27 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef NEW_MMA_AVAILABLE - typedef mma_A_I16K4 mma_A; - typedef mma_A_I16K8 mma_A_K8; - typedef mma_B_J8K4 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); - mma_A A[ntx][8]; - float dA[ntx][mma_C::ne/2][8]; - float mA[ntx][mma_C::ne/2][8]; + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; + float mA[ntx][tile_C::ne/2][8]; #pragma unroll for (int n = 0; n < ntx; ++n) { @@ -1082,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - ((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } } #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) { @@ -1105,58 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - float2 dB[mma_C::ne/2]; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float2 dB[tile_C::ne/2]; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); } #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { - mma_B B[2]; + tile_B B[2]; // Here load_generic is faster than load_ldmatrix. - B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); - mma_C Cm[2]; + tile_C Cm[2]; if (k01 >= WARP_SIZE * 3/4) { - mma_A A1; + tile_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; - Cm[0].mma(A1, B[0]); - Cm[1].mma(A1, B[1]); + mma(Cm[0], A1, B[0]); + mma(Cm[1], A1, B[1]); } #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C Cd[2]; + tile_C Cd[2]; - Cd[0].mma(A[n][k01/4 + 0], B[0]); - Cd[1].mma(A[n][k01/4 + 1], B[1]); + mma(Cd[0], A[n][k01/4 + 0], B[0]); + mma(Cd[1], A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { + for (int l = 0; l < tile_C::ne; ++l) { float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; if (k01 >= WARP_SIZE * 3/4) { tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; } - sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y); } } } #pragma unroll for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) { - float2 sB[mma_C::ne/2]; + float2 sB[tile_C::ne/2]; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } @@ -1164,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; - sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; } } } @@ -1706,15 +1710,15 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef NEW_MMA_AVAILABLE - typedef mma_A_I16K4 mma_A; - typedef mma_B_J8K4 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 4, int> tile_A; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + WARP_SIZE*2; @@ -1722,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); - mma_A A[ntx][8]; - int scA[ntx][mma_C::ne/2][8]; - float dA[ntx][mma_C::ne/2]; + tile_A A[ntx][8]; + int scA[ntx][tile_C::ne/2][8]; + float dA[ntx][tile_C::ne/2]; #pragma unroll for (int n = 0; n < ntx; ++n) { @@ -1734,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { const int k0 = k00 + k01; - A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); - A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll @@ -1743,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int k0 = k00 + k01; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16]; const int8_t * sc = (const int8_t *) &sc_packed; @@ -1757,41 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K]; } } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - float tmp[ntx][mma_C::ne] = {{0.0f}}; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float tmp[ntx][tile_C::ne] = {{0.0f}}; #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { - mma_B B[2]; - float dB[mma_C::ne/2]; + tile_B B[2]; + float dB[tile_C::ne/2]; // Here load_generic is faster than load_ldmatrix. - B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); - B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K); + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C C[2]; - C[0].mma(A[n][k01/4 + 0], B[0]); - C[1].mma(A[n][k01/4 + 1], B[1]); + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { + for (int l = 0; l < tile_C::ne; ++l) { tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2]; } } @@ -1800,8 +1804,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2]; + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2]; } } } @@ -2310,36 +2314,36 @@ template static __device__ __forceinline__ void mmq_write_back_mma( const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { - typedef mma_C_I16J8 mma_C; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I); + const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); #ifdef NEW_MMA_AVAILABLE - static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); + static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); #endif // NEW_MMA_AVAILABLE #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne; ++l) { + const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l); if (j > j_max) { continue; } - const int i = i0 + n*mma_C::I + mma_C::get_i(l); + const int i = i0 + n*tile_C::I + tile_C::get_i(l); if (need_check && i > i_max) { continue; } - dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l]; + dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; } } } @@ -2828,7 +2832,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc); const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; - const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD; + const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD; int mmq_x_best = 0; int nparts_best = INT_MAX; diff --git a/ggml/src/ggml-cuda/sum.cu b/ggml/src/ggml-cuda/sum.cu index e0dafc1d2049d..f9589080a0c3b 100644 --- a/ggml/src/ggml-cuda/sum.cu +++ b/ggml/src/ggml-cuda/sum.cu @@ -1,6 +1,6 @@ -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700 +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 #define USE_CUB -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700 +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 #ifdef USE_CUB #include diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 944d90af34432..087e7f58149f6 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -24,7 +24,7 @@ #endif // create residency sets only on macOS >= 15.0 -#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ +#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \ TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \ TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000 @@ -1983,7 +1983,7 @@ static void ggml_metal_encode_node( const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // TODO: add ggml_metal_kargs struct - // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) + // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 44f04c909bfb2..83e7ac9f411ef 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -373,24 +373,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg template void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { const half d_all = xb->d; - device const uint8_t * ql = (device const uint8_t *)xb->ql; - device const uint8_t * qh = (device const uint8_t *)xb->qh; + device const uint16_t * ql = (device const uint16_t *)xb->ql; + device const uint16_t * qh = (device const uint16_t *)xb->qh; device const int8_t * scales = (device const int8_t *)xb->scales; - ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); - qh = qh + 32*(il/8) + 16*(il&1); + ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1); + qh = qh + 16*(il/8) + 8*(il&1); float sc = scales[(il%2) + 2 * ((il/2))]; il = (il/2) & 3; - const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const float coef = il>1 ? 1.f/16.f : 1.f; + const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303); + const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F; const float ml = d_all * sc * 32.f; - const float dl = d_all * sc * coef; - for (int i = 0; i < 16; ++i) { - const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) - : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); - reg[i/4][i%4] = dl * q - ml; + const float dl0 = d_all * sc; + const float dl1 = dl0 / 256.f; + const float dl2 = dl0 / (256.f * 256.f); + const float dl3 = dl0 / (256.f * 256.f * 256.f); + const uint8_t shr_h = il>2 ? 2 : 0; + const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4); + const uint8_t shr_l = il>1 ? 4 : 0; + for (int i = 0; i < 4; ++i) { + const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2; + const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1; + const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l); + reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml; + reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml; + reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml; + reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml; } } @@ -1058,7 +1067,7 @@ kernel void kernel_soft_max( } // This barrier fixes a failing test - // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335 threadgroup_barrier(mem_flags::mem_none); float sum = simd_sum(lsum); @@ -1163,7 +1172,7 @@ kernel void kernel_soft_max_4( const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; // This barrier fixes a failing test - // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335 threadgroup_barrier(mem_flags::mem_none); float sum = simd_sum(lsum); diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index ed90e471ac0f8..7a0f94cf24cc2 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -143,6 +143,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_rms_norm; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; cl_kernel kernel_soft_max, kernel_soft_max_4; + cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; @@ -614,6 +615,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf_8", &err), err)); CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program, "kernel_soft_max", &err), err)); CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4", &err), err)); + CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4_f16", &err), err)); CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f32", &err), err)); CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); @@ -1044,8 +1047,16 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return true; case GGML_OP_DIAG_MASK_INF: return op->ne[3] == 1; - case GGML_OP_ROPE: + case GGML_OP_ROPE: { + const int mode = ((const int32_t *) op->op_params)[2]; + if (mode & GGML_ROPE_TYPE_MROPE) { + return false; + } + if (mode & GGML_ROPE_TYPE_VISION) { + return false; + } return true; + } default: return false; } @@ -3666,6 +3677,8 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + // Local size must be wave size. Each workgroup is a wave, working on a row, // where a row corresponds to leading dimension. int nth = MIN(32, ne00); @@ -3683,9 +3696,17 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c cl_kernel kernel; if (ne00%4 == 0) { - kernel = backend_ctx->kernel_soft_max_4; + if (use_f16) { + kernel = backend_ctx->kernel_soft_max_4_f16; + } else { + kernel = backend_ctx->kernel_soft_max_4; + } } else { - kernel = backend_ctx->kernel_soft_max; + if (use_f16) { + kernel = backend_ctx->kernel_soft_max_f16; + } else { + kernel = backend_ctx->kernel_soft_max; + } } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); @@ -3766,7 +3787,8 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const const int nb2 = dst ? dst->nb[2] : 0; const int nb3 = dst ? dst->nb[3] : 0; - GGML_ASSERT(ne10 == ne02); + GGML_ASSERT(ne10 % ne02 == 0); + GGML_ASSERT(ne10 >= ne02); int nth = MIN(64, ne00); diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl index d1cdf709babc5..d3cfb2f91e130 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl @@ -679,6 +679,9 @@ kernel void kernel_diag_mask_inf_8( //------------------------------------------------------------------------------ // softmax //------------------------------------------------------------------------------ +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif kernel void kernel_soft_max( global float * src0, ulong offset0, @@ -811,6 +814,141 @@ kernel void kernel_soft_max_4( } } +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} + //------------------------------------------------------------------------------ // kernel_rope //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index bffe95086af7d..131ee1ea044dd 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -222,6 +222,7 @@ struct vk_device_struct { vk_pipeline pipeline_acc_f32; vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat; vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat; + vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat; vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat; vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat; vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; @@ -232,7 +233,7 @@ struct vk_device_struct { vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; - vk_pipeline pipeline_repeat_f32; + vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; @@ -251,12 +252,17 @@ struct vk_device_struct { vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; + vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; + vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_argmax_f32; + vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; + vk_pipeline pipeline_opt_step_adamw_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; @@ -494,6 +500,10 @@ struct vk_op_rope_push_constants { float corr_dims[2]; float theta_scale; uint32_t has_ff; + uint32_t ne02; + uint32_t s1; + uint32_t s2; + int32_t sections[4]; }; struct vk_op_soft_max_push_constants { @@ -1385,6 +1395,10 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec uint32_t lut_size = 0; switch (src0_type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + lut_size = 2*2048; + break; case GGML_TYPE_IQ2_XXS: lut_size = 8*256; break; @@ -1430,6 +1444,7 @@ static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); // some shaders have a minimum subgroup size + const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); @@ -1492,13 +1507,13 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; - l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; - m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; - s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; - l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; - m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; - s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; + l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; @@ -1622,6 +1637,8 @@ static void ggml_vk_load_shaders(vk_device& device) { //CREATE_FA(GGML_TYPE_Q4_K, q4_k) //CREATE_FA(GGML_TYPE_Q5_K, q5_k) //CREATE_FA(GGML_TYPE_Q6_K, q6_k) + //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s) + //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m) //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs) //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs) //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s) @@ -1656,6 +1673,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) @@ -1675,6 +1694,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) @@ -1729,6 +1750,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1748,6 +1771,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1773,13 +1798,15 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } else { CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -1792,13 +1819,15 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } #undef CREATE_MM2 #undef CREATE_MM @@ -1841,6 +1870,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1864,6 +1895,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -1905,13 +1938,15 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); @@ -1928,13 +1963,15 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM } @@ -1964,6 +2001,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); @@ -1984,6 +2023,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); @@ -2005,6 +2046,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); @@ -2025,6 +2068,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s", dequant_iq1_s_len, dequant_iq1_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m", dequant_iq1_m_len, dequant_iq1_m_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); @@ -2041,6 +2086,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -2056,6 +2103,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -2102,6 +2151,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); @@ -2124,6 +2175,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -2141,19 +2193,29 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } else { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); @@ -2167,6 +2229,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + for (auto &c : compiles) { c.wait(); } @@ -3008,6 +3072,8 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -3062,6 +3128,8 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -3099,6 +3167,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -3148,6 +3218,8 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -3180,6 +3252,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -3722,6 +3796,12 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr } } +static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")"); + + ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); +} + static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); @@ -5128,6 +5208,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16; } return nullptr; + case GGML_OP_SUB: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32; + } + return nullptr; case GGML_OP_MUL: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32; @@ -5189,6 +5274,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_repeat_f32; } return nullptr; + case GGML_OP_REPEAT_BACK: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_repeat_back_f32; + } + return nullptr; case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: @@ -5258,6 +5348,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const { const int mode = ((const int32_t *) dst->op_params)[2]; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_neox) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { @@ -5266,6 +5358,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_rope_neox_f16; } + } else if (is_mrope && !is_vision) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_multi_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_multi_f16; + } + } else if (is_vision) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_vision_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_vision_f16; + } } else { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_rope_norm_f32; @@ -5281,11 +5387,22 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_argsort_f32; } return nullptr; + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_sum_rows_f32; } return nullptr; + case GGML_OP_ARGMAX: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_argmax_f32; + } + return nullptr; + case GGML_OP_COUNT_EQUAL: + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) { + return ctx->device->pipeline_count_equal_i32; + } + return nullptr; case GGML_OP_IM2COL: if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_im2col_f32; @@ -5309,6 +5426,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv6_f32; } return nullptr; + case GGML_OP_OPT_STEP_ADAMW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_opt_step_adamw_f32; + } + return nullptr; case GGML_OP_LEAKY_RELU: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_leaky_relu_f32; @@ -5326,6 +5448,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_CPY: case GGML_OP_GET_ROWS: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -5336,6 +5459,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_ROPE: return true; default: return false; @@ -5549,6 +5674,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_RMS_NORM: case GGML_OP_SOFT_MAX: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: { const uint32_t nr = ggml_nrows(src0); if (nr > 262144) { @@ -5559,6 +5685,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_SUM: + // We use GGML_OP_SUM_ROWS with 1 row. + elements = { 1, 1, 1 }; + break; case GGML_OP_GROUP_NORM: { const uint32_t num_groups = dst->op_params[0]; @@ -5605,6 +5735,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { N * OC * OH * OW, 1, 1}; } break; case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_DIV: case GGML_OP_MUL: case GGML_OP_SCALE: @@ -5614,6 +5745,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: case GGML_OP_CPY: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: @@ -5674,6 +5806,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co // im2col uses only src1 and dst buffers ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (op == GGML_OP_COUNT_EQUAL) { + ggml_vk_sync_buffers(subctx); + // count_equal assumes that destination buffer is initialized with zeroes + ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); } else if (use_src2) { ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); @@ -5736,6 +5874,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } +static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); @@ -5894,6 +6047,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ); } +static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) { + const ggml_tensor * x = dst->src[0]; + const ggml_tensor * g = dst->src[1]; + const ggml_tensor * gm = dst->src[2]; + const ggml_tensor * gv = dst->src[3]; + const ggml_tensor * p = dst->src[4]; + + GGML_ASSERT(x->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(gm->type == GGML_TYPE_F32); + GGML_ASSERT(gv->type == GGML_TYPE_F32); + GGML_ASSERT(p->type == GGML_TYPE_F32); + GGML_ASSERT(dst->buffer != nullptr); + GGML_ASSERT(ggml_is_contiguous(x)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(gm)); + GGML_ASSERT(ggml_is_contiguous(gv)); + GGML_ASSERT(ggml_is_contiguous(p)); + GGML_ASSERT(ggml_are_same_shape(x, g)); + GGML_ASSERT(ggml_are_same_shape(x, gm)); + GGML_ASSERT(ggml_are_same_shape(x, gv)); + GGML_ASSERT(ggml_nelements(p) == 7); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context; + ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context; + ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context; + ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context; + ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context; + + ggml_vk_sync_buffers(subctx); + + vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr; + size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0; + bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, x->data, d_X, x_offset); + ggml_vk_host_get(ctx->device, g->data, d_G, g_offset); + ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset); + ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset); + ggml_vk_host_get(ctx->device, p->data, d_P, p_offset); + + X_uma = d_X != nullptr; + G_uma = d_G != nullptr; + GM_uma = d_GM != nullptr; + GV_uma = d_GV != nullptr; + P_uma = d_P != nullptr; + } + + if (!X_uma) { + d_X = x_buf_ctx->dev_buffer; + x_offset = vk_tensor_offset(x) + x->view_offs; + } + if (!G_uma) { + d_G = g_buf_ctx->dev_buffer; + g_offset = vk_tensor_offset(g) + g->view_offs; + } + if (!GM_uma) { + d_GM = gm_buf_ctx->dev_buffer; + gm_offset = vk_tensor_offset(gm) + gm->view_offs; + } + if (!GV_uma) { + d_GV = gv_buf_ctx->dev_buffer; + gv_offset = vk_tensor_offset(gv) + gv->view_offs; + } + if (!P_uma) { + d_P = p_buf_ctx->dev_buffer; + p_offset = vk_tensor_offset(p) + p->view_offs; + } + + const uint64_t x_size = ggml_nbytes(x); + const uint64_t g_size = ggml_nbytes(g); + const uint64_t gm_size = ggml_nbytes(gm); + const uint64_t gv_size = ggml_nbytes(gv); + const uint64_t p_size = ggml_nbytes(p); + + std::array elements = { (uint32_t)ggml_nelements(x), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_X, x_offset, x_size }, + vk_subbuffer{ d_G, g_offset, g_size }, + vk_subbuffer{ d_GM, gm_offset, gm_size }, + vk_subbuffer{ d_GV, gv_offset, gv_size }, + vk_subbuffer{ d_P, p_offset, p_size }, + }, sizeof(vk_op_push_constants), &pc, elements); +} + +static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t n = ggml_nelements(dst->src[0]); + + ggml_vk_op_f32_opt_step_adamw( + ctx, subctx, dst, + { (uint32_t)n, 0, 0.0f, 0.0f }, + dryrun + ); +} + static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { int * op_params = (int *)dst->op_params; @@ -6027,6 +6285,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co }, dryrun); } +static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -6100,7 +6372,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { const int n_dims = ((int32_t *) dst->op_params)[1]; - // const int mode = ((int32_t *) dst->op_params)[2]; + const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; const float freq_base = ((float *) dst->op_params)[5]; @@ -6109,16 +6381,24 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons const float attn_factor = ((float *) dst->op_params)[8]; const float beta_fast = ((float *) dst->op_params)[9]; const float beta_slow = ((float *) dst->op_params)[10]; + int sections[4] {}; + if (mode & GGML_ROPE_TYPE_MROPE) { + memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4); + } float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims); + uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type); + uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type); + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, - src2 != nullptr, + src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, + sections[0], sections[1], sections[2], sections[3], }, dryrun); } @@ -6141,10 +6421,22 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c }, dryrun); } +static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); } +static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const int32_t s0 = dst->op_params[0]; const int32_t s1 = dst->op_params[1]; @@ -7009,9 +7301,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod } break; case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: case GGML_OP_GET_ROWS: case GGML_OP_ADD: case GGML_OP_ACC: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -7034,13 +7328,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_OPT_STEP_ADAMW: break; default: std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; @@ -7061,9 +7359,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod } else { switch (node->op) { case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: case GGML_OP_ACC: case GGML_OP_GET_ROWS: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -7085,7 +7385,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: @@ -7106,6 +7409,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_REPEAT: ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_REPEAT_BACK: + ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_ACC: ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun); @@ -7118,6 +7425,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_ADD: ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_SUB: + ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_MUL: ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun); @@ -7205,10 +7516,22 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_ARGSORT: ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_SUM: + ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_SUM_ROWS: ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_ARGMAX: + ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_COUNT_EQUAL: + ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_IM2COL: ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); @@ -7243,6 +7566,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RWKV_WKV6: ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); + break; + + case GGML_OP_OPT_STEP_ADAMW: + ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); + break; default: return false; @@ -7294,6 +7622,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_ADD: case GGML_OP_ACC: case GGML_OP_GET_ROWS: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -7319,13 +7648,18 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_TRANSPOSE: case GGML_OP_NONE: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_OPT_STEP_ADAMW: buf = tensor->buffer; break; @@ -7517,6 +7851,15 @@ static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggm } } +static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + uint32_t val32 = (uint32_t)value * 0x01010101; + ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size); +} + static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; @@ -7561,7 +7904,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, /* .get_base = */ ggml_backend_vk_buffer_get_base, /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, - /* .memset_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, @@ -8056,6 +8399,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -8130,6 +8475,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_Q4_K: //case GGML_TYPE_Q5_K: //case GGML_TYPE_Q6_K: + //case GGML_TYPE_IQ1_S: + //case GGML_TYPE_IQ1_M: //case GGML_TYPE_IQ2_XXS: //case GGML_TYPE_IQ2_XS: //case GGML_TYPE_IQ2_S: @@ -8153,6 +8500,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -8208,17 +8557,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } break; case GGML_OP_REPEAT: return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + case GGML_OP_REPEAT_BACK: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } - return ggml_is_contiguous(op->src[0]); - } case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -8231,6 +8572,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: case GGML_OP_ACC: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -8244,12 +8586,16 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: + case GGML_OP_OPT_STEP_ADAMW: return true; default: return false; @@ -8522,8 +8868,6 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { ggml_tensor * src0 = tensor->src[0]; ggml_tensor * src1 = tensor->src[1]; - ggml_tensor * src2 = tensor->src[2]; - ggml_tensor * src3 = tensor->src[3]; struct ggml_init_params iparams = { /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, @@ -8533,238 +8877,113 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { struct ggml_context * ggml_ctx = ggml_init(iparams); - struct ggml_tensor * src0_clone = nullptr; - struct ggml_tensor * src1_clone = nullptr; - struct ggml_tensor * src2_clone = nullptr; - struct ggml_tensor * src3_clone = nullptr; - struct ggml_tensor * tensor_clone = nullptr; - - size_t src0_size; - size_t src1_size; - size_t src2_size; - size_t src3_size; - - void * src0_buffer = nullptr; - void * src1_buffer = nullptr; - void * src2_buffer = nullptr; - void * src3_buffer = nullptr; - - if (src0 != nullptr) { - src0_clone = ggml_dup_tensor(ggml_ctx, src0); - - src0_size = ggml_nbytes(src0); + std::array src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + std::array src_size = {0, 0, 0, 0, 0, 0}; + std::array src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"}; - src0_buffer = malloc(src0_size); - src0_clone->data = src0_buffer; - if (ggml_backend_buffer_is_host(src0->buffer)) { - memcpy(src0_clone->data, src0->data, src0_size); - memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(src0->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; - vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(src0) + src0->view_offs; - if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) { - for (int i3 = 0; i3 < src0->ne[3]; i3++) { - for (int i2 = 0; i2 < src0->ne[2]; i2++) { - const int idx = i3*src0->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]); - } - } - - src0_clone->nb[0] = src0->nb[0]; - src0_clone->nb[1] = src0->nb[1]; - for (int i = 2; i < GGML_MAX_DIMS; i++) { - src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1]; - } - } else { - if (offset + src0_size >= buffer_gpu->size) { - src0_size = buffer_gpu->size - offset; - } - ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size); - memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); - } - } else { - GGML_ABORT("fatal error"); - } - - if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(src0, "src0"); - } - } - if (src1 != nullptr) { - src1_clone = ggml_dup_tensor(ggml_ctx, src1); - - src1_size = ggml_nbytes(src1); - - src1_buffer = malloc(src1_size); - src1_clone->data = src1_buffer; - if (ggml_backend_buffer_is_host(src1->buffer)) { - memcpy(src1_clone->data, src1->data, src1_size); - memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(src1->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; - vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(src1) + src1->view_offs; - if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) { - for (int i3 = 0; i3 < src1->ne[3]; i3++) { - for (int i2 = 0; i2 < src1->ne[2]; i2++) { - const int idx = i3*src1->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]); - } - } - - src1_clone->nb[0] = src1->nb[0]; - src1_clone->nb[1] = src1->nb[1]; - for (int i = 2; i < GGML_MAX_DIMS; i++) { - src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1]; - } - } else { - if (offset + src1_size >= buffer_gpu->size) { - src1_size = buffer_gpu->size - offset; - } - ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size); - memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); - } - } else { - GGML_ABORT("fatal error"); - } - - if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(src1, "src1"); - } - } - if (src2 != nullptr) { - src2_clone = ggml_dup_tensor(ggml_ctx, src2); - - src2_size = ggml_nbytes(src2); - - src2_buffer = malloc(src2_size); - src2_clone->data = src2_buffer; - if (ggml_backend_buffer_is_host(src2->buffer)) { - memcpy(src2_clone->data, src2->data, src2_size); - memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(src2->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context; - vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(src2) + src2->view_offs; - if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) { - for (int i3 = 0; i3 < src2->ne[3]; i3++) { - for (int i2 = 0; i2 < src2->ne[2]; i2++) { - const int idx = i3*src2->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]); - } - } - - src2_clone->nb[0] = src2->nb[0]; - src2_clone->nb[1] = src2->nb[1]; - for (int i = 2; i < GGML_MAX_DIMS; i++) { - src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1]; - } - } else { - if (offset + src2_size >= buffer_gpu->size) { - src2_size = buffer_gpu->size - offset; - } - ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size); - memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); - } - } else { - GGML_ABORT("fatal error"); - } + struct ggml_tensor * tensor_clone = nullptr; - if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(src2, "src2"); + for (int i = 0; i < 6; i++) { + ggml_tensor * srci = tensor->src[i]; + if (srci == nullptr) { + continue; } - } - if (src3 != nullptr) { - src3_clone = ggml_dup_tensor(ggml_ctx, src3); + ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci); + size_t srci_size = ggml_nbytes(srci); - src3_size = ggml_nbytes(src3); + src_clone[i] = srci_clone; + src_size[i] = ggml_nbytes(srci); + src_buffer[i] = malloc(srci_size); - src3_buffer = malloc(src3_size); - src3_clone->data = src3_buffer; - if (ggml_backend_buffer_is_host(src3->buffer)) { - memcpy(src3_clone->data, src3->data, src3_size); - memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(src3->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context; + srci_clone->data = src_buffer[i]; + if (ggml_backend_buffer_is_host(srci->buffer)) { + memcpy(srci_clone->data, srci->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(srci->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context; vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(src3) + src3->view_offs; - if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) { - for (int i3 = 0; i3 < src3->ne[3]; i3++) { - for (int i2 = 0; i2 < src3->ne[2]; i2++) { - const int idx = i3*src3->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]); + uint64_t offset = vk_tensor_offset(srci) + srci->view_offs; + if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) { + for (int i3 = 0; i3 < srci->ne[3]; i3++) { + for (int i2 = 0; i2 < srci->ne[2]; i2++) { + const int idx = i3*srci->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]); } } - src3_clone->nb[0] = src3->nb[0]; - src3_clone->nb[1] = src3->nb[1]; + srci_clone->nb[0] = srci->nb[0]; + srci_clone->nb[1] = srci->nb[1]; for (int i = 2; i < GGML_MAX_DIMS; i++) { - src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1]; + srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1]; } } else { - if (offset + src3_size >= buffer_gpu->size) { - src3_size = buffer_gpu->size - offset; + if (offset + srci_size >= buffer_gpu->size) { + srci_size = buffer_gpu->size - offset; } - ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size); - memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); + ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); } } else { GGML_ABORT("fatal error"); } if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(src3, "src3"); + ggml_vk_print_tensor(srci, srci_name[i]); } } if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { const float *params = (const float *)tensor->op_params; - tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]); + tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); } else if (tensor->op == GGML_OP_MUL_MAT) { - tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_MUL_MAT_ID) { - tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone); + tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } else if (tensor->op == GGML_OP_SUB) { + tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_MUL) { - tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_DIV) { - tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_CONCAT) { - tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params); + tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_SCALE) { - tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]); + tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]); } else if (tensor->op == GGML_OP_SQR) { - tensor_clone = ggml_sqr(ggml_ctx, src0_clone); + tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SIN) { - tensor_clone = ggml_sin(ggml_ctx, src0_clone); + tensor_clone = ggml_sin(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_COS) { - tensor_clone = ggml_cos(ggml_ctx, src0_clone); + tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_CLAMP) { - tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_PAD) { - tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]); + tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); } else if (tensor->op == GGML_OP_REPEAT) { - tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor); + tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); + } else if (tensor->op == GGML_OP_REPEAT_BACK) { + tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor); } else if (tensor->op == GGML_OP_ADD) { - tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ACC) { - tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); } else if (tensor->op == GGML_OP_NORM) { - tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); + tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { - tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); + tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_RMS_NORM) { - tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_SOFT_MAX) { if (src1 != nullptr) { - tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else { - tensor_clone = ggml_soft_max(ggml_ctx, src0_clone); + tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); } } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { - tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params); + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_ROPE) { const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; @@ -8776,23 +8995,28 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const float attn_factor = ((float *) tensor->op_params)[8]; const float beta_fast = ((float *) tensor->op_params)[9]; const float beta_slow = ((float *) tensor->op_params)[10]; - tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + if (mode & GGML_ROPE_TYPE_MROPE) { + int32_t *sections = ((int32_t *) tensor->op_params) + 11; + tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } } else if (tensor->op == GGML_OP_UNARY) { switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_SILU: - tensor_clone = ggml_silu(ggml_ctx, src0_clone); + tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); break; case GGML_UNARY_OP_GELU: - tensor_clone = ggml_gelu(ggml_ctx, src0_clone); + tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]); break; case GGML_UNARY_OP_GELU_QUICK: - tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone); + tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]); break; case GGML_UNARY_OP_RELU: - tensor_clone = ggml_relu(ggml_ctx, src0_clone); + tensor_clone = ggml_relu(ggml_ctx, src_clone[0]); break; case GGML_UNARY_OP_TANH: - tensor_clone = ggml_tanh(ggml_ctx, src0_clone); + tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]); break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; @@ -8800,28 +9024,34 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { if (src1 == nullptr) { - tensor_clone = ggml_dup(ggml_ctx, src0_clone); + tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); tensor_clone->type = tensor->type; } else { - tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]); } } else if (tensor->op == GGML_OP_CONT) { - tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_RESHAPE) { - tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_VIEW) { - tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); + tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); } else if (tensor->op == GGML_OP_PERMUTE) { int32_t * params = (int32_t *)tensor->op_params; - tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]); + tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]); } else if (tensor->op == GGML_OP_TRANSPOSE) { - tensor_clone = ggml_transpose(ggml_ctx, src0_clone); + tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_GET_ROWS) { - tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ARGSORT) { - tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params); + tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_SUM) { + tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SUM_ROWS) { - tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone); + tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_ARGMAX) { + tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_COUNT_EQUAL) { + tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_IM2COL) { const int32_t s0 = tensor->op_params[0]; const int32_t s1 = tensor->op_params[1]; @@ -8831,11 +9061,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const int32_t d1 = tensor->op_params[5]; const bool is_2D = tensor->op_params[6] == 1; - tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { const int32_t dim = tensor->op_params[0]; const int32_t max_period = tensor->op_params[1]; - tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period); + tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period); } else if (tensor->op == GGML_OP_POOL_2D) { enum ggml_op_pool op = static_cast(tensor->op_params[0]); const int32_t k0 = tensor->op_params[1]; @@ -8845,13 +9075,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const int32_t p0 = tensor->op_params[5]; const int32_t p1 = tensor->op_params[6]; - tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1); + tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; - tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); + tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); } else if (tensor->op == GGML_OP_RWKV_WKV6) { - tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], - tensor->src[4], tensor->src[5]); + tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { + src_clone[0]->flags = src0->flags; + tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4]); } else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; @@ -8873,11 +9107,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { memcpy(comp_result, tensor_clone->data, comp_size); memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); - if (src0 != nullptr) { - free(src0_buffer); - } - if (src1 != nullptr) { - free(src1_buffer); + for (int i = 0; i < 6; i++) { + if (src_buffer[i] != nullptr) { + free(src_buffer[i]); + } } ggml_free(ggml_ctx); @@ -8941,6 +9174,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { } else if (tensor->type == GGML_TYPE_I32) { correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_I64) { + correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); } else { std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp new file mode 100644 index 0000000000000..eaf4da341e348 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmpmax[BLOCK_SIZE]; +shared uint tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + if (col >= p.KX) { + return; + } + A_TYPE amax = data_a[row*p.KX + col]; + tmp[col] = col; + + for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) { + A_TYPE val = data_a[row*p.KX + i]; + if (val > amax) { + amax = val; + tmp[col] = i; + } + } + tmpmax[col] = amax; + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s && col + s < p.KX) { + if (tmpmax[col] < tmpmax[col + s]) { + tmpmax[col] = tmpmax[col + s]; + tmp[col] = tmp[col + s]; + } + } + barrier(); + } + + if (col == 0) { + data_d[row] = D_TYPE(tmp[0]); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index 9c9fe9626dbba..dbc7daa3328f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -12,7 +12,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); if (gl_LocalInvocationIndex.x != 0) { return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index 660811086d613..c813f14044eca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -217,7 +217,7 @@ void quantize(uint dst_idx, uint src_idx) #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); if (gl_LocalInvocationIndex.x != 0) { return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp new file mode 100644 index 0000000000000..d9345497c73fd --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp @@ -0,0 +1,31 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "types.comp" +#include "generic_head.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +const uint CHUNK_SIZE = 512; + +void main() { + const uint base = gl_WorkGroupID.x * CHUNK_SIZE; + const uint col = gl_LocalInvocationID.x; + + uint count = 0; + [[unroll]] + for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) { + const uint idx = base + i + col; + if (idx >= p.KX) { + break; + } + count += uint(data_a[idx] == data_b[idx]); + } + + atomicAdd(data_d[0], D_TYPE(count)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index ecfdbfaa88cec..10318e8766074 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -88,6 +88,83 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_IQ1_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + const int i8 = int(iqs % 8); + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + // Signed bitfield extract. + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + return dl * (vec2(gvec) + delta); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + const int i8 = int(iqs % 8); + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1; + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + // Signed bitfield extract. + const ivec4 gvec = ivec4( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2), + bitfieldExtract(grid, 2 * (i8 + 2), 2), + bitfieldExtract(grid, 2 * (i8 + 3), 2) + ); + return dl * (vec4(gvec) + delta); +} +#endif + +#if defined(DATA_A_IQ1_M) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib8 = iqs / 8; + const uint ib16 = iqs / 16; + const int i8 = int(iqs % 8); + const uint sc = data_a[a_offset + ib].scales[iqs / 64]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Signed bitfield extract. + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + return dl * (vec2(gvec) + delta); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib8 = iqs / 8; + const uint ib16 = iqs / 16; + const int i8 = int(iqs % 8); + const uint sc = data_a[a_offset + ib].scales[iqs / 64]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Signed bitfield extract. + const ivec4 gvec = ivec4( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2), + bitfieldExtract(grid, 2 * (i8 + 2), 2), + bitfieldExtract(grid, 2 * (i8 + 3), 2) + ); + return dl * (vec4(gvec) + delta); +} +#endif + #if defined(DATA_A_IQ2_XXS) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint ib32 = iqs / 32; @@ -357,7 +434,16 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_IQ1_M) +vec2 get_dm(uint ib, uint a_offset) { + const uint16_t[4] scales = data_a[a_offset + ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + return vec2(d, 0); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) vec2 get_dm(uint ib, uint a_offset) { return vec2(float(data_a[a_offset + ib].d), 0); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 0eba37420118c..4770469eddcab 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -301,6 +301,56 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2 return ret; } +#if defined(DATA_A_IQ1_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S { + block_iq1_s block; +}; + +float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = idx / 32; + const uint ib8 = idx / 8; + + const uint qh = bl.block.qh[ib32]; + const uint qs = bl.block.qs[ib8]; + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]; + + float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta)); + return ret; +} +#endif + +#if defined(DATA_A_IQ1_M) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M { + block_iq1_m block; +}; + +float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12; + const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12)); + const uint idx = coordInBlock[1]; + + const uint ib8 = idx / 8; + const uint ib16 = idx / 16; + const int i8 = int(idx % 8); + const uint sc = bl.block.scales[ib8 / 8]; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | ((qh & 7) << 8)]; + + float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta)); + return ret; +} +#endif + #if defined(DATA_A_IQ2_XXS) layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS { block_iq2_xxs block; @@ -512,6 +562,10 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncQ5_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K +#elif defined(DATA_A_IQ1_S) +#define dequantFuncA dequantFuncIQ1_S +#elif defined(DATA_A_IQ1_M) +#define dequantFuncA dequantFuncIQ1_M #elif defined(DATA_A_IQ2_XXS) #define dequantFuncA dequantFuncIQ2_XXS #elif defined(DATA_A_IQ2_XS) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp new file mode 100644 index 0000000000000..39184ef582355 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq1_m data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint ib64 = ib32 / 2; + const uint b_idx = 256 * ib + 32 * ib32; + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint sc = data_a[ib].scales[ib64]; + [[unroll]] for (int l = 0; l < 4; ++l) { + const uint ib16 = 2 * ib32 + l / 2; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1)); + const uint qs = data_a[ib].qs[4 * ib32 + l]; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + [[unroll]] for (int j = 0; j < 8; ++j) { + data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta)); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp new file mode 100644 index 0000000000000..fd1e4e30d252b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq1_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + uint qh = data_a[ib].qh[ib32]; + const float d = float(data_a[ib].d); + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qs = data_a[ib].qs[4 * ib32 + l]; + const uint hi = bitfieldExtract(qh, 3 * int(l), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]); + [[unroll]] for (int j = 0; j < 8; ++j) { + data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta)); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index ba88ce79a21ae..df30355f635b8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -104,7 +104,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index c16a2a9f605c5..c9f855687dcb5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -12,7 +12,7 @@ void main() { const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index d7e99727db184..31ecd9f81a85f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -133,7 +133,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void main() { const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp new file mode 100644 index 0000000000000..e4acbd4f96261 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -0,0 +1,82 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint16_t[4] scales = data_a[ibi].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1)); + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1)); + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1); + + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int k = 0; k < 4; ++k) { + sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, + fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); + } + temp[j][n] = fma(dl, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp new file mode 100644 index 0000000000000..309da0991ae63 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -0,0 +1,79 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint qh = data_a[ibi].qh[ib32]; + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int k = 0; k < 4; ++k) { + sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, + fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); + } + temp[j][n] = fma(dl, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 33b2234e71df0..39657195cfc8d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -6,6 +6,9 @@ #ifdef FLOAT16 #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif +#if defined(DATA_A_IQ1_M) +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif #ifdef COOPMAT #extension GL_KHR_cooperative_matrix : enable @@ -95,7 +98,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif @@ -437,6 +440,56 @@ void main() { buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +#elif defined(DATA_A_IQ1_S) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint ib8 = (idx % 128) / 4; + const int i8 = 2 * int(idx % 4); + + const float d = float(data_a[ib].d); + const uint qh = data_a[ib].qh[ib32]; + const uint qs = data_a[ib].qs[ib8]; + const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + const vec2 v = dl * (vec2(gvec) + delta); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ1_M) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib8 = (idx % 128) / 4; + const uint ib16 = ib8 / 2; + const int i8 = 2 * int(idx % 4); + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + const uint sc = scales[ib8 / 8]; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + const vec2 v = dl * (vec2(gvec) + delta); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 7e29bbfec7b33..66dd2c860d82d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -106,7 +106,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp new file mode 100644 index 0000000000000..e0214fe7645c2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) buffer X {A_TYPE x[];}; +layout (binding = 1) readonly buffer G {A_TYPE grad[];}; +layout (binding = 2) buffer GM {A_TYPE gradm[];}; +layout (binding = 3) buffer GV {A_TYPE gradv[];}; +layout (binding = 4) readonly buffer P {float params[7];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float alpha = params[0]; + const float beta1 = params[1]; + const float beta2 = params[2]; + const float eps = params[3]; + const float wd = params[4]; + const float beta1h = params[5]; + const float beta2h = params[6]; + + const float gi = grad[i]; + const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1); + const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2); + + gradm[i] = gmi; + gradv[i] = gvi; + + const float mh = gmi*beta1h; + const float vh = sqrt(gvi*beta2h) + eps; + + x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp new file mode 100644 index 0000000000000..d86279934f176 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + // Destination multi-index (inlined dst_idx) + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; + + // Accumulate from sources + A_TYPE acc = A_TYPE(0); + for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) { + for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) { + for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) { + for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) { + acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00]; + } + } + } + } + + data_d[get_doffset() + d_idx] = D_TYPE(acc); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp index 574b51ca55389..38075b75557f9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp @@ -25,6 +25,10 @@ layout (push_constant) uniform parameter { float corr_dims[2]; float theta_scale; uint has_ff; + uint ne02; + uint s1; + uint s2; + int sections[4]; } p; float rope_yarn_ramp(const float low, const float high, const uint i0) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp new file mode 100644 index 0000000000000..4f5b1a0ecaf5d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -0,0 +1,60 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + if (i0 >= p.n_dims) { + const uint i = row_dst*ne0 + i0; + + data_d[i + 0] = data_a[i + 0]; + data_d[i + 1] = data_a[i + 1]; + + return; + } + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= p.sections[0] && sector < sec_w) { + theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims/2]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 83b46b69b2a7f..db775c456cae8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -3,15 +3,18 @@ #include "rope_head.comp" void main() { - const uint col = gl_GlobalInvocationID.y * 2; - const uint row = gl_GlobalInvocationID.x; + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; - if (col >= p.ncols) { + if (i0 >= ne0) { return; } - if (col >= p.n_dims) { - const uint i = row*p.ncols + col; + const uint row_dst = gl_GlobalInvocationID.x; + + if (i0 >= p.n_dims) { + const uint i = row_dst*ne0 + i0; data_d[i + 0] = data_a[i + 0]; data_d[i + 1] = data_a[i + 1]; @@ -19,19 +22,22 @@ void main() { return; } - const uint i = row*p.ncols + col/2; - const uint i2 = row/p.p_delta_rows; + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; - const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); - const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - const float x0 = float(data_a[i + 0]); - const float x1 = float(data_a[i + p.n_dims/2]); + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims/2]); - data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index e416ad9389706..4ad35e549d77f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -3,15 +3,18 @@ #include "rope_head.comp" void main() { - const uint col = gl_GlobalInvocationID.y * 2; - const uint row = gl_GlobalInvocationID.x; + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; - if (col >= p.ncols) { + if (i0 >= ne0) { return; } - if (col >= p.n_dims) { - const uint i = row*p.ncols + col; + const uint row_dst = gl_GlobalInvocationID.x; + + if (i0 >= p.n_dims) { + const uint i = row_dst*ne0 + i0; data_d[i + 0] = data_a[i + 0]; data_d[i + 1] = data_a[i + 1]; @@ -19,19 +22,22 @@ void main() { return; } - const uint i = row*p.ncols + col; - const uint i2 = row/p.p_delta_rows; + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; - const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); - const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - const float x0 = float(data_a[i + 0]); - const float x1 = float(data_a[i + 1]); + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + 1]); - data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp new file mode 100644 index 0000000000000..cedacc4d14439 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -0,0 +1,47 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + const int sect_dims = p.sections[0] + p.sections[1]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + const uint p0 = sector; + theta_base = data_pos[channel_x]*pow(p.theta_scale, p0); + } + else if (sector >= p.sections[0] && sector < sec_w) { + const uint p0 = sector - p.sections[0]; + theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0); + } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp new file mode 100644 index 0000000000000..72353cc3296ed --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp @@ -0,0 +1,29 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index db643a54c8e78..dfa16cda516bd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -294,6 +294,187 @@ struct block_q6_K_packed16 // IQuants +#define QUANT_K_IQ1_S 256 +#define QUANT_R_IQ1_S 1 + +struct block_iq1_s { + float16_t d; + uint8_t qs[QUANT_K_IQ1_S/8]; + uint16_t qh[QUANT_K_IQ1_S/32]; +}; + +#define QUANT_K_IQ1_M 256 +#define QUANT_R_IQ1_M 1 + +struct block_iq1_m { + uint8_t qs[QUANT_K_IQ1_M/8]; + uint8_t qh[QUANT_K_IQ1_M/16]; + uint16_t scales[QUANT_K_IQ1_M/64]; +}; + +#if defined(DATA_A_IQ1_S) +#define QUANT_K QUANT_K_IQ1_S +#define QUANT_R QUANT_R_IQ1_S +#define A_TYPE block_iq1_s +#endif + +#if defined(DATA_A_IQ1_M) +#define QUANT_K QUANT_K_IQ1_M +#define QUANT_R QUANT_R_IQ1_M +#define A_TYPE block_iq1_m +#endif + +#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f + +// Packed IQ1S grid where every 2 vec8 are encoded on 32 bits (2 bits per coordinate). +const uint[1024] iq1s_grid_const = { + 0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01, + 0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4, + 0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41, + 0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f, + 0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334, + 0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f, + 0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040, + 0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f, + 0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5, + 0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3, + 0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff, + 0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570, + 0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f, + 0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf, + 0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f, + 0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07, + 0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc, + 0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374, + 0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0, + 0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001, + 0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043, + 0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc, + 0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117, + 0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f, + 0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5, + 0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474, + 0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d, + 0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd, + 0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50, + 0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10, + 0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30, + 0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1, + 0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c, + 0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074, + 0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134, + 0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7, + 0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3, + 0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450, + 0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577, + 0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c, + 0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5, + 0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c, + 0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00, + 0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300, + 0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc, + 0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034, + 0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077, + 0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5, + 0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117, + 0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f, + 0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5, + 0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404, + 0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1, + 0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd, + 0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71, + 0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7, + 0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00, + 0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44, + 0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00, + 0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0, + 0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303, + 0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343, + 0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd, + 0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031, + 0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011, + 0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c, + 0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4, + 0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c, + 0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174, + 0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7, + 0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d, + 0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4, + 0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c, + 0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7, + 0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510, + 0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33, + 0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4, + 0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73, + 0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f, + 0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337, + 0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343, + 0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030, + 0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075, + 0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4, + 0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170, + 0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705, + 0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c, + 0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c, + 0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514, + 0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c, + 0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3, + 0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70, + 0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03, + 0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c, + 0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c, + 0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074, + 0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104, + 0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7, + 0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757, + 0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c, + 0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c, + 0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4, + 0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc, + 0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03, + 0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc, + 0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54, + 0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f, + 0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf, + 0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c, + 0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c, + 0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4, + 0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174, + 0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700, + 0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7, + 0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d, + 0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531, + 0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf, + 0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57, + 0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13, + 0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01, + 0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f, + 0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7, + 0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074, + 0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107, + 0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd, + 0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0, + 0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7, + 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 +}; + +shared uint16_t iq1s_grid[2048]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq1s_grid_const.length(); i += wgsize.x) { + u16vec2 g = unpack16(iq1s_grid_const[i]); + iq1s_grid[2*i+0] = g.x; + iq1s_grid[2*i+1] = g.y; + } + barrier(); +} +#endif + #define QUANT_K_IQ2_XXS 256 #define QUANT_R_IQ2_XXS 1 @@ -380,6 +561,7 @@ const uvec2[256] iq2xxs_grid_const = { shared uvec2 iq2xxs_grid[256]; +#define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync @@ -547,6 +729,7 @@ const uvec2 iq2xs_grid_const[512] = { shared uvec2 iq2xs_grid[512]; +#define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync @@ -836,6 +1019,7 @@ const uvec2 iq2s_grid_const[1024] = { shared uvec2 iq2s_grid[1024]; +#define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync @@ -904,6 +1088,7 @@ const uint32_t iq3xxs_grid_const[256] = { shared uint32_t iq3xxs_grid[256]; +#define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync @@ -1011,6 +1196,7 @@ const uint32_t iq3s_grid_const[512] = { shared uint32_t iq3s_grid[512]; +#define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync @@ -1073,6 +1259,7 @@ const int8_t kvalues_iq4nl_const[16] = { shared FLOAT_TYPE kvalues_iq4nl[16]; +#define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 77e7e1148b49d..3128c3d507a61 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -55,6 +55,8 @@ const std::vector type_names = { "q4_k", "q5_k", "q6_k", + "iq1_s", + "iq1_m", "iq2_xxs", "iq2_xs", "iq2_s", @@ -182,6 +184,13 @@ std::string to_uppercase(const std::string& input) { return result; } +bool string_starts_with(const std::string& str, const std::string& prefix) { + if (prefix.size() > str.size()) { + return false; + } + return std::equal(prefix.begin(), prefix.end(), str.begin()); +} + bool string_ends_with(const std::string& str, const std::string& suffix) { if (suffix.size() > str.size()) { return false; @@ -387,7 +396,7 @@ void process_shaders() { for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); - std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); @@ -434,6 +443,8 @@ void process_shaders() { string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); @@ -443,6 +454,7 @@ void process_shaders() { string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -482,9 +494,19 @@ void process_shaders() { string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); @@ -496,6 +518,8 @@ void process_shaders() { string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + for (auto &c : compiles) { c.wait(); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3b48615421187..e9f3420c294b8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1379,7 +1379,7 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso (t0->nb[3] == t1->nb[3]); } -// check if t1 can be represented as a repeatition of t0 +// check if t1 can be represented as a repetition of t0 bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); diff --git a/gguf-py/README.md b/gguf-py/README.md index 2e513633d1c5a..dd4ab7bde763a 100644 --- a/gguf-py/README.md +++ b/gguf-py/README.md @@ -1,9 +1,9 @@ ## gguf -This is a Python package for writing binary files in the [GGUF](https://github.com/ggerganov/ggml/pull/302) +This is a Python package for writing binary files in the [GGUF](https://github.com/ggml-org/ggml/pull/302) (GGML Universal File) format. -See [convert_hf_to_gguf.py](https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py) +See [convert_hf_to_gguf.py](https://github.com/ggml-org/llama.cpp/blob/master/convert_hf_to_gguf.py) as an example for its usage. ## Installation @@ -13,17 +13,17 @@ pip install gguf ## API Examples/Simple Tools -[examples/writer.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model. +[examples/writer.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model. -[examples/reader.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/examples/reader.py) — Extracts and displays key-value pairs and tensor details from a GGUF file in a readable format. +[examples/reader.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/examples/reader.py) — Extracts and displays key-value pairs and tensor details from a GGUF file in a readable format. -[gguf/scripts/gguf_dump.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_dump.py) — Dumps a GGUF file's metadata to the console. +[gguf/scripts/gguf_dump.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_dump.py) — Dumps a GGUF file's metadata to the console. -[gguf/scripts/gguf_set_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_set_metadata.py) — Allows changing simple metadata values in a GGUF file by key. +[gguf/scripts/gguf_set_metadata.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_set_metadata.py) — Allows changing simple metadata values in a GGUF file by key. -[gguf/scripts/gguf_convert_endian.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_convert_endian.py) — Allows converting the endianness of GGUF files. +[gguf/scripts/gguf_convert_endian.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_convert_endian.py) — Allows converting the endianness of GGUF files. -[gguf/scripts/gguf_new_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values. +[gguf/scripts/gguf_new_metadata.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values. ## Development Maintainers who participate in development of this package are advised to install it in editable mode: diff --git a/gguf-py/gguf/scripts/gguf_dump.py b/gguf-py/gguf/scripts/gguf_dump.py index f95b4fd4827c6..20f23d729f4b9 100755 --- a/gguf-py/gguf/scripts/gguf_dump.py +++ b/gguf-py/gguf/scripts/gguf_dump.py @@ -181,7 +181,7 @@ def element_count_rounded_notation(count: int) -> str: def translate_tensor_name(name): words = name.split(".") - # Source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#standardized-tensor-names + # Source: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#standardized-tensor-names abbreviation_dictionary = { 'token_embd': 'Token embedding', 'pos_embd': 'Position embedding', diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 40d59b75ee04e..ae92d786a4068 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -47,7 +47,7 @@ def size_label(total_params: int, shared_params: int, expert_params: int, expert def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str: - # Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention + # Reference: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#gguf-naming-convention if base_name is not None: name = base_name.strip().replace(' ', '-').replace('/', '-') diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index f2645f92101db..2ef7d14ab15c0 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -127,7 +127,7 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool: self.merges = merges elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str): # New format since transformers 4.45 to support spaces in merges - # ref: https://github.com/ggerganov/llama.cpp/issues/9692 + # ref: https://github.com/ggml-org/llama.cpp/issues/9692 # TODO: internally store as the new format instead of converting to old if any(' ' in s for pair in merges for s in pair): logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}') diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 78c6baa64a365..b4a47333dd15f 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -9,7 +9,7 @@ packages = [ ] readme = "README.md" homepage = "https://ggml.ai" -repository = "https://github.com/ggerganov/llama.cpp" +repository = "https://github.com/ggml-org/llama.cpp" keywords = ["ggml", "gguf", "llama.cpp"] classifiers = [ "Programming Language :: Python :: 3", diff --git a/grammars/README.md b/grammars/README.md index 9769540919f98..935213f5c1849 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -98,7 +98,7 @@ This guide provides a brief overview. Check out the GBNF files in this directory ## Troubleshooting -Grammars currently have performance gotchas (see https://github.com/ggerganov/llama.cpp/issues/4218). +Grammars currently have performance gotchas (see https://github.com/ggml-org/llama.cpp/issues/4218). ### Efficient optional repetitions @@ -126,7 +126,7 @@ You can use GBNF grammars: - in CLI, with [examples/json_schema_to_grammar.py](../examples/json_schema_to_grammar.py) - in JavaScript with [json-schema-to-grammar.mjs](../examples/server/public_legacy/json-schema-to-grammar.mjs) (this is used by the [server](../examples/server)'s Web UI) -Take a look at [tests](../tests/test-json-schema-to-grammar.cpp) to see which features are likely supported (you'll also find usage examples in https://github.com/ggerganov/llama.cpp/pull/5978, https://github.com/ggerganov/llama.cpp/pull/6659 & https://github.com/ggerganov/llama.cpp/pull/6555). +Take a look at [tests](../tests/test-json-schema-to-grammar.cpp) to see which features are likely supported (you'll also find usage examples in https://github.com/ggml-org/llama.cpp/pull/5978, https://github.com/ggml-org/llama.cpp/pull/6659 & https://github.com/ggml-org/llama.cpp/pull/6555). ```bash llama-cli \ @@ -185,10 +185,10 @@ Here is also a list of known limitations (contributions welcome): - `additionalProperties` defaults to `false` (produces faster grammars + reduces hallucinations). - `"additionalProperties": true` may produce keys that contain unescaped newlines. - Unsupported features are skipped silently. It is currently advised to use the command-line Python converter (see above) to see any warnings, and to inspect the resulting grammar / test it w/ [llama-gbnf-validator](../examples/gbnf-validator/gbnf-validator.cpp). -- Can't mix `properties` w/ `anyOf` / `oneOf` in the same type (https://github.com/ggerganov/llama.cpp/issues/7703) +- Can't mix `properties` w/ `anyOf` / `oneOf` in the same type (https://github.com/ggml-org/llama.cpp/issues/7703) - [prefixItems](https://json-schema.org/draft/2020-12/json-schema-core#name-prefixitems) is broken (but [items](https://json-schema.org/draft/2020-12/json-schema-core#name-items) works) - `minimum`, `exclusiveMinimum`, `maximum`, `exclusiveMaximum`: only supported for `"type": "integer"` for now, not `number` -- Nested `$ref`s are broken (https://github.com/ggerganov/llama.cpp/issues/8073) +- Nested `$ref`s are broken (https://github.com/ggml-org/llama.cpp/issues/8073) - [pattern](https://json-schema.org/draft/2020-12/json-schema-validation#name-pattern)s must start with `^` and end with `$` - Remote `$ref`s not supported in the C++ version (Python & JavaScript versions fetch https refs) - `string` [formats](https://json-schema.org/draft/2020-12/json-schema-validation#name-defined-formats) lack `uri`, `email` diff --git a/include/llama.h b/include/llama.h index f0191e54123bc..3750cd9e335ac 100644 --- a/include/llama.h +++ b/include/llama.h @@ -214,7 +214,7 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported }; - // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979) + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token @@ -308,7 +308,7 @@ extern "C" { }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations - // https://github.com/ggerganov/llama.cpp/pull/7544 + // https://github.com/ggml-org/llama.cpp/pull/7544 struct llama_context_params { uint32_t n_ctx; // text context, 0 = from model uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode @@ -321,7 +321,7 @@ extern "C" { enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings - // ref: https://github.com/ggerganov/llama.cpp/pull/2054 + // ref: https://github.com/ggml-org/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model @@ -386,7 +386,7 @@ extern "C" { struct llama_adapter_lora; // Helpers for getting default parameters - // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172) + // TODO: update API to start accepting pointers to params structs (https://github.com/ggml-org/llama.cpp/discussions/9172) LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void); @@ -1100,7 +1100,7 @@ extern "C" { /// Apply chat template. Inspired by hf apply_chat_template() on python. /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" - /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template + /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. /// @param chat Pointer to a list of multiple llama_chat_message /// @param n_msg Number of llama_chat_message in this chat @@ -1209,7 +1209,7 @@ extern "C" { /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), - "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)"); + "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)"); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); @@ -1217,7 +1217,7 @@ extern "C" { /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep); - /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + /// @details Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841 LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. @@ -1232,6 +1232,9 @@ extern "C" { /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); + /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 + LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -1260,7 +1263,7 @@ extern "C" { const char * grammar_str, const char * grammar_root); - /// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639 + /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future. /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( diff --git a/models/templates/README.md b/models/templates/README.md new file mode 100644 index 0000000000000..72c30d1e1e08e --- /dev/null +++ b/models/templates/README.md @@ -0,0 +1,22 @@ +These templates can be updated with the following commands: + +```bash +./scripts/get_chat_template.py CohereForAI/c4ai-command-r-plus tool_use > models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja +./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 default > models/templates/CohereForAI-c4ai-command-r7b-12-2024-default.jinja +./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 rag > models/templates/CohereForAI-c4ai-command-r7b-12-2024-rag.jinja +./scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use > models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja +./scripts/get_chat_template.py deepseek-ai/DeepSeek-R1-Distill-Llama-8B > models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja +./scripts/get_chat_template.py deepseek-ai/DeepSeek-R1-Distill-Qwen-32B > models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja +./scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 > models/templates/fireworks-ai-llama-3-firefunction-v2.jinja +./scripts/get_chat_template.py google/gemma-2-2b-it > models/templates/google-gemma-2-2b-it.jinja +./scripts/get_chat_template.py meetkai/functionary-medium-v3. > models/templates/meetkai-functionary-medium-v3.jinja +./scripts/get_chat_template.py meetkai/functionary-medium-v3.2 > models/templates/meetkai-functionary-medium-v3.2.jinja +./scripts/get_chat_template.py meta-llama/Llama-3.1-8B-Instruct > models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja +./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct > models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja +./scripts/get_chat_template.py meta-llama/Llama-3.3-70B-Instruct > models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja +./scripts/get_chat_template.py microsoft/Phi-3.5-mini-instruct > models/templates/microsoft-Phi-3.5-mini-instruct.jinja +./scripts/get_chat_template.py mistralai/Mistral-Nemo-Instruct-2407 > models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja +./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja +./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja +./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja +``` \ No newline at end of file diff --git a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja index 02a1c3bce33f4..c2066bd7391c2 100644 --- a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja +++ b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja @@ -1 +1 @@ -{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %} \ No newline at end of file +{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>\n'}}{% endif %} \ No newline at end of file diff --git a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja index 2ebfe7c1e32ab..c2066bd7391c2 100644 --- a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja +++ b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja @@ -1,56 +1 @@ -{% if not add_generation_prompt is defined %} -{% set add_generation_prompt = false %} -{% endif %} -{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %} -{%- for message in messages %} -{%- if message['role'] == 'system' %} -{% set ns.system_prompt = message['content'] %} -{%- endif %} -{%- endfor %} -{{bos_token}} -{{ns.system_prompt}} -{%- for message in messages %} -{%- if message['role'] == 'user' %} -{%- set ns.is_tool = false -%} -{{'<|User|>' + message['content']}} -{%- endif %} -{%- if message['role'] == 'assistant' and message['content'] is none %} -{%- set ns.is_tool = false -%} -{%- for tool in message['tool_calls']%} -{%- if not ns.is_first %} -{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} -{%- set ns.is_first = true -%} -{%- else %} -{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} -{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} -{%- endif %} -{%- endfor %} -{%- endif %} -{%- if message['role'] == 'assistant' and message['content'] is not none %} -{%- if ns.is_tool %} -{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}} -{%- set ns.is_tool = false -%} -{%- else %} -{% set content = message['content'] %} -{% if '' in content %} -{% set content = content.split('')[-1] %} -{% endif %} -{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}} -{%- endif %} -{%- endif %} -{%- if message['role'] == 'tool' %} -{%- set ns.is_tool = true -%} -{%- if ns.is_output_first %} -{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} -{%- set ns.is_output_first = false %} -{%- else %} -{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} -{%- endif %} -{%- endif %} -{%- endfor -%} -{% if ns.is_tool %} -{{'<|tool▁outputs▁end|>'}} -{% endif %} -{% if add_generation_prompt and not ns.is_tool %} -{{'<|Assistant|>'}} -{% endif %} \ No newline at end of file +{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>\n'}}{% endif %} \ No newline at end of file diff --git a/models/templates/llama-cpp-deepseek-r1.jinja b/models/templates/llama-cpp-deepseek-r1.jinja new file mode 100644 index 0000000000000..fcb1732eb8fe7 --- /dev/null +++ b/models/templates/llama-cpp-deepseek-r1.jinja @@ -0,0 +1,76 @@ +{%- if not add_generation_prompt is defined -%} + {%- set add_generation_prompt = false -%} +{%- endif -%} +{%- set ns = namespace(is_first=false, is_tool_outputs=false, is_output_first=true, system_prompt='') -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- set ns.system_prompt = message['content'] -%} + {%- endif -%} +{%- endfor -%} +{{bos_token}} +{%- if tools %} +You can call any of the following function tools to satisfy the user's requests: {{tools | map(attribute='function') | tojson(indent=2)}} + +Example function tool call syntax: + +<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>example_function_name +```json +{ + "arg1": "some_value" + ... +} +``` +<|tool▁call▁end|><|tool▁calls▁end|> + +{% endif -%} +{{ns.system_prompt}} +{%- macro flush_tool_outputs() -%} + {%- if ns.is_tool_outputs -%} + {{- '<|tool▁outputs▁end|><|end▁of▁sentence|>' -}} + {%- set ns.is_tool_outputs = false -%} + {%- endif -%} +{%- endmacro -%} +{{- flush_tool_outputs() -}} +{%- for message in messages -%} + {%- if message['role'] != 'tool' -%} + {{- flush_tool_outputs() -}} + {%- endif -%} + {%- if message['role'] == 'user' -%} + {{- '<|User|>' + message['content'] + '<|end▁of▁sentence|>' -}} + {%- endif -%} + {%- if message['role'] == 'assistant' and message['content'] is none -%} + {{- '<|Assistant|><|tool▁calls▁begin|>' -}} + {%- set ns.is_first = true -%} + {%- for tc in message['tool_calls'] -%} + {%- if ns.is_first -%} + {%- set ns.is_first = false -%} + {%- else -%} + {{- '\n' -}} + {%- endif -%} + {%- set tool_name = tc['function']['name'] -%} + {%- set tool_args = tc['function']['arguments'] -%} + {{- '<|tool▁call▁begin|>' + tc['type'] + '<|tool▁sep|>' + tool_name + '\n' + '```json' + '\n' + tool_args + '\n' + '```' + '<|tool▁call▁end|>' -}} + {%- endfor -%} + {{- '<|tool▁calls▁end|><|end▁of▁sentence|>' -}} + {%- endif -%} + {%- if message['role'] == 'assistant' and message['content'] is not none -%} + {{- flush_tool_outputs() -}} + {%- set content = message['content'] -%} + {%- if '' in content -%} + {%- set content = content.split('')[-1] -%} + {%- endif -%} + {{- '<|Assistant|>' + content + '<|end▁of▁sentence|>' -}} + {%- endif -%} + {%- if message['role'] == 'tool' -%} + {%- set ns.is_tool_outputs = true -%} + {%- if ns.is_output_first -%} + {{- '<|tool▁outputs▁begin|>' -}} + {%- set ns.is_output_first = false -%} + {%- endif -%} + {{- '\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>' -}} + {%- endif -%} +{%- endfor -%} +{{- flush_tool_outputs() -}} +{%- if add_generation_prompt and not ns.is_tool_outputs -%} + {{- '<|Assistant|>\n' -}} +{%- endif -%} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 84e71de6def38..ed62264ba62db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "Scripts that ship with llama.cpp" authors = ["GGML "] readme = "README.md" homepage = "https://ggml.ai" -repository = "https://github.com/ggerganov/llama.cpp" +repository = "https://github.com/ggml-org/llama.cpp" keywords = ["ggml", "gguf", "llama.cpp"] packages = [{ include = "*.py", from = "." }] classifiers = [ diff --git a/scripts/check-requirements.sh b/scripts/check-requirements.sh index d3bbded130daf..4c3b05f68b7ba 100755 --- a/scripts/check-requirements.sh +++ b/scripts/check-requirements.sh @@ -170,7 +170,7 @@ check_convert_script examples/convert_legacy_llama.py for py in convert_*.py; do # skip convert_hf_to_gguf_update.py # TODO: the check is failing for some reason: - # https://github.com/ggerganov/llama.cpp/actions/runs/8875330981/job/24364557177?pr=6920 + # https://github.com/ggml-org/llama.cpp/actions/runs/8875330981/job/24364557177?pr=6920 [[ $py == convert_hf_to_gguf_update.py ]] && continue check_convert_script "$py" diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index 239c458d8b958..6205fe88d7239 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -124,9 +124,22 @@ connection = sqlite3.connect(input_file) cursor = connection.cursor() + +build_len_min: int = cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0] +build_len_max: int = cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0] + +if build_len_min != build_len_max: + logger.warning(f"{input_file} contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. " + "Try purging the the database of old commits.") + cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {build_len_min});") + +build_len: int = build_len_min + builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall() +builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str] -commit_short_len = len(builds[0][0]) +if not builds: + raise RuntimeError(f"{input_file} does not contain any builds.") try: repo = git.Repo(".", search_parent_directories=True) @@ -140,11 +153,11 @@ def find_parent_in_data(commit: git.Commit): seen_hexsha8 = set() while heap: depth, current_commit = heapq.heappop(heap) - current_hexsha8 = commit.hexsha[:commit_short_len] - if (current_hexsha8,) in builds: + current_hexsha8 = commit.hexsha[:build_len] + if current_hexsha8 in builds: return current_hexsha8 for parent in commit.parents: - parent_hexsha8 = parent.hexsha[:commit_short_len] + parent_hexsha8 = parent.hexsha[:build_len] if parent_hexsha8 not in seen_hexsha8: seen_hexsha8.add(parent_hexsha8) heapq.heappush(heap, (depth + 1, parent)) @@ -158,40 +171,40 @@ def get_all_parent_hexsha8s(commit: git.Commit): while unvisited: current_commit = unvisited.pop(0) - visited.append(current_commit.hexsha[:commit_short_len]) + visited.append(current_commit.hexsha[:build_len]) for parent in current_commit.parents: - if parent.hexsha[:commit_short_len] not in visited: + if parent.hexsha[:build_len] not in visited: unvisited.append(parent) return visited -def get_commit_name(hexsha8): +def get_commit_name(hexsha8: str): """Helper function to find a human-readable name for a commit if possible.""" if repo is None: return hexsha8 for h in repo.heads: - if h.commit.hexsha[:commit_short_len] == hexsha8: + if h.commit.hexsha[:build_len] == hexsha8: return h.name for t in repo.tags: - if t.commit.hexsha[:commit_short_len] == hexsha8: + if t.commit.hexsha[:build_len] == hexsha8: return t.name return hexsha8 -def get_commit_hexsha8(name): +def get_commit_hexsha8(name: str): """Helper function to search for a commit given a human-readable name.""" if repo is None: return None for h in repo.heads: if h.name == name: - return h.commit.hexsha[:commit_short_len] + return h.commit.hexsha[:build_len] for t in repo.tags: if t.name == name: - return t.commit.hexsha[:commit_short_len] + return t.commit.hexsha[:build_len] for c in repo.iter_commits("--all"): - if c.hexsha[:commit_short_len] == name[:commit_short_len]: - return c.hexsha[:commit_short_len] + if c.hexsha[:build_len] == name[:build_len]: + return c.hexsha[:build_len] return None @@ -199,7 +212,7 @@ def get_commit_hexsha8(name): # If the user specified a baseline, try to find a commit for it: if known_args.baseline is not None: - if (known_args.baseline,) in builds: + if known_args.baseline in builds: hexsha8_baseline = known_args.baseline if hexsha8_baseline is None: hexsha8_baseline = get_commit_hexsha8(known_args.baseline) @@ -228,7 +241,7 @@ def get_commit_hexsha8(name): # If the user has specified a compare value, try to find a corresponding commit: if known_args.compare is not None: - if (known_args.compare,) in builds: + if known_args.compare in builds: hexsha8_compare = known_args.compare if hexsha8_compare is None: hexsha8_compare = get_commit_hexsha8(known_args.compare) diff --git a/scripts/get_chat_template.py b/scripts/get_chat_template.py old mode 100644 new mode 100755 index e8982d11ad7ba..d8143e4005dec --- a/scripts/get_chat_template.py +++ b/scripts/get_chat_template.py @@ -7,9 +7,8 @@ ./scripts/get_chat_template.py model_id [variant] Examples: - ./scripts/get_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct - ./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use - ./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct + ./scripts/get_chat_template.py CohereForAI/c4ai-command-r-plus tool_use + ./scripts/get_chat_template.py microsoft/Phi-3.5-mini-instruct ''' import json diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 255109ab71e0c..f99625dca0396 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -08b538031f7f944e84f472483ef5d26bf5190ead +98a61a0d0b43cba06c3ac1c603813639552a0701 diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 9b518d1ac64a5..46e27a96ed728 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1186,7 +1186,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token return; } } - LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str()); + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str()); return; } } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 252d54d4c9f45..b143d834cfabe 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -116,7 +116,7 @@ struct llama_grammar { llama_partial_utf8 partial_utf8; // lazy grammars wait for trigger words or tokens before constraining the sampling. - // we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens. + // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens. // (useful e.g. for tool_choice=required) bool lazy = false; bool awaiting_trigger = false; // Initialized to true for lazy grammars only diff --git a/src/llama-impl.h b/src/llama-impl.h index 12d1fb0828d3a..02b1d07f8400d 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -6,13 +6,13 @@ #include #ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# if defined(__MINGW32__) && !defined(__clang__) +# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# else +# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +# endif #else -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif -#else -#define LLAMA_ATTRIBUTE_FORMAT(...) +# define LLAMA_ATTRIBUTE_FORMAT(...) #endif // diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 3bb07ca9da431..3ea9abfce59be 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -57,7 +57,7 @@ struct llama_kv_cache { bool can_shift = false; // Note: The value of head isn't only used to optimize searching - // for a free KV slot. llama_decode_internal also uses it, so it + // for a free KV slot. llama_decode_impl also uses it, so it // cannot be freely changed after a slot has been allocated. uint32_t head = 0; uint32_t size = 0; diff --git a/src/llama-mmap.h b/src/llama-mmap.h index 1da9ecb6b9812..4e5aec3f440d7 100644 --- a/src/llama-mmap.h +++ b/src/llama-mmap.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 990b6129746de..f40bf2db83a80 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1698,6 +1698,73 @@ struct llama_sampler * llama_sampler_init_penalties( ); } +// top-n-sigma + +struct llama_sampler_top_n_sigma { + const float n; +}; + +static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { + return "top-n-sigma"; +} + +static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; + + // find max logit and calculate mean + float max = cur_p->data[0].logit; + float logits_sum = 0; + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit > max) { + max = cur_p->data[i].logit; + } + logits_sum += cur_p->data[i].logit; + } + float mean = logits_sum/cur_p->size; + + // calculate standard deviation + float acc = 0; + for (size_t i = 0; i < cur_p->size; ++i) { + acc += pow(cur_p->data[i].logit - mean, 2); + } + float std = sqrt(acc/cur_p->size); + + //apply mask + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit < max - (ctx->n * std)) { + cur_p->data[i].logit = -INFINITY; + } + } + llama_sampler_softmax_impl(cur_p); +} + +static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx; + return llama_sampler_init_top_n_sigma(ctx->n); +} + +static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_n_sigma *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_top_n_sigma_i = { + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, +}; + +struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_top_n_sigma_i, + /* .ctx = */ new llama_sampler_top_n_sigma { + /* .n = */ n, + } + ); +} + // DRY struct llama_sampler_dry { diff --git a/src/unicode.cpp b/src/unicode.cpp index a32ae6d0824f2..e63bb4ab085d6 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -708,7 +708,7 @@ std::vector unicode_regex_split(const std::string & text, const std const auto cpts = unicode_cpts_from_utf8(text); // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte - // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935 + // ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935 std::string text_collapsed; if (need_collapse) { // collapse all unicode categories diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1bfd41254aa99..c9ab6c135f325 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1254,7 +1254,7 @@ struct test_count_equal : public test_case { ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_name(b, "b"); - ggml_tensor * b_argmax = ggml_argmax(ctx, a); + ggml_tensor * b_argmax = ggml_argmax(ctx, b); ggml_set_name(b_argmax, "b_argmax"); ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax); @@ -1511,6 +1511,7 @@ struct test_cont : public test_case { }; // GGML_OP_ADD +// GGML_OP_SUB // GGML_OP_MUL // GGML_OP_DIV struct test_bin_bcast : public test_case { @@ -3860,7 +3861,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); - test_cases.emplace_back(new test_count_equal()); + test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1})); + test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1})); @@ -3885,8 +3887,6 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view)); test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view)); test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view)); - test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view)); - test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view)); } test_cases.emplace_back(new test_dup(GGML_TYPE_F32)); @@ -3938,7 +3938,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7})); auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr) { - for (auto op : {ggml_add, ggml_mul, ggml_div}) { + for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) { test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr)); } }; diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index b78da2cdb48ba..2836caf6a71a3 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -24,7 +24,10 @@ static common_chat_msg msg_from_json(const json & message) { ret.content = message.at("content"); } if (message.contains("tool_plan")) { - ret.tool_plan = message.at("tool_plan"); + ret.reasoning_content = message.at("tool_plan"); + } + if (message.contains("reasoning_content")) { + ret.reasoning_content = message.at("reasoning_content"); } auto has_tool_calls = message.contains("tool_calls"); if (has_tool_calls) { @@ -105,6 +108,7 @@ static std::string dump(const json & j) { static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) { assert_equals(expected.role, actual.role); assert_equals(expected.content, actual.content); + assert_equals(expected.reasoning_content, actual.reasoning_content); assert_equals(expected.tool_calls.size(), actual.tool_calls.size()); for (size_t i = 0; i < expected.tool_calls.size(); i++) { const auto & expected_tool_call = expected.tool_calls[i]; @@ -176,13 +180,15 @@ struct delta_data { static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools, - const json & tool_choice) { + const json & tool_choice, + bool think = false) { common_chat_inputs inputs; inputs.parallel_tool_calls = true; inputs.messages = json::array(); inputs.messages.push_back(user_message); inputs.tools = tools; inputs.tool_choice = tool_choice; + inputs.extract_reasoning = think; auto params_prefix = common_chat_params_init(tmpl, inputs); inputs.messages.push_back(delta_message); @@ -192,17 +198,24 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto std::string prefix = params_prefix.prompt; std::string full = params_full.prompt; - // Check full starts with prefix - if (full.find(prefix) != 0) { - fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str()); - throw std::runtime_error("Full message does not start with prefix"); - } - if (full == prefix) { throw std::runtime_error("Full message is the same as the prefix"); } - auto delta = full.substr(prefix.size()); + size_t common_prefix_length = 0; + for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) { + if (prefix[i] != full[i]) { + break; + } + if (prefix[i] == '<') { + // DeepSeek R1's template (as of 20250209) adds a trailing if add_generation_prompt, + // but it removes thinking tags for past messages. + // The prefix and full strings diverge at vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. + continue; + } + common_prefix_length = i + 1; + } + auto delta = full.substr(common_prefix_length); // Strip end tokens for (const auto & end_token : end_tokens) { @@ -223,7 +236,9 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto */ static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, const json & test_message, const json & tools = {}, const std::string & expected_delta = "", - bool expect_grammar_triggered = true) { + bool expect_grammar_triggered = true, + bool test_grammar_if_triggered = true, + bool think = false) { common_chat_msg expected_msg = msg_from_json(test_message); auto user_message = json{ @@ -232,7 +247,7 @@ static void test_template(const common_chat_template & tmpl, const std::vectorI'm thinkingHello, world!\nWhat's up?" }, + }; + json message_assist_thoughts_unparsed_r7b { + { "role", "assistant" }, + { "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" }, + }; + json message_assist_thoughts { { "role", "assistant" }, { "content", "Hello, world!\nWhat's up?" }, + { "reasoning_content", "I'm thinking" }, }; json tool_calls = json::array({{ { "type", "function" }, { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } }, }}); - json tool_call_message { + json message_assist_call { { "role", "assistant"}, { "content", {}}, { "tool_calls", { @@ -305,7 +337,34 @@ static void test_template_output_parsers() { }, }}, }; - json tool_call_message_with_id { + json message_assist_call_thoughts = { + { "role", "assistant" }, + { "content", nullptr }, + { "reasoning_content", "I'm\nthinking" }, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + }, + }}, + }; + json message_assist_call_thoughts_unparsed = { + { "role", "assistant" }, + { "content", "I'm\nthinking" }, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + }, + }}, + }; + json message_assist_call_id { { "role", "assistant"}, { "content", {}}, { "tool_calls", { @@ -322,10 +381,9 @@ static void test_template_output_parsers() { { "content", {} }, { "tool_calls", tool_calls } }; - json tool_call_plan_message_with_idx { + json message_assist_call_idx { { "role", "assistant"}, { "content", {}}, - { "tool_plan", "I'm not so sure"}, { "tool_calls", { { { "type", "function" }, @@ -341,8 +399,10 @@ static void test_template_output_parsers() { { "content", {} }, { "tool_calls", tool_calls } }; + json message_assist_call_tool_plan_idx = message_assist_call_idx; + message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking"; - auto python_tool_call_message = json{ + auto python_message_assist_call = json{ { "role", "assistant" }, { "content", {} }, { "tool_calls", json{ { @@ -357,7 +417,7 @@ static void test_template_output_parsers() { } }, } } } }; - auto code_interpreter_tool_call_message = json{ + auto code_interpreter_message_assist_call = json{ { "role", "assistant" }, { "content", {} }, { "tool_calls", json{ { @@ -374,17 +434,27 @@ static void test_template_output_parsers() { }; common_chat_inputs inputs_no_tools; - inputs_no_tools.messages = { - { { "role", "user" }, { "content", "Hey\nThere" } } - }; + inputs_no_tools.messages = json::array({message_user}); + inputs_no_tools.extract_reasoning = false; - common_chat_inputs inputs_tools = inputs_no_tools; - inputs_tools.tools = json::array(); - inputs_tools.tools.push_back(special_function_tool); + common_chat_inputs inputs_no_tools_think; + inputs_no_tools_think.messages = json::array({message_user}); + inputs_no_tools_think.extract_reasoning = true; - common_chat_inputs inputs_tools_builtin = inputs_no_tools; - inputs_tools_builtin.tools = json::array(); - inputs_tools_builtin.tools.push_back(python_tool); + common_chat_inputs inputs_tools; + inputs_tools.messages = json::array({message_user}); + inputs_tools.tools = json::array({special_function_tool}); + inputs_tools.extract_reasoning = false; + + common_chat_inputs inputs_tools_think; + inputs_tools_think.messages = json::array({message_user}); + inputs_tools_think.tools = json::array({special_function_tool}); + inputs_tools_think.extract_reasoning = true; + + common_chat_inputs inputs_tools_builtin; + inputs_tools_builtin.messages = json::array({message_user}); + inputs_tools_builtin.tools = json::array({python_tool}); + inputs_tools_builtin.extract_reasoning = false; { // Not supported yet @@ -395,15 +465,53 @@ static void test_template_output_parsers() { const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "", ""); std::vector end_tokens{ "<|END_OF_TURN_TOKEN|>" }; - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format); - - test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools, - "<|START_THINKING|>I'm not so sure<|END_THINKING|>" + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + + assert_msg_equals(msg_from_json(message_assist), + common_chat_parse( + "Hello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_COMMAND_R7B)); + assert_msg_equals(msg_from_json(message_assist), + common_chat_parse( + "Hello, world!\nWhat's up?<|END_RESPONSE|>", + COMMON_CHAT_FORMAT_COMMAND_R7B)); + assert_msg_equals(msg_from_json(message_assist), + common_chat_parse( + "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", + COMMON_CHAT_FORMAT_COMMAND_R7B)); + assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b), + common_chat_parse( + "<|START_THINKING|>I'm thinking<|END_THINKING|>" + "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", + COMMON_CHAT_FORMAT_COMMAND_R7B)); + assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b), + common_chat_parse( + "<|START_THINKING|>I'm thinking<|END_THINKING|>" + "Hello, world!\nWhat's up?<|END_RESPONSE|>", + COMMON_CHAT_FORMAT_COMMAND_R7B)); + + assert_msg_equals(msg_from_json(message_assist_thoughts), + common_chat_parse( + "<|START_THINKING|>I'm thinking<|END_THINKING|>" + "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", + COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING)); + + test_template(tmpl, end_tokens, message_assist_call_idx, tools, + "<|START_THINKING|><|END_THINKING|>" "<|START_ACTION|>[\n" " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" "]<|END_ACTION|>"); - test_template(tmpl, end_tokens, text_message, tools, + test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools, + "<|START_THINKING|>I'm thinking<|END_THINKING|>" + "<|START_ACTION|>[\n" + " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" + "]<|END_ACTION|>", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ true, + /* think= */ true); + test_template(tmpl, end_tokens, message_assist, tools, "<|START_RESPONSE|>Hello, world!\n" "What's up?<|END_RESPONSE|>", /* expect_grammar_triggered= */ false); @@ -423,12 +531,12 @@ static void test_template_output_parsers() { // Generic tool calls doesn't generate / parse content-only messages symmetrically. - assert_msg_equals(msg_from_json(text_message), + assert_msg_equals(msg_from_json(message_assist), common_chat_parse("{\n" " \"response\": \"Hello, world!\\nWhat's up?\"\n" "}", common_chat_params_init(tmpl, inputs_tools).format)); - test_template(tmpl, end_tokens, tool_call_message_with_id, tools, + test_template(tmpl, end_tokens, message_assist_call_id, tools, "{\n" " \"tool_calls\": [\n" " {\n" @@ -448,9 +556,9 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_template( - tmpl, end_tokens, tool_call_message_with_id, tools, + tmpl, end_tokens, message_assist_call_id, tools, "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]"); } { @@ -473,12 +581,12 @@ static void test_template_output_parsers() { inputs_tools) .format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist_call, tools, "\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" ""); - test_template(tmpl, end_tokens, python_tool_call_message, tools, + test_template(tmpl, end_tokens, python_message_assist_call, tools, "\n" "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n" ""); @@ -498,12 +606,12 @@ static void test_template_output_parsers() { inputs_tools_builtin) .format); - // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools, + // test_template(tmpl, end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, code_interpreter_message_assist_call, llama_3_1_tools, "<|python_tag|>code_interpreter.call(code=\"print('hey')\")"); - test_template(tmpl, end_tokens, python_tool_call_message, tools, + test_template(tmpl, end_tokens, python_message_assist_call, tools, "<|python_tag|>python.call(code=\"print('hey')\")"); - test_template(tmpl, end_tokens, tool_call_message, tools, + test_template(tmpl, end_tokens, message_assist_call, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { @@ -513,8 +621,8 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist_call, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { @@ -525,8 +633,8 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist_call, tools, "{\"arg1\": 1}"); } { @@ -537,12 +645,12 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format); assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, {}, + test_template(tmpl, end_tokens, message_assist, {}, "all\n" "Hello, world!\n" "What's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, + test_template(tmpl, end_tokens, message_assist_call, tools, "special_function\n" "{\"arg1\": 1}"); } @@ -553,23 +661,79 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist_call, tools, " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]"); } { + // Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt. const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "", ""); std::vector end_tokens{ "<|end▁of▁sentence|>" }; - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think), + common_chat_parse("I'm thinkingHello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_DEEPSEEK_R1)); + assert_msg_equals(msg_from_json(message_assist_thoughts), + common_chat_parse("I'm thinkingHello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); + assert_msg_equals(msg_from_json(message_assist_thoughts), + // Latest template update (ast of 20250209) adds a trailing \n if add_generation_prompt is true. + common_chat_parse("I'm thinkingHello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); + // test_template(tmpl, end_tokens, message_assist_call, tools, + // "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + // "```json\n" + // "{\"arg1\": 1}\n" + // // Look what's not here: <|tool▁calls▁end|> (also missing the <|end▁of▁sentence|>, but that is removed lazily by the test's delta logic) + // "```<|tool▁call▁end|>", + // /* expect_grammar_triggered= */ true, + // /* test_grammar_if_triggered= */ false); + } + { + // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all. + const common_chat_template tmpl(read_file("models/templates/llama-cpp-deepseek-r1.jinja"), + "", ""); + std::vector end_tokens{ "<|end▁of▁sentence|>" }; - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, tool_call_message, tools, - "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" - "```json\n" - "{\"arg1\": 1}\n" - "```<|tool▁call▁end|>"); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + + test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think), + common_chat_parse("I'm thinkingHello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_DEEPSEEK_R1)); + assert_msg_equals(msg_from_json(message_assist_thoughts), + common_chat_parse("I'm thinkingHello, world!\nWhat's up?", + COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); + + assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed), + common_chat_parse( + "I'm\nthinking\n\n" + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + "```json\n" + "{\"arg1\": 1}\n" + "```<|tool▁call▁end|><|tool▁calls▁end|>", + COMMON_CHAT_FORMAT_DEEPSEEK_R1)); + assert_msg_equals(msg_from_json(message_assist_call_thoughts), + common_chat_parse( + "I'm\nthinking\n\n" + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + "```json\n" + "{\"arg1\": 1}\n" + "```<|tool▁call▁end|><|tool▁calls▁end|>", + COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); + test_template(tmpl, end_tokens, message_assist_call, tools, + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + "```json\n" + "{\"arg1\": 1}\n" + "```<|tool▁call▁end|><|tool▁calls▁end|>"); } } @@ -586,16 +750,20 @@ int main(int argc, char ** argv) { std::cout << "|----------|--------|\n"; for (int i = 1; i < argc; i++) { - std::string path = argv[i]; - if (path.rfind(".jinja") != path.size() - 6) { - std::cerr << "Skipping non-jinja file: " << path << std::endl; - continue; + try { + std::string path = argv[i]; + if (path.rfind(".jinja") != path.size() - 6) { + std::cerr << "Skipping non-jinja file: " << path << std::endl; + continue; + } + common_chat_template tmpl(read_file(path), "", ""); + auto parts = string_split(path, "/"); + auto name = parts[parts.size() - 1]; + auto format = common_chat_format_name(common_chat_params_init(tmpl, inputs).format); + std::cout << "| " << name << " | " << format << " |\n"; + } catch (const std::exception & e) { + std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl; } - common_chat_template tmpl(read_file(path), "", ""); - auto parts = string_split(path, "/"); - auto name = parts[parts.size() - 1]; - std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format) - << " |\n"; } } else #endif diff --git a/tests/test-gguf.cpp b/tests/test-gguf.cpp index 6ed696328d71a..eaf572c666410 100644 --- a/tests/test-gguf.cpp +++ b/tests/test-gguf.cpp @@ -697,8 +697,8 @@ static std::pair test_handcrafted_file(const unsigned int seed) { #ifdef _WIN32 if (!file) { - printf("%s: failed to create tmpfile(), needs elevated privileges on Windows"); - printf("%s: skipping tests"); + printf("failed to create tmpfile(), needs elevated privileges on Windows"); + printf("skipping tests"); continue; } #else @@ -1086,8 +1086,8 @@ static std::pair test_roundtrip(ggml_backend_dev_t dev, const unsigned #ifdef _WIN32 if (!file) { - printf("%s: failed to create tmpfile(), needs elevated privileges on Windows"); - printf("%s: skipping tests"); + printf("failed to create tmpfile(), needs elevated privileges on Windows"); + printf("skipping tests"); return std::make_pair(0, 0); } #else diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 61bd6785044ea..f1f87d454d464 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -181,6 +181,17 @@ static void test_dry( tester.check(); } +static void test_top_n_sigma(const std::vector & probs, const std::vector & probs_expected, int n) { + sampler_tester tester(probs, probs_expected); + + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_top_n_sigma(n)); + tester.apply(llama_sampler_init_dist (0)); + DUMP(&tester.cur_p); + + tester.check(); +} + static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { sampler_tester tester(n_vocab); @@ -348,6 +359,10 @@ int main(void) { test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {}); test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {}); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.00f); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f); + test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f); test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);