Skip to content

Commit

Permalink
Support Fused Mixtral on multi-GPU (#352)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Feb 16, 2024
1 parent 7405310 commit 79b6fbd
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ jobs:
if ( $env:CUDA_VERSION -eq $env:PYPI_CUDA_VERSION ){
$env:PYPI_BUILD = 1
}
$env:PYPI_FORCE_TAGS = 1
python setup.py sdist bdist_wheel
Expand Down Expand Up @@ -223,7 +224,7 @@ jobs:
python --version
which python
ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel
ROCM_VERSION=${{ matrix.rocm }} PYPI_FORCE_TAGS=1 python setup.py sdist bdist_wheel
- name: Upload Assets
uses: shogo82148/actions-upload-release-asset@v1
Expand Down
2 changes: 1 addition & 1 deletion awq/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def fuse_transformer(self):
)

sparse_moe = module.block_sparse_moe
if isinstance(sparse_moe.experts[0].w1, WQLinear_GEMM) and torch.cuda.device_count() == 1:
if isinstance(sparse_moe.experts[0].w1, WQLinear_GEMM):
fused_w1w3s = [
fuse_linears(
[
Expand Down
21 changes: 12 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,21 +132,24 @@ def get_kernels_whl_url(
"Please install the kernels manually from https://github.com/casper-hansen/AutoAWQ_kernels"
)

force_extension = os.getenv("PYPI_FORCE_TAGS", "0")
if force_extension == "1":
# NOTE: We create an empty CUDAExtension because torch helps us with
# creating the right boilerplate to enable correct targeting of
# the autoawq-kernels package
common_setup_kwargs["ext_modules"] = [
CUDAExtension(
name="test_kernel",
sources=[],
)
]

setup(
packages=find_packages(),
install_requires=requirements,
extras_require={
"eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"],
"dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"]
},
# NOTE: We create an empty CUDAExtension because torch helps us with
# creating the right boilerplate to enable correct targeting of
# the autoawq-kernels package
ext_modules=[
CUDAExtension(
name="__build_artifact_for_awq_kernel_targeting",
sources=[],
)
],
**common_setup_kwargs,
)

0 comments on commit 79b6fbd

Please sign in to comment.