This is my GPU course final project in MICS600J. The main content is my attempt to implement the attention mechanism efficiently.
In order to simplify the process, I replaced the original matrix dimensions [batch_size, nheads, seq_len, headdim] with [seq_len, headdim].
N: seq_len d: headdim
The attention mechanism is well known and I won’t go into details.
- In = N * d
- WQ WK WV = d * d
- Q K V = N * d
- P = Q * K^T = N * N
- S = SoftMax(P) = N * N
- o = S * V = N * d
- Out = S * W = N * d
- Use
Tensor core
to compute GEMM. - Use
Asynchronous transfer
to overlap computation and communication(transfer data from global memory to shared memory). Bank Conflict
free.
Use make
command to build program.
- Range for N,d: N~(32, 1024), d~(32, 2048). Please get more detail from script.
- Test on NVIDIA A100 in HKUST(GZ)-HPC Server.
Fine-tuning Llama-2-7B, when using Sparse Attention Mechanism, we found that accuracy can be improved and restored with little overhead.
- Kernel Fusion, just like Flash attention.
- Sparse Attention Mechanism, just like DFSS, make full use of
sparse tensor core
.