A GPT2-style transformer in Triton. Differentiable with forward + backward speeds matching that of a PyTorch transformer using cuBLAS kernels and flash-attention.
Supports Autograd and Automatic Mixed Precision.
Triton kernels definitions are located under tritonformer/kernels
. Autograd functions wrapping kernels are located under tritonformer
.
Kernels implemented:
-
Fused Attention
kernels/attention.py
-
Fused Attention w/ ALiBi
kernels/biased_attention.py
- MatMul:
(b,m,n) @ (n,k) -> (b,m,k)
kernels/gemm.py
- Fused MatMul + ReLU:
(b,m,n) @ (n,k) -> (b,m,k)
kernels/gemm.py
- Fused MatMul + Add:
(b,m,n) @ (n,k)-> (b,m,k) + (k,)
kernels/gemm.py
- Fused MatMul + Add + ReLU:
(b,m,n) @ (n,k)-> (b,m,k) + (k,)
kernels/gemm.py
- Batched MatMul:
(b,m,n) @ (b,n,k) -> (b,m,k)
kernels/gemm.py
- CrossEntropy
kernels/crossentropy.py
- LayerNorm
kernels/layernorm.py
- Softmax
kernels/softmax.py
Clone repo and install requirements:
git clone https://github.com/fattorib/tritonformer.git
pip install -r requirements.txt
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==2.1.0.dev20231014192330
PyTorch transformer models using Triton and cuBLAS kernels are provided under transformer_triton.py
and transformer_torch.py
. For example to use the model with Triton kernels, add:
...
from transformer import Transformer, TransformerConfig
config = TransformerConfig(...) # create config
model = Transformer(config) # create model
...
Train like you would any other model!
Benchmark code is quite simple:
# imports and instantiate model
...
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for step in range(100):
with torch.cuda.amp.autocast():
logits, loss = model(batch, batch)
loss.backward()
torch.cuda.synchronize()
end.record()
time_s = start.elapsed_time(end) / 1e3
# convert time_s to TFLOPs
Setup: Benchmarks are performed on an RTX 3090 with torch==2.0.0
and triton-nightly==2.1.0.dev20231014192330
. Flash Attention from PyTorch is enabled with the torch.backends.cuda.sdp_kernel
context manager.
Model Size | Context | Tritonformer (TFLOPs) | PyTorch (TFLOPs) | Config |
---|---|---|---|---|
410M | 2048 | 50.1719 | 50.2183 | bs=4 ,use_linear_bias = False ,attn_bias = False |
410M | 4096 | 42.2310 | 45.5872 | bs=2 ,use_linear_bias = False ,attn_bias = False |
840M | 2048 | 52.6596 | 52.2332 | bs=2 ,use_linear_bias = False ,attn_bias = False |
840M | 4096 | 45.8857 | 48.4785 | bs=1 ,use_linear_bias = False ,attn_bias = False |
1.4B | 2048 | 55.6311 | 55.8419 | bs=2 ,use_linear_bias = False ,attn_bias = False |
On a small 160M parameter model, end-to-end training1 with Tritonformer achieves pairity with the equivalent PyTorch model:
Tests are handled by PyTest. Run pytest tritonformer/
to run all tests. Note: There are a lot of tests so running them all can take a while!
-
Parts of the PyTorch transformer code were originally based off of
zphang/minimal-opt
-
Some earlier forward-pass kernels were written as I followed along through the Triton Tutorials
I'm currently happy with the state of this project but given more time, here are a few extensions I'd like to do:
- The weight backward pass for matrix multiplication requires a reduction over the batch dimension after performing
bmm(activation.T, grad_output)
2. It should be possible to fuse this reduction into the matmul kernel. - Indexing support in Triton is difficult. As such token and position embeddings use
nn.Embedding
instead of a custom kernel. - Extend Fused Attention to support Flash Attention 2.