From 1dcc5dfb0d0ef49903fdce147f445f84dae2e375 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Fri, 8 Nov 2024 16:07:19 -0500 Subject: [PATCH] Making flash attention build happen after we built and installed the required torch version (#270) --- Dockerfile.rocm | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 1f9d71e6b0c60..661f10fcd9b2b 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -73,23 +73,6 @@ ARG COMMON_WORKDIR COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb / FROM scratch AS export_rccl_0 FROM export_rccl_${BUILD_RCCL} AS export_rccl - -# ----------------------- -# flash attn build stages -FROM base AS build_flash_attn -ARG FA_BRANCH="3cea2fb" -ARG FA_REPO="https://github.com/ROCm/flash-attention.git" -RUN git clone ${FA_REPO} \ - && cd flash-attention \ - && git checkout ${FA_BRANCH} \ - && git submodule update --init \ - && GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist -FROM scratch AS export_flash_attn_1 -ARG COMMON_WORKDIR -COPY --from=build_flash_attn ${COMMON_WORKDIR}/flash-attention/dist/*.whl / -FROM scratch AS export_flash_attn_0 -FROM export_flash_attn_${BUILD_FA} AS export_flash_attn - # ----------------------- # Triton build stages FROM base AS build_triton @@ -143,6 +126,27 @@ COPY --from=build_pytorch ${COMMON_WORKDIR}/vision/dist/*.whl / FROM scratch as export_pytorch_0 from export_pytorch_${BUILD_PYTORCH} as export_pytorch +# ----------------------- +# flash attn build stages +FROM base AS build_flash_attn +ARG FA_BRANCH="3cea2fb" +ARG FA_REPO="https://github.com/ROCm/flash-attention.git" +RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \ +if ls /install/*.whl; then \ + pip uninstall -y torch torchvision \ + && pip install /install/*.whl; \ +fi +RUN git clone ${FA_REPO} \ + && cd flash-attention \ + && git checkout ${FA_BRANCH} \ + && git submodule update --init \ + && MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist +FROM scratch AS export_flash_attn_1 +ARG COMMON_WORKDIR +COPY --from=build_flash_attn ${COMMON_WORKDIR}/flash-attention/dist/*.whl / +FROM scratch AS export_flash_attn_0 +FROM export_flash_attn_${BUILD_FA} AS export_flash_attn + # ----------------------- # vLLM (and gradlib) fetch stages FROM base AS fetch_vllm_0