From 79b6fbd84aa5a06b0e9ea3a829fc116444ff64b9 Mon Sep 17 00:00:00 2001 From: Casper Date: Fri, 16 Feb 2024 16:13:11 +0100 Subject: [PATCH] Support Fused Mixtral on multi-GPU (#352) --- .github/workflows/build.yaml | 3 ++- awq/models/mixtral.py | 2 +- setup.py | 21 ++++++++++++--------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index eae16a25..ce067955 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -111,6 +111,7 @@ jobs: if ( $env:CUDA_VERSION -eq $env:PYPI_CUDA_VERSION ){ $env:PYPI_BUILD = 1 } + $env:PYPI_FORCE_TAGS = 1 python setup.py sdist bdist_wheel @@ -223,7 +224,7 @@ jobs: python --version which python - ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel + ROCM_VERSION=${{ matrix.rocm }} PYPI_FORCE_TAGS=1 python setup.py sdist bdist_wheel - name: Upload Assets uses: shogo82148/actions-upload-release-asset@v1 diff --git a/awq/models/mixtral.py b/awq/models/mixtral.py index 1b7e49dc..bc284a2a 100644 --- a/awq/models/mixtral.py +++ b/awq/models/mixtral.py @@ -127,7 +127,7 @@ def fuse_transformer(self): ) sparse_moe = module.block_sparse_moe - if isinstance(sparse_moe.experts[0].w1, WQLinear_GEMM) and torch.cuda.device_count() == 1: + if isinstance(sparse_moe.experts[0].w1, WQLinear_GEMM): fused_w1w3s = [ fuse_linears( [ diff --git a/setup.py b/setup.py index 45111f77..25851d0f 100644 --- a/setup.py +++ b/setup.py @@ -132,6 +132,18 @@ def get_kernels_whl_url( "Please install the kernels manually from https://github.com/casper-hansen/AutoAWQ_kernels" ) +force_extension = os.getenv("PYPI_FORCE_TAGS", "0") +if force_extension == "1": + # NOTE: We create an empty CUDAExtension because torch helps us with + # creating the right boilerplate to enable correct targeting of + # the autoawq-kernels package + common_setup_kwargs["ext_modules"] = [ + CUDAExtension( + name="test_kernel", + sources=[], + ) + ] + setup( packages=find_packages(), install_requires=requirements, @@ -139,14 +151,5 @@ def get_kernels_whl_url( "eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"], "dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"] }, - # NOTE: We create an empty CUDAExtension because torch helps us with - # creating the right boilerplate to enable correct targeting of - # the autoawq-kernels package - ext_modules=[ - CUDAExtension( - name="__build_artifact_for_awq_kernel_targeting", - sources=[], - ) - ], **common_setup_kwargs, )