-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(bench): Add pipeline FlashAttention-2 implementation. #23
Conversation
@microsoft-github-policy-service agree company="Microsoft" |
7771636
to
d3fccac
Compare
.vscode/settings.json
Outdated
"gotoSymbolStack.currentStackPosition": 0, | ||
"gotoSymbolStack.maxStackPosition": 0, | ||
"gotoSymbolStack.filePositionInfo": [] | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious as to why the pre-commit hooks (see: https://github.com/microsoft/TileFusion/blob/master/.pre-commit-config.yaml#L28) do not address these unseen characters, which are often caused by differences in IDEs. I have observed this issue several times. This hook is supposed to fix it automatically before filing a PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just used pre-commit run --all-files
to automatically fix the issues, but it seems that when I use Git to commit, it doesn't automatically fix all files before the pre-commit hook. I will check the reason for this issue later.
b214a27
to
29f47eb
Compare
44fba6a
to
e57fa5c
Compare
# -------------------------------------------------------------------------- | ||
|
||
cmake_minimum_required(VERSION 3.25 FATAL_ERROR) | ||
project(gemm_bench LANGUAGES C CXX CUDA) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the project name gemm_bench
should be updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops! I forgot to make the modifications, but they have been made now.
include_directories("${THIRD_PARTY_DIR}/cutlass/include") | ||
|
||
add_executable(flash_attn main.cu) | ||
target_link_libraries(flash_attn ${CUDA_CUBLAS_LIBRARIES}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is CuBLAS utilized in this code? It doesn't appear to be. Do we need to link it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. 😊
This is a basic version of the pipelined FlashAttention-2 implementation, and I would like to first merge these changes into the master branch.
The current version of FlashAttention has the following features:
async_copy
, which improves the utilization of the computational components(Tensor Core in Ampere Architecture).load_q_once
has been implemented, wherekTK == kK
. In this situation, the k dimension is not partitioned within a single SM Block, and the Q matrix only needs to be loaded once.kN
in the outer loop and once forkTN
in the inner loop to load the V matrix. The inner loop partitioning has not been implemented yet.The current implementation is not a final version; I will continue to add more features in subsequent PRs.